168 zstd.ZstdError, "decompression error: did not decompress full frame" |
168 zstd.ZstdError, "decompression error: did not decompress full frame" |
169 ): |
169 ): |
170 dctx.decompress(compressed, max_output_size=len(source) - 1) |
170 dctx.decompress(compressed, max_output_size=len(source) - 1) |
171 |
171 |
172 # Input size + 1 works |
172 # Input size + 1 works |
173 decompressed = dctx.decompress(compressed, max_output_size=len(source) + 1) |
173 decompressed = dctx.decompress( |
|
174 compressed, max_output_size=len(source) + 1 |
|
175 ) |
174 self.assertEqual(decompressed, source) |
176 self.assertEqual(decompressed, source) |
175 |
177 |
176 # A much larger buffer works. |
178 # A much larger buffer works. |
177 decompressed = dctx.decompress(compressed, max_output_size=len(source) * 64) |
179 decompressed = dctx.decompress( |
|
180 compressed, max_output_size=len(source) * 64 |
|
181 ) |
178 self.assertEqual(decompressed, source) |
182 self.assertEqual(decompressed, source) |
179 |
183 |
180 def test_stupidly_large_output_buffer(self): |
184 def test_stupidly_large_output_buffer(self): |
181 cctx = zstd.ZstdCompressor(write_content_size=False) |
185 cctx = zstd.ZstdCompressor(write_content_size=False) |
182 compressed = cctx.compress(b"foobar" * 256) |
186 compressed = cctx.compress(b"foobar" * 256) |
235 frame = cctx.compress(source) |
239 frame = cctx.compress(source) |
236 |
240 |
237 dctx = zstd.ZstdDecompressor(max_window_size=2 ** zstd.WINDOWLOG_MIN) |
241 dctx = zstd.ZstdDecompressor(max_window_size=2 ** zstd.WINDOWLOG_MIN) |
238 |
242 |
239 with self.assertRaisesRegex( |
243 with self.assertRaisesRegex( |
240 zstd.ZstdError, "decompression error: Frame requires too much memory" |
244 zstd.ZstdError, |
|
245 "decompression error: Frame requires too much memory", |
241 ): |
246 ): |
242 dctx.decompress(frame, max_output_size=len(source)) |
247 dctx.decompress(frame, max_output_size=len(source)) |
243 |
248 |
244 |
249 |
245 @make_cffi |
250 @make_cffi |
289 |
294 |
290 self.assertEqual(r, len(compressed.getvalue())) |
295 self.assertEqual(r, len(compressed.getvalue())) |
291 self.assertEqual(w, len(source.getvalue())) |
296 self.assertEqual(w, len(source.getvalue())) |
292 |
297 |
293 def test_read_write_size(self): |
298 def test_read_write_size(self): |
294 source = OpCountingBytesIO(zstd.ZstdCompressor().compress(b"foobarfoobar")) |
299 source = OpCountingBytesIO( |
|
300 zstd.ZstdCompressor().compress(b"foobarfoobar") |
|
301 ) |
295 |
302 |
296 dest = OpCountingBytesIO() |
303 dest = OpCountingBytesIO() |
297 dctx = zstd.ZstdDecompressor() |
304 dctx = zstd.ZstdDecompressor() |
298 r, w = dctx.copy_stream(source, dest, read_size=1, write_size=1) |
305 r, w = dctx.copy_stream(source, dest, read_size=1, write_size=1) |
299 |
306 |
307 class TestDecompressor_stream_reader(TestCase): |
314 class TestDecompressor_stream_reader(TestCase): |
308 def test_context_manager(self): |
315 def test_context_manager(self): |
309 dctx = zstd.ZstdDecompressor() |
316 dctx = zstd.ZstdDecompressor() |
310 |
317 |
311 with dctx.stream_reader(b"foo") as reader: |
318 with dctx.stream_reader(b"foo") as reader: |
312 with self.assertRaisesRegex(ValueError, "cannot __enter__ multiple times"): |
319 with self.assertRaisesRegex( |
|
320 ValueError, "cannot __enter__ multiple times" |
|
321 ): |
313 with reader as reader2: |
322 with reader as reader2: |
314 pass |
323 pass |
315 |
324 |
316 def test_not_implemented(self): |
325 def test_not_implemented(self): |
317 dctx = zstd.ZstdDecompressor() |
326 dctx = zstd.ZstdDecompressor() |
472 frame = cctx.compress(b"foo" * 60) |
481 frame = cctx.compress(b"foo" * 60) |
473 |
482 |
474 dctx = zstd.ZstdDecompressor() |
483 dctx = zstd.ZstdDecompressor() |
475 |
484 |
476 with dctx.stream_reader(frame) as reader: |
485 with dctx.stream_reader(frame) as reader: |
477 with self.assertRaisesRegex(ValueError, "cannot seek to negative position"): |
486 with self.assertRaisesRegex( |
|
487 ValueError, "cannot seek to negative position" |
|
488 ): |
478 reader.seek(-1, os.SEEK_SET) |
489 reader.seek(-1, os.SEEK_SET) |
479 |
490 |
480 reader.read(1) |
491 reader.read(1) |
481 |
492 |
482 with self.assertRaisesRegex( |
493 with self.assertRaisesRegex( |
488 ValueError, "cannot seek zstd decompression stream backwards" |
499 ValueError, "cannot seek zstd decompression stream backwards" |
489 ): |
500 ): |
490 reader.seek(-1, os.SEEK_CUR) |
501 reader.seek(-1, os.SEEK_CUR) |
491 |
502 |
492 with self.assertRaisesRegex( |
503 with self.assertRaisesRegex( |
493 ValueError, "zstd decompression streams cannot be seeked with SEEK_END" |
504 ValueError, |
|
505 "zstd decompression streams cannot be seeked with SEEK_END", |
494 ): |
506 ): |
495 reader.seek(0, os.SEEK_END) |
507 reader.seek(0, os.SEEK_END) |
496 |
508 |
497 reader.close() |
509 reader.close() |
498 |
510 |
741 self.assertEqual(reader.read1(1), b"") |
753 self.assertEqual(reader.read1(1), b"") |
742 self.assertEqual(b._read_count, 2) |
754 self.assertEqual(b._read_count, 2) |
743 |
755 |
744 def test_read_lines(self): |
756 def test_read_lines(self): |
745 cctx = zstd.ZstdCompressor() |
757 cctx = zstd.ZstdCompressor() |
746 source = b"\n".join(("line %d" % i).encode("ascii") for i in range(1024)) |
758 source = b"\n".join( |
|
759 ("line %d" % i).encode("ascii") for i in range(1024) |
|
760 ) |
747 |
761 |
748 frame = cctx.compress(source) |
762 frame = cctx.compress(source) |
749 |
763 |
750 dctx = zstd.ZstdDecompressor() |
764 dctx = zstd.ZstdDecompressor() |
751 reader = dctx.stream_reader(frame) |
765 reader = dctx.stream_reader(frame) |
819 |
833 |
820 dctx = zstd.ZstdDecompressor() |
834 dctx = zstd.ZstdDecompressor() |
821 dobj = dctx.decompressobj() |
835 dobj = dctx.decompressobj() |
822 dobj.decompress(data) |
836 dobj.decompress(data) |
823 |
837 |
824 with self.assertRaisesRegex(zstd.ZstdError, "cannot use a decompressobj"): |
838 with self.assertRaisesRegex( |
|
839 zstd.ZstdError, "cannot use a decompressobj" |
|
840 ): |
825 dobj.decompress(data) |
841 dobj.decompress(data) |
826 self.assertIsNone(dobj.flush()) |
842 self.assertIsNone(dobj.flush()) |
827 |
843 |
828 def test_bad_write_size(self): |
844 def test_bad_write_size(self): |
829 dctx = zstd.ZstdDecompressor() |
845 dctx = zstd.ZstdDecompressor() |
1122 dctx.read_to_iter(io.BytesIO()) |
1138 dctx.read_to_iter(io.BytesIO()) |
1123 |
1139 |
1124 # Buffer protocol works. |
1140 # Buffer protocol works. |
1125 dctx.read_to_iter(b"foobar") |
1141 dctx.read_to_iter(b"foobar") |
1126 |
1142 |
1127 with self.assertRaisesRegex(ValueError, "must pass an object with a read"): |
1143 with self.assertRaisesRegex( |
|
1144 ValueError, "must pass an object with a read" |
|
1145 ): |
1128 b"".join(dctx.read_to_iter(True)) |
1146 b"".join(dctx.read_to_iter(True)) |
1129 |
1147 |
1130 def test_empty_input(self): |
1148 def test_empty_input(self): |
1131 dctx = zstd.ZstdDecompressor() |
1149 dctx = zstd.ZstdDecompressor() |
1132 |
1150 |
1224 next(it) |
1242 next(it) |
1225 |
1243 |
1226 decompressed = b"".join(chunks) |
1244 decompressed = b"".join(chunks) |
1227 self.assertEqual(decompressed, source.getvalue()) |
1245 self.assertEqual(decompressed, source.getvalue()) |
1228 |
1246 |
1229 @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set") |
1247 @unittest.skipUnless( |
|
1248 "ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set" |
|
1249 ) |
1230 def test_large_input(self): |
1250 def test_large_input(self): |
1231 bytes = list(struct.Struct(">B").pack(i) for i in range(256)) |
1251 bytes = list(struct.Struct(">B").pack(i) for i in range(256)) |
1232 compressed = NonClosingBytesIO() |
1252 compressed = NonClosingBytesIO() |
1233 input_size = 0 |
1253 input_size = 0 |
1234 cctx = zstd.ZstdCompressor(level=1) |
1254 cctx = zstd.ZstdCompressor(level=1) |
1239 |
1259 |
1240 have_compressed = ( |
1260 have_compressed = ( |
1241 len(compressed.getvalue()) |
1261 len(compressed.getvalue()) |
1242 > zstd.DECOMPRESSION_RECOMMENDED_INPUT_SIZE |
1262 > zstd.DECOMPRESSION_RECOMMENDED_INPUT_SIZE |
1243 ) |
1263 ) |
1244 have_raw = input_size > zstd.DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE * 2 |
1264 have_raw = ( |
|
1265 input_size > zstd.DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE * 2 |
|
1266 ) |
1245 if have_compressed and have_raw: |
1267 if have_compressed and have_raw: |
1246 break |
1268 break |
1247 |
1269 |
1248 compressed = io.BytesIO(compressed.getvalue()) |
1270 compressed = io.BytesIO(compressed.getvalue()) |
1249 self.assertGreater( |
1271 self.assertGreater( |
1250 len(compressed.getvalue()), zstd.DECOMPRESSION_RECOMMENDED_INPUT_SIZE |
1272 len(compressed.getvalue()), |
|
1273 zstd.DECOMPRESSION_RECOMMENDED_INPUT_SIZE, |
1251 ) |
1274 ) |
1252 |
1275 |
1253 dctx = zstd.ZstdDecompressor() |
1276 dctx = zstd.ZstdDecompressor() |
1254 it = dctx.read_to_iter(compressed) |
1277 it = dctx.read_to_iter(compressed) |
1255 |
1278 |
1301 compressed = io.BytesIO(compressed.getvalue()) |
1324 compressed = io.BytesIO(compressed.getvalue()) |
1302 streamed = b"".join(dctx.read_to_iter(compressed)) |
1325 streamed = b"".join(dctx.read_to_iter(compressed)) |
1303 self.assertEqual(streamed, source.getvalue()) |
1326 self.assertEqual(streamed, source.getvalue()) |
1304 |
1327 |
1305 def test_read_write_size(self): |
1328 def test_read_write_size(self): |
1306 source = OpCountingBytesIO(zstd.ZstdCompressor().compress(b"foobarfoobar")) |
1329 source = OpCountingBytesIO( |
|
1330 zstd.ZstdCompressor().compress(b"foobarfoobar") |
|
1331 ) |
1307 dctx = zstd.ZstdDecompressor() |
1332 dctx = zstd.ZstdDecompressor() |
1308 for chunk in dctx.read_to_iter(source, read_size=1, write_size=1): |
1333 for chunk in dctx.read_to_iter(source, read_size=1, write_size=1): |
1309 self.assertEqual(len(chunk), 1) |
1334 self.assertEqual(len(chunk), 1) |
1310 |
1335 |
1311 self.assertEqual(source._read_count, len(source.getvalue())) |
1336 self.assertEqual(source._read_count, len(source.getvalue())) |
1353 with self.assertRaisesRegex( |
1378 with self.assertRaisesRegex( |
1354 ValueError, "chunk 0 is too small to contain a zstd frame" |
1379 ValueError, "chunk 0 is too small to contain a zstd frame" |
1355 ): |
1380 ): |
1356 dctx.decompress_content_dict_chain([zstd.FRAME_HEADER]) |
1381 dctx.decompress_content_dict_chain([zstd.FRAME_HEADER]) |
1357 |
1382 |
1358 with self.assertRaisesRegex(ValueError, "chunk 0 is not a valid zstd frame"): |
1383 with self.assertRaisesRegex( |
|
1384 ValueError, "chunk 0 is not a valid zstd frame" |
|
1385 ): |
1359 dctx.decompress_content_dict_chain([b"foo" * 8]) |
1386 dctx.decompress_content_dict_chain([b"foo" * 8]) |
1360 |
1387 |
1361 no_size = zstd.ZstdCompressor(write_content_size=False).compress(b"foo" * 64) |
1388 no_size = zstd.ZstdCompressor(write_content_size=False).compress( |
|
1389 b"foo" * 64 |
|
1390 ) |
1362 |
1391 |
1363 with self.assertRaisesRegex( |
1392 with self.assertRaisesRegex( |
1364 ValueError, "chunk 0 missing content size in frame" |
1393 ValueError, "chunk 0 missing content size in frame" |
1365 ): |
1394 ): |
1366 dctx.decompress_content_dict_chain([no_size]) |
1395 dctx.decompress_content_dict_chain([no_size]) |
1387 with self.assertRaisesRegex( |
1416 with self.assertRaisesRegex( |
1388 ValueError, "chunk 1 is too small to contain a zstd frame" |
1417 ValueError, "chunk 1 is too small to contain a zstd frame" |
1389 ): |
1418 ): |
1390 dctx.decompress_content_dict_chain([initial, zstd.FRAME_HEADER]) |
1419 dctx.decompress_content_dict_chain([initial, zstd.FRAME_HEADER]) |
1391 |
1420 |
1392 with self.assertRaisesRegex(ValueError, "chunk 1 is not a valid zstd frame"): |
1421 with self.assertRaisesRegex( |
|
1422 ValueError, "chunk 1 is not a valid zstd frame" |
|
1423 ): |
1393 dctx.decompress_content_dict_chain([initial, b"foo" * 8]) |
1424 dctx.decompress_content_dict_chain([initial, b"foo" * 8]) |
1394 |
1425 |
1395 no_size = zstd.ZstdCompressor(write_content_size=False).compress(b"foo" * 64) |
1426 no_size = zstd.ZstdCompressor(write_content_size=False).compress( |
|
1427 b"foo" * 64 |
|
1428 ) |
1396 |
1429 |
1397 with self.assertRaisesRegex( |
1430 with self.assertRaisesRegex( |
1398 ValueError, "chunk 1 missing content size in frame" |
1431 ValueError, "chunk 1 missing content size in frame" |
1399 ): |
1432 ): |
1400 dctx.decompress_content_dict_chain([initial, no_size]) |
1433 dctx.decompress_content_dict_chain([initial, no_size]) |
1401 |
1434 |
1402 # Corrupt second frame. |
1435 # Corrupt second frame. |
1403 cctx = zstd.ZstdCompressor(dict_data=zstd.ZstdCompressionDict(b"foo" * 64)) |
1436 cctx = zstd.ZstdCompressor( |
|
1437 dict_data=zstd.ZstdCompressionDict(b"foo" * 64) |
|
1438 ) |
1404 frame = cctx.compress(b"bar" * 64) |
1439 frame = cctx.compress(b"bar" * 64) |
1405 frame = frame[0:12] + frame[15:] |
1440 frame = frame[0:12] + frame[15:] |
1406 |
1441 |
1407 with self.assertRaisesRegex( |
1442 with self.assertRaisesRegex( |
1408 zstd.ZstdError, "chunk 1 did not decompress full frame" |
1443 zstd.ZstdError, "chunk 1 did not decompress full frame" |
1445 dctx.multi_decompress_to_buffer(True) |
1480 dctx.multi_decompress_to_buffer(True) |
1446 |
1481 |
1447 with self.assertRaises(TypeError): |
1482 with self.assertRaises(TypeError): |
1448 dctx.multi_decompress_to_buffer((1, 2)) |
1483 dctx.multi_decompress_to_buffer((1, 2)) |
1449 |
1484 |
1450 with self.assertRaisesRegex(TypeError, "item 0 not a bytes like object"): |
1485 with self.assertRaisesRegex( |
|
1486 TypeError, "item 0 not a bytes like object" |
|
1487 ): |
1451 dctx.multi_decompress_to_buffer([u"foo"]) |
1488 dctx.multi_decompress_to_buffer([u"foo"]) |
1452 |
1489 |
1453 with self.assertRaisesRegex( |
1490 with self.assertRaisesRegex( |
1454 ValueError, "could not determine decompressed size of item 0" |
1491 ValueError, "could not determine decompressed size of item 0" |
1455 ): |
1492 ): |
1489 dctx = zstd.ZstdDecompressor() |
1526 dctx = zstd.ZstdDecompressor() |
1490 |
1527 |
1491 if not hasattr(dctx, "multi_decompress_to_buffer"): |
1528 if not hasattr(dctx, "multi_decompress_to_buffer"): |
1492 self.skipTest("multi_decompress_to_buffer not available") |
1529 self.skipTest("multi_decompress_to_buffer not available") |
1493 |
1530 |
1494 result = dctx.multi_decompress_to_buffer(frames, decompressed_sizes=sizes) |
1531 result = dctx.multi_decompress_to_buffer( |
|
1532 frames, decompressed_sizes=sizes |
|
1533 ) |
1495 |
1534 |
1496 self.assertEqual(len(result), len(frames)) |
1535 self.assertEqual(len(result), len(frames)) |
1497 self.assertEqual(result.size(), sum(map(len, original))) |
1536 self.assertEqual(result.size(), sum(map(len, original))) |
1498 |
1537 |
1499 for i, data in enumerate(original): |
1538 for i, data in enumerate(original): |
1580 self.assertEqual(data, decompressed[i].tobytes()) |
1619 self.assertEqual(data, decompressed[i].tobytes()) |
1581 |
1620 |
1582 # And a manual mode. |
1621 # And a manual mode. |
1583 b = b"".join([frames[0].tobytes(), frames[1].tobytes()]) |
1622 b = b"".join([frames[0].tobytes(), frames[1].tobytes()]) |
1584 b1 = zstd.BufferWithSegments( |
1623 b1 = zstd.BufferWithSegments( |
1585 b, struct.pack("=QQQQ", 0, len(frames[0]), len(frames[0]), len(frames[1])) |
1624 b, |
|
1625 struct.pack( |
|
1626 "=QQQQ", 0, len(frames[0]), len(frames[0]), len(frames[1]) |
|
1627 ), |
1586 ) |
1628 ) |
1587 |
1629 |
1588 b = b"".join([frames[2].tobytes(), frames[3].tobytes(), frames[4].tobytes()]) |
1630 b = b"".join( |
|
1631 [frames[2].tobytes(), frames[3].tobytes(), frames[4].tobytes()] |
|
1632 ) |
1589 b2 = zstd.BufferWithSegments( |
1633 b2 = zstd.BufferWithSegments( |
1590 b, |
1634 b, |
1591 struct.pack( |
1635 struct.pack( |
1592 "=QQQQQQ", |
1636 "=QQQQQQ", |
1593 0, |
1637 0, |