Mercurial > public > mercurial-scm > hg
comparison contrib/python-zstandard/tests/test_decompressor.py @ 42070:675775c33ab6
zstandard: vendor python-zstandard 0.11
The upstream source distribution from PyPI was extracted. Unwanted
files were removed.
The clang-format ignore list was updated to reflect the new source
of files.
The project contains a vendored copy of zstandard 1.3.8. The old
version was 1.3.6. This should result in some minor performance wins.
test-check-py3-compat.t was updated to reflect now-passing tests on
Python 3.8.
Some HTTP tests were updated to reflect new zstd compression output.
# no-check-commit because 3rd party code has different style guidelines
Differential Revision: https://phab.mercurial-scm.org/D6199
author | Gregory Szorc <gregory.szorc@gmail.com> |
---|---|
date | Thu, 04 Apr 2019 17:34:43 -0700 |
parents | 73fef626dae3 |
children | de7838053207 |
comparison
equal
deleted
inserted
replaced
42069:668eff08387f | 42070:675775c33ab6 |
---|---|
1 import io | 1 import io |
2 import os | 2 import os |
3 import random | 3 import random |
4 import struct | 4 import struct |
5 import sys | 5 import sys |
6 import tempfile | |
6 import unittest | 7 import unittest |
7 | 8 |
8 import zstandard as zstd | 9 import zstandard as zstd |
9 | 10 |
10 from .common import ( | 11 from .common import ( |
11 generate_samples, | 12 generate_samples, |
12 make_cffi, | 13 make_cffi, |
14 NonClosingBytesIO, | |
13 OpCountingBytesIO, | 15 OpCountingBytesIO, |
14 ) | 16 ) |
15 | 17 |
16 | 18 |
17 if sys.version_info[0] >= 3: | 19 if sys.version_info[0] >= 3: |
217 # If we write a content size, the decompressor engages single pass | 219 # If we write a content size, the decompressor engages single pass |
218 # mode and the window size doesn't come into play. | 220 # mode and the window size doesn't come into play. |
219 cctx = zstd.ZstdCompressor(write_content_size=False) | 221 cctx = zstd.ZstdCompressor(write_content_size=False) |
220 frame = cctx.compress(source) | 222 frame = cctx.compress(source) |
221 | 223 |
222 dctx = zstd.ZstdDecompressor(max_window_size=1) | 224 dctx = zstd.ZstdDecompressor(max_window_size=2**zstd.WINDOWLOG_MIN) |
223 | 225 |
224 with self.assertRaisesRegexp( | 226 with self.assertRaisesRegexp( |
225 zstd.ZstdError, 'decompression error: Frame requires too much memory'): | 227 zstd.ZstdError, 'decompression error: Frame requires too much memory'): |
226 dctx.decompress(frame, max_output_size=len(source)) | 228 dctx.decompress(frame, max_output_size=len(source)) |
227 | 229 |
300 | 302 |
301 def test_not_implemented(self): | 303 def test_not_implemented(self): |
302 dctx = zstd.ZstdDecompressor() | 304 dctx = zstd.ZstdDecompressor() |
303 | 305 |
304 with dctx.stream_reader(b'foo') as reader: | 306 with dctx.stream_reader(b'foo') as reader: |
305 with self.assertRaises(NotImplementedError): | 307 with self.assertRaises(io.UnsupportedOperation): |
306 reader.readline() | 308 reader.readline() |
307 | 309 |
308 with self.assertRaises(NotImplementedError): | 310 with self.assertRaises(io.UnsupportedOperation): |
309 reader.readlines() | 311 reader.readlines() |
310 | 312 |
311 with self.assertRaises(NotImplementedError): | 313 with self.assertRaises(io.UnsupportedOperation): |
312 reader.readall() | |
313 | |
314 with self.assertRaises(NotImplementedError): | |
315 iter(reader) | 314 iter(reader) |
316 | 315 |
317 with self.assertRaises(NotImplementedError): | 316 with self.assertRaises(io.UnsupportedOperation): |
318 next(reader) | 317 next(reader) |
319 | 318 |
320 with self.assertRaises(io.UnsupportedOperation): | 319 with self.assertRaises(io.UnsupportedOperation): |
321 reader.write(b'foo') | 320 reader.write(b'foo') |
322 | 321 |
345 reader.close() | 344 reader.close() |
346 self.assertTrue(reader.closed) | 345 self.assertTrue(reader.closed) |
347 with self.assertRaisesRegexp(ValueError, 'stream is closed'): | 346 with self.assertRaisesRegexp(ValueError, 'stream is closed'): |
348 reader.read(1) | 347 reader.read(1) |
349 | 348 |
350 def test_bad_read_size(self): | 349 def test_read_sizes(self): |
351 dctx = zstd.ZstdDecompressor() | 350 cctx = zstd.ZstdCompressor() |
352 | 351 foo = cctx.compress(b'foo') |
353 with dctx.stream_reader(b'foo') as reader: | 352 |
354 with self.assertRaisesRegexp(ValueError, 'cannot read negative or size 0 amounts'): | 353 dctx = zstd.ZstdDecompressor() |
355 reader.read(-1) | 354 |
356 | 355 with dctx.stream_reader(foo) as reader: |
357 with self.assertRaisesRegexp(ValueError, 'cannot read negative or size 0 amounts'): | 356 with self.assertRaisesRegexp(ValueError, 'cannot read negative amounts less than -1'): |
358 reader.read(0) | 357 reader.read(-2) |
358 | |
359 self.assertEqual(reader.read(0), b'') | |
360 self.assertEqual(reader.read(), b'foo') | |
359 | 361 |
360 def test_read_buffer(self): | 362 def test_read_buffer(self): |
361 cctx = zstd.ZstdCompressor() | 363 cctx = zstd.ZstdCompressor() |
362 | 364 |
363 source = b''.join([b'foo' * 60, b'bar' * 60, b'baz' * 60]) | 365 source = b''.join([b'foo' * 60, b'bar' * 60, b'baz' * 60]) |
522 dctx = zstd.ZstdDecompressor() | 524 dctx = zstd.ZstdDecompressor() |
523 | 525 |
524 reader = dctx.stream_reader(source) | 526 reader = dctx.stream_reader(source) |
525 | 527 |
526 with reader: | 528 with reader: |
527 with self.assertRaises(TypeError): | 529 reader.read(0) |
528 reader.read() | |
529 | 530 |
530 with reader: | 531 with reader: |
531 with self.assertRaisesRegexp(ValueError, 'stream is closed'): | 532 with self.assertRaisesRegexp(ValueError, 'stream is closed'): |
532 reader.read(100) | 533 reader.read(100) |
534 | |
535 def test_partial_read(self): | |
536 # Inspired by https://github.com/indygreg/python-zstandard/issues/71. | |
537 buffer = io.BytesIO() | |
538 cctx = zstd.ZstdCompressor() | |
539 writer = cctx.stream_writer(buffer) | |
540 writer.write(bytearray(os.urandom(1000000))) | |
541 writer.flush(zstd.FLUSH_FRAME) | |
542 buffer.seek(0) | |
543 | |
544 dctx = zstd.ZstdDecompressor() | |
545 reader = dctx.stream_reader(buffer) | |
546 | |
547 while True: | |
548 chunk = reader.read(8192) | |
549 if not chunk: | |
550 break | |
551 | |
552 def test_read_multiple_frames(self): | |
553 cctx = zstd.ZstdCompressor() | |
554 source = io.BytesIO() | |
555 writer = cctx.stream_writer(source) | |
556 writer.write(b'foo') | |
557 writer.flush(zstd.FLUSH_FRAME) | |
558 writer.write(b'bar') | |
559 writer.flush(zstd.FLUSH_FRAME) | |
560 | |
561 dctx = zstd.ZstdDecompressor() | |
562 | |
563 reader = dctx.stream_reader(source.getvalue()) | |
564 self.assertEqual(reader.read(2), b'fo') | |
565 self.assertEqual(reader.read(2), b'o') | |
566 self.assertEqual(reader.read(2), b'ba') | |
567 self.assertEqual(reader.read(2), b'r') | |
568 | |
569 source.seek(0) | |
570 reader = dctx.stream_reader(source) | |
571 self.assertEqual(reader.read(2), b'fo') | |
572 self.assertEqual(reader.read(2), b'o') | |
573 self.assertEqual(reader.read(2), b'ba') | |
574 self.assertEqual(reader.read(2), b'r') | |
575 | |
576 reader = dctx.stream_reader(source.getvalue()) | |
577 self.assertEqual(reader.read(3), b'foo') | |
578 self.assertEqual(reader.read(3), b'bar') | |
579 | |
580 source.seek(0) | |
581 reader = dctx.stream_reader(source) | |
582 self.assertEqual(reader.read(3), b'foo') | |
583 self.assertEqual(reader.read(3), b'bar') | |
584 | |
585 reader = dctx.stream_reader(source.getvalue()) | |
586 self.assertEqual(reader.read(4), b'foo') | |
587 self.assertEqual(reader.read(4), b'bar') | |
588 | |
589 source.seek(0) | |
590 reader = dctx.stream_reader(source) | |
591 self.assertEqual(reader.read(4), b'foo') | |
592 self.assertEqual(reader.read(4), b'bar') | |
593 | |
594 reader = dctx.stream_reader(source.getvalue()) | |
595 self.assertEqual(reader.read(128), b'foo') | |
596 self.assertEqual(reader.read(128), b'bar') | |
597 | |
598 source.seek(0) | |
599 reader = dctx.stream_reader(source) | |
600 self.assertEqual(reader.read(128), b'foo') | |
601 self.assertEqual(reader.read(128), b'bar') | |
602 | |
603 # Now tests for reads spanning frames. | |
604 reader = dctx.stream_reader(source.getvalue(), read_across_frames=True) | |
605 self.assertEqual(reader.read(3), b'foo') | |
606 self.assertEqual(reader.read(3), b'bar') | |
607 | |
608 source.seek(0) | |
609 reader = dctx.stream_reader(source, read_across_frames=True) | |
610 self.assertEqual(reader.read(3), b'foo') | |
611 self.assertEqual(reader.read(3), b'bar') | |
612 | |
613 reader = dctx.stream_reader(source.getvalue(), read_across_frames=True) | |
614 self.assertEqual(reader.read(6), b'foobar') | |
615 | |
616 source.seek(0) | |
617 reader = dctx.stream_reader(source, read_across_frames=True) | |
618 self.assertEqual(reader.read(6), b'foobar') | |
619 | |
620 reader = dctx.stream_reader(source.getvalue(), read_across_frames=True) | |
621 self.assertEqual(reader.read(7), b'foobar') | |
622 | |
623 source.seek(0) | |
624 reader = dctx.stream_reader(source, read_across_frames=True) | |
625 self.assertEqual(reader.read(7), b'foobar') | |
626 | |
627 reader = dctx.stream_reader(source.getvalue(), read_across_frames=True) | |
628 self.assertEqual(reader.read(128), b'foobar') | |
629 | |
630 source.seek(0) | |
631 reader = dctx.stream_reader(source, read_across_frames=True) | |
632 self.assertEqual(reader.read(128), b'foobar') | |
633 | |
634 def test_readinto(self): | |
635 cctx = zstd.ZstdCompressor() | |
636 foo = cctx.compress(b'foo') | |
637 | |
638 dctx = zstd.ZstdDecompressor() | |
639 | |
640 # Attempting to readinto() a non-writable buffer fails. | |
641 # The exact exception varies based on the backend. | |
642 reader = dctx.stream_reader(foo) | |
643 with self.assertRaises(Exception): | |
644 reader.readinto(b'foobar') | |
645 | |
646 # readinto() with sufficiently large destination. | |
647 b = bytearray(1024) | |
648 reader = dctx.stream_reader(foo) | |
649 self.assertEqual(reader.readinto(b), 3) | |
650 self.assertEqual(b[0:3], b'foo') | |
651 self.assertEqual(reader.readinto(b), 0) | |
652 self.assertEqual(b[0:3], b'foo') | |
653 | |
654 # readinto() with small reads. | |
655 b = bytearray(1024) | |
656 reader = dctx.stream_reader(foo, read_size=1) | |
657 self.assertEqual(reader.readinto(b), 3) | |
658 self.assertEqual(b[0:3], b'foo') | |
659 | |
660 # Too small destination buffer. | |
661 b = bytearray(2) | |
662 reader = dctx.stream_reader(foo) | |
663 self.assertEqual(reader.readinto(b), 2) | |
664 self.assertEqual(b[:], b'fo') | |
665 | |
666 def test_readinto1(self): | |
667 cctx = zstd.ZstdCompressor() | |
668 foo = cctx.compress(b'foo') | |
669 | |
670 dctx = zstd.ZstdDecompressor() | |
671 | |
672 reader = dctx.stream_reader(foo) | |
673 with self.assertRaises(Exception): | |
674 reader.readinto1(b'foobar') | |
675 | |
676 # Sufficiently large destination. | |
677 b = bytearray(1024) | |
678 reader = dctx.stream_reader(foo) | |
679 self.assertEqual(reader.readinto1(b), 3) | |
680 self.assertEqual(b[0:3], b'foo') | |
681 self.assertEqual(reader.readinto1(b), 0) | |
682 self.assertEqual(b[0:3], b'foo') | |
683 | |
684 # readinto() with small reads. | |
685 b = bytearray(1024) | |
686 reader = dctx.stream_reader(foo, read_size=1) | |
687 self.assertEqual(reader.readinto1(b), 3) | |
688 self.assertEqual(b[0:3], b'foo') | |
689 | |
690 # Too small destination buffer. | |
691 b = bytearray(2) | |
692 reader = dctx.stream_reader(foo) | |
693 self.assertEqual(reader.readinto1(b), 2) | |
694 self.assertEqual(b[:], b'fo') | |
695 | |
696 def test_readall(self): | |
697 cctx = zstd.ZstdCompressor() | |
698 foo = cctx.compress(b'foo') | |
699 | |
700 dctx = zstd.ZstdDecompressor() | |
701 reader = dctx.stream_reader(foo) | |
702 | |
703 self.assertEqual(reader.readall(), b'foo') | |
704 | |
705 def test_read1(self): | |
706 cctx = zstd.ZstdCompressor() | |
707 foo = cctx.compress(b'foo') | |
708 | |
709 dctx = zstd.ZstdDecompressor() | |
710 | |
711 b = OpCountingBytesIO(foo) | |
712 reader = dctx.stream_reader(b) | |
713 | |
714 self.assertEqual(reader.read1(), b'foo') | |
715 self.assertEqual(b._read_count, 1) | |
716 | |
717 b = OpCountingBytesIO(foo) | |
718 reader = dctx.stream_reader(b) | |
719 | |
720 self.assertEqual(reader.read1(0), b'') | |
721 self.assertEqual(reader.read1(2), b'fo') | |
722 self.assertEqual(b._read_count, 1) | |
723 self.assertEqual(reader.read1(1), b'o') | |
724 self.assertEqual(b._read_count, 1) | |
725 self.assertEqual(reader.read1(1), b'') | |
726 self.assertEqual(b._read_count, 2) | |
727 | |
728 def test_read_lines(self): | |
729 cctx = zstd.ZstdCompressor() | |
730 source = b'\n'.join(('line %d' % i).encode('ascii') for i in range(1024)) | |
731 | |
732 frame = cctx.compress(source) | |
733 | |
734 dctx = zstd.ZstdDecompressor() | |
735 reader = dctx.stream_reader(frame) | |
736 tr = io.TextIOWrapper(reader, encoding='utf-8') | |
737 | |
738 lines = [] | |
739 for line in tr: | |
740 lines.append(line.encode('utf-8')) | |
741 | |
742 self.assertEqual(len(lines), 1024) | |
743 self.assertEqual(b''.join(lines), source) | |
744 | |
745 reader = dctx.stream_reader(frame) | |
746 tr = io.TextIOWrapper(reader, encoding='utf-8') | |
747 | |
748 lines = tr.readlines() | |
749 self.assertEqual(len(lines), 1024) | |
750 self.assertEqual(''.join(lines).encode('utf-8'), source) | |
751 | |
752 reader = dctx.stream_reader(frame) | |
753 tr = io.TextIOWrapper(reader, encoding='utf-8') | |
754 | |
755 lines = [] | |
756 while True: | |
757 line = tr.readline() | |
758 if not line: | |
759 break | |
760 | |
761 lines.append(line.encode('utf-8')) | |
762 | |
763 self.assertEqual(len(lines), 1024) | |
764 self.assertEqual(b''.join(lines), source) | |
533 | 765 |
534 | 766 |
535 @make_cffi | 767 @make_cffi |
536 class TestDecompressor_decompressobj(unittest.TestCase): | 768 class TestDecompressor_decompressobj(unittest.TestCase): |
537 def test_simple(self): | 769 def test_simple(self): |
538 data = zstd.ZstdCompressor(level=1).compress(b'foobar') | 770 data = zstd.ZstdCompressor(level=1).compress(b'foobar') |
539 | 771 |
540 dctx = zstd.ZstdDecompressor() | 772 dctx = zstd.ZstdDecompressor() |
541 dobj = dctx.decompressobj() | 773 dobj = dctx.decompressobj() |
542 self.assertEqual(dobj.decompress(data), b'foobar') | 774 self.assertEqual(dobj.decompress(data), b'foobar') |
775 self.assertIsNone(dobj.flush()) | |
776 self.assertIsNone(dobj.flush(10)) | |
777 self.assertIsNone(dobj.flush(length=100)) | |
543 | 778 |
544 def test_input_types(self): | 779 def test_input_types(self): |
545 compressed = zstd.ZstdCompressor(level=1).compress(b'foo') | 780 compressed = zstd.ZstdCompressor(level=1).compress(b'foo') |
546 | 781 |
547 dctx = zstd.ZstdDecompressor() | 782 dctx = zstd.ZstdDecompressor() |
555 mutable_array, | 790 mutable_array, |
556 ] | 791 ] |
557 | 792 |
558 for source in sources: | 793 for source in sources: |
559 dobj = dctx.decompressobj() | 794 dobj = dctx.decompressobj() |
795 self.assertIsNone(dobj.flush()) | |
796 self.assertIsNone(dobj.flush(10)) | |
797 self.assertIsNone(dobj.flush(length=100)) | |
560 self.assertEqual(dobj.decompress(source), b'foo') | 798 self.assertEqual(dobj.decompress(source), b'foo') |
799 self.assertIsNone(dobj.flush()) | |
561 | 800 |
562 def test_reuse(self): | 801 def test_reuse(self): |
563 data = zstd.ZstdCompressor(level=1).compress(b'foobar') | 802 data = zstd.ZstdCompressor(level=1).compress(b'foobar') |
564 | 803 |
565 dctx = zstd.ZstdDecompressor() | 804 dctx = zstd.ZstdDecompressor() |
566 dobj = dctx.decompressobj() | 805 dobj = dctx.decompressobj() |
567 dobj.decompress(data) | 806 dobj.decompress(data) |
568 | 807 |
569 with self.assertRaisesRegexp(zstd.ZstdError, 'cannot use a decompressobj'): | 808 with self.assertRaisesRegexp(zstd.ZstdError, 'cannot use a decompressobj'): |
570 dobj.decompress(data) | 809 dobj.decompress(data) |
810 self.assertIsNone(dobj.flush()) | |
571 | 811 |
572 def test_bad_write_size(self): | 812 def test_bad_write_size(self): |
573 dctx = zstd.ZstdDecompressor() | 813 dctx = zstd.ZstdDecompressor() |
574 | 814 |
575 with self.assertRaisesRegexp(ValueError, 'write_size must be positive'): | 815 with self.assertRaisesRegexp(ValueError, 'write_size must be positive'): |
583 | 823 |
584 for i in range(128): | 824 for i in range(128): |
585 dobj = dctx.decompressobj(write_size=i + 1) | 825 dobj = dctx.decompressobj(write_size=i + 1) |
586 self.assertEqual(dobj.decompress(data), source) | 826 self.assertEqual(dobj.decompress(data), source) |
587 | 827 |
828 | |
588 def decompress_via_writer(data): | 829 def decompress_via_writer(data): |
589 buffer = io.BytesIO() | 830 buffer = io.BytesIO() |
590 dctx = zstd.ZstdDecompressor() | 831 dctx = zstd.ZstdDecompressor() |
591 with dctx.stream_writer(buffer) as decompressor: | 832 decompressor = dctx.stream_writer(buffer) |
592 decompressor.write(data) | 833 decompressor.write(data) |
834 | |
593 return buffer.getvalue() | 835 return buffer.getvalue() |
594 | 836 |
595 | 837 |
596 @make_cffi | 838 @make_cffi |
597 class TestDecompressor_stream_writer(unittest.TestCase): | 839 class TestDecompressor_stream_writer(unittest.TestCase): |
840 def test_io_api(self): | |
841 buffer = io.BytesIO() | |
842 dctx = zstd.ZstdDecompressor() | |
843 writer = dctx.stream_writer(buffer) | |
844 | |
845 self.assertFalse(writer.closed) | |
846 self.assertFalse(writer.isatty()) | |
847 self.assertFalse(writer.readable()) | |
848 | |
849 with self.assertRaises(io.UnsupportedOperation): | |
850 writer.readline() | |
851 | |
852 with self.assertRaises(io.UnsupportedOperation): | |
853 writer.readline(42) | |
854 | |
855 with self.assertRaises(io.UnsupportedOperation): | |
856 writer.readline(size=42) | |
857 | |
858 with self.assertRaises(io.UnsupportedOperation): | |
859 writer.readlines() | |
860 | |
861 with self.assertRaises(io.UnsupportedOperation): | |
862 writer.readlines(42) | |
863 | |
864 with self.assertRaises(io.UnsupportedOperation): | |
865 writer.readlines(hint=42) | |
866 | |
867 with self.assertRaises(io.UnsupportedOperation): | |
868 writer.seek(0) | |
869 | |
870 with self.assertRaises(io.UnsupportedOperation): | |
871 writer.seek(10, os.SEEK_SET) | |
872 | |
873 self.assertFalse(writer.seekable()) | |
874 | |
875 with self.assertRaises(io.UnsupportedOperation): | |
876 writer.tell() | |
877 | |
878 with self.assertRaises(io.UnsupportedOperation): | |
879 writer.truncate() | |
880 | |
881 with self.assertRaises(io.UnsupportedOperation): | |
882 writer.truncate(42) | |
883 | |
884 with self.assertRaises(io.UnsupportedOperation): | |
885 writer.truncate(size=42) | |
886 | |
887 self.assertTrue(writer.writable()) | |
888 | |
889 with self.assertRaises(io.UnsupportedOperation): | |
890 writer.writelines([]) | |
891 | |
892 with self.assertRaises(io.UnsupportedOperation): | |
893 writer.read() | |
894 | |
895 with self.assertRaises(io.UnsupportedOperation): | |
896 writer.read(42) | |
897 | |
898 with self.assertRaises(io.UnsupportedOperation): | |
899 writer.read(size=42) | |
900 | |
901 with self.assertRaises(io.UnsupportedOperation): | |
902 writer.readall() | |
903 | |
904 with self.assertRaises(io.UnsupportedOperation): | |
905 writer.readinto(None) | |
906 | |
907 with self.assertRaises(io.UnsupportedOperation): | |
908 writer.fileno() | |
909 | |
910 def test_fileno_file(self): | |
911 with tempfile.TemporaryFile('wb') as tf: | |
912 dctx = zstd.ZstdDecompressor() | |
913 writer = dctx.stream_writer(tf) | |
914 | |
915 self.assertEqual(writer.fileno(), tf.fileno()) | |
916 | |
917 def test_close(self): | |
918 foo = zstd.ZstdCompressor().compress(b'foo') | |
919 | |
920 buffer = NonClosingBytesIO() | |
921 dctx = zstd.ZstdDecompressor() | |
922 writer = dctx.stream_writer(buffer) | |
923 | |
924 writer.write(foo) | |
925 self.assertFalse(writer.closed) | |
926 self.assertFalse(buffer.closed) | |
927 writer.close() | |
928 self.assertTrue(writer.closed) | |
929 self.assertTrue(buffer.closed) | |
930 | |
931 with self.assertRaisesRegexp(ValueError, 'stream is closed'): | |
932 writer.write(b'') | |
933 | |
934 with self.assertRaisesRegexp(ValueError, 'stream is closed'): | |
935 writer.flush() | |
936 | |
937 with self.assertRaisesRegexp(ValueError, 'stream is closed'): | |
938 with writer: | |
939 pass | |
940 | |
941 self.assertEqual(buffer.getvalue(), b'foo') | |
942 | |
943 # Context manager exit should close stream. | |
944 buffer = NonClosingBytesIO() | |
945 writer = dctx.stream_writer(buffer) | |
946 | |
947 with writer: | |
948 writer.write(foo) | |
949 | |
950 self.assertTrue(writer.closed) | |
951 self.assertEqual(buffer.getvalue(), b'foo') | |
952 | |
953 def test_flush(self): | |
954 buffer = OpCountingBytesIO() | |
955 dctx = zstd.ZstdDecompressor() | |
956 writer = dctx.stream_writer(buffer) | |
957 | |
958 writer.flush() | |
959 self.assertEqual(buffer._flush_count, 1) | |
960 writer.flush() | |
961 self.assertEqual(buffer._flush_count, 2) | |
962 | |
598 def test_empty_roundtrip(self): | 963 def test_empty_roundtrip(self): |
599 cctx = zstd.ZstdCompressor() | 964 cctx = zstd.ZstdCompressor() |
600 empty = cctx.compress(b'') | 965 empty = cctx.compress(b'') |
601 self.assertEqual(decompress_via_writer(empty), b'') | 966 self.assertEqual(decompress_via_writer(empty), b'') |
602 | 967 |
614 ] | 979 ] |
615 | 980 |
616 dctx = zstd.ZstdDecompressor() | 981 dctx = zstd.ZstdDecompressor() |
617 for source in sources: | 982 for source in sources: |
618 buffer = io.BytesIO() | 983 buffer = io.BytesIO() |
984 | |
985 decompressor = dctx.stream_writer(buffer) | |
986 decompressor.write(source) | |
987 self.assertEqual(buffer.getvalue(), b'foo') | |
988 | |
989 buffer = NonClosingBytesIO() | |
990 | |
619 with dctx.stream_writer(buffer) as decompressor: | 991 with dctx.stream_writer(buffer) as decompressor: |
620 decompressor.write(source) | 992 self.assertEqual(decompressor.write(source), 3) |
621 | 993 |
994 self.assertEqual(buffer.getvalue(), b'foo') | |
995 | |
996 buffer = io.BytesIO() | |
997 writer = dctx.stream_writer(buffer, write_return_read=True) | |
998 self.assertEqual(writer.write(source), len(source)) | |
622 self.assertEqual(buffer.getvalue(), b'foo') | 999 self.assertEqual(buffer.getvalue(), b'foo') |
623 | 1000 |
624 def test_large_roundtrip(self): | 1001 def test_large_roundtrip(self): |
625 chunks = [] | 1002 chunks = [] |
626 for i in range(255): | 1003 for i in range(255): |
639 | 1016 |
640 orig = b''.join(chunks) | 1017 orig = b''.join(chunks) |
641 cctx = zstd.ZstdCompressor() | 1018 cctx = zstd.ZstdCompressor() |
642 compressed = cctx.compress(orig) | 1019 compressed = cctx.compress(orig) |
643 | 1020 |
644 buffer = io.BytesIO() | 1021 buffer = NonClosingBytesIO() |
645 dctx = zstd.ZstdDecompressor() | 1022 dctx = zstd.ZstdDecompressor() |
646 with dctx.stream_writer(buffer) as decompressor: | 1023 with dctx.stream_writer(buffer) as decompressor: |
647 pos = 0 | 1024 pos = 0 |
648 while pos < len(compressed): | 1025 while pos < len(compressed): |
649 pos2 = pos + 8192 | 1026 pos2 = pos + 8192 |
650 decompressor.write(compressed[pos:pos2]) | 1027 decompressor.write(compressed[pos:pos2]) |
651 pos += 8192 | 1028 pos += 8192 |
652 self.assertEqual(buffer.getvalue(), orig) | 1029 self.assertEqual(buffer.getvalue(), orig) |
653 | 1030 |
1031 # Again with write_return_read=True | |
1032 buffer = io.BytesIO() | |
1033 writer = dctx.stream_writer(buffer, write_return_read=True) | |
1034 pos = 0 | |
1035 while pos < len(compressed): | |
1036 pos2 = pos + 8192 | |
1037 chunk = compressed[pos:pos2] | |
1038 self.assertEqual(writer.write(chunk), len(chunk)) | |
1039 pos += 8192 | |
1040 self.assertEqual(buffer.getvalue(), orig) | |
1041 | |
654 def test_dictionary(self): | 1042 def test_dictionary(self): |
655 samples = [] | 1043 samples = [] |
656 for i in range(128): | 1044 for i in range(128): |
657 samples.append(b'foo' * 64) | 1045 samples.append(b'foo' * 64) |
658 samples.append(b'bar' * 64) | 1046 samples.append(b'bar' * 64) |
659 samples.append(b'foobar' * 64) | 1047 samples.append(b'foobar' * 64) |
660 | 1048 |
661 d = zstd.train_dictionary(8192, samples) | 1049 d = zstd.train_dictionary(8192, samples) |
662 | 1050 |
663 orig = b'foobar' * 16384 | 1051 orig = b'foobar' * 16384 |
664 buffer = io.BytesIO() | 1052 buffer = NonClosingBytesIO() |
665 cctx = zstd.ZstdCompressor(dict_data=d) | 1053 cctx = zstd.ZstdCompressor(dict_data=d) |
666 with cctx.stream_writer(buffer) as compressor: | 1054 with cctx.stream_writer(buffer) as compressor: |
667 self.assertEqual(compressor.write(orig), 0) | 1055 self.assertEqual(compressor.write(orig), 0) |
668 | 1056 |
669 compressed = buffer.getvalue() | 1057 compressed = buffer.getvalue() |
670 buffer = io.BytesIO() | 1058 buffer = io.BytesIO() |
671 | 1059 |
672 dctx = zstd.ZstdDecompressor(dict_data=d) | 1060 dctx = zstd.ZstdDecompressor(dict_data=d) |
1061 decompressor = dctx.stream_writer(buffer) | |
1062 self.assertEqual(decompressor.write(compressed), len(orig)) | |
1063 self.assertEqual(buffer.getvalue(), orig) | |
1064 | |
1065 buffer = NonClosingBytesIO() | |
1066 | |
673 with dctx.stream_writer(buffer) as decompressor: | 1067 with dctx.stream_writer(buffer) as decompressor: |
674 self.assertEqual(decompressor.write(compressed), len(orig)) | 1068 self.assertEqual(decompressor.write(compressed), len(orig)) |
675 | 1069 |
676 self.assertEqual(buffer.getvalue(), orig) | 1070 self.assertEqual(buffer.getvalue(), orig) |
677 | 1071 |
678 def test_memory_size(self): | 1072 def test_memory_size(self): |
679 dctx = zstd.ZstdDecompressor() | 1073 dctx = zstd.ZstdDecompressor() |
680 buffer = io.BytesIO() | 1074 buffer = io.BytesIO() |
1075 | |
1076 decompressor = dctx.stream_writer(buffer) | |
1077 size = decompressor.memory_size() | |
1078 self.assertGreater(size, 100000) | |
1079 | |
681 with dctx.stream_writer(buffer) as decompressor: | 1080 with dctx.stream_writer(buffer) as decompressor: |
682 size = decompressor.memory_size() | 1081 size = decompressor.memory_size() |
683 | 1082 |
684 self.assertGreater(size, 100000) | 1083 self.assertGreater(size, 100000) |
685 | 1084 |
808 self.assertEqual(decompressed, source.getvalue()) | 1207 self.assertEqual(decompressed, source.getvalue()) |
809 | 1208 |
810 @unittest.skipUnless('ZSTD_SLOW_TESTS' in os.environ, 'ZSTD_SLOW_TESTS not set') | 1209 @unittest.skipUnless('ZSTD_SLOW_TESTS' in os.environ, 'ZSTD_SLOW_TESTS not set') |
811 def test_large_input(self): | 1210 def test_large_input(self): |
812 bytes = list(struct.Struct('>B').pack(i) for i in range(256)) | 1211 bytes = list(struct.Struct('>B').pack(i) for i in range(256)) |
813 compressed = io.BytesIO() | 1212 compressed = NonClosingBytesIO() |
814 input_size = 0 | 1213 input_size = 0 |
815 cctx = zstd.ZstdCompressor(level=1) | 1214 cctx = zstd.ZstdCompressor(level=1) |
816 with cctx.stream_writer(compressed) as compressor: | 1215 with cctx.stream_writer(compressed) as compressor: |
817 while True: | 1216 while True: |
818 compressor.write(random.choice(bytes)) | 1217 compressor.write(random.choice(bytes)) |
821 have_compressed = len(compressed.getvalue()) > zstd.DECOMPRESSION_RECOMMENDED_INPUT_SIZE | 1220 have_compressed = len(compressed.getvalue()) > zstd.DECOMPRESSION_RECOMMENDED_INPUT_SIZE |
822 have_raw = input_size > zstd.DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE * 2 | 1221 have_raw = input_size > zstd.DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE * 2 |
823 if have_compressed and have_raw: | 1222 if have_compressed and have_raw: |
824 break | 1223 break |
825 | 1224 |
826 compressed.seek(0) | 1225 compressed = io.BytesIO(compressed.getvalue()) |
827 self.assertGreater(len(compressed.getvalue()), | 1226 self.assertGreater(len(compressed.getvalue()), |
828 zstd.DECOMPRESSION_RECOMMENDED_INPUT_SIZE) | 1227 zstd.DECOMPRESSION_RECOMMENDED_INPUT_SIZE) |
829 | 1228 |
830 dctx = zstd.ZstdDecompressor() | 1229 dctx = zstd.ZstdDecompressor() |
831 it = dctx.read_to_iter(compressed) | 1230 it = dctx.read_to_iter(compressed) |
859 # Found this edge case via fuzzing. | 1258 # Found this edge case via fuzzing. |
860 cctx = zstd.ZstdCompressor(level=1) | 1259 cctx = zstd.ZstdCompressor(level=1) |
861 | 1260 |
862 source = io.BytesIO() | 1261 source = io.BytesIO() |
863 | 1262 |
864 compressed = io.BytesIO() | 1263 compressed = NonClosingBytesIO() |
865 with cctx.stream_writer(compressed) as compressor: | 1264 with cctx.stream_writer(compressed) as compressor: |
866 for i in range(256): | 1265 for i in range(256): |
867 chunk = b'\0' * 1024 | 1266 chunk = b'\0' * 1024 |
868 compressor.write(chunk) | 1267 compressor.write(chunk) |
869 source.write(chunk) | 1268 source.write(chunk) |
872 | 1271 |
873 simple = dctx.decompress(compressed.getvalue(), | 1272 simple = dctx.decompress(compressed.getvalue(), |
874 max_output_size=len(source.getvalue())) | 1273 max_output_size=len(source.getvalue())) |
875 self.assertEqual(simple, source.getvalue()) | 1274 self.assertEqual(simple, source.getvalue()) |
876 | 1275 |
877 compressed.seek(0) | 1276 compressed = io.BytesIO(compressed.getvalue()) |
878 streamed = b''.join(dctx.read_to_iter(compressed)) | 1277 streamed = b''.join(dctx.read_to_iter(compressed)) |
879 self.assertEqual(streamed, source.getvalue()) | 1278 self.assertEqual(streamed, source.getvalue()) |
880 | 1279 |
881 def test_read_write_size(self): | 1280 def test_read_write_size(self): |
882 source = OpCountingBytesIO(zstd.ZstdCompressor().compress(b'foobarfoobar')) | 1281 source = OpCountingBytesIO(zstd.ZstdCompressor().compress(b'foobarfoobar')) |
999 # TODO enable for CFFI | 1398 # TODO enable for CFFI |
1000 class TestDecompressor_multi_decompress_to_buffer(unittest.TestCase): | 1399 class TestDecompressor_multi_decompress_to_buffer(unittest.TestCase): |
1001 def test_invalid_inputs(self): | 1400 def test_invalid_inputs(self): |
1002 dctx = zstd.ZstdDecompressor() | 1401 dctx = zstd.ZstdDecompressor() |
1003 | 1402 |
1403 if not hasattr(dctx, 'multi_decompress_to_buffer'): | |
1404 self.skipTest('multi_decompress_to_buffer not available') | |
1405 | |
1004 with self.assertRaises(TypeError): | 1406 with self.assertRaises(TypeError): |
1005 dctx.multi_decompress_to_buffer(True) | 1407 dctx.multi_decompress_to_buffer(True) |
1006 | 1408 |
1007 with self.assertRaises(TypeError): | 1409 with self.assertRaises(TypeError): |
1008 dctx.multi_decompress_to_buffer((1, 2)) | 1410 dctx.multi_decompress_to_buffer((1, 2)) |
1018 | 1420 |
1019 original = [b'foo' * 4, b'bar' * 6] | 1421 original = [b'foo' * 4, b'bar' * 6] |
1020 frames = [cctx.compress(d) for d in original] | 1422 frames = [cctx.compress(d) for d in original] |
1021 | 1423 |
1022 dctx = zstd.ZstdDecompressor() | 1424 dctx = zstd.ZstdDecompressor() |
1425 | |
1426 if not hasattr(dctx, 'multi_decompress_to_buffer'): | |
1427 self.skipTest('multi_decompress_to_buffer not available') | |
1428 | |
1023 result = dctx.multi_decompress_to_buffer(frames) | 1429 result = dctx.multi_decompress_to_buffer(frames) |
1024 | 1430 |
1025 self.assertEqual(len(result), len(frames)) | 1431 self.assertEqual(len(result), len(frames)) |
1026 self.assertEqual(result.size(), sum(map(len, original))) | 1432 self.assertEqual(result.size(), sum(map(len, original))) |
1027 | 1433 |
1039 original = [b'foo' * 4, b'bar' * 6, b'baz' * 8] | 1445 original = [b'foo' * 4, b'bar' * 6, b'baz' * 8] |
1040 frames = [cctx.compress(d) for d in original] | 1446 frames = [cctx.compress(d) for d in original] |
1041 sizes = struct.pack('=' + 'Q' * len(original), *map(len, original)) | 1447 sizes = struct.pack('=' + 'Q' * len(original), *map(len, original)) |
1042 | 1448 |
1043 dctx = zstd.ZstdDecompressor() | 1449 dctx = zstd.ZstdDecompressor() |
1450 | |
1451 if not hasattr(dctx, 'multi_decompress_to_buffer'): | |
1452 self.skipTest('multi_decompress_to_buffer not available') | |
1453 | |
1044 result = dctx.multi_decompress_to_buffer(frames, decompressed_sizes=sizes) | 1454 result = dctx.multi_decompress_to_buffer(frames, decompressed_sizes=sizes) |
1045 | 1455 |
1046 self.assertEqual(len(result), len(frames)) | 1456 self.assertEqual(len(result), len(frames)) |
1047 self.assertEqual(result.size(), sum(map(len, original))) | 1457 self.assertEqual(result.size(), sum(map(len, original))) |
1048 | 1458 |
1054 | 1464 |
1055 original = [b'foo' * 4, b'bar' * 6] | 1465 original = [b'foo' * 4, b'bar' * 6] |
1056 frames = [cctx.compress(d) for d in original] | 1466 frames = [cctx.compress(d) for d in original] |
1057 | 1467 |
1058 dctx = zstd.ZstdDecompressor() | 1468 dctx = zstd.ZstdDecompressor() |
1469 | |
1470 if not hasattr(dctx, 'multi_decompress_to_buffer'): | |
1471 self.skipTest('multi_decompress_to_buffer not available') | |
1059 | 1472 |
1060 segments = struct.pack('=QQQQ', 0, len(frames[0]), len(frames[0]), len(frames[1])) | 1473 segments = struct.pack('=QQQQ', 0, len(frames[0]), len(frames[0]), len(frames[1])) |
1061 b = zstd.BufferWithSegments(b''.join(frames), segments) | 1474 b = zstd.BufferWithSegments(b''.join(frames), segments) |
1062 | 1475 |
1063 result = dctx.multi_decompress_to_buffer(b) | 1476 result = dctx.multi_decompress_to_buffer(b) |
1072 cctx = zstd.ZstdCompressor(write_content_size=False) | 1485 cctx = zstd.ZstdCompressor(write_content_size=False) |
1073 original = [b'foo' * 4, b'bar' * 6, b'baz' * 8] | 1486 original = [b'foo' * 4, b'bar' * 6, b'baz' * 8] |
1074 frames = [cctx.compress(d) for d in original] | 1487 frames = [cctx.compress(d) for d in original] |
1075 sizes = struct.pack('=' + 'Q' * len(original), *map(len, original)) | 1488 sizes = struct.pack('=' + 'Q' * len(original), *map(len, original)) |
1076 | 1489 |
1490 dctx = zstd.ZstdDecompressor() | |
1491 | |
1492 if not hasattr(dctx, 'multi_decompress_to_buffer'): | |
1493 self.skipTest('multi_decompress_to_buffer not available') | |
1494 | |
1077 segments = struct.pack('=QQQQQQ', 0, len(frames[0]), | 1495 segments = struct.pack('=QQQQQQ', 0, len(frames[0]), |
1078 len(frames[0]), len(frames[1]), | 1496 len(frames[0]), len(frames[1]), |
1079 len(frames[0]) + len(frames[1]), len(frames[2])) | 1497 len(frames[0]) + len(frames[1]), len(frames[2])) |
1080 b = zstd.BufferWithSegments(b''.join(frames), segments) | 1498 b = zstd.BufferWithSegments(b''.join(frames), segments) |
1081 | 1499 |
1082 dctx = zstd.ZstdDecompressor() | |
1083 result = dctx.multi_decompress_to_buffer(b, decompressed_sizes=sizes) | 1500 result = dctx.multi_decompress_to_buffer(b, decompressed_sizes=sizes) |
1084 | 1501 |
1085 self.assertEqual(len(result), len(frames)) | 1502 self.assertEqual(len(result), len(frames)) |
1086 self.assertEqual(result.size(), sum(map(len, original))) | 1503 self.assertEqual(result.size(), sum(map(len, original))) |
1087 | 1504 |
1097 b'foo2' * 4, | 1514 b'foo2' * 4, |
1098 b'foo3' * 5, | 1515 b'foo3' * 5, |
1099 b'foo4' * 6, | 1516 b'foo4' * 6, |
1100 ] | 1517 ] |
1101 | 1518 |
1519 if not hasattr(cctx, 'multi_compress_to_buffer'): | |
1520 self.skipTest('multi_compress_to_buffer not available') | |
1521 | |
1102 frames = cctx.multi_compress_to_buffer(original) | 1522 frames = cctx.multi_compress_to_buffer(original) |
1103 | 1523 |
1104 # Check round trip. | 1524 # Check round trip. |
1105 dctx = zstd.ZstdDecompressor() | 1525 dctx = zstd.ZstdDecompressor() |
1526 | |
1106 decompressed = dctx.multi_decompress_to_buffer(frames, threads=3) | 1527 decompressed = dctx.multi_decompress_to_buffer(frames, threads=3) |
1107 | 1528 |
1108 self.assertEqual(len(decompressed), len(original)) | 1529 self.assertEqual(len(decompressed), len(original)) |
1109 | 1530 |
1110 for i, data in enumerate(original): | 1531 for i, data in enumerate(original): |
1136 | 1557 |
1137 cctx = zstd.ZstdCompressor(dict_data=d, level=1) | 1558 cctx = zstd.ZstdCompressor(dict_data=d, level=1) |
1138 frames = [cctx.compress(s) for s in generate_samples()] | 1559 frames = [cctx.compress(s) for s in generate_samples()] |
1139 | 1560 |
1140 dctx = zstd.ZstdDecompressor(dict_data=d) | 1561 dctx = zstd.ZstdDecompressor(dict_data=d) |
1562 | |
1563 if not hasattr(dctx, 'multi_decompress_to_buffer'): | |
1564 self.skipTest('multi_decompress_to_buffer not available') | |
1565 | |
1141 result = dctx.multi_decompress_to_buffer(frames) | 1566 result = dctx.multi_decompress_to_buffer(frames) |
1567 | |
1142 self.assertEqual([o.tobytes() for o in result], generate_samples()) | 1568 self.assertEqual([o.tobytes() for o in result], generate_samples()) |
1143 | 1569 |
1144 def test_multiple_threads(self): | 1570 def test_multiple_threads(self): |
1145 cctx = zstd.ZstdCompressor() | 1571 cctx = zstd.ZstdCompressor() |
1146 | 1572 |
1147 frames = [] | 1573 frames = [] |
1148 frames.extend(cctx.compress(b'x' * 64) for i in range(256)) | 1574 frames.extend(cctx.compress(b'x' * 64) for i in range(256)) |
1149 frames.extend(cctx.compress(b'y' * 64) for i in range(256)) | 1575 frames.extend(cctx.compress(b'y' * 64) for i in range(256)) |
1150 | 1576 |
1151 dctx = zstd.ZstdDecompressor() | 1577 dctx = zstd.ZstdDecompressor() |
1578 | |
1579 if not hasattr(dctx, 'multi_decompress_to_buffer'): | |
1580 self.skipTest('multi_decompress_to_buffer not available') | |
1581 | |
1152 result = dctx.multi_decompress_to_buffer(frames, threads=-1) | 1582 result = dctx.multi_decompress_to_buffer(frames, threads=-1) |
1153 | 1583 |
1154 self.assertEqual(len(result), len(frames)) | 1584 self.assertEqual(len(result), len(frames)) |
1155 self.assertEqual(result.size(), 2 * 64 * 256) | 1585 self.assertEqual(result.size(), 2 * 64 * 256) |
1156 self.assertEqual(result[0].tobytes(), b'x' * 64) | 1586 self.assertEqual(result[0].tobytes(), b'x' * 64) |
1161 frames = [cctx.compress(b'x' * 128), cctx.compress(b'y' * 128)] | 1591 frames = [cctx.compress(b'x' * 128), cctx.compress(b'y' * 128)] |
1162 | 1592 |
1163 frames[1] = frames[1][0:15] + b'extra' + frames[1][15:] | 1593 frames[1] = frames[1][0:15] + b'extra' + frames[1][15:] |
1164 | 1594 |
1165 dctx = zstd.ZstdDecompressor() | 1595 dctx = zstd.ZstdDecompressor() |
1596 | |
1597 if not hasattr(dctx, 'multi_decompress_to_buffer'): | |
1598 self.skipTest('multi_decompress_to_buffer not available') | |
1166 | 1599 |
1167 with self.assertRaisesRegexp(zstd.ZstdError, | 1600 with self.assertRaisesRegexp(zstd.ZstdError, |
1168 'error decompressing item 1: (' | 1601 'error decompressing item 1: (' |
1169 'Corrupted block|' | 1602 'Corrupted block|' |
1170 'Destination buffer is too small)'): | 1603 'Destination buffer is too small)'): |