comparison contrib/python-zstandard/tests/test_decompressor.py @ 44147:5e84a96d865b

python-zstandard: blacken at 80 characters I made this change upstream and it will make it into the next release of python-zstandard. I figured I'd send it Mercurial's way because it will allow us to drop this directory from the black exclusion list. # skip-blame blackening Differential Revision: https://phab.mercurial-scm.org/D7937
author Gregory Szorc <gregory.szorc@gmail.com>
date Wed, 22 Jan 2020 22:23:04 -0800
parents de7838053207
children 493034cc3265
comparison
equal deleted inserted replaced
44146:45ec64d93b3a 44147:5e84a96d865b
168 zstd.ZstdError, "decompression error: did not decompress full frame" 168 zstd.ZstdError, "decompression error: did not decompress full frame"
169 ): 169 ):
170 dctx.decompress(compressed, max_output_size=len(source) - 1) 170 dctx.decompress(compressed, max_output_size=len(source) - 1)
171 171
172 # Input size + 1 works 172 # Input size + 1 works
173 decompressed = dctx.decompress(compressed, max_output_size=len(source) + 1) 173 decompressed = dctx.decompress(
174 compressed, max_output_size=len(source) + 1
175 )
174 self.assertEqual(decompressed, source) 176 self.assertEqual(decompressed, source)
175 177
176 # A much larger buffer works. 178 # A much larger buffer works.
177 decompressed = dctx.decompress(compressed, max_output_size=len(source) * 64) 179 decompressed = dctx.decompress(
180 compressed, max_output_size=len(source) * 64
181 )
178 self.assertEqual(decompressed, source) 182 self.assertEqual(decompressed, source)
179 183
180 def test_stupidly_large_output_buffer(self): 184 def test_stupidly_large_output_buffer(self):
181 cctx = zstd.ZstdCompressor(write_content_size=False) 185 cctx = zstd.ZstdCompressor(write_content_size=False)
182 compressed = cctx.compress(b"foobar" * 256) 186 compressed = cctx.compress(b"foobar" * 256)
235 frame = cctx.compress(source) 239 frame = cctx.compress(source)
236 240
237 dctx = zstd.ZstdDecompressor(max_window_size=2 ** zstd.WINDOWLOG_MIN) 241 dctx = zstd.ZstdDecompressor(max_window_size=2 ** zstd.WINDOWLOG_MIN)
238 242
239 with self.assertRaisesRegex( 243 with self.assertRaisesRegex(
240 zstd.ZstdError, "decompression error: Frame requires too much memory" 244 zstd.ZstdError,
245 "decompression error: Frame requires too much memory",
241 ): 246 ):
242 dctx.decompress(frame, max_output_size=len(source)) 247 dctx.decompress(frame, max_output_size=len(source))
243 248
244 249
245 @make_cffi 250 @make_cffi
289 294
290 self.assertEqual(r, len(compressed.getvalue())) 295 self.assertEqual(r, len(compressed.getvalue()))
291 self.assertEqual(w, len(source.getvalue())) 296 self.assertEqual(w, len(source.getvalue()))
292 297
293 def test_read_write_size(self): 298 def test_read_write_size(self):
294 source = OpCountingBytesIO(zstd.ZstdCompressor().compress(b"foobarfoobar")) 299 source = OpCountingBytesIO(
300 zstd.ZstdCompressor().compress(b"foobarfoobar")
301 )
295 302
296 dest = OpCountingBytesIO() 303 dest = OpCountingBytesIO()
297 dctx = zstd.ZstdDecompressor() 304 dctx = zstd.ZstdDecompressor()
298 r, w = dctx.copy_stream(source, dest, read_size=1, write_size=1) 305 r, w = dctx.copy_stream(source, dest, read_size=1, write_size=1)
299 306
307 class TestDecompressor_stream_reader(TestCase): 314 class TestDecompressor_stream_reader(TestCase):
308 def test_context_manager(self): 315 def test_context_manager(self):
309 dctx = zstd.ZstdDecompressor() 316 dctx = zstd.ZstdDecompressor()
310 317
311 with dctx.stream_reader(b"foo") as reader: 318 with dctx.stream_reader(b"foo") as reader:
312 with self.assertRaisesRegex(ValueError, "cannot __enter__ multiple times"): 319 with self.assertRaisesRegex(
320 ValueError, "cannot __enter__ multiple times"
321 ):
313 with reader as reader2: 322 with reader as reader2:
314 pass 323 pass
315 324
316 def test_not_implemented(self): 325 def test_not_implemented(self):
317 dctx = zstd.ZstdDecompressor() 326 dctx = zstd.ZstdDecompressor()
472 frame = cctx.compress(b"foo" * 60) 481 frame = cctx.compress(b"foo" * 60)
473 482
474 dctx = zstd.ZstdDecompressor() 483 dctx = zstd.ZstdDecompressor()
475 484
476 with dctx.stream_reader(frame) as reader: 485 with dctx.stream_reader(frame) as reader:
477 with self.assertRaisesRegex(ValueError, "cannot seek to negative position"): 486 with self.assertRaisesRegex(
487 ValueError, "cannot seek to negative position"
488 ):
478 reader.seek(-1, os.SEEK_SET) 489 reader.seek(-1, os.SEEK_SET)
479 490
480 reader.read(1) 491 reader.read(1)
481 492
482 with self.assertRaisesRegex( 493 with self.assertRaisesRegex(
488 ValueError, "cannot seek zstd decompression stream backwards" 499 ValueError, "cannot seek zstd decompression stream backwards"
489 ): 500 ):
490 reader.seek(-1, os.SEEK_CUR) 501 reader.seek(-1, os.SEEK_CUR)
491 502
492 with self.assertRaisesRegex( 503 with self.assertRaisesRegex(
493 ValueError, "zstd decompression streams cannot be seeked with SEEK_END" 504 ValueError,
505 "zstd decompression streams cannot be seeked with SEEK_END",
494 ): 506 ):
495 reader.seek(0, os.SEEK_END) 507 reader.seek(0, os.SEEK_END)
496 508
497 reader.close() 509 reader.close()
498 510
741 self.assertEqual(reader.read1(1), b"") 753 self.assertEqual(reader.read1(1), b"")
742 self.assertEqual(b._read_count, 2) 754 self.assertEqual(b._read_count, 2)
743 755
744 def test_read_lines(self): 756 def test_read_lines(self):
745 cctx = zstd.ZstdCompressor() 757 cctx = zstd.ZstdCompressor()
746 source = b"\n".join(("line %d" % i).encode("ascii") for i in range(1024)) 758 source = b"\n".join(
759 ("line %d" % i).encode("ascii") for i in range(1024)
760 )
747 761
748 frame = cctx.compress(source) 762 frame = cctx.compress(source)
749 763
750 dctx = zstd.ZstdDecompressor() 764 dctx = zstd.ZstdDecompressor()
751 reader = dctx.stream_reader(frame) 765 reader = dctx.stream_reader(frame)
819 833
820 dctx = zstd.ZstdDecompressor() 834 dctx = zstd.ZstdDecompressor()
821 dobj = dctx.decompressobj() 835 dobj = dctx.decompressobj()
822 dobj.decompress(data) 836 dobj.decompress(data)
823 837
824 with self.assertRaisesRegex(zstd.ZstdError, "cannot use a decompressobj"): 838 with self.assertRaisesRegex(
839 zstd.ZstdError, "cannot use a decompressobj"
840 ):
825 dobj.decompress(data) 841 dobj.decompress(data)
826 self.assertIsNone(dobj.flush()) 842 self.assertIsNone(dobj.flush())
827 843
828 def test_bad_write_size(self): 844 def test_bad_write_size(self):
829 dctx = zstd.ZstdDecompressor() 845 dctx = zstd.ZstdDecompressor()
1122 dctx.read_to_iter(io.BytesIO()) 1138 dctx.read_to_iter(io.BytesIO())
1123 1139
1124 # Buffer protocol works. 1140 # Buffer protocol works.
1125 dctx.read_to_iter(b"foobar") 1141 dctx.read_to_iter(b"foobar")
1126 1142
1127 with self.assertRaisesRegex(ValueError, "must pass an object with a read"): 1143 with self.assertRaisesRegex(
1144 ValueError, "must pass an object with a read"
1145 ):
1128 b"".join(dctx.read_to_iter(True)) 1146 b"".join(dctx.read_to_iter(True))
1129 1147
1130 def test_empty_input(self): 1148 def test_empty_input(self):
1131 dctx = zstd.ZstdDecompressor() 1149 dctx = zstd.ZstdDecompressor()
1132 1150
1224 next(it) 1242 next(it)
1225 1243
1226 decompressed = b"".join(chunks) 1244 decompressed = b"".join(chunks)
1227 self.assertEqual(decompressed, source.getvalue()) 1245 self.assertEqual(decompressed, source.getvalue())
1228 1246
1229 @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set") 1247 @unittest.skipUnless(
1248 "ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set"
1249 )
1230 def test_large_input(self): 1250 def test_large_input(self):
1231 bytes = list(struct.Struct(">B").pack(i) for i in range(256)) 1251 bytes = list(struct.Struct(">B").pack(i) for i in range(256))
1232 compressed = NonClosingBytesIO() 1252 compressed = NonClosingBytesIO()
1233 input_size = 0 1253 input_size = 0
1234 cctx = zstd.ZstdCompressor(level=1) 1254 cctx = zstd.ZstdCompressor(level=1)
1239 1259
1240 have_compressed = ( 1260 have_compressed = (
1241 len(compressed.getvalue()) 1261 len(compressed.getvalue())
1242 > zstd.DECOMPRESSION_RECOMMENDED_INPUT_SIZE 1262 > zstd.DECOMPRESSION_RECOMMENDED_INPUT_SIZE
1243 ) 1263 )
1244 have_raw = input_size > zstd.DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE * 2 1264 have_raw = (
1265 input_size > zstd.DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE * 2
1266 )
1245 if have_compressed and have_raw: 1267 if have_compressed and have_raw:
1246 break 1268 break
1247 1269
1248 compressed = io.BytesIO(compressed.getvalue()) 1270 compressed = io.BytesIO(compressed.getvalue())
1249 self.assertGreater( 1271 self.assertGreater(
1250 len(compressed.getvalue()), zstd.DECOMPRESSION_RECOMMENDED_INPUT_SIZE 1272 len(compressed.getvalue()),
1273 zstd.DECOMPRESSION_RECOMMENDED_INPUT_SIZE,
1251 ) 1274 )
1252 1275
1253 dctx = zstd.ZstdDecompressor() 1276 dctx = zstd.ZstdDecompressor()
1254 it = dctx.read_to_iter(compressed) 1277 it = dctx.read_to_iter(compressed)
1255 1278
1301 compressed = io.BytesIO(compressed.getvalue()) 1324 compressed = io.BytesIO(compressed.getvalue())
1302 streamed = b"".join(dctx.read_to_iter(compressed)) 1325 streamed = b"".join(dctx.read_to_iter(compressed))
1303 self.assertEqual(streamed, source.getvalue()) 1326 self.assertEqual(streamed, source.getvalue())
1304 1327
1305 def test_read_write_size(self): 1328 def test_read_write_size(self):
1306 source = OpCountingBytesIO(zstd.ZstdCompressor().compress(b"foobarfoobar")) 1329 source = OpCountingBytesIO(
1330 zstd.ZstdCompressor().compress(b"foobarfoobar")
1331 )
1307 dctx = zstd.ZstdDecompressor() 1332 dctx = zstd.ZstdDecompressor()
1308 for chunk in dctx.read_to_iter(source, read_size=1, write_size=1): 1333 for chunk in dctx.read_to_iter(source, read_size=1, write_size=1):
1309 self.assertEqual(len(chunk), 1) 1334 self.assertEqual(len(chunk), 1)
1310 1335
1311 self.assertEqual(source._read_count, len(source.getvalue())) 1336 self.assertEqual(source._read_count, len(source.getvalue()))
1353 with self.assertRaisesRegex( 1378 with self.assertRaisesRegex(
1354 ValueError, "chunk 0 is too small to contain a zstd frame" 1379 ValueError, "chunk 0 is too small to contain a zstd frame"
1355 ): 1380 ):
1356 dctx.decompress_content_dict_chain([zstd.FRAME_HEADER]) 1381 dctx.decompress_content_dict_chain([zstd.FRAME_HEADER])
1357 1382
1358 with self.assertRaisesRegex(ValueError, "chunk 0 is not a valid zstd frame"): 1383 with self.assertRaisesRegex(
1384 ValueError, "chunk 0 is not a valid zstd frame"
1385 ):
1359 dctx.decompress_content_dict_chain([b"foo" * 8]) 1386 dctx.decompress_content_dict_chain([b"foo" * 8])
1360 1387
1361 no_size = zstd.ZstdCompressor(write_content_size=False).compress(b"foo" * 64) 1388 no_size = zstd.ZstdCompressor(write_content_size=False).compress(
1389 b"foo" * 64
1390 )
1362 1391
1363 with self.assertRaisesRegex( 1392 with self.assertRaisesRegex(
1364 ValueError, "chunk 0 missing content size in frame" 1393 ValueError, "chunk 0 missing content size in frame"
1365 ): 1394 ):
1366 dctx.decompress_content_dict_chain([no_size]) 1395 dctx.decompress_content_dict_chain([no_size])
1387 with self.assertRaisesRegex( 1416 with self.assertRaisesRegex(
1388 ValueError, "chunk 1 is too small to contain a zstd frame" 1417 ValueError, "chunk 1 is too small to contain a zstd frame"
1389 ): 1418 ):
1390 dctx.decompress_content_dict_chain([initial, zstd.FRAME_HEADER]) 1419 dctx.decompress_content_dict_chain([initial, zstd.FRAME_HEADER])
1391 1420
1392 with self.assertRaisesRegex(ValueError, "chunk 1 is not a valid zstd frame"): 1421 with self.assertRaisesRegex(
1422 ValueError, "chunk 1 is not a valid zstd frame"
1423 ):
1393 dctx.decompress_content_dict_chain([initial, b"foo" * 8]) 1424 dctx.decompress_content_dict_chain([initial, b"foo" * 8])
1394 1425
1395 no_size = zstd.ZstdCompressor(write_content_size=False).compress(b"foo" * 64) 1426 no_size = zstd.ZstdCompressor(write_content_size=False).compress(
1427 b"foo" * 64
1428 )
1396 1429
1397 with self.assertRaisesRegex( 1430 with self.assertRaisesRegex(
1398 ValueError, "chunk 1 missing content size in frame" 1431 ValueError, "chunk 1 missing content size in frame"
1399 ): 1432 ):
1400 dctx.decompress_content_dict_chain([initial, no_size]) 1433 dctx.decompress_content_dict_chain([initial, no_size])
1401 1434
1402 # Corrupt second frame. 1435 # Corrupt second frame.
1403 cctx = zstd.ZstdCompressor(dict_data=zstd.ZstdCompressionDict(b"foo" * 64)) 1436 cctx = zstd.ZstdCompressor(
1437 dict_data=zstd.ZstdCompressionDict(b"foo" * 64)
1438 )
1404 frame = cctx.compress(b"bar" * 64) 1439 frame = cctx.compress(b"bar" * 64)
1405 frame = frame[0:12] + frame[15:] 1440 frame = frame[0:12] + frame[15:]
1406 1441
1407 with self.assertRaisesRegex( 1442 with self.assertRaisesRegex(
1408 zstd.ZstdError, "chunk 1 did not decompress full frame" 1443 zstd.ZstdError, "chunk 1 did not decompress full frame"
1445 dctx.multi_decompress_to_buffer(True) 1480 dctx.multi_decompress_to_buffer(True)
1446 1481
1447 with self.assertRaises(TypeError): 1482 with self.assertRaises(TypeError):
1448 dctx.multi_decompress_to_buffer((1, 2)) 1483 dctx.multi_decompress_to_buffer((1, 2))
1449 1484
1450 with self.assertRaisesRegex(TypeError, "item 0 not a bytes like object"): 1485 with self.assertRaisesRegex(
1486 TypeError, "item 0 not a bytes like object"
1487 ):
1451 dctx.multi_decompress_to_buffer([u"foo"]) 1488 dctx.multi_decompress_to_buffer([u"foo"])
1452 1489
1453 with self.assertRaisesRegex( 1490 with self.assertRaisesRegex(
1454 ValueError, "could not determine decompressed size of item 0" 1491 ValueError, "could not determine decompressed size of item 0"
1455 ): 1492 ):
1489 dctx = zstd.ZstdDecompressor() 1526 dctx = zstd.ZstdDecompressor()
1490 1527
1491 if not hasattr(dctx, "multi_decompress_to_buffer"): 1528 if not hasattr(dctx, "multi_decompress_to_buffer"):
1492 self.skipTest("multi_decompress_to_buffer not available") 1529 self.skipTest("multi_decompress_to_buffer not available")
1493 1530
1494 result = dctx.multi_decompress_to_buffer(frames, decompressed_sizes=sizes) 1531 result = dctx.multi_decompress_to_buffer(
1532 frames, decompressed_sizes=sizes
1533 )
1495 1534
1496 self.assertEqual(len(result), len(frames)) 1535 self.assertEqual(len(result), len(frames))
1497 self.assertEqual(result.size(), sum(map(len, original))) 1536 self.assertEqual(result.size(), sum(map(len, original)))
1498 1537
1499 for i, data in enumerate(original): 1538 for i, data in enumerate(original):
1580 self.assertEqual(data, decompressed[i].tobytes()) 1619 self.assertEqual(data, decompressed[i].tobytes())
1581 1620
1582 # And a manual mode. 1621 # And a manual mode.
1583 b = b"".join([frames[0].tobytes(), frames[1].tobytes()]) 1622 b = b"".join([frames[0].tobytes(), frames[1].tobytes()])
1584 b1 = zstd.BufferWithSegments( 1623 b1 = zstd.BufferWithSegments(
1585 b, struct.pack("=QQQQ", 0, len(frames[0]), len(frames[0]), len(frames[1])) 1624 b,
1625 struct.pack(
1626 "=QQQQ", 0, len(frames[0]), len(frames[0]), len(frames[1])
1627 ),
1586 ) 1628 )
1587 1629
1588 b = b"".join([frames[2].tobytes(), frames[3].tobytes(), frames[4].tobytes()]) 1630 b = b"".join(
1631 [frames[2].tobytes(), frames[3].tobytes(), frames[4].tobytes()]
1632 )
1589 b2 = zstd.BufferWithSegments( 1633 b2 = zstd.BufferWithSegments(
1590 b, 1634 b,
1591 struct.pack( 1635 struct.pack(
1592 "=QQQQQQ", 1636 "=QQQQQQ",
1593 0, 1637 0,