diff -r 5b60464efbde -r c32454d69b85 contrib/python-zstandard/tests/test_decompressor.py --- a/contrib/python-zstandard/tests/test_decompressor.py Thu Feb 09 21:44:32 2017 -0500 +++ b/contrib/python-zstandard/tests/test_decompressor.py Tue Feb 07 23:24:47 2017 -0800 @@ -10,7 +10,10 @@ import zstd -from .common import OpCountingBytesIO +from .common import ( + make_cffi, + OpCountingBytesIO, +) if sys.version_info[0] >= 3: @@ -19,6 +22,7 @@ next = lambda it: it.next() +@make_cffi class TestDecompressor_decompress(unittest.TestCase): def test_empty_input(self): dctx = zstd.ZstdDecompressor() @@ -119,6 +123,7 @@ self.assertEqual(decompressed, sources[i]) +@make_cffi class TestDecompressor_copy_stream(unittest.TestCase): def test_no_read(self): source = object() @@ -180,6 +185,7 @@ self.assertEqual(dest._write_count, len(dest.getvalue())) +@make_cffi class TestDecompressor_decompressobj(unittest.TestCase): def test_simple(self): data = zstd.ZstdCompressor(level=1).compress(b'foobar') @@ -207,6 +213,7 @@ return buffer.getvalue() +@make_cffi class TestDecompressor_write_to(unittest.TestCase): def test_empty_roundtrip(self): cctx = zstd.ZstdCompressor() @@ -256,14 +263,14 @@ buffer = io.BytesIO() cctx = zstd.ZstdCompressor(dict_data=d) with cctx.write_to(buffer) as compressor: - compressor.write(orig) + self.assertEqual(compressor.write(orig), 1544) compressed = buffer.getvalue() buffer = io.BytesIO() dctx = zstd.ZstdDecompressor(dict_data=d) with dctx.write_to(buffer) as decompressor: - decompressor.write(compressed) + self.assertEqual(decompressor.write(compressed), len(orig)) self.assertEqual(buffer.getvalue(), orig) @@ -291,6 +298,7 @@ self.assertEqual(dest._write_count, len(dest.getvalue())) +@make_cffi class TestDecompressor_read_from(unittest.TestCase): def test_type_validation(self): dctx = zstd.ZstdDecompressor() @@ -302,7 +310,7 @@ dctx.read_from(b'foobar') with self.assertRaisesRegexp(ValueError, 'must pass an object with a read'): - dctx.read_from(True) + b''.join(dctx.read_from(True)) def test_empty_input(self): dctx = zstd.ZstdDecompressor() @@ -351,7 +359,7 @@ dctx = zstd.ZstdDecompressor() with self.assertRaisesRegexp(ValueError, 'skip_bytes must be smaller than read_size'): - dctx.read_from(b'', skip_bytes=1, read_size=1) + b''.join(dctx.read_from(b'', skip_bytes=1, read_size=1)) with self.assertRaisesRegexp(ValueError, 'skip_bytes larger than first input chunk'): b''.join(dctx.read_from(b'foobar', skip_bytes=10)) @@ -476,3 +484,94 @@ self.assertEqual(len(chunk), 1) self.assertEqual(source._read_count, len(source.getvalue())) + + +@make_cffi +class TestDecompressor_content_dict_chain(unittest.TestCase): + def test_bad_inputs_simple(self): + dctx = zstd.ZstdDecompressor() + + with self.assertRaises(TypeError): + dctx.decompress_content_dict_chain(b'foo') + + with self.assertRaises(TypeError): + dctx.decompress_content_dict_chain((b'foo', b'bar')) + + with self.assertRaisesRegexp(ValueError, 'empty input chain'): + dctx.decompress_content_dict_chain([]) + + with self.assertRaisesRegexp(ValueError, 'chunk 0 must be bytes'): + dctx.decompress_content_dict_chain([u'foo']) + + with self.assertRaisesRegexp(ValueError, 'chunk 0 must be bytes'): + dctx.decompress_content_dict_chain([True]) + + with self.assertRaisesRegexp(ValueError, 'chunk 0 is too small to contain a zstd frame'): + dctx.decompress_content_dict_chain([zstd.FRAME_HEADER]) + + with self.assertRaisesRegexp(ValueError, 'chunk 0 is not a valid zstd frame'): + dctx.decompress_content_dict_chain([b'foo' * 8]) + + no_size = zstd.ZstdCompressor().compress(b'foo' * 64) + + with self.assertRaisesRegexp(ValueError, 'chunk 0 missing content size in frame'): + dctx.decompress_content_dict_chain([no_size]) + + # Corrupt first frame. + frame = zstd.ZstdCompressor(write_content_size=True).compress(b'foo' * 64) + frame = frame[0:12] + frame[15:] + with self.assertRaisesRegexp(zstd.ZstdError, 'could not decompress chunk 0'): + dctx.decompress_content_dict_chain([frame]) + + def test_bad_subsequent_input(self): + initial = zstd.ZstdCompressor(write_content_size=True).compress(b'foo' * 64) + + dctx = zstd.ZstdDecompressor() + + with self.assertRaisesRegexp(ValueError, 'chunk 1 must be bytes'): + dctx.decompress_content_dict_chain([initial, u'foo']) + + with self.assertRaisesRegexp(ValueError, 'chunk 1 must be bytes'): + dctx.decompress_content_dict_chain([initial, None]) + + with self.assertRaisesRegexp(ValueError, 'chunk 1 is too small to contain a zstd frame'): + dctx.decompress_content_dict_chain([initial, zstd.FRAME_HEADER]) + + with self.assertRaisesRegexp(ValueError, 'chunk 1 is not a valid zstd frame'): + dctx.decompress_content_dict_chain([initial, b'foo' * 8]) + + no_size = zstd.ZstdCompressor().compress(b'foo' * 64) + + with self.assertRaisesRegexp(ValueError, 'chunk 1 missing content size in frame'): + dctx.decompress_content_dict_chain([initial, no_size]) + + # Corrupt second frame. + cctx = zstd.ZstdCompressor(write_content_size=True, dict_data=zstd.ZstdCompressionDict(b'foo' * 64)) + frame = cctx.compress(b'bar' * 64) + frame = frame[0:12] + frame[15:] + + with self.assertRaisesRegexp(zstd.ZstdError, 'could not decompress chunk 1'): + dctx.decompress_content_dict_chain([initial, frame]) + + def test_simple(self): + original = [ + b'foo' * 64, + b'foobar' * 64, + b'baz' * 64, + b'foobaz' * 64, + b'foobarbaz' * 64, + ] + + chunks = [] + chunks.append(zstd.ZstdCompressor(write_content_size=True).compress(original[0])) + for i, chunk in enumerate(original[1:]): + d = zstd.ZstdCompressionDict(original[i]) + cctx = zstd.ZstdCompressor(dict_data=d, write_content_size=True) + chunks.append(cctx.compress(chunk)) + + for i in range(1, len(original)): + chain = chunks[0:i] + expected = original[i - 1] + dctx = zstd.ZstdDecompressor() + decompressed = dctx.decompress_content_dict_chain(chain) + self.assertEqual(decompressed, expected)