changeset 42070 | 675775c33ab6 |
parent 40121 | 73fef626dae3 |
child 43994 | de7838053207 |
42069:668eff08387f | 42070:675775c33ab6 |
---|---|
1 import io |
1 import io |
2 import os |
2 import os |
3 import random |
3 import random |
4 import struct |
4 import struct |
5 import sys |
5 import sys |
6 import tempfile |
|
6 import unittest |
7 import unittest |
7 |
8 |
8 import zstandard as zstd |
9 import zstandard as zstd |
9 |
10 |
10 from .common import ( |
11 from .common import ( |
11 generate_samples, |
12 generate_samples, |
12 make_cffi, |
13 make_cffi, |
14 NonClosingBytesIO, |
|
13 OpCountingBytesIO, |
15 OpCountingBytesIO, |
14 ) |
16 ) |
15 |
17 |
16 |
18 |
17 if sys.version_info[0] >= 3: |
19 if sys.version_info[0] >= 3: |
217 # If we write a content size, the decompressor engages single pass |
219 # If we write a content size, the decompressor engages single pass |
218 # mode and the window size doesn't come into play. |
220 # mode and the window size doesn't come into play. |
219 cctx = zstd.ZstdCompressor(write_content_size=False) |
221 cctx = zstd.ZstdCompressor(write_content_size=False) |
220 frame = cctx.compress(source) |
222 frame = cctx.compress(source) |
221 |
223 |
222 dctx = zstd.ZstdDecompressor(max_window_size=1) |
224 dctx = zstd.ZstdDecompressor(max_window_size=2**zstd.WINDOWLOG_MIN) |
223 |
225 |
224 with self.assertRaisesRegexp( |
226 with self.assertRaisesRegexp( |
225 zstd.ZstdError, 'decompression error: Frame requires too much memory'): |
227 zstd.ZstdError, 'decompression error: Frame requires too much memory'): |
226 dctx.decompress(frame, max_output_size=len(source)) |
228 dctx.decompress(frame, max_output_size=len(source)) |
227 |
229 |
300 |
302 |
301 def test_not_implemented(self): |
303 def test_not_implemented(self): |
302 dctx = zstd.ZstdDecompressor() |
304 dctx = zstd.ZstdDecompressor() |
303 |
305 |
304 with dctx.stream_reader(b'foo') as reader: |
306 with dctx.stream_reader(b'foo') as reader: |
305 with self.assertRaises(NotImplementedError): |
307 with self.assertRaises(io.UnsupportedOperation): |
306 reader.readline() |
308 reader.readline() |
307 |
309 |
308 with self.assertRaises(NotImplementedError): |
310 with self.assertRaises(io.UnsupportedOperation): |
309 reader.readlines() |
311 reader.readlines() |
310 |
312 |
311 with self.assertRaises(NotImplementedError): |
313 with self.assertRaises(io.UnsupportedOperation): |
312 reader.readall() |
|
313 |
|
314 with self.assertRaises(NotImplementedError): |
|
315 iter(reader) |
314 iter(reader) |
316 |
315 |
317 with self.assertRaises(NotImplementedError): |
316 with self.assertRaises(io.UnsupportedOperation): |
318 next(reader) |
317 next(reader) |
319 |
318 |
320 with self.assertRaises(io.UnsupportedOperation): |
319 with self.assertRaises(io.UnsupportedOperation): |
321 reader.write(b'foo') |
320 reader.write(b'foo') |
322 |
321 |
345 reader.close() |
344 reader.close() |
346 self.assertTrue(reader.closed) |
345 self.assertTrue(reader.closed) |
347 with self.assertRaisesRegexp(ValueError, 'stream is closed'): |
346 with self.assertRaisesRegexp(ValueError, 'stream is closed'): |
348 reader.read(1) |
347 reader.read(1) |
349 |
348 |
350 def test_bad_read_size(self): |
349 def test_read_sizes(self): |
351 dctx = zstd.ZstdDecompressor() |
350 cctx = zstd.ZstdCompressor() |
352 |
351 foo = cctx.compress(b'foo') |
353 with dctx.stream_reader(b'foo') as reader: |
352 |
354 with self.assertRaisesRegexp(ValueError, 'cannot read negative or size 0 amounts'): |
353 dctx = zstd.ZstdDecompressor() |
355 reader.read(-1) |
354 |
356 |
355 with dctx.stream_reader(foo) as reader: |
357 with self.assertRaisesRegexp(ValueError, 'cannot read negative or size 0 amounts'): |
356 with self.assertRaisesRegexp(ValueError, 'cannot read negative amounts less than -1'): |
358 reader.read(0) |
357 reader.read(-2) |
358 |
|
359 self.assertEqual(reader.read(0), b'') |
|
360 self.assertEqual(reader.read(), b'foo') |
|
359 |
361 |
360 def test_read_buffer(self): |
362 def test_read_buffer(self): |
361 cctx = zstd.ZstdCompressor() |
363 cctx = zstd.ZstdCompressor() |
362 |
364 |
363 source = b''.join([b'foo' * 60, b'bar' * 60, b'baz' * 60]) |
365 source = b''.join([b'foo' * 60, b'bar' * 60, b'baz' * 60]) |
522 dctx = zstd.ZstdDecompressor() |
524 dctx = zstd.ZstdDecompressor() |
523 |
525 |
524 reader = dctx.stream_reader(source) |
526 reader = dctx.stream_reader(source) |
525 |
527 |
526 with reader: |
528 with reader: |
527 with self.assertRaises(TypeError): |
529 reader.read(0) |
528 reader.read() |
|
529 |
530 |
530 with reader: |
531 with reader: |
531 with self.assertRaisesRegexp(ValueError, 'stream is closed'): |
532 with self.assertRaisesRegexp(ValueError, 'stream is closed'): |
532 reader.read(100) |
533 reader.read(100) |
534 |
|
535 def test_partial_read(self): |
|
536 # Inspired by https://github.com/indygreg/python-zstandard/issues/71. |
|
537 buffer = io.BytesIO() |
|
538 cctx = zstd.ZstdCompressor() |
|
539 writer = cctx.stream_writer(buffer) |
|
540 writer.write(bytearray(os.urandom(1000000))) |
|
541 writer.flush(zstd.FLUSH_FRAME) |
|
542 buffer.seek(0) |
|
543 |
|
544 dctx = zstd.ZstdDecompressor() |
|
545 reader = dctx.stream_reader(buffer) |
|
546 |
|
547 while True: |
|
548 chunk = reader.read(8192) |
|
549 if not chunk: |
|
550 break |
|
551 |
|
552 def test_read_multiple_frames(self): |
|
553 cctx = zstd.ZstdCompressor() |
|
554 source = io.BytesIO() |
|
555 writer = cctx.stream_writer(source) |
|
556 writer.write(b'foo') |
|
557 writer.flush(zstd.FLUSH_FRAME) |
|
558 writer.write(b'bar') |
|
559 writer.flush(zstd.FLUSH_FRAME) |
|
560 |
|
561 dctx = zstd.ZstdDecompressor() |
|
562 |
|
563 reader = dctx.stream_reader(source.getvalue()) |
|
564 self.assertEqual(reader.read(2), b'fo') |
|
565 self.assertEqual(reader.read(2), b'o') |
|
566 self.assertEqual(reader.read(2), b'ba') |
|
567 self.assertEqual(reader.read(2), b'r') |
|
568 |
|
569 source.seek(0) |
|
570 reader = dctx.stream_reader(source) |
|
571 self.assertEqual(reader.read(2), b'fo') |
|
572 self.assertEqual(reader.read(2), b'o') |
|
573 self.assertEqual(reader.read(2), b'ba') |
|
574 self.assertEqual(reader.read(2), b'r') |
|
575 |
|
576 reader = dctx.stream_reader(source.getvalue()) |
|
577 self.assertEqual(reader.read(3), b'foo') |
|
578 self.assertEqual(reader.read(3), b'bar') |
|
579 |
|
580 source.seek(0) |
|
581 reader = dctx.stream_reader(source) |
|
582 self.assertEqual(reader.read(3), b'foo') |
|
583 self.assertEqual(reader.read(3), b'bar') |
|
584 |
|
585 reader = dctx.stream_reader(source.getvalue()) |
|
586 self.assertEqual(reader.read(4), b'foo') |
|
587 self.assertEqual(reader.read(4), b'bar') |
|
588 |
|
589 source.seek(0) |
|
590 reader = dctx.stream_reader(source) |
|
591 self.assertEqual(reader.read(4), b'foo') |
|
592 self.assertEqual(reader.read(4), b'bar') |
|
593 |
|
594 reader = dctx.stream_reader(source.getvalue()) |
|
595 self.assertEqual(reader.read(128), b'foo') |
|
596 self.assertEqual(reader.read(128), b'bar') |
|
597 |
|
598 source.seek(0) |
|
599 reader = dctx.stream_reader(source) |
|
600 self.assertEqual(reader.read(128), b'foo') |
|
601 self.assertEqual(reader.read(128), b'bar') |
|
602 |
|
603 # Now tests for reads spanning frames. |
|
604 reader = dctx.stream_reader(source.getvalue(), read_across_frames=True) |
|
605 self.assertEqual(reader.read(3), b'foo') |
|
606 self.assertEqual(reader.read(3), b'bar') |
|
607 |
|
608 source.seek(0) |
|
609 reader = dctx.stream_reader(source, read_across_frames=True) |
|
610 self.assertEqual(reader.read(3), b'foo') |
|
611 self.assertEqual(reader.read(3), b'bar') |
|
612 |
|
613 reader = dctx.stream_reader(source.getvalue(), read_across_frames=True) |
|
614 self.assertEqual(reader.read(6), b'foobar') |
|
615 |
|
616 source.seek(0) |
|
617 reader = dctx.stream_reader(source, read_across_frames=True) |
|
618 self.assertEqual(reader.read(6), b'foobar') |
|
619 |
|
620 reader = dctx.stream_reader(source.getvalue(), read_across_frames=True) |
|
621 self.assertEqual(reader.read(7), b'foobar') |
|
622 |
|
623 source.seek(0) |
|
624 reader = dctx.stream_reader(source, read_across_frames=True) |
|
625 self.assertEqual(reader.read(7), b'foobar') |
|
626 |
|
627 reader = dctx.stream_reader(source.getvalue(), read_across_frames=True) |
|
628 self.assertEqual(reader.read(128), b'foobar') |
|
629 |
|
630 source.seek(0) |
|
631 reader = dctx.stream_reader(source, read_across_frames=True) |
|
632 self.assertEqual(reader.read(128), b'foobar') |
|
633 |
|
634 def test_readinto(self): |
|
635 cctx = zstd.ZstdCompressor() |
|
636 foo = cctx.compress(b'foo') |
|
637 |
|
638 dctx = zstd.ZstdDecompressor() |
|
639 |
|
640 # Attempting to readinto() a non-writable buffer fails. |
|
641 # The exact exception varies based on the backend. |
|
642 reader = dctx.stream_reader(foo) |
|
643 with self.assertRaises(Exception): |
|
644 reader.readinto(b'foobar') |
|
645 |
|
646 # readinto() with sufficiently large destination. |
|
647 b = bytearray(1024) |
|
648 reader = dctx.stream_reader(foo) |
|
649 self.assertEqual(reader.readinto(b), 3) |
|
650 self.assertEqual(b[0:3], b'foo') |
|
651 self.assertEqual(reader.readinto(b), 0) |
|
652 self.assertEqual(b[0:3], b'foo') |
|
653 |
|
654 # readinto() with small reads. |
|
655 b = bytearray(1024) |
|
656 reader = dctx.stream_reader(foo, read_size=1) |
|
657 self.assertEqual(reader.readinto(b), 3) |
|
658 self.assertEqual(b[0:3], b'foo') |
|
659 |
|
660 # Too small destination buffer. |
|
661 b = bytearray(2) |
|
662 reader = dctx.stream_reader(foo) |
|
663 self.assertEqual(reader.readinto(b), 2) |
|
664 self.assertEqual(b[:], b'fo') |
|
665 |
|
666 def test_readinto1(self): |
|
667 cctx = zstd.ZstdCompressor() |
|
668 foo = cctx.compress(b'foo') |
|
669 |
|
670 dctx = zstd.ZstdDecompressor() |
|
671 |
|
672 reader = dctx.stream_reader(foo) |
|
673 with self.assertRaises(Exception): |
|
674 reader.readinto1(b'foobar') |
|
675 |
|
676 # Sufficiently large destination. |
|
677 b = bytearray(1024) |
|
678 reader = dctx.stream_reader(foo) |
|
679 self.assertEqual(reader.readinto1(b), 3) |
|
680 self.assertEqual(b[0:3], b'foo') |
|
681 self.assertEqual(reader.readinto1(b), 0) |
|
682 self.assertEqual(b[0:3], b'foo') |
|
683 |
|
684 # readinto() with small reads. |
|
685 b = bytearray(1024) |
|
686 reader = dctx.stream_reader(foo, read_size=1) |
|
687 self.assertEqual(reader.readinto1(b), 3) |
|
688 self.assertEqual(b[0:3], b'foo') |
|
689 |
|
690 # Too small destination buffer. |
|
691 b = bytearray(2) |
|
692 reader = dctx.stream_reader(foo) |
|
693 self.assertEqual(reader.readinto1(b), 2) |
|
694 self.assertEqual(b[:], b'fo') |
|
695 |
|
696 def test_readall(self): |
|
697 cctx = zstd.ZstdCompressor() |
|
698 foo = cctx.compress(b'foo') |
|
699 |
|
700 dctx = zstd.ZstdDecompressor() |
|
701 reader = dctx.stream_reader(foo) |
|
702 |
|
703 self.assertEqual(reader.readall(), b'foo') |
|
704 |
|
705 def test_read1(self): |
|
706 cctx = zstd.ZstdCompressor() |
|
707 foo = cctx.compress(b'foo') |
|
708 |
|
709 dctx = zstd.ZstdDecompressor() |
|
710 |
|
711 b = OpCountingBytesIO(foo) |
|
712 reader = dctx.stream_reader(b) |
|
713 |
|
714 self.assertEqual(reader.read1(), b'foo') |
|
715 self.assertEqual(b._read_count, 1) |
|
716 |
|
717 b = OpCountingBytesIO(foo) |
|
718 reader = dctx.stream_reader(b) |
|
719 |
|
720 self.assertEqual(reader.read1(0), b'') |
|
721 self.assertEqual(reader.read1(2), b'fo') |
|
722 self.assertEqual(b._read_count, 1) |
|
723 self.assertEqual(reader.read1(1), b'o') |
|
724 self.assertEqual(b._read_count, 1) |
|
725 self.assertEqual(reader.read1(1), b'') |
|
726 self.assertEqual(b._read_count, 2) |
|
727 |
|
728 def test_read_lines(self): |
|
729 cctx = zstd.ZstdCompressor() |
|
730 source = b'\n'.join(('line %d' % i).encode('ascii') for i in range(1024)) |
|
731 |
|
732 frame = cctx.compress(source) |
|
733 |
|
734 dctx = zstd.ZstdDecompressor() |
|
735 reader = dctx.stream_reader(frame) |
|
736 tr = io.TextIOWrapper(reader, encoding='utf-8') |
|
737 |
|
738 lines = [] |
|
739 for line in tr: |
|
740 lines.append(line.encode('utf-8')) |
|
741 |
|
742 self.assertEqual(len(lines), 1024) |
|
743 self.assertEqual(b''.join(lines), source) |
|
744 |
|
745 reader = dctx.stream_reader(frame) |
|
746 tr = io.TextIOWrapper(reader, encoding='utf-8') |
|
747 |
|
748 lines = tr.readlines() |
|
749 self.assertEqual(len(lines), 1024) |
|
750 self.assertEqual(''.join(lines).encode('utf-8'), source) |
|
751 |
|
752 reader = dctx.stream_reader(frame) |
|
753 tr = io.TextIOWrapper(reader, encoding='utf-8') |
|
754 |
|
755 lines = [] |
|
756 while True: |
|
757 line = tr.readline() |
|
758 if not line: |
|
759 break |
|
760 |
|
761 lines.append(line.encode('utf-8')) |
|
762 |
|
763 self.assertEqual(len(lines), 1024) |
|
764 self.assertEqual(b''.join(lines), source) |
|
533 |
765 |
534 |
766 |
535 @make_cffi |
767 @make_cffi |
536 class TestDecompressor_decompressobj(unittest.TestCase): |
768 class TestDecompressor_decompressobj(unittest.TestCase): |
537 def test_simple(self): |
769 def test_simple(self): |
538 data = zstd.ZstdCompressor(level=1).compress(b'foobar') |
770 data = zstd.ZstdCompressor(level=1).compress(b'foobar') |
539 |
771 |
540 dctx = zstd.ZstdDecompressor() |
772 dctx = zstd.ZstdDecompressor() |
541 dobj = dctx.decompressobj() |
773 dobj = dctx.decompressobj() |
542 self.assertEqual(dobj.decompress(data), b'foobar') |
774 self.assertEqual(dobj.decompress(data), b'foobar') |
775 self.assertIsNone(dobj.flush()) |
|
776 self.assertIsNone(dobj.flush(10)) |
|
777 self.assertIsNone(dobj.flush(length=100)) |
|
543 |
778 |
544 def test_input_types(self): |
779 def test_input_types(self): |
545 compressed = zstd.ZstdCompressor(level=1).compress(b'foo') |
780 compressed = zstd.ZstdCompressor(level=1).compress(b'foo') |
546 |
781 |
547 dctx = zstd.ZstdDecompressor() |
782 dctx = zstd.ZstdDecompressor() |
555 mutable_array, |
790 mutable_array, |
556 ] |
791 ] |
557 |
792 |
558 for source in sources: |
793 for source in sources: |
559 dobj = dctx.decompressobj() |
794 dobj = dctx.decompressobj() |
795 self.assertIsNone(dobj.flush()) |
|
796 self.assertIsNone(dobj.flush(10)) |
|
797 self.assertIsNone(dobj.flush(length=100)) |
|
560 self.assertEqual(dobj.decompress(source), b'foo') |
798 self.assertEqual(dobj.decompress(source), b'foo') |
799 self.assertIsNone(dobj.flush()) |
|
561 |
800 |
562 def test_reuse(self): |
801 def test_reuse(self): |
563 data = zstd.ZstdCompressor(level=1).compress(b'foobar') |
802 data = zstd.ZstdCompressor(level=1).compress(b'foobar') |
564 |
803 |
565 dctx = zstd.ZstdDecompressor() |
804 dctx = zstd.ZstdDecompressor() |
566 dobj = dctx.decompressobj() |
805 dobj = dctx.decompressobj() |
567 dobj.decompress(data) |
806 dobj.decompress(data) |
568 |
807 |
569 with self.assertRaisesRegexp(zstd.ZstdError, 'cannot use a decompressobj'): |
808 with self.assertRaisesRegexp(zstd.ZstdError, 'cannot use a decompressobj'): |
570 dobj.decompress(data) |
809 dobj.decompress(data) |
810 self.assertIsNone(dobj.flush()) |
|
571 |
811 |
572 def test_bad_write_size(self): |
812 def test_bad_write_size(self): |
573 dctx = zstd.ZstdDecompressor() |
813 dctx = zstd.ZstdDecompressor() |
574 |
814 |
575 with self.assertRaisesRegexp(ValueError, 'write_size must be positive'): |
815 with self.assertRaisesRegexp(ValueError, 'write_size must be positive'): |
583 |
823 |
584 for i in range(128): |
824 for i in range(128): |
585 dobj = dctx.decompressobj(write_size=i + 1) |
825 dobj = dctx.decompressobj(write_size=i + 1) |
586 self.assertEqual(dobj.decompress(data), source) |
826 self.assertEqual(dobj.decompress(data), source) |
587 |
827 |
828 |
|
588 def decompress_via_writer(data): |
829 def decompress_via_writer(data): |
589 buffer = io.BytesIO() |
830 buffer = io.BytesIO() |
590 dctx = zstd.ZstdDecompressor() |
831 dctx = zstd.ZstdDecompressor() |
591 with dctx.stream_writer(buffer) as decompressor: |
832 decompressor = dctx.stream_writer(buffer) |
592 decompressor.write(data) |
833 decompressor.write(data) |
834 |
|
593 return buffer.getvalue() |
835 return buffer.getvalue() |
594 |
836 |
595 |
837 |
596 @make_cffi |
838 @make_cffi |
597 class TestDecompressor_stream_writer(unittest.TestCase): |
839 class TestDecompressor_stream_writer(unittest.TestCase): |
840 def test_io_api(self): |
|
841 buffer = io.BytesIO() |
|
842 dctx = zstd.ZstdDecompressor() |
|
843 writer = dctx.stream_writer(buffer) |
|
844 |
|
845 self.assertFalse(writer.closed) |
|
846 self.assertFalse(writer.isatty()) |
|
847 self.assertFalse(writer.readable()) |
|
848 |
|
849 with self.assertRaises(io.UnsupportedOperation): |
|
850 writer.readline() |
|
851 |
|
852 with self.assertRaises(io.UnsupportedOperation): |
|
853 writer.readline(42) |
|
854 |
|
855 with self.assertRaises(io.UnsupportedOperation): |
|
856 writer.readline(size=42) |
|
857 |
|
858 with self.assertRaises(io.UnsupportedOperation): |
|
859 writer.readlines() |
|
860 |
|
861 with self.assertRaises(io.UnsupportedOperation): |
|
862 writer.readlines(42) |
|
863 |
|
864 with self.assertRaises(io.UnsupportedOperation): |
|
865 writer.readlines(hint=42) |
|
866 |
|
867 with self.assertRaises(io.UnsupportedOperation): |
|
868 writer.seek(0) |
|
869 |
|
870 with self.assertRaises(io.UnsupportedOperation): |
|
871 writer.seek(10, os.SEEK_SET) |
|
872 |
|
873 self.assertFalse(writer.seekable()) |
|
874 |
|
875 with self.assertRaises(io.UnsupportedOperation): |
|
876 writer.tell() |
|
877 |
|
878 with self.assertRaises(io.UnsupportedOperation): |
|
879 writer.truncate() |
|
880 |
|
881 with self.assertRaises(io.UnsupportedOperation): |
|
882 writer.truncate(42) |
|
883 |
|
884 with self.assertRaises(io.UnsupportedOperation): |
|
885 writer.truncate(size=42) |
|
886 |
|
887 self.assertTrue(writer.writable()) |
|
888 |
|
889 with self.assertRaises(io.UnsupportedOperation): |
|
890 writer.writelines([]) |
|
891 |
|
892 with self.assertRaises(io.UnsupportedOperation): |
|
893 writer.read() |
|
894 |
|
895 with self.assertRaises(io.UnsupportedOperation): |
|
896 writer.read(42) |
|
897 |
|
898 with self.assertRaises(io.UnsupportedOperation): |
|
899 writer.read(size=42) |
|
900 |
|
901 with self.assertRaises(io.UnsupportedOperation): |
|
902 writer.readall() |
|
903 |
|
904 with self.assertRaises(io.UnsupportedOperation): |
|
905 writer.readinto(None) |
|
906 |
|
907 with self.assertRaises(io.UnsupportedOperation): |
|
908 writer.fileno() |
|
909 |
|
910 def test_fileno_file(self): |
|
911 with tempfile.TemporaryFile('wb') as tf: |
|
912 dctx = zstd.ZstdDecompressor() |
|
913 writer = dctx.stream_writer(tf) |
|
914 |
|
915 self.assertEqual(writer.fileno(), tf.fileno()) |
|
916 |
|
917 def test_close(self): |
|
918 foo = zstd.ZstdCompressor().compress(b'foo') |
|
919 |
|
920 buffer = NonClosingBytesIO() |
|
921 dctx = zstd.ZstdDecompressor() |
|
922 writer = dctx.stream_writer(buffer) |
|
923 |
|
924 writer.write(foo) |
|
925 self.assertFalse(writer.closed) |
|
926 self.assertFalse(buffer.closed) |
|
927 writer.close() |
|
928 self.assertTrue(writer.closed) |
|
929 self.assertTrue(buffer.closed) |
|
930 |
|
931 with self.assertRaisesRegexp(ValueError, 'stream is closed'): |
|
932 writer.write(b'') |
|
933 |
|
934 with self.assertRaisesRegexp(ValueError, 'stream is closed'): |
|
935 writer.flush() |
|
936 |
|
937 with self.assertRaisesRegexp(ValueError, 'stream is closed'): |
|
938 with writer: |
|
939 pass |
|
940 |
|
941 self.assertEqual(buffer.getvalue(), b'foo') |
|
942 |
|
943 # Context manager exit should close stream. |
|
944 buffer = NonClosingBytesIO() |
|
945 writer = dctx.stream_writer(buffer) |
|
946 |
|
947 with writer: |
|
948 writer.write(foo) |
|
949 |
|
950 self.assertTrue(writer.closed) |
|
951 self.assertEqual(buffer.getvalue(), b'foo') |
|
952 |
|
953 def test_flush(self): |
|
954 buffer = OpCountingBytesIO() |
|
955 dctx = zstd.ZstdDecompressor() |
|
956 writer = dctx.stream_writer(buffer) |
|
957 |
|
958 writer.flush() |
|
959 self.assertEqual(buffer._flush_count, 1) |
|
960 writer.flush() |
|
961 self.assertEqual(buffer._flush_count, 2) |
|
962 |
|
598 def test_empty_roundtrip(self): |
963 def test_empty_roundtrip(self): |
599 cctx = zstd.ZstdCompressor() |
964 cctx = zstd.ZstdCompressor() |
600 empty = cctx.compress(b'') |
965 empty = cctx.compress(b'') |
601 self.assertEqual(decompress_via_writer(empty), b'') |
966 self.assertEqual(decompress_via_writer(empty), b'') |
602 |
967 |
614 ] |
979 ] |
615 |
980 |
616 dctx = zstd.ZstdDecompressor() |
981 dctx = zstd.ZstdDecompressor() |
617 for source in sources: |
982 for source in sources: |
618 buffer = io.BytesIO() |
983 buffer = io.BytesIO() |
984 |
|
985 decompressor = dctx.stream_writer(buffer) |
|
986 decompressor.write(source) |
|
987 self.assertEqual(buffer.getvalue(), b'foo') |
|
988 |
|
989 buffer = NonClosingBytesIO() |
|
990 |
|
619 with dctx.stream_writer(buffer) as decompressor: |
991 with dctx.stream_writer(buffer) as decompressor: |
620 decompressor.write(source) |
992 self.assertEqual(decompressor.write(source), 3) |
621 |
993 |
994 self.assertEqual(buffer.getvalue(), b'foo') |
|
995 |
|
996 buffer = io.BytesIO() |
|
997 writer = dctx.stream_writer(buffer, write_return_read=True) |
|
998 self.assertEqual(writer.write(source), len(source)) |
|
622 self.assertEqual(buffer.getvalue(), b'foo') |
999 self.assertEqual(buffer.getvalue(), b'foo') |
623 |
1000 |
624 def test_large_roundtrip(self): |
1001 def test_large_roundtrip(self): |
625 chunks = [] |
1002 chunks = [] |
626 for i in range(255): |
1003 for i in range(255): |
639 |
1016 |
640 orig = b''.join(chunks) |
1017 orig = b''.join(chunks) |
641 cctx = zstd.ZstdCompressor() |
1018 cctx = zstd.ZstdCompressor() |
642 compressed = cctx.compress(orig) |
1019 compressed = cctx.compress(orig) |
643 |
1020 |
644 buffer = io.BytesIO() |
1021 buffer = NonClosingBytesIO() |
645 dctx = zstd.ZstdDecompressor() |
1022 dctx = zstd.ZstdDecompressor() |
646 with dctx.stream_writer(buffer) as decompressor: |
1023 with dctx.stream_writer(buffer) as decompressor: |
647 pos = 0 |
1024 pos = 0 |
648 while pos < len(compressed): |
1025 while pos < len(compressed): |
649 pos2 = pos + 8192 |
1026 pos2 = pos + 8192 |
650 decompressor.write(compressed[pos:pos2]) |
1027 decompressor.write(compressed[pos:pos2]) |
651 pos += 8192 |
1028 pos += 8192 |
652 self.assertEqual(buffer.getvalue(), orig) |
1029 self.assertEqual(buffer.getvalue(), orig) |
653 |
1030 |
1031 # Again with write_return_read=True |
|
1032 buffer = io.BytesIO() |
|
1033 writer = dctx.stream_writer(buffer, write_return_read=True) |
|
1034 pos = 0 |
|
1035 while pos < len(compressed): |
|
1036 pos2 = pos + 8192 |
|
1037 chunk = compressed[pos:pos2] |
|
1038 self.assertEqual(writer.write(chunk), len(chunk)) |
|
1039 pos += 8192 |
|
1040 self.assertEqual(buffer.getvalue(), orig) |
|
1041 |
|
654 def test_dictionary(self): |
1042 def test_dictionary(self): |
655 samples = [] |
1043 samples = [] |
656 for i in range(128): |
1044 for i in range(128): |
657 samples.append(b'foo' * 64) |
1045 samples.append(b'foo' * 64) |
658 samples.append(b'bar' * 64) |
1046 samples.append(b'bar' * 64) |
659 samples.append(b'foobar' * 64) |
1047 samples.append(b'foobar' * 64) |
660 |
1048 |
661 d = zstd.train_dictionary(8192, samples) |
1049 d = zstd.train_dictionary(8192, samples) |
662 |
1050 |
663 orig = b'foobar' * 16384 |
1051 orig = b'foobar' * 16384 |
664 buffer = io.BytesIO() |
1052 buffer = NonClosingBytesIO() |
665 cctx = zstd.ZstdCompressor(dict_data=d) |
1053 cctx = zstd.ZstdCompressor(dict_data=d) |
666 with cctx.stream_writer(buffer) as compressor: |
1054 with cctx.stream_writer(buffer) as compressor: |
667 self.assertEqual(compressor.write(orig), 0) |
1055 self.assertEqual(compressor.write(orig), 0) |
668 |
1056 |
669 compressed = buffer.getvalue() |
1057 compressed = buffer.getvalue() |
670 buffer = io.BytesIO() |
1058 buffer = io.BytesIO() |
671 |
1059 |
672 dctx = zstd.ZstdDecompressor(dict_data=d) |
1060 dctx = zstd.ZstdDecompressor(dict_data=d) |
1061 decompressor = dctx.stream_writer(buffer) |
|
1062 self.assertEqual(decompressor.write(compressed), len(orig)) |
|
1063 self.assertEqual(buffer.getvalue(), orig) |
|
1064 |
|
1065 buffer = NonClosingBytesIO() |
|
1066 |
|
673 with dctx.stream_writer(buffer) as decompressor: |
1067 with dctx.stream_writer(buffer) as decompressor: |
674 self.assertEqual(decompressor.write(compressed), len(orig)) |
1068 self.assertEqual(decompressor.write(compressed), len(orig)) |
675 |
1069 |
676 self.assertEqual(buffer.getvalue(), orig) |
1070 self.assertEqual(buffer.getvalue(), orig) |
677 |
1071 |
678 def test_memory_size(self): |
1072 def test_memory_size(self): |
679 dctx = zstd.ZstdDecompressor() |
1073 dctx = zstd.ZstdDecompressor() |
680 buffer = io.BytesIO() |
1074 buffer = io.BytesIO() |
1075 |
|
1076 decompressor = dctx.stream_writer(buffer) |
|
1077 size = decompressor.memory_size() |
|
1078 self.assertGreater(size, 100000) |
|
1079 |
|
681 with dctx.stream_writer(buffer) as decompressor: |
1080 with dctx.stream_writer(buffer) as decompressor: |
682 size = decompressor.memory_size() |
1081 size = decompressor.memory_size() |
683 |
1082 |
684 self.assertGreater(size, 100000) |
1083 self.assertGreater(size, 100000) |
685 |
1084 |
808 self.assertEqual(decompressed, source.getvalue()) |
1207 self.assertEqual(decompressed, source.getvalue()) |
809 |
1208 |
810 @unittest.skipUnless('ZSTD_SLOW_TESTS' in os.environ, 'ZSTD_SLOW_TESTS not set') |
1209 @unittest.skipUnless('ZSTD_SLOW_TESTS' in os.environ, 'ZSTD_SLOW_TESTS not set') |
811 def test_large_input(self): |
1210 def test_large_input(self): |
812 bytes = list(struct.Struct('>B').pack(i) for i in range(256)) |
1211 bytes = list(struct.Struct('>B').pack(i) for i in range(256)) |
813 compressed = io.BytesIO() |
1212 compressed = NonClosingBytesIO() |
814 input_size = 0 |
1213 input_size = 0 |
815 cctx = zstd.ZstdCompressor(level=1) |
1214 cctx = zstd.ZstdCompressor(level=1) |
816 with cctx.stream_writer(compressed) as compressor: |
1215 with cctx.stream_writer(compressed) as compressor: |
817 while True: |
1216 while True: |
818 compressor.write(random.choice(bytes)) |
1217 compressor.write(random.choice(bytes)) |
821 have_compressed = len(compressed.getvalue()) > zstd.DECOMPRESSION_RECOMMENDED_INPUT_SIZE |
1220 have_compressed = len(compressed.getvalue()) > zstd.DECOMPRESSION_RECOMMENDED_INPUT_SIZE |
822 have_raw = input_size > zstd.DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE * 2 |
1221 have_raw = input_size > zstd.DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE * 2 |
823 if have_compressed and have_raw: |
1222 if have_compressed and have_raw: |
824 break |
1223 break |
825 |
1224 |
826 compressed.seek(0) |
1225 compressed = io.BytesIO(compressed.getvalue()) |
827 self.assertGreater(len(compressed.getvalue()), |
1226 self.assertGreater(len(compressed.getvalue()), |
828 zstd.DECOMPRESSION_RECOMMENDED_INPUT_SIZE) |
1227 zstd.DECOMPRESSION_RECOMMENDED_INPUT_SIZE) |
829 |
1228 |
830 dctx = zstd.ZstdDecompressor() |
1229 dctx = zstd.ZstdDecompressor() |
831 it = dctx.read_to_iter(compressed) |
1230 it = dctx.read_to_iter(compressed) |
859 # Found this edge case via fuzzing. |
1258 # Found this edge case via fuzzing. |
860 cctx = zstd.ZstdCompressor(level=1) |
1259 cctx = zstd.ZstdCompressor(level=1) |
861 |
1260 |
862 source = io.BytesIO() |
1261 source = io.BytesIO() |
863 |
1262 |
864 compressed = io.BytesIO() |
1263 compressed = NonClosingBytesIO() |
865 with cctx.stream_writer(compressed) as compressor: |
1264 with cctx.stream_writer(compressed) as compressor: |
866 for i in range(256): |
1265 for i in range(256): |
867 chunk = b'\0' * 1024 |
1266 chunk = b'\0' * 1024 |
868 compressor.write(chunk) |
1267 compressor.write(chunk) |
869 source.write(chunk) |
1268 source.write(chunk) |
872 |
1271 |
873 simple = dctx.decompress(compressed.getvalue(), |
1272 simple = dctx.decompress(compressed.getvalue(), |
874 max_output_size=len(source.getvalue())) |
1273 max_output_size=len(source.getvalue())) |
875 self.assertEqual(simple, source.getvalue()) |
1274 self.assertEqual(simple, source.getvalue()) |
876 |
1275 |
877 compressed.seek(0) |
1276 compressed = io.BytesIO(compressed.getvalue()) |
878 streamed = b''.join(dctx.read_to_iter(compressed)) |
1277 streamed = b''.join(dctx.read_to_iter(compressed)) |
879 self.assertEqual(streamed, source.getvalue()) |
1278 self.assertEqual(streamed, source.getvalue()) |
880 |
1279 |
881 def test_read_write_size(self): |
1280 def test_read_write_size(self): |
882 source = OpCountingBytesIO(zstd.ZstdCompressor().compress(b'foobarfoobar')) |
1281 source = OpCountingBytesIO(zstd.ZstdCompressor().compress(b'foobarfoobar')) |
999 # TODO enable for CFFI |
1398 # TODO enable for CFFI |
1000 class TestDecompressor_multi_decompress_to_buffer(unittest.TestCase): |
1399 class TestDecompressor_multi_decompress_to_buffer(unittest.TestCase): |
1001 def test_invalid_inputs(self): |
1400 def test_invalid_inputs(self): |
1002 dctx = zstd.ZstdDecompressor() |
1401 dctx = zstd.ZstdDecompressor() |
1003 |
1402 |
1403 if not hasattr(dctx, 'multi_decompress_to_buffer'): |
|
1404 self.skipTest('multi_decompress_to_buffer not available') |
|
1405 |
|
1004 with self.assertRaises(TypeError): |
1406 with self.assertRaises(TypeError): |
1005 dctx.multi_decompress_to_buffer(True) |
1407 dctx.multi_decompress_to_buffer(True) |
1006 |
1408 |
1007 with self.assertRaises(TypeError): |
1409 with self.assertRaises(TypeError): |
1008 dctx.multi_decompress_to_buffer((1, 2)) |
1410 dctx.multi_decompress_to_buffer((1, 2)) |
1018 |
1420 |
1019 original = [b'foo' * 4, b'bar' * 6] |
1421 original = [b'foo' * 4, b'bar' * 6] |
1020 frames = [cctx.compress(d) for d in original] |
1422 frames = [cctx.compress(d) for d in original] |
1021 |
1423 |
1022 dctx = zstd.ZstdDecompressor() |
1424 dctx = zstd.ZstdDecompressor() |
1425 |
|
1426 if not hasattr(dctx, 'multi_decompress_to_buffer'): |
|
1427 self.skipTest('multi_decompress_to_buffer not available') |
|
1428 |
|
1023 result = dctx.multi_decompress_to_buffer(frames) |
1429 result = dctx.multi_decompress_to_buffer(frames) |
1024 |
1430 |
1025 self.assertEqual(len(result), len(frames)) |
1431 self.assertEqual(len(result), len(frames)) |
1026 self.assertEqual(result.size(), sum(map(len, original))) |
1432 self.assertEqual(result.size(), sum(map(len, original))) |
1027 |
1433 |
1039 original = [b'foo' * 4, b'bar' * 6, b'baz' * 8] |
1445 original = [b'foo' * 4, b'bar' * 6, b'baz' * 8] |
1040 frames = [cctx.compress(d) for d in original] |
1446 frames = [cctx.compress(d) for d in original] |
1041 sizes = struct.pack('=' + 'Q' * len(original), *map(len, original)) |
1447 sizes = struct.pack('=' + 'Q' * len(original), *map(len, original)) |
1042 |
1448 |
1043 dctx = zstd.ZstdDecompressor() |
1449 dctx = zstd.ZstdDecompressor() |
1450 |
|
1451 if not hasattr(dctx, 'multi_decompress_to_buffer'): |
|
1452 self.skipTest('multi_decompress_to_buffer not available') |
|
1453 |
|
1044 result = dctx.multi_decompress_to_buffer(frames, decompressed_sizes=sizes) |
1454 result = dctx.multi_decompress_to_buffer(frames, decompressed_sizes=sizes) |
1045 |
1455 |
1046 self.assertEqual(len(result), len(frames)) |
1456 self.assertEqual(len(result), len(frames)) |
1047 self.assertEqual(result.size(), sum(map(len, original))) |
1457 self.assertEqual(result.size(), sum(map(len, original))) |
1048 |
1458 |
1054 |
1464 |
1055 original = [b'foo' * 4, b'bar' * 6] |
1465 original = [b'foo' * 4, b'bar' * 6] |
1056 frames = [cctx.compress(d) for d in original] |
1466 frames = [cctx.compress(d) for d in original] |
1057 |
1467 |
1058 dctx = zstd.ZstdDecompressor() |
1468 dctx = zstd.ZstdDecompressor() |
1469 |
|
1470 if not hasattr(dctx, 'multi_decompress_to_buffer'): |
|
1471 self.skipTest('multi_decompress_to_buffer not available') |
|
1059 |
1472 |
1060 segments = struct.pack('=QQQQ', 0, len(frames[0]), len(frames[0]), len(frames[1])) |
1473 segments = struct.pack('=QQQQ', 0, len(frames[0]), len(frames[0]), len(frames[1])) |
1061 b = zstd.BufferWithSegments(b''.join(frames), segments) |
1474 b = zstd.BufferWithSegments(b''.join(frames), segments) |
1062 |
1475 |
1063 result = dctx.multi_decompress_to_buffer(b) |
1476 result = dctx.multi_decompress_to_buffer(b) |
1072 cctx = zstd.ZstdCompressor(write_content_size=False) |
1485 cctx = zstd.ZstdCompressor(write_content_size=False) |
1073 original = [b'foo' * 4, b'bar' * 6, b'baz' * 8] |
1486 original = [b'foo' * 4, b'bar' * 6, b'baz' * 8] |
1074 frames = [cctx.compress(d) for d in original] |
1487 frames = [cctx.compress(d) for d in original] |
1075 sizes = struct.pack('=' + 'Q' * len(original), *map(len, original)) |
1488 sizes = struct.pack('=' + 'Q' * len(original), *map(len, original)) |
1076 |
1489 |
1490 dctx = zstd.ZstdDecompressor() |
|
1491 |
|
1492 if not hasattr(dctx, 'multi_decompress_to_buffer'): |
|
1493 self.skipTest('multi_decompress_to_buffer not available') |
|
1494 |
|
1077 segments = struct.pack('=QQQQQQ', 0, len(frames[0]), |
1495 segments = struct.pack('=QQQQQQ', 0, len(frames[0]), |
1078 len(frames[0]), len(frames[1]), |
1496 len(frames[0]), len(frames[1]), |
1079 len(frames[0]) + len(frames[1]), len(frames[2])) |
1497 len(frames[0]) + len(frames[1]), len(frames[2])) |
1080 b = zstd.BufferWithSegments(b''.join(frames), segments) |
1498 b = zstd.BufferWithSegments(b''.join(frames), segments) |
1081 |
1499 |
1082 dctx = zstd.ZstdDecompressor() |
|
1083 result = dctx.multi_decompress_to_buffer(b, decompressed_sizes=sizes) |
1500 result = dctx.multi_decompress_to_buffer(b, decompressed_sizes=sizes) |
1084 |
1501 |
1085 self.assertEqual(len(result), len(frames)) |
1502 self.assertEqual(len(result), len(frames)) |
1086 self.assertEqual(result.size(), sum(map(len, original))) |
1503 self.assertEqual(result.size(), sum(map(len, original))) |
1087 |
1504 |
1097 b'foo2' * 4, |
1514 b'foo2' * 4, |
1098 b'foo3' * 5, |
1515 b'foo3' * 5, |
1099 b'foo4' * 6, |
1516 b'foo4' * 6, |
1100 ] |
1517 ] |
1101 |
1518 |
1519 if not hasattr(cctx, 'multi_compress_to_buffer'): |
|
1520 self.skipTest('multi_compress_to_buffer not available') |
|
1521 |
|
1102 frames = cctx.multi_compress_to_buffer(original) |
1522 frames = cctx.multi_compress_to_buffer(original) |
1103 |
1523 |
1104 # Check round trip. |
1524 # Check round trip. |
1105 dctx = zstd.ZstdDecompressor() |
1525 dctx = zstd.ZstdDecompressor() |
1526 |
|
1106 decompressed = dctx.multi_decompress_to_buffer(frames, threads=3) |
1527 decompressed = dctx.multi_decompress_to_buffer(frames, threads=3) |
1107 |
1528 |
1108 self.assertEqual(len(decompressed), len(original)) |
1529 self.assertEqual(len(decompressed), len(original)) |
1109 |
1530 |
1110 for i, data in enumerate(original): |
1531 for i, data in enumerate(original): |
1136 |
1557 |
1137 cctx = zstd.ZstdCompressor(dict_data=d, level=1) |
1558 cctx = zstd.ZstdCompressor(dict_data=d, level=1) |
1138 frames = [cctx.compress(s) for s in generate_samples()] |
1559 frames = [cctx.compress(s) for s in generate_samples()] |
1139 |
1560 |
1140 dctx = zstd.ZstdDecompressor(dict_data=d) |
1561 dctx = zstd.ZstdDecompressor(dict_data=d) |
1562 |
|
1563 if not hasattr(dctx, 'multi_decompress_to_buffer'): |
|
1564 self.skipTest('multi_decompress_to_buffer not available') |
|
1565 |
|
1141 result = dctx.multi_decompress_to_buffer(frames) |
1566 result = dctx.multi_decompress_to_buffer(frames) |
1567 |
|
1142 self.assertEqual([o.tobytes() for o in result], generate_samples()) |
1568 self.assertEqual([o.tobytes() for o in result], generate_samples()) |
1143 |
1569 |
1144 def test_multiple_threads(self): |
1570 def test_multiple_threads(self): |
1145 cctx = zstd.ZstdCompressor() |
1571 cctx = zstd.ZstdCompressor() |
1146 |
1572 |
1147 frames = [] |
1573 frames = [] |
1148 frames.extend(cctx.compress(b'x' * 64) for i in range(256)) |
1574 frames.extend(cctx.compress(b'x' * 64) for i in range(256)) |
1149 frames.extend(cctx.compress(b'y' * 64) for i in range(256)) |
1575 frames.extend(cctx.compress(b'y' * 64) for i in range(256)) |
1150 |
1576 |
1151 dctx = zstd.ZstdDecompressor() |
1577 dctx = zstd.ZstdDecompressor() |
1578 |
|
1579 if not hasattr(dctx, 'multi_decompress_to_buffer'): |
|
1580 self.skipTest('multi_decompress_to_buffer not available') |
|
1581 |
|
1152 result = dctx.multi_decompress_to_buffer(frames, threads=-1) |
1582 result = dctx.multi_decompress_to_buffer(frames, threads=-1) |
1153 |
1583 |
1154 self.assertEqual(len(result), len(frames)) |
1584 self.assertEqual(len(result), len(frames)) |
1155 self.assertEqual(result.size(), 2 * 64 * 256) |
1585 self.assertEqual(result.size(), 2 * 64 * 256) |
1156 self.assertEqual(result[0].tobytes(), b'x' * 64) |
1586 self.assertEqual(result[0].tobytes(), b'x' * 64) |
1161 frames = [cctx.compress(b'x' * 128), cctx.compress(b'y' * 128)] |
1591 frames = [cctx.compress(b'x' * 128), cctx.compress(b'y' * 128)] |
1162 |
1592 |
1163 frames[1] = frames[1][0:15] + b'extra' + frames[1][15:] |
1593 frames[1] = frames[1][0:15] + b'extra' + frames[1][15:] |
1164 |
1594 |
1165 dctx = zstd.ZstdDecompressor() |
1595 dctx = zstd.ZstdDecompressor() |
1596 |
|
1597 if not hasattr(dctx, 'multi_decompress_to_buffer'): |
|
1598 self.skipTest('multi_decompress_to_buffer not available') |
|
1166 |
1599 |
1167 with self.assertRaisesRegexp(zstd.ZstdError, |
1600 with self.assertRaisesRegexp(zstd.ZstdError, |
1168 'error decompressing item 1: (' |
1601 'error decompressing item 1: (' |
1169 'Corrupted block|' |
1602 'Corrupted block|' |
1170 'Destination buffer is too small)'): |
1603 'Destination buffer is too small)'): |