comparison contrib/python-zstandard/tests/test_decompressor.py @ 42070:675775c33ab6

zstandard: vendor python-zstandard 0.11 The upstream source distribution from PyPI was extracted. Unwanted files were removed. The clang-format ignore list was updated to reflect the new source of files. The project contains a vendored copy of zstandard 1.3.8. The old version was 1.3.6. This should result in some minor performance wins. test-check-py3-compat.t was updated to reflect now-passing tests on Python 3.8. Some HTTP tests were updated to reflect new zstd compression output. # no-check-commit because 3rd party code has different style guidelines Differential Revision: https://phab.mercurial-scm.org/D6199
author Gregory Szorc <gregory.szorc@gmail.com>
date Thu, 04 Apr 2019 17:34:43 -0700
parents 73fef626dae3
children de7838053207
comparison
equal deleted inserted replaced
42069:668eff08387f 42070:675775c33ab6
1 import io 1 import io
2 import os 2 import os
3 import random 3 import random
4 import struct 4 import struct
5 import sys 5 import sys
6 import tempfile
6 import unittest 7 import unittest
7 8
8 import zstandard as zstd 9 import zstandard as zstd
9 10
10 from .common import ( 11 from .common import (
11 generate_samples, 12 generate_samples,
12 make_cffi, 13 make_cffi,
14 NonClosingBytesIO,
13 OpCountingBytesIO, 15 OpCountingBytesIO,
14 ) 16 )
15 17
16 18
17 if sys.version_info[0] >= 3: 19 if sys.version_info[0] >= 3:
217 # If we write a content size, the decompressor engages single pass 219 # If we write a content size, the decompressor engages single pass
218 # mode and the window size doesn't come into play. 220 # mode and the window size doesn't come into play.
219 cctx = zstd.ZstdCompressor(write_content_size=False) 221 cctx = zstd.ZstdCompressor(write_content_size=False)
220 frame = cctx.compress(source) 222 frame = cctx.compress(source)
221 223
222 dctx = zstd.ZstdDecompressor(max_window_size=1) 224 dctx = zstd.ZstdDecompressor(max_window_size=2**zstd.WINDOWLOG_MIN)
223 225
224 with self.assertRaisesRegexp( 226 with self.assertRaisesRegexp(
225 zstd.ZstdError, 'decompression error: Frame requires too much memory'): 227 zstd.ZstdError, 'decompression error: Frame requires too much memory'):
226 dctx.decompress(frame, max_output_size=len(source)) 228 dctx.decompress(frame, max_output_size=len(source))
227 229
300 302
301 def test_not_implemented(self): 303 def test_not_implemented(self):
302 dctx = zstd.ZstdDecompressor() 304 dctx = zstd.ZstdDecompressor()
303 305
304 with dctx.stream_reader(b'foo') as reader: 306 with dctx.stream_reader(b'foo') as reader:
305 with self.assertRaises(NotImplementedError): 307 with self.assertRaises(io.UnsupportedOperation):
306 reader.readline() 308 reader.readline()
307 309
308 with self.assertRaises(NotImplementedError): 310 with self.assertRaises(io.UnsupportedOperation):
309 reader.readlines() 311 reader.readlines()
310 312
311 with self.assertRaises(NotImplementedError): 313 with self.assertRaises(io.UnsupportedOperation):
312 reader.readall()
313
314 with self.assertRaises(NotImplementedError):
315 iter(reader) 314 iter(reader)
316 315
317 with self.assertRaises(NotImplementedError): 316 with self.assertRaises(io.UnsupportedOperation):
318 next(reader) 317 next(reader)
319 318
320 with self.assertRaises(io.UnsupportedOperation): 319 with self.assertRaises(io.UnsupportedOperation):
321 reader.write(b'foo') 320 reader.write(b'foo')
322 321
345 reader.close() 344 reader.close()
346 self.assertTrue(reader.closed) 345 self.assertTrue(reader.closed)
347 with self.assertRaisesRegexp(ValueError, 'stream is closed'): 346 with self.assertRaisesRegexp(ValueError, 'stream is closed'):
348 reader.read(1) 347 reader.read(1)
349 348
350 def test_bad_read_size(self): 349 def test_read_sizes(self):
351 dctx = zstd.ZstdDecompressor() 350 cctx = zstd.ZstdCompressor()
352 351 foo = cctx.compress(b'foo')
353 with dctx.stream_reader(b'foo') as reader: 352
354 with self.assertRaisesRegexp(ValueError, 'cannot read negative or size 0 amounts'): 353 dctx = zstd.ZstdDecompressor()
355 reader.read(-1) 354
356 355 with dctx.stream_reader(foo) as reader:
357 with self.assertRaisesRegexp(ValueError, 'cannot read negative or size 0 amounts'): 356 with self.assertRaisesRegexp(ValueError, 'cannot read negative amounts less than -1'):
358 reader.read(0) 357 reader.read(-2)
358
359 self.assertEqual(reader.read(0), b'')
360 self.assertEqual(reader.read(), b'foo')
359 361
360 def test_read_buffer(self): 362 def test_read_buffer(self):
361 cctx = zstd.ZstdCompressor() 363 cctx = zstd.ZstdCompressor()
362 364
363 source = b''.join([b'foo' * 60, b'bar' * 60, b'baz' * 60]) 365 source = b''.join([b'foo' * 60, b'bar' * 60, b'baz' * 60])
522 dctx = zstd.ZstdDecompressor() 524 dctx = zstd.ZstdDecompressor()
523 525
524 reader = dctx.stream_reader(source) 526 reader = dctx.stream_reader(source)
525 527
526 with reader: 528 with reader:
527 with self.assertRaises(TypeError): 529 reader.read(0)
528 reader.read()
529 530
530 with reader: 531 with reader:
531 with self.assertRaisesRegexp(ValueError, 'stream is closed'): 532 with self.assertRaisesRegexp(ValueError, 'stream is closed'):
532 reader.read(100) 533 reader.read(100)
534
535 def test_partial_read(self):
536 # Inspired by https://github.com/indygreg/python-zstandard/issues/71.
537 buffer = io.BytesIO()
538 cctx = zstd.ZstdCompressor()
539 writer = cctx.stream_writer(buffer)
540 writer.write(bytearray(os.urandom(1000000)))
541 writer.flush(zstd.FLUSH_FRAME)
542 buffer.seek(0)
543
544 dctx = zstd.ZstdDecompressor()
545 reader = dctx.stream_reader(buffer)
546
547 while True:
548 chunk = reader.read(8192)
549 if not chunk:
550 break
551
552 def test_read_multiple_frames(self):
553 cctx = zstd.ZstdCompressor()
554 source = io.BytesIO()
555 writer = cctx.stream_writer(source)
556 writer.write(b'foo')
557 writer.flush(zstd.FLUSH_FRAME)
558 writer.write(b'bar')
559 writer.flush(zstd.FLUSH_FRAME)
560
561 dctx = zstd.ZstdDecompressor()
562
563 reader = dctx.stream_reader(source.getvalue())
564 self.assertEqual(reader.read(2), b'fo')
565 self.assertEqual(reader.read(2), b'o')
566 self.assertEqual(reader.read(2), b'ba')
567 self.assertEqual(reader.read(2), b'r')
568
569 source.seek(0)
570 reader = dctx.stream_reader(source)
571 self.assertEqual(reader.read(2), b'fo')
572 self.assertEqual(reader.read(2), b'o')
573 self.assertEqual(reader.read(2), b'ba')
574 self.assertEqual(reader.read(2), b'r')
575
576 reader = dctx.stream_reader(source.getvalue())
577 self.assertEqual(reader.read(3), b'foo')
578 self.assertEqual(reader.read(3), b'bar')
579
580 source.seek(0)
581 reader = dctx.stream_reader(source)
582 self.assertEqual(reader.read(3), b'foo')
583 self.assertEqual(reader.read(3), b'bar')
584
585 reader = dctx.stream_reader(source.getvalue())
586 self.assertEqual(reader.read(4), b'foo')
587 self.assertEqual(reader.read(4), b'bar')
588
589 source.seek(0)
590 reader = dctx.stream_reader(source)
591 self.assertEqual(reader.read(4), b'foo')
592 self.assertEqual(reader.read(4), b'bar')
593
594 reader = dctx.stream_reader(source.getvalue())
595 self.assertEqual(reader.read(128), b'foo')
596 self.assertEqual(reader.read(128), b'bar')
597
598 source.seek(0)
599 reader = dctx.stream_reader(source)
600 self.assertEqual(reader.read(128), b'foo')
601 self.assertEqual(reader.read(128), b'bar')
602
603 # Now tests for reads spanning frames.
604 reader = dctx.stream_reader(source.getvalue(), read_across_frames=True)
605 self.assertEqual(reader.read(3), b'foo')
606 self.assertEqual(reader.read(3), b'bar')
607
608 source.seek(0)
609 reader = dctx.stream_reader(source, read_across_frames=True)
610 self.assertEqual(reader.read(3), b'foo')
611 self.assertEqual(reader.read(3), b'bar')
612
613 reader = dctx.stream_reader(source.getvalue(), read_across_frames=True)
614 self.assertEqual(reader.read(6), b'foobar')
615
616 source.seek(0)
617 reader = dctx.stream_reader(source, read_across_frames=True)
618 self.assertEqual(reader.read(6), b'foobar')
619
620 reader = dctx.stream_reader(source.getvalue(), read_across_frames=True)
621 self.assertEqual(reader.read(7), b'foobar')
622
623 source.seek(0)
624 reader = dctx.stream_reader(source, read_across_frames=True)
625 self.assertEqual(reader.read(7), b'foobar')
626
627 reader = dctx.stream_reader(source.getvalue(), read_across_frames=True)
628 self.assertEqual(reader.read(128), b'foobar')
629
630 source.seek(0)
631 reader = dctx.stream_reader(source, read_across_frames=True)
632 self.assertEqual(reader.read(128), b'foobar')
633
634 def test_readinto(self):
635 cctx = zstd.ZstdCompressor()
636 foo = cctx.compress(b'foo')
637
638 dctx = zstd.ZstdDecompressor()
639
640 # Attempting to readinto() a non-writable buffer fails.
641 # The exact exception varies based on the backend.
642 reader = dctx.stream_reader(foo)
643 with self.assertRaises(Exception):
644 reader.readinto(b'foobar')
645
646 # readinto() with sufficiently large destination.
647 b = bytearray(1024)
648 reader = dctx.stream_reader(foo)
649 self.assertEqual(reader.readinto(b), 3)
650 self.assertEqual(b[0:3], b'foo')
651 self.assertEqual(reader.readinto(b), 0)
652 self.assertEqual(b[0:3], b'foo')
653
654 # readinto() with small reads.
655 b = bytearray(1024)
656 reader = dctx.stream_reader(foo, read_size=1)
657 self.assertEqual(reader.readinto(b), 3)
658 self.assertEqual(b[0:3], b'foo')
659
660 # Too small destination buffer.
661 b = bytearray(2)
662 reader = dctx.stream_reader(foo)
663 self.assertEqual(reader.readinto(b), 2)
664 self.assertEqual(b[:], b'fo')
665
666 def test_readinto1(self):
667 cctx = zstd.ZstdCompressor()
668 foo = cctx.compress(b'foo')
669
670 dctx = zstd.ZstdDecompressor()
671
672 reader = dctx.stream_reader(foo)
673 with self.assertRaises(Exception):
674 reader.readinto1(b'foobar')
675
676 # Sufficiently large destination.
677 b = bytearray(1024)
678 reader = dctx.stream_reader(foo)
679 self.assertEqual(reader.readinto1(b), 3)
680 self.assertEqual(b[0:3], b'foo')
681 self.assertEqual(reader.readinto1(b), 0)
682 self.assertEqual(b[0:3], b'foo')
683
684 # readinto() with small reads.
685 b = bytearray(1024)
686 reader = dctx.stream_reader(foo, read_size=1)
687 self.assertEqual(reader.readinto1(b), 3)
688 self.assertEqual(b[0:3], b'foo')
689
690 # Too small destination buffer.
691 b = bytearray(2)
692 reader = dctx.stream_reader(foo)
693 self.assertEqual(reader.readinto1(b), 2)
694 self.assertEqual(b[:], b'fo')
695
696 def test_readall(self):
697 cctx = zstd.ZstdCompressor()
698 foo = cctx.compress(b'foo')
699
700 dctx = zstd.ZstdDecompressor()
701 reader = dctx.stream_reader(foo)
702
703 self.assertEqual(reader.readall(), b'foo')
704
705 def test_read1(self):
706 cctx = zstd.ZstdCompressor()
707 foo = cctx.compress(b'foo')
708
709 dctx = zstd.ZstdDecompressor()
710
711 b = OpCountingBytesIO(foo)
712 reader = dctx.stream_reader(b)
713
714 self.assertEqual(reader.read1(), b'foo')
715 self.assertEqual(b._read_count, 1)
716
717 b = OpCountingBytesIO(foo)
718 reader = dctx.stream_reader(b)
719
720 self.assertEqual(reader.read1(0), b'')
721 self.assertEqual(reader.read1(2), b'fo')
722 self.assertEqual(b._read_count, 1)
723 self.assertEqual(reader.read1(1), b'o')
724 self.assertEqual(b._read_count, 1)
725 self.assertEqual(reader.read1(1), b'')
726 self.assertEqual(b._read_count, 2)
727
728 def test_read_lines(self):
729 cctx = zstd.ZstdCompressor()
730 source = b'\n'.join(('line %d' % i).encode('ascii') for i in range(1024))
731
732 frame = cctx.compress(source)
733
734 dctx = zstd.ZstdDecompressor()
735 reader = dctx.stream_reader(frame)
736 tr = io.TextIOWrapper(reader, encoding='utf-8')
737
738 lines = []
739 for line in tr:
740 lines.append(line.encode('utf-8'))
741
742 self.assertEqual(len(lines), 1024)
743 self.assertEqual(b''.join(lines), source)
744
745 reader = dctx.stream_reader(frame)
746 tr = io.TextIOWrapper(reader, encoding='utf-8')
747
748 lines = tr.readlines()
749 self.assertEqual(len(lines), 1024)
750 self.assertEqual(''.join(lines).encode('utf-8'), source)
751
752 reader = dctx.stream_reader(frame)
753 tr = io.TextIOWrapper(reader, encoding='utf-8')
754
755 lines = []
756 while True:
757 line = tr.readline()
758 if not line:
759 break
760
761 lines.append(line.encode('utf-8'))
762
763 self.assertEqual(len(lines), 1024)
764 self.assertEqual(b''.join(lines), source)
533 765
534 766
535 @make_cffi 767 @make_cffi
536 class TestDecompressor_decompressobj(unittest.TestCase): 768 class TestDecompressor_decompressobj(unittest.TestCase):
537 def test_simple(self): 769 def test_simple(self):
538 data = zstd.ZstdCompressor(level=1).compress(b'foobar') 770 data = zstd.ZstdCompressor(level=1).compress(b'foobar')
539 771
540 dctx = zstd.ZstdDecompressor() 772 dctx = zstd.ZstdDecompressor()
541 dobj = dctx.decompressobj() 773 dobj = dctx.decompressobj()
542 self.assertEqual(dobj.decompress(data), b'foobar') 774 self.assertEqual(dobj.decompress(data), b'foobar')
775 self.assertIsNone(dobj.flush())
776 self.assertIsNone(dobj.flush(10))
777 self.assertIsNone(dobj.flush(length=100))
543 778
544 def test_input_types(self): 779 def test_input_types(self):
545 compressed = zstd.ZstdCompressor(level=1).compress(b'foo') 780 compressed = zstd.ZstdCompressor(level=1).compress(b'foo')
546 781
547 dctx = zstd.ZstdDecompressor() 782 dctx = zstd.ZstdDecompressor()
555 mutable_array, 790 mutable_array,
556 ] 791 ]
557 792
558 for source in sources: 793 for source in sources:
559 dobj = dctx.decompressobj() 794 dobj = dctx.decompressobj()
795 self.assertIsNone(dobj.flush())
796 self.assertIsNone(dobj.flush(10))
797 self.assertIsNone(dobj.flush(length=100))
560 self.assertEqual(dobj.decompress(source), b'foo') 798 self.assertEqual(dobj.decompress(source), b'foo')
799 self.assertIsNone(dobj.flush())
561 800
562 def test_reuse(self): 801 def test_reuse(self):
563 data = zstd.ZstdCompressor(level=1).compress(b'foobar') 802 data = zstd.ZstdCompressor(level=1).compress(b'foobar')
564 803
565 dctx = zstd.ZstdDecompressor() 804 dctx = zstd.ZstdDecompressor()
566 dobj = dctx.decompressobj() 805 dobj = dctx.decompressobj()
567 dobj.decompress(data) 806 dobj.decompress(data)
568 807
569 with self.assertRaisesRegexp(zstd.ZstdError, 'cannot use a decompressobj'): 808 with self.assertRaisesRegexp(zstd.ZstdError, 'cannot use a decompressobj'):
570 dobj.decompress(data) 809 dobj.decompress(data)
810 self.assertIsNone(dobj.flush())
571 811
572 def test_bad_write_size(self): 812 def test_bad_write_size(self):
573 dctx = zstd.ZstdDecompressor() 813 dctx = zstd.ZstdDecompressor()
574 814
575 with self.assertRaisesRegexp(ValueError, 'write_size must be positive'): 815 with self.assertRaisesRegexp(ValueError, 'write_size must be positive'):
583 823
584 for i in range(128): 824 for i in range(128):
585 dobj = dctx.decompressobj(write_size=i + 1) 825 dobj = dctx.decompressobj(write_size=i + 1)
586 self.assertEqual(dobj.decompress(data), source) 826 self.assertEqual(dobj.decompress(data), source)
587 827
828
588 def decompress_via_writer(data): 829 def decompress_via_writer(data):
589 buffer = io.BytesIO() 830 buffer = io.BytesIO()
590 dctx = zstd.ZstdDecompressor() 831 dctx = zstd.ZstdDecompressor()
591 with dctx.stream_writer(buffer) as decompressor: 832 decompressor = dctx.stream_writer(buffer)
592 decompressor.write(data) 833 decompressor.write(data)
834
593 return buffer.getvalue() 835 return buffer.getvalue()
594 836
595 837
596 @make_cffi 838 @make_cffi
597 class TestDecompressor_stream_writer(unittest.TestCase): 839 class TestDecompressor_stream_writer(unittest.TestCase):
840 def test_io_api(self):
841 buffer = io.BytesIO()
842 dctx = zstd.ZstdDecompressor()
843 writer = dctx.stream_writer(buffer)
844
845 self.assertFalse(writer.closed)
846 self.assertFalse(writer.isatty())
847 self.assertFalse(writer.readable())
848
849 with self.assertRaises(io.UnsupportedOperation):
850 writer.readline()
851
852 with self.assertRaises(io.UnsupportedOperation):
853 writer.readline(42)
854
855 with self.assertRaises(io.UnsupportedOperation):
856 writer.readline(size=42)
857
858 with self.assertRaises(io.UnsupportedOperation):
859 writer.readlines()
860
861 with self.assertRaises(io.UnsupportedOperation):
862 writer.readlines(42)
863
864 with self.assertRaises(io.UnsupportedOperation):
865 writer.readlines(hint=42)
866
867 with self.assertRaises(io.UnsupportedOperation):
868 writer.seek(0)
869
870 with self.assertRaises(io.UnsupportedOperation):
871 writer.seek(10, os.SEEK_SET)
872
873 self.assertFalse(writer.seekable())
874
875 with self.assertRaises(io.UnsupportedOperation):
876 writer.tell()
877
878 with self.assertRaises(io.UnsupportedOperation):
879 writer.truncate()
880
881 with self.assertRaises(io.UnsupportedOperation):
882 writer.truncate(42)
883
884 with self.assertRaises(io.UnsupportedOperation):
885 writer.truncate(size=42)
886
887 self.assertTrue(writer.writable())
888
889 with self.assertRaises(io.UnsupportedOperation):
890 writer.writelines([])
891
892 with self.assertRaises(io.UnsupportedOperation):
893 writer.read()
894
895 with self.assertRaises(io.UnsupportedOperation):
896 writer.read(42)
897
898 with self.assertRaises(io.UnsupportedOperation):
899 writer.read(size=42)
900
901 with self.assertRaises(io.UnsupportedOperation):
902 writer.readall()
903
904 with self.assertRaises(io.UnsupportedOperation):
905 writer.readinto(None)
906
907 with self.assertRaises(io.UnsupportedOperation):
908 writer.fileno()
909
910 def test_fileno_file(self):
911 with tempfile.TemporaryFile('wb') as tf:
912 dctx = zstd.ZstdDecompressor()
913 writer = dctx.stream_writer(tf)
914
915 self.assertEqual(writer.fileno(), tf.fileno())
916
917 def test_close(self):
918 foo = zstd.ZstdCompressor().compress(b'foo')
919
920 buffer = NonClosingBytesIO()
921 dctx = zstd.ZstdDecompressor()
922 writer = dctx.stream_writer(buffer)
923
924 writer.write(foo)
925 self.assertFalse(writer.closed)
926 self.assertFalse(buffer.closed)
927 writer.close()
928 self.assertTrue(writer.closed)
929 self.assertTrue(buffer.closed)
930
931 with self.assertRaisesRegexp(ValueError, 'stream is closed'):
932 writer.write(b'')
933
934 with self.assertRaisesRegexp(ValueError, 'stream is closed'):
935 writer.flush()
936
937 with self.assertRaisesRegexp(ValueError, 'stream is closed'):
938 with writer:
939 pass
940
941 self.assertEqual(buffer.getvalue(), b'foo')
942
943 # Context manager exit should close stream.
944 buffer = NonClosingBytesIO()
945 writer = dctx.stream_writer(buffer)
946
947 with writer:
948 writer.write(foo)
949
950 self.assertTrue(writer.closed)
951 self.assertEqual(buffer.getvalue(), b'foo')
952
953 def test_flush(self):
954 buffer = OpCountingBytesIO()
955 dctx = zstd.ZstdDecompressor()
956 writer = dctx.stream_writer(buffer)
957
958 writer.flush()
959 self.assertEqual(buffer._flush_count, 1)
960 writer.flush()
961 self.assertEqual(buffer._flush_count, 2)
962
598 def test_empty_roundtrip(self): 963 def test_empty_roundtrip(self):
599 cctx = zstd.ZstdCompressor() 964 cctx = zstd.ZstdCompressor()
600 empty = cctx.compress(b'') 965 empty = cctx.compress(b'')
601 self.assertEqual(decompress_via_writer(empty), b'') 966 self.assertEqual(decompress_via_writer(empty), b'')
602 967
614 ] 979 ]
615 980
616 dctx = zstd.ZstdDecompressor() 981 dctx = zstd.ZstdDecompressor()
617 for source in sources: 982 for source in sources:
618 buffer = io.BytesIO() 983 buffer = io.BytesIO()
984
985 decompressor = dctx.stream_writer(buffer)
986 decompressor.write(source)
987 self.assertEqual(buffer.getvalue(), b'foo')
988
989 buffer = NonClosingBytesIO()
990
619 with dctx.stream_writer(buffer) as decompressor: 991 with dctx.stream_writer(buffer) as decompressor:
620 decompressor.write(source) 992 self.assertEqual(decompressor.write(source), 3)
621 993
994 self.assertEqual(buffer.getvalue(), b'foo')
995
996 buffer = io.BytesIO()
997 writer = dctx.stream_writer(buffer, write_return_read=True)
998 self.assertEqual(writer.write(source), len(source))
622 self.assertEqual(buffer.getvalue(), b'foo') 999 self.assertEqual(buffer.getvalue(), b'foo')
623 1000
624 def test_large_roundtrip(self): 1001 def test_large_roundtrip(self):
625 chunks = [] 1002 chunks = []
626 for i in range(255): 1003 for i in range(255):
639 1016
640 orig = b''.join(chunks) 1017 orig = b''.join(chunks)
641 cctx = zstd.ZstdCompressor() 1018 cctx = zstd.ZstdCompressor()
642 compressed = cctx.compress(orig) 1019 compressed = cctx.compress(orig)
643 1020
644 buffer = io.BytesIO() 1021 buffer = NonClosingBytesIO()
645 dctx = zstd.ZstdDecompressor() 1022 dctx = zstd.ZstdDecompressor()
646 with dctx.stream_writer(buffer) as decompressor: 1023 with dctx.stream_writer(buffer) as decompressor:
647 pos = 0 1024 pos = 0
648 while pos < len(compressed): 1025 while pos < len(compressed):
649 pos2 = pos + 8192 1026 pos2 = pos + 8192
650 decompressor.write(compressed[pos:pos2]) 1027 decompressor.write(compressed[pos:pos2])
651 pos += 8192 1028 pos += 8192
652 self.assertEqual(buffer.getvalue(), orig) 1029 self.assertEqual(buffer.getvalue(), orig)
653 1030
1031 # Again with write_return_read=True
1032 buffer = io.BytesIO()
1033 writer = dctx.stream_writer(buffer, write_return_read=True)
1034 pos = 0
1035 while pos < len(compressed):
1036 pos2 = pos + 8192
1037 chunk = compressed[pos:pos2]
1038 self.assertEqual(writer.write(chunk), len(chunk))
1039 pos += 8192
1040 self.assertEqual(buffer.getvalue(), orig)
1041
654 def test_dictionary(self): 1042 def test_dictionary(self):
655 samples = [] 1043 samples = []
656 for i in range(128): 1044 for i in range(128):
657 samples.append(b'foo' * 64) 1045 samples.append(b'foo' * 64)
658 samples.append(b'bar' * 64) 1046 samples.append(b'bar' * 64)
659 samples.append(b'foobar' * 64) 1047 samples.append(b'foobar' * 64)
660 1048
661 d = zstd.train_dictionary(8192, samples) 1049 d = zstd.train_dictionary(8192, samples)
662 1050
663 orig = b'foobar' * 16384 1051 orig = b'foobar' * 16384
664 buffer = io.BytesIO() 1052 buffer = NonClosingBytesIO()
665 cctx = zstd.ZstdCompressor(dict_data=d) 1053 cctx = zstd.ZstdCompressor(dict_data=d)
666 with cctx.stream_writer(buffer) as compressor: 1054 with cctx.stream_writer(buffer) as compressor:
667 self.assertEqual(compressor.write(orig), 0) 1055 self.assertEqual(compressor.write(orig), 0)
668 1056
669 compressed = buffer.getvalue() 1057 compressed = buffer.getvalue()
670 buffer = io.BytesIO() 1058 buffer = io.BytesIO()
671 1059
672 dctx = zstd.ZstdDecompressor(dict_data=d) 1060 dctx = zstd.ZstdDecompressor(dict_data=d)
1061 decompressor = dctx.stream_writer(buffer)
1062 self.assertEqual(decompressor.write(compressed), len(orig))
1063 self.assertEqual(buffer.getvalue(), orig)
1064
1065 buffer = NonClosingBytesIO()
1066
673 with dctx.stream_writer(buffer) as decompressor: 1067 with dctx.stream_writer(buffer) as decompressor:
674 self.assertEqual(decompressor.write(compressed), len(orig)) 1068 self.assertEqual(decompressor.write(compressed), len(orig))
675 1069
676 self.assertEqual(buffer.getvalue(), orig) 1070 self.assertEqual(buffer.getvalue(), orig)
677 1071
678 def test_memory_size(self): 1072 def test_memory_size(self):
679 dctx = zstd.ZstdDecompressor() 1073 dctx = zstd.ZstdDecompressor()
680 buffer = io.BytesIO() 1074 buffer = io.BytesIO()
1075
1076 decompressor = dctx.stream_writer(buffer)
1077 size = decompressor.memory_size()
1078 self.assertGreater(size, 100000)
1079
681 with dctx.stream_writer(buffer) as decompressor: 1080 with dctx.stream_writer(buffer) as decompressor:
682 size = decompressor.memory_size() 1081 size = decompressor.memory_size()
683 1082
684 self.assertGreater(size, 100000) 1083 self.assertGreater(size, 100000)
685 1084
808 self.assertEqual(decompressed, source.getvalue()) 1207 self.assertEqual(decompressed, source.getvalue())
809 1208
810 @unittest.skipUnless('ZSTD_SLOW_TESTS' in os.environ, 'ZSTD_SLOW_TESTS not set') 1209 @unittest.skipUnless('ZSTD_SLOW_TESTS' in os.environ, 'ZSTD_SLOW_TESTS not set')
811 def test_large_input(self): 1210 def test_large_input(self):
812 bytes = list(struct.Struct('>B').pack(i) for i in range(256)) 1211 bytes = list(struct.Struct('>B').pack(i) for i in range(256))
813 compressed = io.BytesIO() 1212 compressed = NonClosingBytesIO()
814 input_size = 0 1213 input_size = 0
815 cctx = zstd.ZstdCompressor(level=1) 1214 cctx = zstd.ZstdCompressor(level=1)
816 with cctx.stream_writer(compressed) as compressor: 1215 with cctx.stream_writer(compressed) as compressor:
817 while True: 1216 while True:
818 compressor.write(random.choice(bytes)) 1217 compressor.write(random.choice(bytes))
821 have_compressed = len(compressed.getvalue()) > zstd.DECOMPRESSION_RECOMMENDED_INPUT_SIZE 1220 have_compressed = len(compressed.getvalue()) > zstd.DECOMPRESSION_RECOMMENDED_INPUT_SIZE
822 have_raw = input_size > zstd.DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE * 2 1221 have_raw = input_size > zstd.DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE * 2
823 if have_compressed and have_raw: 1222 if have_compressed and have_raw:
824 break 1223 break
825 1224
826 compressed.seek(0) 1225 compressed = io.BytesIO(compressed.getvalue())
827 self.assertGreater(len(compressed.getvalue()), 1226 self.assertGreater(len(compressed.getvalue()),
828 zstd.DECOMPRESSION_RECOMMENDED_INPUT_SIZE) 1227 zstd.DECOMPRESSION_RECOMMENDED_INPUT_SIZE)
829 1228
830 dctx = zstd.ZstdDecompressor() 1229 dctx = zstd.ZstdDecompressor()
831 it = dctx.read_to_iter(compressed) 1230 it = dctx.read_to_iter(compressed)
859 # Found this edge case via fuzzing. 1258 # Found this edge case via fuzzing.
860 cctx = zstd.ZstdCompressor(level=1) 1259 cctx = zstd.ZstdCompressor(level=1)
861 1260
862 source = io.BytesIO() 1261 source = io.BytesIO()
863 1262
864 compressed = io.BytesIO() 1263 compressed = NonClosingBytesIO()
865 with cctx.stream_writer(compressed) as compressor: 1264 with cctx.stream_writer(compressed) as compressor:
866 for i in range(256): 1265 for i in range(256):
867 chunk = b'\0' * 1024 1266 chunk = b'\0' * 1024
868 compressor.write(chunk) 1267 compressor.write(chunk)
869 source.write(chunk) 1268 source.write(chunk)
872 1271
873 simple = dctx.decompress(compressed.getvalue(), 1272 simple = dctx.decompress(compressed.getvalue(),
874 max_output_size=len(source.getvalue())) 1273 max_output_size=len(source.getvalue()))
875 self.assertEqual(simple, source.getvalue()) 1274 self.assertEqual(simple, source.getvalue())
876 1275
877 compressed.seek(0) 1276 compressed = io.BytesIO(compressed.getvalue())
878 streamed = b''.join(dctx.read_to_iter(compressed)) 1277 streamed = b''.join(dctx.read_to_iter(compressed))
879 self.assertEqual(streamed, source.getvalue()) 1278 self.assertEqual(streamed, source.getvalue())
880 1279
881 def test_read_write_size(self): 1280 def test_read_write_size(self):
882 source = OpCountingBytesIO(zstd.ZstdCompressor().compress(b'foobarfoobar')) 1281 source = OpCountingBytesIO(zstd.ZstdCompressor().compress(b'foobarfoobar'))
999 # TODO enable for CFFI 1398 # TODO enable for CFFI
1000 class TestDecompressor_multi_decompress_to_buffer(unittest.TestCase): 1399 class TestDecompressor_multi_decompress_to_buffer(unittest.TestCase):
1001 def test_invalid_inputs(self): 1400 def test_invalid_inputs(self):
1002 dctx = zstd.ZstdDecompressor() 1401 dctx = zstd.ZstdDecompressor()
1003 1402
1403 if not hasattr(dctx, 'multi_decompress_to_buffer'):
1404 self.skipTest('multi_decompress_to_buffer not available')
1405
1004 with self.assertRaises(TypeError): 1406 with self.assertRaises(TypeError):
1005 dctx.multi_decompress_to_buffer(True) 1407 dctx.multi_decompress_to_buffer(True)
1006 1408
1007 with self.assertRaises(TypeError): 1409 with self.assertRaises(TypeError):
1008 dctx.multi_decompress_to_buffer((1, 2)) 1410 dctx.multi_decompress_to_buffer((1, 2))
1018 1420
1019 original = [b'foo' * 4, b'bar' * 6] 1421 original = [b'foo' * 4, b'bar' * 6]
1020 frames = [cctx.compress(d) for d in original] 1422 frames = [cctx.compress(d) for d in original]
1021 1423
1022 dctx = zstd.ZstdDecompressor() 1424 dctx = zstd.ZstdDecompressor()
1425
1426 if not hasattr(dctx, 'multi_decompress_to_buffer'):
1427 self.skipTest('multi_decompress_to_buffer not available')
1428
1023 result = dctx.multi_decompress_to_buffer(frames) 1429 result = dctx.multi_decompress_to_buffer(frames)
1024 1430
1025 self.assertEqual(len(result), len(frames)) 1431 self.assertEqual(len(result), len(frames))
1026 self.assertEqual(result.size(), sum(map(len, original))) 1432 self.assertEqual(result.size(), sum(map(len, original)))
1027 1433
1039 original = [b'foo' * 4, b'bar' * 6, b'baz' * 8] 1445 original = [b'foo' * 4, b'bar' * 6, b'baz' * 8]
1040 frames = [cctx.compress(d) for d in original] 1446 frames = [cctx.compress(d) for d in original]
1041 sizes = struct.pack('=' + 'Q' * len(original), *map(len, original)) 1447 sizes = struct.pack('=' + 'Q' * len(original), *map(len, original))
1042 1448
1043 dctx = zstd.ZstdDecompressor() 1449 dctx = zstd.ZstdDecompressor()
1450
1451 if not hasattr(dctx, 'multi_decompress_to_buffer'):
1452 self.skipTest('multi_decompress_to_buffer not available')
1453
1044 result = dctx.multi_decompress_to_buffer(frames, decompressed_sizes=sizes) 1454 result = dctx.multi_decompress_to_buffer(frames, decompressed_sizes=sizes)
1045 1455
1046 self.assertEqual(len(result), len(frames)) 1456 self.assertEqual(len(result), len(frames))
1047 self.assertEqual(result.size(), sum(map(len, original))) 1457 self.assertEqual(result.size(), sum(map(len, original)))
1048 1458
1054 1464
1055 original = [b'foo' * 4, b'bar' * 6] 1465 original = [b'foo' * 4, b'bar' * 6]
1056 frames = [cctx.compress(d) for d in original] 1466 frames = [cctx.compress(d) for d in original]
1057 1467
1058 dctx = zstd.ZstdDecompressor() 1468 dctx = zstd.ZstdDecompressor()
1469
1470 if not hasattr(dctx, 'multi_decompress_to_buffer'):
1471 self.skipTest('multi_decompress_to_buffer not available')
1059 1472
1060 segments = struct.pack('=QQQQ', 0, len(frames[0]), len(frames[0]), len(frames[1])) 1473 segments = struct.pack('=QQQQ', 0, len(frames[0]), len(frames[0]), len(frames[1]))
1061 b = zstd.BufferWithSegments(b''.join(frames), segments) 1474 b = zstd.BufferWithSegments(b''.join(frames), segments)
1062 1475
1063 result = dctx.multi_decompress_to_buffer(b) 1476 result = dctx.multi_decompress_to_buffer(b)
1072 cctx = zstd.ZstdCompressor(write_content_size=False) 1485 cctx = zstd.ZstdCompressor(write_content_size=False)
1073 original = [b'foo' * 4, b'bar' * 6, b'baz' * 8] 1486 original = [b'foo' * 4, b'bar' * 6, b'baz' * 8]
1074 frames = [cctx.compress(d) for d in original] 1487 frames = [cctx.compress(d) for d in original]
1075 sizes = struct.pack('=' + 'Q' * len(original), *map(len, original)) 1488 sizes = struct.pack('=' + 'Q' * len(original), *map(len, original))
1076 1489
1490 dctx = zstd.ZstdDecompressor()
1491
1492 if not hasattr(dctx, 'multi_decompress_to_buffer'):
1493 self.skipTest('multi_decompress_to_buffer not available')
1494
1077 segments = struct.pack('=QQQQQQ', 0, len(frames[0]), 1495 segments = struct.pack('=QQQQQQ', 0, len(frames[0]),
1078 len(frames[0]), len(frames[1]), 1496 len(frames[0]), len(frames[1]),
1079 len(frames[0]) + len(frames[1]), len(frames[2])) 1497 len(frames[0]) + len(frames[1]), len(frames[2]))
1080 b = zstd.BufferWithSegments(b''.join(frames), segments) 1498 b = zstd.BufferWithSegments(b''.join(frames), segments)
1081 1499
1082 dctx = zstd.ZstdDecompressor()
1083 result = dctx.multi_decompress_to_buffer(b, decompressed_sizes=sizes) 1500 result = dctx.multi_decompress_to_buffer(b, decompressed_sizes=sizes)
1084 1501
1085 self.assertEqual(len(result), len(frames)) 1502 self.assertEqual(len(result), len(frames))
1086 self.assertEqual(result.size(), sum(map(len, original))) 1503 self.assertEqual(result.size(), sum(map(len, original)))
1087 1504
1097 b'foo2' * 4, 1514 b'foo2' * 4,
1098 b'foo3' * 5, 1515 b'foo3' * 5,
1099 b'foo4' * 6, 1516 b'foo4' * 6,
1100 ] 1517 ]
1101 1518
1519 if not hasattr(cctx, 'multi_compress_to_buffer'):
1520 self.skipTest('multi_compress_to_buffer not available')
1521
1102 frames = cctx.multi_compress_to_buffer(original) 1522 frames = cctx.multi_compress_to_buffer(original)
1103 1523
1104 # Check round trip. 1524 # Check round trip.
1105 dctx = zstd.ZstdDecompressor() 1525 dctx = zstd.ZstdDecompressor()
1526
1106 decompressed = dctx.multi_decompress_to_buffer(frames, threads=3) 1527 decompressed = dctx.multi_decompress_to_buffer(frames, threads=3)
1107 1528
1108 self.assertEqual(len(decompressed), len(original)) 1529 self.assertEqual(len(decompressed), len(original))
1109 1530
1110 for i, data in enumerate(original): 1531 for i, data in enumerate(original):
1136 1557
1137 cctx = zstd.ZstdCompressor(dict_data=d, level=1) 1558 cctx = zstd.ZstdCompressor(dict_data=d, level=1)
1138 frames = [cctx.compress(s) for s in generate_samples()] 1559 frames = [cctx.compress(s) for s in generate_samples()]
1139 1560
1140 dctx = zstd.ZstdDecompressor(dict_data=d) 1561 dctx = zstd.ZstdDecompressor(dict_data=d)
1562
1563 if not hasattr(dctx, 'multi_decompress_to_buffer'):
1564 self.skipTest('multi_decompress_to_buffer not available')
1565
1141 result = dctx.multi_decompress_to_buffer(frames) 1566 result = dctx.multi_decompress_to_buffer(frames)
1567
1142 self.assertEqual([o.tobytes() for o in result], generate_samples()) 1568 self.assertEqual([o.tobytes() for o in result], generate_samples())
1143 1569
1144 def test_multiple_threads(self): 1570 def test_multiple_threads(self):
1145 cctx = zstd.ZstdCompressor() 1571 cctx = zstd.ZstdCompressor()
1146 1572
1147 frames = [] 1573 frames = []
1148 frames.extend(cctx.compress(b'x' * 64) for i in range(256)) 1574 frames.extend(cctx.compress(b'x' * 64) for i in range(256))
1149 frames.extend(cctx.compress(b'y' * 64) for i in range(256)) 1575 frames.extend(cctx.compress(b'y' * 64) for i in range(256))
1150 1576
1151 dctx = zstd.ZstdDecompressor() 1577 dctx = zstd.ZstdDecompressor()
1578
1579 if not hasattr(dctx, 'multi_decompress_to_buffer'):
1580 self.skipTest('multi_decompress_to_buffer not available')
1581
1152 result = dctx.multi_decompress_to_buffer(frames, threads=-1) 1582 result = dctx.multi_decompress_to_buffer(frames, threads=-1)
1153 1583
1154 self.assertEqual(len(result), len(frames)) 1584 self.assertEqual(len(result), len(frames))
1155 self.assertEqual(result.size(), 2 * 64 * 256) 1585 self.assertEqual(result.size(), 2 * 64 * 256)
1156 self.assertEqual(result[0].tobytes(), b'x' * 64) 1586 self.assertEqual(result[0].tobytes(), b'x' * 64)
1161 frames = [cctx.compress(b'x' * 128), cctx.compress(b'y' * 128)] 1591 frames = [cctx.compress(b'x' * 128), cctx.compress(b'y' * 128)]
1162 1592
1163 frames[1] = frames[1][0:15] + b'extra' + frames[1][15:] 1593 frames[1] = frames[1][0:15] + b'extra' + frames[1][15:]
1164 1594
1165 dctx = zstd.ZstdDecompressor() 1595 dctx = zstd.ZstdDecompressor()
1596
1597 if not hasattr(dctx, 'multi_decompress_to_buffer'):
1598 self.skipTest('multi_decompress_to_buffer not available')
1166 1599
1167 with self.assertRaisesRegexp(zstd.ZstdError, 1600 with self.assertRaisesRegexp(zstd.ZstdError,
1168 'error decompressing item 1: (' 1601 'error decompressing item 1: ('
1169 'Corrupted block|' 1602 'Corrupted block|'
1170 'Destination buffer is too small)'): 1603 'Destination buffer is too small)'):