contrib/python-zstandard/tests/test_decompressor.py
changeset 44147 5e84a96d865b
parent 43994 de7838053207
child 51686 493034cc3265
equal deleted inserted replaced
44146:45ec64d93b3a 44147:5e84a96d865b
   168             zstd.ZstdError, "decompression error: did not decompress full frame"
   168             zstd.ZstdError, "decompression error: did not decompress full frame"
   169         ):
   169         ):
   170             dctx.decompress(compressed, max_output_size=len(source) - 1)
   170             dctx.decompress(compressed, max_output_size=len(source) - 1)
   171 
   171 
   172         # Input size + 1 works
   172         # Input size + 1 works
   173         decompressed = dctx.decompress(compressed, max_output_size=len(source) + 1)
   173         decompressed = dctx.decompress(
       
   174             compressed, max_output_size=len(source) + 1
       
   175         )
   174         self.assertEqual(decompressed, source)
   176         self.assertEqual(decompressed, source)
   175 
   177 
   176         # A much larger buffer works.
   178         # A much larger buffer works.
   177         decompressed = dctx.decompress(compressed, max_output_size=len(source) * 64)
   179         decompressed = dctx.decompress(
       
   180             compressed, max_output_size=len(source) * 64
       
   181         )
   178         self.assertEqual(decompressed, source)
   182         self.assertEqual(decompressed, source)
   179 
   183 
   180     def test_stupidly_large_output_buffer(self):
   184     def test_stupidly_large_output_buffer(self):
   181         cctx = zstd.ZstdCompressor(write_content_size=False)
   185         cctx = zstd.ZstdCompressor(write_content_size=False)
   182         compressed = cctx.compress(b"foobar" * 256)
   186         compressed = cctx.compress(b"foobar" * 256)
   235         frame = cctx.compress(source)
   239         frame = cctx.compress(source)
   236 
   240 
   237         dctx = zstd.ZstdDecompressor(max_window_size=2 ** zstd.WINDOWLOG_MIN)
   241         dctx = zstd.ZstdDecompressor(max_window_size=2 ** zstd.WINDOWLOG_MIN)
   238 
   242 
   239         with self.assertRaisesRegex(
   243         with self.assertRaisesRegex(
   240             zstd.ZstdError, "decompression error: Frame requires too much memory"
   244             zstd.ZstdError,
       
   245             "decompression error: Frame requires too much memory",
   241         ):
   246         ):
   242             dctx.decompress(frame, max_output_size=len(source))
   247             dctx.decompress(frame, max_output_size=len(source))
   243 
   248 
   244 
   249 
   245 @make_cffi
   250 @make_cffi
   289 
   294 
   290         self.assertEqual(r, len(compressed.getvalue()))
   295         self.assertEqual(r, len(compressed.getvalue()))
   291         self.assertEqual(w, len(source.getvalue()))
   296         self.assertEqual(w, len(source.getvalue()))
   292 
   297 
   293     def test_read_write_size(self):
   298     def test_read_write_size(self):
   294         source = OpCountingBytesIO(zstd.ZstdCompressor().compress(b"foobarfoobar"))
   299         source = OpCountingBytesIO(
       
   300             zstd.ZstdCompressor().compress(b"foobarfoobar")
       
   301         )
   295 
   302 
   296         dest = OpCountingBytesIO()
   303         dest = OpCountingBytesIO()
   297         dctx = zstd.ZstdDecompressor()
   304         dctx = zstd.ZstdDecompressor()
   298         r, w = dctx.copy_stream(source, dest, read_size=1, write_size=1)
   305         r, w = dctx.copy_stream(source, dest, read_size=1, write_size=1)
   299 
   306 
   307 class TestDecompressor_stream_reader(TestCase):
   314 class TestDecompressor_stream_reader(TestCase):
   308     def test_context_manager(self):
   315     def test_context_manager(self):
   309         dctx = zstd.ZstdDecompressor()
   316         dctx = zstd.ZstdDecompressor()
   310 
   317 
   311         with dctx.stream_reader(b"foo") as reader:
   318         with dctx.stream_reader(b"foo") as reader:
   312             with self.assertRaisesRegex(ValueError, "cannot __enter__ multiple times"):
   319             with self.assertRaisesRegex(
       
   320                 ValueError, "cannot __enter__ multiple times"
       
   321             ):
   313                 with reader as reader2:
   322                 with reader as reader2:
   314                     pass
   323                     pass
   315 
   324 
   316     def test_not_implemented(self):
   325     def test_not_implemented(self):
   317         dctx = zstd.ZstdDecompressor()
   326         dctx = zstd.ZstdDecompressor()
   472         frame = cctx.compress(b"foo" * 60)
   481         frame = cctx.compress(b"foo" * 60)
   473 
   482 
   474         dctx = zstd.ZstdDecompressor()
   483         dctx = zstd.ZstdDecompressor()
   475 
   484 
   476         with dctx.stream_reader(frame) as reader:
   485         with dctx.stream_reader(frame) as reader:
   477             with self.assertRaisesRegex(ValueError, "cannot seek to negative position"):
   486             with self.assertRaisesRegex(
       
   487                 ValueError, "cannot seek to negative position"
       
   488             ):
   478                 reader.seek(-1, os.SEEK_SET)
   489                 reader.seek(-1, os.SEEK_SET)
   479 
   490 
   480             reader.read(1)
   491             reader.read(1)
   481 
   492 
   482             with self.assertRaisesRegex(
   493             with self.assertRaisesRegex(
   488                 ValueError, "cannot seek zstd decompression stream backwards"
   499                 ValueError, "cannot seek zstd decompression stream backwards"
   489             ):
   500             ):
   490                 reader.seek(-1, os.SEEK_CUR)
   501                 reader.seek(-1, os.SEEK_CUR)
   491 
   502 
   492             with self.assertRaisesRegex(
   503             with self.assertRaisesRegex(
   493                 ValueError, "zstd decompression streams cannot be seeked with SEEK_END"
   504                 ValueError,
       
   505                 "zstd decompression streams cannot be seeked with SEEK_END",
   494             ):
   506             ):
   495                 reader.seek(0, os.SEEK_END)
   507                 reader.seek(0, os.SEEK_END)
   496 
   508 
   497             reader.close()
   509             reader.close()
   498 
   510 
   741         self.assertEqual(reader.read1(1), b"")
   753         self.assertEqual(reader.read1(1), b"")
   742         self.assertEqual(b._read_count, 2)
   754         self.assertEqual(b._read_count, 2)
   743 
   755 
   744     def test_read_lines(self):
   756     def test_read_lines(self):
   745         cctx = zstd.ZstdCompressor()
   757         cctx = zstd.ZstdCompressor()
   746         source = b"\n".join(("line %d" % i).encode("ascii") for i in range(1024))
   758         source = b"\n".join(
       
   759             ("line %d" % i).encode("ascii") for i in range(1024)
       
   760         )
   747 
   761 
   748         frame = cctx.compress(source)
   762         frame = cctx.compress(source)
   749 
   763 
   750         dctx = zstd.ZstdDecompressor()
   764         dctx = zstd.ZstdDecompressor()
   751         reader = dctx.stream_reader(frame)
   765         reader = dctx.stream_reader(frame)
   819 
   833 
   820         dctx = zstd.ZstdDecompressor()
   834         dctx = zstd.ZstdDecompressor()
   821         dobj = dctx.decompressobj()
   835         dobj = dctx.decompressobj()
   822         dobj.decompress(data)
   836         dobj.decompress(data)
   823 
   837 
   824         with self.assertRaisesRegex(zstd.ZstdError, "cannot use a decompressobj"):
   838         with self.assertRaisesRegex(
       
   839             zstd.ZstdError, "cannot use a decompressobj"
       
   840         ):
   825             dobj.decompress(data)
   841             dobj.decompress(data)
   826             self.assertIsNone(dobj.flush())
   842             self.assertIsNone(dobj.flush())
   827 
   843 
   828     def test_bad_write_size(self):
   844     def test_bad_write_size(self):
   829         dctx = zstd.ZstdDecompressor()
   845         dctx = zstd.ZstdDecompressor()
  1122         dctx.read_to_iter(io.BytesIO())
  1138         dctx.read_to_iter(io.BytesIO())
  1123 
  1139 
  1124         # Buffer protocol works.
  1140         # Buffer protocol works.
  1125         dctx.read_to_iter(b"foobar")
  1141         dctx.read_to_iter(b"foobar")
  1126 
  1142 
  1127         with self.assertRaisesRegex(ValueError, "must pass an object with a read"):
  1143         with self.assertRaisesRegex(
       
  1144             ValueError, "must pass an object with a read"
       
  1145         ):
  1128             b"".join(dctx.read_to_iter(True))
  1146             b"".join(dctx.read_to_iter(True))
  1129 
  1147 
  1130     def test_empty_input(self):
  1148     def test_empty_input(self):
  1131         dctx = zstd.ZstdDecompressor()
  1149         dctx = zstd.ZstdDecompressor()
  1132 
  1150 
  1224             next(it)
  1242             next(it)
  1225 
  1243 
  1226         decompressed = b"".join(chunks)
  1244         decompressed = b"".join(chunks)
  1227         self.assertEqual(decompressed, source.getvalue())
  1245         self.assertEqual(decompressed, source.getvalue())
  1228 
  1246 
  1229     @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set")
  1247     @unittest.skipUnless(
       
  1248         "ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set"
       
  1249     )
  1230     def test_large_input(self):
  1250     def test_large_input(self):
  1231         bytes = list(struct.Struct(">B").pack(i) for i in range(256))
  1251         bytes = list(struct.Struct(">B").pack(i) for i in range(256))
  1232         compressed = NonClosingBytesIO()
  1252         compressed = NonClosingBytesIO()
  1233         input_size = 0
  1253         input_size = 0
  1234         cctx = zstd.ZstdCompressor(level=1)
  1254         cctx = zstd.ZstdCompressor(level=1)
  1239 
  1259 
  1240                 have_compressed = (
  1260                 have_compressed = (
  1241                     len(compressed.getvalue())
  1261                     len(compressed.getvalue())
  1242                     > zstd.DECOMPRESSION_RECOMMENDED_INPUT_SIZE
  1262                     > zstd.DECOMPRESSION_RECOMMENDED_INPUT_SIZE
  1243                 )
  1263                 )
  1244                 have_raw = input_size > zstd.DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE * 2
  1264                 have_raw = (
       
  1265                     input_size > zstd.DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE * 2
       
  1266                 )
  1245                 if have_compressed and have_raw:
  1267                 if have_compressed and have_raw:
  1246                     break
  1268                     break
  1247 
  1269 
  1248         compressed = io.BytesIO(compressed.getvalue())
  1270         compressed = io.BytesIO(compressed.getvalue())
  1249         self.assertGreater(
  1271         self.assertGreater(
  1250             len(compressed.getvalue()), zstd.DECOMPRESSION_RECOMMENDED_INPUT_SIZE
  1272             len(compressed.getvalue()),
       
  1273             zstd.DECOMPRESSION_RECOMMENDED_INPUT_SIZE,
  1251         )
  1274         )
  1252 
  1275 
  1253         dctx = zstd.ZstdDecompressor()
  1276         dctx = zstd.ZstdDecompressor()
  1254         it = dctx.read_to_iter(compressed)
  1277         it = dctx.read_to_iter(compressed)
  1255 
  1278 
  1301         compressed = io.BytesIO(compressed.getvalue())
  1324         compressed = io.BytesIO(compressed.getvalue())
  1302         streamed = b"".join(dctx.read_to_iter(compressed))
  1325         streamed = b"".join(dctx.read_to_iter(compressed))
  1303         self.assertEqual(streamed, source.getvalue())
  1326         self.assertEqual(streamed, source.getvalue())
  1304 
  1327 
  1305     def test_read_write_size(self):
  1328     def test_read_write_size(self):
  1306         source = OpCountingBytesIO(zstd.ZstdCompressor().compress(b"foobarfoobar"))
  1329         source = OpCountingBytesIO(
       
  1330             zstd.ZstdCompressor().compress(b"foobarfoobar")
       
  1331         )
  1307         dctx = zstd.ZstdDecompressor()
  1332         dctx = zstd.ZstdDecompressor()
  1308         for chunk in dctx.read_to_iter(source, read_size=1, write_size=1):
  1333         for chunk in dctx.read_to_iter(source, read_size=1, write_size=1):
  1309             self.assertEqual(len(chunk), 1)
  1334             self.assertEqual(len(chunk), 1)
  1310 
  1335 
  1311         self.assertEqual(source._read_count, len(source.getvalue()))
  1336         self.assertEqual(source._read_count, len(source.getvalue()))
  1353         with self.assertRaisesRegex(
  1378         with self.assertRaisesRegex(
  1354             ValueError, "chunk 0 is too small to contain a zstd frame"
  1379             ValueError, "chunk 0 is too small to contain a zstd frame"
  1355         ):
  1380         ):
  1356             dctx.decompress_content_dict_chain([zstd.FRAME_HEADER])
  1381             dctx.decompress_content_dict_chain([zstd.FRAME_HEADER])
  1357 
  1382 
  1358         with self.assertRaisesRegex(ValueError, "chunk 0 is not a valid zstd frame"):
  1383         with self.assertRaisesRegex(
       
  1384             ValueError, "chunk 0 is not a valid zstd frame"
       
  1385         ):
  1359             dctx.decompress_content_dict_chain([b"foo" * 8])
  1386             dctx.decompress_content_dict_chain([b"foo" * 8])
  1360 
  1387 
  1361         no_size = zstd.ZstdCompressor(write_content_size=False).compress(b"foo" * 64)
  1388         no_size = zstd.ZstdCompressor(write_content_size=False).compress(
       
  1389             b"foo" * 64
       
  1390         )
  1362 
  1391 
  1363         with self.assertRaisesRegex(
  1392         with self.assertRaisesRegex(
  1364             ValueError, "chunk 0 missing content size in frame"
  1393             ValueError, "chunk 0 missing content size in frame"
  1365         ):
  1394         ):
  1366             dctx.decompress_content_dict_chain([no_size])
  1395             dctx.decompress_content_dict_chain([no_size])
  1387         with self.assertRaisesRegex(
  1416         with self.assertRaisesRegex(
  1388             ValueError, "chunk 1 is too small to contain a zstd frame"
  1417             ValueError, "chunk 1 is too small to contain a zstd frame"
  1389         ):
  1418         ):
  1390             dctx.decompress_content_dict_chain([initial, zstd.FRAME_HEADER])
  1419             dctx.decompress_content_dict_chain([initial, zstd.FRAME_HEADER])
  1391 
  1420 
  1392         with self.assertRaisesRegex(ValueError, "chunk 1 is not a valid zstd frame"):
  1421         with self.assertRaisesRegex(
       
  1422             ValueError, "chunk 1 is not a valid zstd frame"
       
  1423         ):
  1393             dctx.decompress_content_dict_chain([initial, b"foo" * 8])
  1424             dctx.decompress_content_dict_chain([initial, b"foo" * 8])
  1394 
  1425 
  1395         no_size = zstd.ZstdCompressor(write_content_size=False).compress(b"foo" * 64)
  1426         no_size = zstd.ZstdCompressor(write_content_size=False).compress(
       
  1427             b"foo" * 64
       
  1428         )
  1396 
  1429 
  1397         with self.assertRaisesRegex(
  1430         with self.assertRaisesRegex(
  1398             ValueError, "chunk 1 missing content size in frame"
  1431             ValueError, "chunk 1 missing content size in frame"
  1399         ):
  1432         ):
  1400             dctx.decompress_content_dict_chain([initial, no_size])
  1433             dctx.decompress_content_dict_chain([initial, no_size])
  1401 
  1434 
  1402         # Corrupt second frame.
  1435         # Corrupt second frame.
  1403         cctx = zstd.ZstdCompressor(dict_data=zstd.ZstdCompressionDict(b"foo" * 64))
  1436         cctx = zstd.ZstdCompressor(
       
  1437             dict_data=zstd.ZstdCompressionDict(b"foo" * 64)
       
  1438         )
  1404         frame = cctx.compress(b"bar" * 64)
  1439         frame = cctx.compress(b"bar" * 64)
  1405         frame = frame[0:12] + frame[15:]
  1440         frame = frame[0:12] + frame[15:]
  1406 
  1441 
  1407         with self.assertRaisesRegex(
  1442         with self.assertRaisesRegex(
  1408             zstd.ZstdError, "chunk 1 did not decompress full frame"
  1443             zstd.ZstdError, "chunk 1 did not decompress full frame"
  1445             dctx.multi_decompress_to_buffer(True)
  1480             dctx.multi_decompress_to_buffer(True)
  1446 
  1481 
  1447         with self.assertRaises(TypeError):
  1482         with self.assertRaises(TypeError):
  1448             dctx.multi_decompress_to_buffer((1, 2))
  1483             dctx.multi_decompress_to_buffer((1, 2))
  1449 
  1484 
  1450         with self.assertRaisesRegex(TypeError, "item 0 not a bytes like object"):
  1485         with self.assertRaisesRegex(
       
  1486             TypeError, "item 0 not a bytes like object"
       
  1487         ):
  1451             dctx.multi_decompress_to_buffer([u"foo"])
  1488             dctx.multi_decompress_to_buffer([u"foo"])
  1452 
  1489 
  1453         with self.assertRaisesRegex(
  1490         with self.assertRaisesRegex(
  1454             ValueError, "could not determine decompressed size of item 0"
  1491             ValueError, "could not determine decompressed size of item 0"
  1455         ):
  1492         ):
  1489         dctx = zstd.ZstdDecompressor()
  1526         dctx = zstd.ZstdDecompressor()
  1490 
  1527 
  1491         if not hasattr(dctx, "multi_decompress_to_buffer"):
  1528         if not hasattr(dctx, "multi_decompress_to_buffer"):
  1492             self.skipTest("multi_decompress_to_buffer not available")
  1529             self.skipTest("multi_decompress_to_buffer not available")
  1493 
  1530 
  1494         result = dctx.multi_decompress_to_buffer(frames, decompressed_sizes=sizes)
  1531         result = dctx.multi_decompress_to_buffer(
       
  1532             frames, decompressed_sizes=sizes
       
  1533         )
  1495 
  1534 
  1496         self.assertEqual(len(result), len(frames))
  1535         self.assertEqual(len(result), len(frames))
  1497         self.assertEqual(result.size(), sum(map(len, original)))
  1536         self.assertEqual(result.size(), sum(map(len, original)))
  1498 
  1537 
  1499         for i, data in enumerate(original):
  1538         for i, data in enumerate(original):
  1580             self.assertEqual(data, decompressed[i].tobytes())
  1619             self.assertEqual(data, decompressed[i].tobytes())
  1581 
  1620 
  1582         # And a manual mode.
  1621         # And a manual mode.
  1583         b = b"".join([frames[0].tobytes(), frames[1].tobytes()])
  1622         b = b"".join([frames[0].tobytes(), frames[1].tobytes()])
  1584         b1 = zstd.BufferWithSegments(
  1623         b1 = zstd.BufferWithSegments(
  1585             b, struct.pack("=QQQQ", 0, len(frames[0]), len(frames[0]), len(frames[1]))
  1624             b,
       
  1625             struct.pack(
       
  1626                 "=QQQQ", 0, len(frames[0]), len(frames[0]), len(frames[1])
       
  1627             ),
  1586         )
  1628         )
  1587 
  1629 
  1588         b = b"".join([frames[2].tobytes(), frames[3].tobytes(), frames[4].tobytes()])
  1630         b = b"".join(
       
  1631             [frames[2].tobytes(), frames[3].tobytes(), frames[4].tobytes()]
       
  1632         )
  1589         b2 = zstd.BufferWithSegments(
  1633         b2 = zstd.BufferWithSegments(
  1590             b,
  1634             b,
  1591             struct.pack(
  1635             struct.pack(
  1592                 "=QQQQQQ",
  1636                 "=QQQQQQ",
  1593                 0,
  1637                 0,