contrib/python-zstandard/tests/test_compressor.py
changeset 31796 e0dc40530c5a
parent 30895 c32454d69b85
child 37495 b1fb341d8a61
equal deleted inserted replaced
31795:2b130e26c3a4 31796:e0dc40530c5a
    18 
    18 
    19 if sys.version_info[0] >= 3:
    19 if sys.version_info[0] >= 3:
    20     next = lambda it: it.__next__()
    20     next = lambda it: it.__next__()
    21 else:
    21 else:
    22     next = lambda it: it.next()
    22     next = lambda it: it.next()
       
    23 
       
    24 
       
    25 def multithreaded_chunk_size(level, source_size=0):
       
    26     params = zstd.get_compression_parameters(level, source_size)
       
    27 
       
    28     return 1 << (params.window_log + 2)
    23 
    29 
    24 
    30 
    25 @make_cffi
    31 @make_cffi
    26 class TestCompressor(unittest.TestCase):
    32 class TestCompressor(unittest.TestCase):
    27     def test_level_bounds(self):
    33     def test_level_bounds(self):
    32             zstd.ZstdCompressor(level=23)
    38             zstd.ZstdCompressor(level=23)
    33 
    39 
    34 
    40 
    35 @make_cffi
    41 @make_cffi
    36 class TestCompressor_compress(unittest.TestCase):
    42 class TestCompressor_compress(unittest.TestCase):
       
    43     def test_multithreaded_unsupported(self):
       
    44         samples = []
       
    45         for i in range(128):
       
    46             samples.append(b'foo' * 64)
       
    47             samples.append(b'bar' * 64)
       
    48 
       
    49         d = zstd.train_dictionary(8192, samples)
       
    50 
       
    51         cctx = zstd.ZstdCompressor(dict_data=d, threads=2)
       
    52 
       
    53         with self.assertRaisesRegexp(zstd.ZstdError, 'compress\(\) cannot be used with both dictionaries and multi-threaded compression'):
       
    54             cctx.compress(b'foo')
       
    55 
       
    56         params = zstd.get_compression_parameters(3)
       
    57         cctx = zstd.ZstdCompressor(compression_params=params, threads=2)
       
    58         with self.assertRaisesRegexp(zstd.ZstdError, 'compress\(\) cannot be used with both compression parameters and multi-threaded compression'):
       
    59             cctx.compress(b'foo')
       
    60 
    37     def test_compress_empty(self):
    61     def test_compress_empty(self):
    38         cctx = zstd.ZstdCompressor(level=1)
    62         cctx = zstd.ZstdCompressor(level=1)
    39         result = cctx.compress(b'')
    63         result = cctx.compress(b'')
    40         self.assertEqual(result, b'\x28\xb5\x2f\xfd\x00\x48\x01\x00\x00')
    64         self.assertEqual(result, b'\x28\xb5\x2f\xfd\x00\x48\x01\x00\x00')
    41         params = zstd.get_frame_parameters(result)
    65         params = zstd.get_frame_parameters(result)
   129 
   153 
   130         cctx = zstd.ZstdCompressor(level=1, dict_data=d)
   154         cctx = zstd.ZstdCompressor(level=1, dict_data=d)
   131 
   155 
   132         for i in range(32):
   156         for i in range(32):
   133             cctx.compress(b'foo bar foobar foo bar foobar')
   157             cctx.compress(b'foo bar foobar foo bar foobar')
       
   158 
       
   159     def test_multithreaded(self):
       
   160         chunk_size = multithreaded_chunk_size(1)
       
   161         source = b''.join([b'x' * chunk_size, b'y' * chunk_size])
       
   162 
       
   163         cctx = zstd.ZstdCompressor(level=1, threads=2)
       
   164         compressed = cctx.compress(source)
       
   165 
       
   166         params = zstd.get_frame_parameters(compressed)
       
   167         self.assertEqual(params.content_size, chunk_size * 2)
       
   168         self.assertEqual(params.dict_id, 0)
       
   169         self.assertFalse(params.has_checksum)
       
   170 
       
   171         dctx = zstd.ZstdDecompressor()
       
   172         self.assertEqual(dctx.decompress(compressed), source)
   134 
   173 
   135 
   174 
   136 @make_cffi
   175 @make_cffi
   137 class TestCompressor_compressobj(unittest.TestCase):
   176 class TestCompressor_compressobj(unittest.TestCase):
   138     def test_compressobj_empty(self):
   177     def test_compressobj_empty(self):
   235         # 3 bytes block header + 4 bytes frame checksum
   274         # 3 bytes block header + 4 bytes frame checksum
   236         self.assertEqual(len(trailing), 7)
   275         self.assertEqual(len(trailing), 7)
   237         header = trailing[0:3]
   276         header = trailing[0:3]
   238         self.assertEqual(header, b'\x01\x00\x00')
   277         self.assertEqual(header, b'\x01\x00\x00')
   239 
   278 
       
   279     def test_multithreaded(self):
       
   280         source = io.BytesIO()
       
   281         source.write(b'a' * 1048576)
       
   282         source.write(b'b' * 1048576)
       
   283         source.write(b'c' * 1048576)
       
   284         source.seek(0)
       
   285 
       
   286         cctx = zstd.ZstdCompressor(level=1, threads=2)
       
   287         cobj = cctx.compressobj()
       
   288 
       
   289         chunks = []
       
   290         while True:
       
   291             d = source.read(8192)
       
   292             if not d:
       
   293                 break
       
   294 
       
   295             chunks.append(cobj.compress(d))
       
   296 
       
   297         chunks.append(cobj.flush())
       
   298 
       
   299         compressed = b''.join(chunks)
       
   300 
       
   301         self.assertEqual(len(compressed), 295)
       
   302 
   240 
   303 
   241 @make_cffi
   304 @make_cffi
   242 class TestCompressor_copy_stream(unittest.TestCase):
   305 class TestCompressor_copy_stream(unittest.TestCase):
   243     def test_no_read(self):
   306     def test_no_read(self):
   244         source = object()
   307         source = object()
   352 
   415 
   353         self.assertEqual(r, len(source.getvalue()))
   416         self.assertEqual(r, len(source.getvalue()))
   354         self.assertEqual(w, 21)
   417         self.assertEqual(w, 21)
   355         self.assertEqual(source._read_count, len(source.getvalue()) + 1)
   418         self.assertEqual(source._read_count, len(source.getvalue()) + 1)
   356         self.assertEqual(dest._write_count, len(dest.getvalue()))
   419         self.assertEqual(dest._write_count, len(dest.getvalue()))
       
   420 
       
   421     def test_multithreaded(self):
       
   422         source = io.BytesIO()
       
   423         source.write(b'a' * 1048576)
       
   424         source.write(b'b' * 1048576)
       
   425         source.write(b'c' * 1048576)
       
   426         source.seek(0)
       
   427 
       
   428         dest = io.BytesIO()
       
   429         cctx = zstd.ZstdCompressor(threads=2)
       
   430         r, w = cctx.copy_stream(source, dest)
       
   431         self.assertEqual(r, 3145728)
       
   432         self.assertEqual(w, 295)
       
   433 
       
   434         params = zstd.get_frame_parameters(dest.getvalue())
       
   435         self.assertEqual(params.content_size, 0)
       
   436         self.assertEqual(params.dict_id, 0)
       
   437         self.assertFalse(params.has_checksum)
       
   438 
       
   439         # Writing content size and checksum works.
       
   440         cctx = zstd.ZstdCompressor(threads=2, write_content_size=True,
       
   441                                    write_checksum=True)
       
   442         dest = io.BytesIO()
       
   443         source.seek(0)
       
   444         cctx.copy_stream(source, dest, size=len(source.getvalue()))
       
   445 
       
   446         params = zstd.get_frame_parameters(dest.getvalue())
       
   447         self.assertEqual(params.content_size, 3145728)
       
   448         self.assertEqual(params.dict_id, 0)
       
   449         self.assertTrue(params.has_checksum)
   357 
   450 
   358 
   451 
   359 def compress(data, level):
   452 def compress(data, level):
   360     buffer = io.BytesIO()
   453     buffer = io.BytesIO()
   361     cctx = zstd.ZstdCompressor(level=level)
   454     cctx = zstd.ZstdCompressor(level=level)
   582         self.assertEqual(len(trailing), 7)
   675         self.assertEqual(len(trailing), 7)
   583 
   676 
   584         header = trailing[0:3]
   677         header = trailing[0:3]
   585         self.assertEqual(header, b'\x01\x00\x00')
   678         self.assertEqual(header, b'\x01\x00\x00')
   586 
   679 
       
   680     def test_multithreaded(self):
       
   681         dest = io.BytesIO()
       
   682         cctx = zstd.ZstdCompressor(threads=2)
       
   683         with cctx.write_to(dest) as compressor:
       
   684             compressor.write(b'a' * 1048576)
       
   685             compressor.write(b'b' * 1048576)
       
   686             compressor.write(b'c' * 1048576)
       
   687 
       
   688         self.assertEqual(len(dest.getvalue()), 295)
       
   689 
   587 
   690 
   588 @make_cffi
   691 @make_cffi
   589 class TestCompressor_read_from(unittest.TestCase):
   692 class TestCompressor_read_from(unittest.TestCase):
   590     def test_type_validation(self):
   693     def test_type_validation(self):
   591         cctx = zstd.ZstdCompressor()
   694         cctx = zstd.ZstdCompressor()
   671         cctx = zstd.ZstdCompressor(level=3)
   774         cctx = zstd.ZstdCompressor(level=3)
   672         for chunk in cctx.read_from(source, read_size=1, write_size=1):
   775         for chunk in cctx.read_from(source, read_size=1, write_size=1):
   673             self.assertEqual(len(chunk), 1)
   776             self.assertEqual(len(chunk), 1)
   674 
   777 
   675         self.assertEqual(source._read_count, len(source.getvalue()) + 1)
   778         self.assertEqual(source._read_count, len(source.getvalue()) + 1)
       
   779 
       
   780     def test_multithreaded(self):
       
   781         source = io.BytesIO()
       
   782         source.write(b'a' * 1048576)
       
   783         source.write(b'b' * 1048576)
       
   784         source.write(b'c' * 1048576)
       
   785         source.seek(0)
       
   786 
       
   787         cctx = zstd.ZstdCompressor(threads=2)
       
   788 
       
   789         compressed = b''.join(cctx.read_from(source))
       
   790         self.assertEqual(len(compressed), 295)
       
   791 
       
   792 
       
   793 class TestCompressor_multi_compress_to_buffer(unittest.TestCase):
       
   794     def test_multithreaded_unsupported(self):
       
   795         cctx = zstd.ZstdCompressor(threads=2)
       
   796 
       
   797         with self.assertRaisesRegexp(zstd.ZstdError, 'function cannot be called on ZstdCompressor configured for multi-threaded compression'):
       
   798             cctx.multi_compress_to_buffer([b'foo'])
       
   799 
       
   800     def test_invalid_inputs(self):
       
   801         cctx = zstd.ZstdCompressor()
       
   802 
       
   803         with self.assertRaises(TypeError):
       
   804             cctx.multi_compress_to_buffer(True)
       
   805 
       
   806         with self.assertRaises(TypeError):
       
   807             cctx.multi_compress_to_buffer((1, 2))
       
   808 
       
   809         with self.assertRaisesRegexp(TypeError, 'item 0 not a bytes like object'):
       
   810             cctx.multi_compress_to_buffer([u'foo'])
       
   811 
       
   812     def test_empty_input(self):
       
   813         cctx = zstd.ZstdCompressor()
       
   814 
       
   815         with self.assertRaisesRegexp(ValueError, 'no source elements found'):
       
   816             cctx.multi_compress_to_buffer([])
       
   817 
       
   818         with self.assertRaisesRegexp(ValueError, 'source elements are empty'):
       
   819             cctx.multi_compress_to_buffer([b'', b'', b''])
       
   820 
       
   821     def test_list_input(self):
       
   822         cctx = zstd.ZstdCompressor(write_content_size=True, write_checksum=True)
       
   823 
       
   824         original = [b'foo' * 12, b'bar' * 6]
       
   825         frames = [cctx.compress(c) for c in original]
       
   826         b = cctx.multi_compress_to_buffer(original)
       
   827 
       
   828         self.assertIsInstance(b, zstd.BufferWithSegmentsCollection)
       
   829 
       
   830         self.assertEqual(len(b), 2)
       
   831         self.assertEqual(b.size(), 44)
       
   832 
       
   833         self.assertEqual(b[0].tobytes(), frames[0])
       
   834         self.assertEqual(b[1].tobytes(), frames[1])
       
   835 
       
   836     def test_buffer_with_segments_input(self):
       
   837         cctx = zstd.ZstdCompressor(write_content_size=True, write_checksum=True)
       
   838 
       
   839         original = [b'foo' * 4, b'bar' * 6]
       
   840         frames = [cctx.compress(c) for c in original]
       
   841 
       
   842         offsets = struct.pack('=QQQQ', 0, len(original[0]),
       
   843                                        len(original[0]), len(original[1]))
       
   844         segments = zstd.BufferWithSegments(b''.join(original), offsets)
       
   845 
       
   846         result = cctx.multi_compress_to_buffer(segments)
       
   847 
       
   848         self.assertEqual(len(result), 2)
       
   849         self.assertEqual(result.size(), 47)
       
   850 
       
   851         self.assertEqual(result[0].tobytes(), frames[0])
       
   852         self.assertEqual(result[1].tobytes(), frames[1])
       
   853 
       
   854     def test_buffer_with_segments_collection_input(self):
       
   855         cctx = zstd.ZstdCompressor(write_content_size=True, write_checksum=True)
       
   856 
       
   857         original = [
       
   858             b'foo1',
       
   859             b'foo2' * 2,
       
   860             b'foo3' * 3,
       
   861             b'foo4' * 4,
       
   862             b'foo5' * 5,
       
   863         ]
       
   864 
       
   865         frames = [cctx.compress(c) for c in original]
       
   866 
       
   867         b = b''.join([original[0], original[1]])
       
   868         b1 = zstd.BufferWithSegments(b, struct.pack('=QQQQ',
       
   869                                                     0, len(original[0]),
       
   870                                                     len(original[0]), len(original[1])))
       
   871         b = b''.join([original[2], original[3], original[4]])
       
   872         b2 = zstd.BufferWithSegments(b, struct.pack('=QQQQQQ',
       
   873                                                     0, len(original[2]),
       
   874                                                     len(original[2]), len(original[3]),
       
   875                                                     len(original[2]) + len(original[3]), len(original[4])))
       
   876 
       
   877         c = zstd.BufferWithSegmentsCollection(b1, b2)
       
   878 
       
   879         result = cctx.multi_compress_to_buffer(c)
       
   880 
       
   881         self.assertEqual(len(result), len(frames))
       
   882 
       
   883         for i, frame in enumerate(frames):
       
   884             self.assertEqual(result[i].tobytes(), frame)
       
   885 
       
   886     def test_multiple_threads(self):
       
   887         # threads argument will cause multi-threaded ZSTD APIs to be used, which will
       
   888         # make output different.
       
   889         refcctx = zstd.ZstdCompressor(write_content_size=True, write_checksum=True)
       
   890         reference = [refcctx.compress(b'x' * 64), refcctx.compress(b'y' * 64)]
       
   891 
       
   892         cctx = zstd.ZstdCompressor(write_content_size=True, write_checksum=True)
       
   893 
       
   894         frames = []
       
   895         frames.extend(b'x' * 64 for i in range(256))
       
   896         frames.extend(b'y' * 64 for i in range(256))
       
   897 
       
   898         result = cctx.multi_compress_to_buffer(frames, threads=-1)
       
   899 
       
   900         self.assertEqual(len(result), 512)
       
   901         for i in range(512):
       
   902             if i < 256:
       
   903                 self.assertEqual(result[i].tobytes(), reference[0])
       
   904             else:
       
   905                 self.assertEqual(result[i].tobytes(), reference[1])