diff -r 2e484bdea8c4 -r b86a448a2965 contrib/python-zstandard/tests/test_decompressor.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/contrib/python-zstandard/tests/test_decompressor.py Thu Nov 10 22:15:58 2016 -0800 @@ -0,0 +1,478 @@ +import io +import random +import struct +import sys + +try: + import unittest2 as unittest +except ImportError: + import unittest + +import zstd + +from .common import OpCountingBytesIO + + +if sys.version_info[0] >= 3: + next = lambda it: it.__next__() +else: + next = lambda it: it.next() + + +class TestDecompressor_decompress(unittest.TestCase): + def test_empty_input(self): + dctx = zstd.ZstdDecompressor() + + with self.assertRaisesRegexp(zstd.ZstdError, 'input data invalid'): + dctx.decompress(b'') + + def test_invalid_input(self): + dctx = zstd.ZstdDecompressor() + + with self.assertRaisesRegexp(zstd.ZstdError, 'input data invalid'): + dctx.decompress(b'foobar') + + def test_no_content_size_in_frame(self): + cctx = zstd.ZstdCompressor(write_content_size=False) + compressed = cctx.compress(b'foobar') + + dctx = zstd.ZstdDecompressor() + with self.assertRaisesRegexp(zstd.ZstdError, 'input data invalid'): + dctx.decompress(compressed) + + def test_content_size_present(self): + cctx = zstd.ZstdCompressor(write_content_size=True) + compressed = cctx.compress(b'foobar') + + dctx = zstd.ZstdDecompressor() + decompressed = dctx.decompress(compressed) + self.assertEqual(decompressed, b'foobar') + + def test_max_output_size(self): + cctx = zstd.ZstdCompressor(write_content_size=False) + source = b'foobar' * 256 + compressed = cctx.compress(source) + + dctx = zstd.ZstdDecompressor() + # Will fit into buffer exactly the size of input. + decompressed = dctx.decompress(compressed, max_output_size=len(source)) + self.assertEqual(decompressed, source) + + # Input size - 1 fails + with self.assertRaisesRegexp(zstd.ZstdError, 'Destination buffer is too small'): + dctx.decompress(compressed, max_output_size=len(source) - 1) + + # Input size + 1 works + decompressed = dctx.decompress(compressed, max_output_size=len(source) + 1) + self.assertEqual(decompressed, source) + + # A much larger buffer works. + decompressed = dctx.decompress(compressed, max_output_size=len(source) * 64) + self.assertEqual(decompressed, source) + + def test_stupidly_large_output_buffer(self): + cctx = zstd.ZstdCompressor(write_content_size=False) + compressed = cctx.compress(b'foobar' * 256) + dctx = zstd.ZstdDecompressor() + + # Will get OverflowError on some Python distributions that can't + # handle really large integers. + with self.assertRaises((MemoryError, OverflowError)): + dctx.decompress(compressed, max_output_size=2**62) + + def test_dictionary(self): + samples = [] + for i in range(128): + samples.append(b'foo' * 64) + samples.append(b'bar' * 64) + samples.append(b'foobar' * 64) + + d = zstd.train_dictionary(8192, samples) + + orig = b'foobar' * 16384 + cctx = zstd.ZstdCompressor(level=1, dict_data=d, write_content_size=True) + compressed = cctx.compress(orig) + + dctx = zstd.ZstdDecompressor(dict_data=d) + decompressed = dctx.decompress(compressed) + + self.assertEqual(decompressed, orig) + + def test_dictionary_multiple(self): + samples = [] + for i in range(128): + samples.append(b'foo' * 64) + samples.append(b'bar' * 64) + samples.append(b'foobar' * 64) + + d = zstd.train_dictionary(8192, samples) + + sources = (b'foobar' * 8192, b'foo' * 8192, b'bar' * 8192) + compressed = [] + cctx = zstd.ZstdCompressor(level=1, dict_data=d, write_content_size=True) + for source in sources: + compressed.append(cctx.compress(source)) + + dctx = zstd.ZstdDecompressor(dict_data=d) + for i in range(len(sources)): + decompressed = dctx.decompress(compressed[i]) + self.assertEqual(decompressed, sources[i]) + + +class TestDecompressor_copy_stream(unittest.TestCase): + def test_no_read(self): + source = object() + dest = io.BytesIO() + + dctx = zstd.ZstdDecompressor() + with self.assertRaises(ValueError): + dctx.copy_stream(source, dest) + + def test_no_write(self): + source = io.BytesIO() + dest = object() + + dctx = zstd.ZstdDecompressor() + with self.assertRaises(ValueError): + dctx.copy_stream(source, dest) + + def test_empty(self): + source = io.BytesIO() + dest = io.BytesIO() + + dctx = zstd.ZstdDecompressor() + # TODO should this raise an error? + r, w = dctx.copy_stream(source, dest) + + self.assertEqual(r, 0) + self.assertEqual(w, 0) + self.assertEqual(dest.getvalue(), b'') + + def test_large_data(self): + source = io.BytesIO() + for i in range(255): + source.write(struct.Struct('>B').pack(i) * 16384) + source.seek(0) + + compressed = io.BytesIO() + cctx = zstd.ZstdCompressor() + cctx.copy_stream(source, compressed) + + compressed.seek(0) + dest = io.BytesIO() + dctx = zstd.ZstdDecompressor() + r, w = dctx.copy_stream(compressed, dest) + + self.assertEqual(r, len(compressed.getvalue())) + self.assertEqual(w, len(source.getvalue())) + + def test_read_write_size(self): + source = OpCountingBytesIO(zstd.ZstdCompressor().compress( + b'foobarfoobar')) + + dest = OpCountingBytesIO() + dctx = zstd.ZstdDecompressor() + r, w = dctx.copy_stream(source, dest, read_size=1, write_size=1) + + self.assertEqual(r, len(source.getvalue())) + self.assertEqual(w, len(b'foobarfoobar')) + self.assertEqual(source._read_count, len(source.getvalue()) + 1) + self.assertEqual(dest._write_count, len(dest.getvalue())) + + +class TestDecompressor_decompressobj(unittest.TestCase): + def test_simple(self): + data = zstd.ZstdCompressor(level=1).compress(b'foobar') + + dctx = zstd.ZstdDecompressor() + dobj = dctx.decompressobj() + self.assertEqual(dobj.decompress(data), b'foobar') + + def test_reuse(self): + data = zstd.ZstdCompressor(level=1).compress(b'foobar') + + dctx = zstd.ZstdDecompressor() + dobj = dctx.decompressobj() + dobj.decompress(data) + + with self.assertRaisesRegexp(zstd.ZstdError, 'cannot use a decompressobj'): + dobj.decompress(data) + + +def decompress_via_writer(data): + buffer = io.BytesIO() + dctx = zstd.ZstdDecompressor() + with dctx.write_to(buffer) as decompressor: + decompressor.write(data) + return buffer.getvalue() + + +class TestDecompressor_write_to(unittest.TestCase): + def test_empty_roundtrip(self): + cctx = zstd.ZstdCompressor() + empty = cctx.compress(b'') + self.assertEqual(decompress_via_writer(empty), b'') + + def test_large_roundtrip(self): + chunks = [] + for i in range(255): + chunks.append(struct.Struct('>B').pack(i) * 16384) + orig = b''.join(chunks) + cctx = zstd.ZstdCompressor() + compressed = cctx.compress(orig) + + self.assertEqual(decompress_via_writer(compressed), orig) + + def test_multiple_calls(self): + chunks = [] + for i in range(255): + for j in range(255): + chunks.append(struct.Struct('>B').pack(j) * i) + + orig = b''.join(chunks) + cctx = zstd.ZstdCompressor() + compressed = cctx.compress(orig) + + buffer = io.BytesIO() + dctx = zstd.ZstdDecompressor() + with dctx.write_to(buffer) as decompressor: + pos = 0 + while pos < len(compressed): + pos2 = pos + 8192 + decompressor.write(compressed[pos:pos2]) + pos += 8192 + self.assertEqual(buffer.getvalue(), orig) + + def test_dictionary(self): + samples = [] + for i in range(128): + samples.append(b'foo' * 64) + samples.append(b'bar' * 64) + samples.append(b'foobar' * 64) + + d = zstd.train_dictionary(8192, samples) + + orig = b'foobar' * 16384 + buffer = io.BytesIO() + cctx = zstd.ZstdCompressor(dict_data=d) + with cctx.write_to(buffer) as compressor: + compressor.write(orig) + + compressed = buffer.getvalue() + buffer = io.BytesIO() + + dctx = zstd.ZstdDecompressor(dict_data=d) + with dctx.write_to(buffer) as decompressor: + decompressor.write(compressed) + + self.assertEqual(buffer.getvalue(), orig) + + def test_memory_size(self): + dctx = zstd.ZstdDecompressor() + buffer = io.BytesIO() + with dctx.write_to(buffer) as decompressor: + size = decompressor.memory_size() + + self.assertGreater(size, 100000) + + def test_write_size(self): + source = zstd.ZstdCompressor().compress(b'foobarfoobar') + dest = OpCountingBytesIO() + dctx = zstd.ZstdDecompressor() + with dctx.write_to(dest, write_size=1) as decompressor: + s = struct.Struct('>B') + for c in source: + if not isinstance(c, str): + c = s.pack(c) + decompressor.write(c) + + + self.assertEqual(dest.getvalue(), b'foobarfoobar') + self.assertEqual(dest._write_count, len(dest.getvalue())) + + +class TestDecompressor_read_from(unittest.TestCase): + def test_type_validation(self): + dctx = zstd.ZstdDecompressor() + + # Object with read() works. + dctx.read_from(io.BytesIO()) + + # Buffer protocol works. + dctx.read_from(b'foobar') + + with self.assertRaisesRegexp(ValueError, 'must pass an object with a read'): + dctx.read_from(True) + + def test_empty_input(self): + dctx = zstd.ZstdDecompressor() + + source = io.BytesIO() + it = dctx.read_from(source) + # TODO this is arguably wrong. Should get an error about missing frame foo. + with self.assertRaises(StopIteration): + next(it) + + it = dctx.read_from(b'') + with self.assertRaises(StopIteration): + next(it) + + def test_invalid_input(self): + dctx = zstd.ZstdDecompressor() + + source = io.BytesIO(b'foobar') + it = dctx.read_from(source) + with self.assertRaisesRegexp(zstd.ZstdError, 'Unknown frame descriptor'): + next(it) + + it = dctx.read_from(b'foobar') + with self.assertRaisesRegexp(zstd.ZstdError, 'Unknown frame descriptor'): + next(it) + + def test_empty_roundtrip(self): + cctx = zstd.ZstdCompressor(level=1, write_content_size=False) + empty = cctx.compress(b'') + + source = io.BytesIO(empty) + source.seek(0) + + dctx = zstd.ZstdDecompressor() + it = dctx.read_from(source) + + # No chunks should be emitted since there is no data. + with self.assertRaises(StopIteration): + next(it) + + # Again for good measure. + with self.assertRaises(StopIteration): + next(it) + + def test_skip_bytes_too_large(self): + dctx = zstd.ZstdDecompressor() + + with self.assertRaisesRegexp(ValueError, 'skip_bytes must be smaller than read_size'): + dctx.read_from(b'', skip_bytes=1, read_size=1) + + with self.assertRaisesRegexp(ValueError, 'skip_bytes larger than first input chunk'): + b''.join(dctx.read_from(b'foobar', skip_bytes=10)) + + def test_skip_bytes(self): + cctx = zstd.ZstdCompressor(write_content_size=False) + compressed = cctx.compress(b'foobar') + + dctx = zstd.ZstdDecompressor() + output = b''.join(dctx.read_from(b'hdr' + compressed, skip_bytes=3)) + self.assertEqual(output, b'foobar') + + def test_large_output(self): + source = io.BytesIO() + source.write(b'f' * zstd.DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE) + source.write(b'o') + source.seek(0) + + cctx = zstd.ZstdCompressor(level=1) + compressed = io.BytesIO(cctx.compress(source.getvalue())) + compressed.seek(0) + + dctx = zstd.ZstdDecompressor() + it = dctx.read_from(compressed) + + chunks = [] + chunks.append(next(it)) + chunks.append(next(it)) + + with self.assertRaises(StopIteration): + next(it) + + decompressed = b''.join(chunks) + self.assertEqual(decompressed, source.getvalue()) + + # And again with buffer protocol. + it = dctx.read_from(compressed.getvalue()) + chunks = [] + chunks.append(next(it)) + chunks.append(next(it)) + + with self.assertRaises(StopIteration): + next(it) + + decompressed = b''.join(chunks) + self.assertEqual(decompressed, source.getvalue()) + + def test_large_input(self): + bytes = list(struct.Struct('>B').pack(i) for i in range(256)) + compressed = io.BytesIO() + input_size = 0 + cctx = zstd.ZstdCompressor(level=1) + with cctx.write_to(compressed) as compressor: + while True: + compressor.write(random.choice(bytes)) + input_size += 1 + + have_compressed = len(compressed.getvalue()) > zstd.DECOMPRESSION_RECOMMENDED_INPUT_SIZE + have_raw = input_size > zstd.DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE * 2 + if have_compressed and have_raw: + break + + compressed.seek(0) + self.assertGreater(len(compressed.getvalue()), + zstd.DECOMPRESSION_RECOMMENDED_INPUT_SIZE) + + dctx = zstd.ZstdDecompressor() + it = dctx.read_from(compressed) + + chunks = [] + chunks.append(next(it)) + chunks.append(next(it)) + chunks.append(next(it)) + + with self.assertRaises(StopIteration): + next(it) + + decompressed = b''.join(chunks) + self.assertEqual(len(decompressed), input_size) + + # And again with buffer protocol. + it = dctx.read_from(compressed.getvalue()) + + chunks = [] + chunks.append(next(it)) + chunks.append(next(it)) + chunks.append(next(it)) + + with self.assertRaises(StopIteration): + next(it) + + decompressed = b''.join(chunks) + self.assertEqual(len(decompressed), input_size) + + def test_interesting(self): + # Found this edge case via fuzzing. + cctx = zstd.ZstdCompressor(level=1) + + source = io.BytesIO() + + compressed = io.BytesIO() + with cctx.write_to(compressed) as compressor: + for i in range(256): + chunk = b'\0' * 1024 + compressor.write(chunk) + source.write(chunk) + + dctx = zstd.ZstdDecompressor() + + simple = dctx.decompress(compressed.getvalue(), + max_output_size=len(source.getvalue())) + self.assertEqual(simple, source.getvalue()) + + compressed.seek(0) + streamed = b''.join(dctx.read_from(compressed)) + self.assertEqual(streamed, source.getvalue()) + + def test_read_write_size(self): + source = OpCountingBytesIO(zstd.ZstdCompressor().compress(b'foobarfoobar')) + dctx = zstd.ZstdDecompressor() + for chunk in dctx.read_from(source, read_size=1, write_size=1): + self.assertEqual(len(chunk), 1) + + self.assertEqual(source._read_count, len(source.getvalue()))