Mercurial > public > mercurial-scm > hg
comparison contrib/python-zstandard/tests/test_decompressor.py @ 44147:5e84a96d865b
python-zstandard: blacken at 80 characters
I made this change upstream and it will make it into the next
release of python-zstandard. I figured I'd send it Mercurial's
way because it will allow us to drop this directory from the black
exclusion list.
# skip-blame blackening
Differential Revision: https://phab.mercurial-scm.org/D7937
author | Gregory Szorc <gregory.szorc@gmail.com> |
---|---|
date | Wed, 22 Jan 2020 22:23:04 -0800 |
parents | de7838053207 |
children | 493034cc3265 |
comparison
equal
deleted
inserted
replaced
44146:45ec64d93b3a | 44147:5e84a96d865b |
---|---|
168 zstd.ZstdError, "decompression error: did not decompress full frame" | 168 zstd.ZstdError, "decompression error: did not decompress full frame" |
169 ): | 169 ): |
170 dctx.decompress(compressed, max_output_size=len(source) - 1) | 170 dctx.decompress(compressed, max_output_size=len(source) - 1) |
171 | 171 |
172 # Input size + 1 works | 172 # Input size + 1 works |
173 decompressed = dctx.decompress(compressed, max_output_size=len(source) + 1) | 173 decompressed = dctx.decompress( |
174 compressed, max_output_size=len(source) + 1 | |
175 ) | |
174 self.assertEqual(decompressed, source) | 176 self.assertEqual(decompressed, source) |
175 | 177 |
176 # A much larger buffer works. | 178 # A much larger buffer works. |
177 decompressed = dctx.decompress(compressed, max_output_size=len(source) * 64) | 179 decompressed = dctx.decompress( |
180 compressed, max_output_size=len(source) * 64 | |
181 ) | |
178 self.assertEqual(decompressed, source) | 182 self.assertEqual(decompressed, source) |
179 | 183 |
180 def test_stupidly_large_output_buffer(self): | 184 def test_stupidly_large_output_buffer(self): |
181 cctx = zstd.ZstdCompressor(write_content_size=False) | 185 cctx = zstd.ZstdCompressor(write_content_size=False) |
182 compressed = cctx.compress(b"foobar" * 256) | 186 compressed = cctx.compress(b"foobar" * 256) |
235 frame = cctx.compress(source) | 239 frame = cctx.compress(source) |
236 | 240 |
237 dctx = zstd.ZstdDecompressor(max_window_size=2 ** zstd.WINDOWLOG_MIN) | 241 dctx = zstd.ZstdDecompressor(max_window_size=2 ** zstd.WINDOWLOG_MIN) |
238 | 242 |
239 with self.assertRaisesRegex( | 243 with self.assertRaisesRegex( |
240 zstd.ZstdError, "decompression error: Frame requires too much memory" | 244 zstd.ZstdError, |
245 "decompression error: Frame requires too much memory", | |
241 ): | 246 ): |
242 dctx.decompress(frame, max_output_size=len(source)) | 247 dctx.decompress(frame, max_output_size=len(source)) |
243 | 248 |
244 | 249 |
245 @make_cffi | 250 @make_cffi |
289 | 294 |
290 self.assertEqual(r, len(compressed.getvalue())) | 295 self.assertEqual(r, len(compressed.getvalue())) |
291 self.assertEqual(w, len(source.getvalue())) | 296 self.assertEqual(w, len(source.getvalue())) |
292 | 297 |
293 def test_read_write_size(self): | 298 def test_read_write_size(self): |
294 source = OpCountingBytesIO(zstd.ZstdCompressor().compress(b"foobarfoobar")) | 299 source = OpCountingBytesIO( |
300 zstd.ZstdCompressor().compress(b"foobarfoobar") | |
301 ) | |
295 | 302 |
296 dest = OpCountingBytesIO() | 303 dest = OpCountingBytesIO() |
297 dctx = zstd.ZstdDecompressor() | 304 dctx = zstd.ZstdDecompressor() |
298 r, w = dctx.copy_stream(source, dest, read_size=1, write_size=1) | 305 r, w = dctx.copy_stream(source, dest, read_size=1, write_size=1) |
299 | 306 |
307 class TestDecompressor_stream_reader(TestCase): | 314 class TestDecompressor_stream_reader(TestCase): |
308 def test_context_manager(self): | 315 def test_context_manager(self): |
309 dctx = zstd.ZstdDecompressor() | 316 dctx = zstd.ZstdDecompressor() |
310 | 317 |
311 with dctx.stream_reader(b"foo") as reader: | 318 with dctx.stream_reader(b"foo") as reader: |
312 with self.assertRaisesRegex(ValueError, "cannot __enter__ multiple times"): | 319 with self.assertRaisesRegex( |
320 ValueError, "cannot __enter__ multiple times" | |
321 ): | |
313 with reader as reader2: | 322 with reader as reader2: |
314 pass | 323 pass |
315 | 324 |
316 def test_not_implemented(self): | 325 def test_not_implemented(self): |
317 dctx = zstd.ZstdDecompressor() | 326 dctx = zstd.ZstdDecompressor() |
472 frame = cctx.compress(b"foo" * 60) | 481 frame = cctx.compress(b"foo" * 60) |
473 | 482 |
474 dctx = zstd.ZstdDecompressor() | 483 dctx = zstd.ZstdDecompressor() |
475 | 484 |
476 with dctx.stream_reader(frame) as reader: | 485 with dctx.stream_reader(frame) as reader: |
477 with self.assertRaisesRegex(ValueError, "cannot seek to negative position"): | 486 with self.assertRaisesRegex( |
487 ValueError, "cannot seek to negative position" | |
488 ): | |
478 reader.seek(-1, os.SEEK_SET) | 489 reader.seek(-1, os.SEEK_SET) |
479 | 490 |
480 reader.read(1) | 491 reader.read(1) |
481 | 492 |
482 with self.assertRaisesRegex( | 493 with self.assertRaisesRegex( |
488 ValueError, "cannot seek zstd decompression stream backwards" | 499 ValueError, "cannot seek zstd decompression stream backwards" |
489 ): | 500 ): |
490 reader.seek(-1, os.SEEK_CUR) | 501 reader.seek(-1, os.SEEK_CUR) |
491 | 502 |
492 with self.assertRaisesRegex( | 503 with self.assertRaisesRegex( |
493 ValueError, "zstd decompression streams cannot be seeked with SEEK_END" | 504 ValueError, |
505 "zstd decompression streams cannot be seeked with SEEK_END", | |
494 ): | 506 ): |
495 reader.seek(0, os.SEEK_END) | 507 reader.seek(0, os.SEEK_END) |
496 | 508 |
497 reader.close() | 509 reader.close() |
498 | 510 |
741 self.assertEqual(reader.read1(1), b"") | 753 self.assertEqual(reader.read1(1), b"") |
742 self.assertEqual(b._read_count, 2) | 754 self.assertEqual(b._read_count, 2) |
743 | 755 |
744 def test_read_lines(self): | 756 def test_read_lines(self): |
745 cctx = zstd.ZstdCompressor() | 757 cctx = zstd.ZstdCompressor() |
746 source = b"\n".join(("line %d" % i).encode("ascii") for i in range(1024)) | 758 source = b"\n".join( |
759 ("line %d" % i).encode("ascii") for i in range(1024) | |
760 ) | |
747 | 761 |
748 frame = cctx.compress(source) | 762 frame = cctx.compress(source) |
749 | 763 |
750 dctx = zstd.ZstdDecompressor() | 764 dctx = zstd.ZstdDecompressor() |
751 reader = dctx.stream_reader(frame) | 765 reader = dctx.stream_reader(frame) |
819 | 833 |
820 dctx = zstd.ZstdDecompressor() | 834 dctx = zstd.ZstdDecompressor() |
821 dobj = dctx.decompressobj() | 835 dobj = dctx.decompressobj() |
822 dobj.decompress(data) | 836 dobj.decompress(data) |
823 | 837 |
824 with self.assertRaisesRegex(zstd.ZstdError, "cannot use a decompressobj"): | 838 with self.assertRaisesRegex( |
839 zstd.ZstdError, "cannot use a decompressobj" | |
840 ): | |
825 dobj.decompress(data) | 841 dobj.decompress(data) |
826 self.assertIsNone(dobj.flush()) | 842 self.assertIsNone(dobj.flush()) |
827 | 843 |
828 def test_bad_write_size(self): | 844 def test_bad_write_size(self): |
829 dctx = zstd.ZstdDecompressor() | 845 dctx = zstd.ZstdDecompressor() |
1122 dctx.read_to_iter(io.BytesIO()) | 1138 dctx.read_to_iter(io.BytesIO()) |
1123 | 1139 |
1124 # Buffer protocol works. | 1140 # Buffer protocol works. |
1125 dctx.read_to_iter(b"foobar") | 1141 dctx.read_to_iter(b"foobar") |
1126 | 1142 |
1127 with self.assertRaisesRegex(ValueError, "must pass an object with a read"): | 1143 with self.assertRaisesRegex( |
1144 ValueError, "must pass an object with a read" | |
1145 ): | |
1128 b"".join(dctx.read_to_iter(True)) | 1146 b"".join(dctx.read_to_iter(True)) |
1129 | 1147 |
1130 def test_empty_input(self): | 1148 def test_empty_input(self): |
1131 dctx = zstd.ZstdDecompressor() | 1149 dctx = zstd.ZstdDecompressor() |
1132 | 1150 |
1224 next(it) | 1242 next(it) |
1225 | 1243 |
1226 decompressed = b"".join(chunks) | 1244 decompressed = b"".join(chunks) |
1227 self.assertEqual(decompressed, source.getvalue()) | 1245 self.assertEqual(decompressed, source.getvalue()) |
1228 | 1246 |
1229 @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set") | 1247 @unittest.skipUnless( |
1248 "ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set" | |
1249 ) | |
1230 def test_large_input(self): | 1250 def test_large_input(self): |
1231 bytes = list(struct.Struct(">B").pack(i) for i in range(256)) | 1251 bytes = list(struct.Struct(">B").pack(i) for i in range(256)) |
1232 compressed = NonClosingBytesIO() | 1252 compressed = NonClosingBytesIO() |
1233 input_size = 0 | 1253 input_size = 0 |
1234 cctx = zstd.ZstdCompressor(level=1) | 1254 cctx = zstd.ZstdCompressor(level=1) |
1239 | 1259 |
1240 have_compressed = ( | 1260 have_compressed = ( |
1241 len(compressed.getvalue()) | 1261 len(compressed.getvalue()) |
1242 > zstd.DECOMPRESSION_RECOMMENDED_INPUT_SIZE | 1262 > zstd.DECOMPRESSION_RECOMMENDED_INPUT_SIZE |
1243 ) | 1263 ) |
1244 have_raw = input_size > zstd.DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE * 2 | 1264 have_raw = ( |
1265 input_size > zstd.DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE * 2 | |
1266 ) | |
1245 if have_compressed and have_raw: | 1267 if have_compressed and have_raw: |
1246 break | 1268 break |
1247 | 1269 |
1248 compressed = io.BytesIO(compressed.getvalue()) | 1270 compressed = io.BytesIO(compressed.getvalue()) |
1249 self.assertGreater( | 1271 self.assertGreater( |
1250 len(compressed.getvalue()), zstd.DECOMPRESSION_RECOMMENDED_INPUT_SIZE | 1272 len(compressed.getvalue()), |
1273 zstd.DECOMPRESSION_RECOMMENDED_INPUT_SIZE, | |
1251 ) | 1274 ) |
1252 | 1275 |
1253 dctx = zstd.ZstdDecompressor() | 1276 dctx = zstd.ZstdDecompressor() |
1254 it = dctx.read_to_iter(compressed) | 1277 it = dctx.read_to_iter(compressed) |
1255 | 1278 |
1301 compressed = io.BytesIO(compressed.getvalue()) | 1324 compressed = io.BytesIO(compressed.getvalue()) |
1302 streamed = b"".join(dctx.read_to_iter(compressed)) | 1325 streamed = b"".join(dctx.read_to_iter(compressed)) |
1303 self.assertEqual(streamed, source.getvalue()) | 1326 self.assertEqual(streamed, source.getvalue()) |
1304 | 1327 |
1305 def test_read_write_size(self): | 1328 def test_read_write_size(self): |
1306 source = OpCountingBytesIO(zstd.ZstdCompressor().compress(b"foobarfoobar")) | 1329 source = OpCountingBytesIO( |
1330 zstd.ZstdCompressor().compress(b"foobarfoobar") | |
1331 ) | |
1307 dctx = zstd.ZstdDecompressor() | 1332 dctx = zstd.ZstdDecompressor() |
1308 for chunk in dctx.read_to_iter(source, read_size=1, write_size=1): | 1333 for chunk in dctx.read_to_iter(source, read_size=1, write_size=1): |
1309 self.assertEqual(len(chunk), 1) | 1334 self.assertEqual(len(chunk), 1) |
1310 | 1335 |
1311 self.assertEqual(source._read_count, len(source.getvalue())) | 1336 self.assertEqual(source._read_count, len(source.getvalue())) |
1353 with self.assertRaisesRegex( | 1378 with self.assertRaisesRegex( |
1354 ValueError, "chunk 0 is too small to contain a zstd frame" | 1379 ValueError, "chunk 0 is too small to contain a zstd frame" |
1355 ): | 1380 ): |
1356 dctx.decompress_content_dict_chain([zstd.FRAME_HEADER]) | 1381 dctx.decompress_content_dict_chain([zstd.FRAME_HEADER]) |
1357 | 1382 |
1358 with self.assertRaisesRegex(ValueError, "chunk 0 is not a valid zstd frame"): | 1383 with self.assertRaisesRegex( |
1384 ValueError, "chunk 0 is not a valid zstd frame" | |
1385 ): | |
1359 dctx.decompress_content_dict_chain([b"foo" * 8]) | 1386 dctx.decompress_content_dict_chain([b"foo" * 8]) |
1360 | 1387 |
1361 no_size = zstd.ZstdCompressor(write_content_size=False).compress(b"foo" * 64) | 1388 no_size = zstd.ZstdCompressor(write_content_size=False).compress( |
1389 b"foo" * 64 | |
1390 ) | |
1362 | 1391 |
1363 with self.assertRaisesRegex( | 1392 with self.assertRaisesRegex( |
1364 ValueError, "chunk 0 missing content size in frame" | 1393 ValueError, "chunk 0 missing content size in frame" |
1365 ): | 1394 ): |
1366 dctx.decompress_content_dict_chain([no_size]) | 1395 dctx.decompress_content_dict_chain([no_size]) |
1387 with self.assertRaisesRegex( | 1416 with self.assertRaisesRegex( |
1388 ValueError, "chunk 1 is too small to contain a zstd frame" | 1417 ValueError, "chunk 1 is too small to contain a zstd frame" |
1389 ): | 1418 ): |
1390 dctx.decompress_content_dict_chain([initial, zstd.FRAME_HEADER]) | 1419 dctx.decompress_content_dict_chain([initial, zstd.FRAME_HEADER]) |
1391 | 1420 |
1392 with self.assertRaisesRegex(ValueError, "chunk 1 is not a valid zstd frame"): | 1421 with self.assertRaisesRegex( |
1422 ValueError, "chunk 1 is not a valid zstd frame" | |
1423 ): | |
1393 dctx.decompress_content_dict_chain([initial, b"foo" * 8]) | 1424 dctx.decompress_content_dict_chain([initial, b"foo" * 8]) |
1394 | 1425 |
1395 no_size = zstd.ZstdCompressor(write_content_size=False).compress(b"foo" * 64) | 1426 no_size = zstd.ZstdCompressor(write_content_size=False).compress( |
1427 b"foo" * 64 | |
1428 ) | |
1396 | 1429 |
1397 with self.assertRaisesRegex( | 1430 with self.assertRaisesRegex( |
1398 ValueError, "chunk 1 missing content size in frame" | 1431 ValueError, "chunk 1 missing content size in frame" |
1399 ): | 1432 ): |
1400 dctx.decompress_content_dict_chain([initial, no_size]) | 1433 dctx.decompress_content_dict_chain([initial, no_size]) |
1401 | 1434 |
1402 # Corrupt second frame. | 1435 # Corrupt second frame. |
1403 cctx = zstd.ZstdCompressor(dict_data=zstd.ZstdCompressionDict(b"foo" * 64)) | 1436 cctx = zstd.ZstdCompressor( |
1437 dict_data=zstd.ZstdCompressionDict(b"foo" * 64) | |
1438 ) | |
1404 frame = cctx.compress(b"bar" * 64) | 1439 frame = cctx.compress(b"bar" * 64) |
1405 frame = frame[0:12] + frame[15:] | 1440 frame = frame[0:12] + frame[15:] |
1406 | 1441 |
1407 with self.assertRaisesRegex( | 1442 with self.assertRaisesRegex( |
1408 zstd.ZstdError, "chunk 1 did not decompress full frame" | 1443 zstd.ZstdError, "chunk 1 did not decompress full frame" |
1445 dctx.multi_decompress_to_buffer(True) | 1480 dctx.multi_decompress_to_buffer(True) |
1446 | 1481 |
1447 with self.assertRaises(TypeError): | 1482 with self.assertRaises(TypeError): |
1448 dctx.multi_decompress_to_buffer((1, 2)) | 1483 dctx.multi_decompress_to_buffer((1, 2)) |
1449 | 1484 |
1450 with self.assertRaisesRegex(TypeError, "item 0 not a bytes like object"): | 1485 with self.assertRaisesRegex( |
1486 TypeError, "item 0 not a bytes like object" | |
1487 ): | |
1451 dctx.multi_decompress_to_buffer([u"foo"]) | 1488 dctx.multi_decompress_to_buffer([u"foo"]) |
1452 | 1489 |
1453 with self.assertRaisesRegex( | 1490 with self.assertRaisesRegex( |
1454 ValueError, "could not determine decompressed size of item 0" | 1491 ValueError, "could not determine decompressed size of item 0" |
1455 ): | 1492 ): |
1489 dctx = zstd.ZstdDecompressor() | 1526 dctx = zstd.ZstdDecompressor() |
1490 | 1527 |
1491 if not hasattr(dctx, "multi_decompress_to_buffer"): | 1528 if not hasattr(dctx, "multi_decompress_to_buffer"): |
1492 self.skipTest("multi_decompress_to_buffer not available") | 1529 self.skipTest("multi_decompress_to_buffer not available") |
1493 | 1530 |
1494 result = dctx.multi_decompress_to_buffer(frames, decompressed_sizes=sizes) | 1531 result = dctx.multi_decompress_to_buffer( |
1532 frames, decompressed_sizes=sizes | |
1533 ) | |
1495 | 1534 |
1496 self.assertEqual(len(result), len(frames)) | 1535 self.assertEqual(len(result), len(frames)) |
1497 self.assertEqual(result.size(), sum(map(len, original))) | 1536 self.assertEqual(result.size(), sum(map(len, original))) |
1498 | 1537 |
1499 for i, data in enumerate(original): | 1538 for i, data in enumerate(original): |
1580 self.assertEqual(data, decompressed[i].tobytes()) | 1619 self.assertEqual(data, decompressed[i].tobytes()) |
1581 | 1620 |
1582 # And a manual mode. | 1621 # And a manual mode. |
1583 b = b"".join([frames[0].tobytes(), frames[1].tobytes()]) | 1622 b = b"".join([frames[0].tobytes(), frames[1].tobytes()]) |
1584 b1 = zstd.BufferWithSegments( | 1623 b1 = zstd.BufferWithSegments( |
1585 b, struct.pack("=QQQQ", 0, len(frames[0]), len(frames[0]), len(frames[1])) | 1624 b, |
1625 struct.pack( | |
1626 "=QQQQ", 0, len(frames[0]), len(frames[0]), len(frames[1]) | |
1627 ), | |
1586 ) | 1628 ) |
1587 | 1629 |
1588 b = b"".join([frames[2].tobytes(), frames[3].tobytes(), frames[4].tobytes()]) | 1630 b = b"".join( |
1631 [frames[2].tobytes(), frames[3].tobytes(), frames[4].tobytes()] | |
1632 ) | |
1589 b2 = zstd.BufferWithSegments( | 1633 b2 = zstd.BufferWithSegments( |
1590 b, | 1634 b, |
1591 struct.pack( | 1635 struct.pack( |
1592 "=QQQQQQ", | 1636 "=QQQQQQ", |
1593 0, | 1637 0, |