contrib/python-zstandard/tests/test_decompressor.py
changeset 42070 675775c33ab6
parent 40121 73fef626dae3
child 43994 de7838053207
equal deleted inserted replaced
42069:668eff08387f 42070:675775c33ab6
     1 import io
     1 import io
     2 import os
     2 import os
     3 import random
     3 import random
     4 import struct
     4 import struct
     5 import sys
     5 import sys
       
     6 import tempfile
     6 import unittest
     7 import unittest
     7 
     8 
     8 import zstandard as zstd
     9 import zstandard as zstd
     9 
    10 
    10 from .common import (
    11 from .common import (
    11     generate_samples,
    12     generate_samples,
    12     make_cffi,
    13     make_cffi,
       
    14     NonClosingBytesIO,
    13     OpCountingBytesIO,
    15     OpCountingBytesIO,
    14 )
    16 )
    15 
    17 
    16 
    18 
    17 if sys.version_info[0] >= 3:
    19 if sys.version_info[0] >= 3:
   217         # If we write a content size, the decompressor engages single pass
   219         # If we write a content size, the decompressor engages single pass
   218         # mode and the window size doesn't come into play.
   220         # mode and the window size doesn't come into play.
   219         cctx = zstd.ZstdCompressor(write_content_size=False)
   221         cctx = zstd.ZstdCompressor(write_content_size=False)
   220         frame = cctx.compress(source)
   222         frame = cctx.compress(source)
   221 
   223 
   222         dctx = zstd.ZstdDecompressor(max_window_size=1)
   224         dctx = zstd.ZstdDecompressor(max_window_size=2**zstd.WINDOWLOG_MIN)
   223 
   225 
   224         with self.assertRaisesRegexp(
   226         with self.assertRaisesRegexp(
   225             zstd.ZstdError, 'decompression error: Frame requires too much memory'):
   227             zstd.ZstdError, 'decompression error: Frame requires too much memory'):
   226             dctx.decompress(frame, max_output_size=len(source))
   228             dctx.decompress(frame, max_output_size=len(source))
   227 
   229 
   300 
   302 
   301     def test_not_implemented(self):
   303     def test_not_implemented(self):
   302         dctx = zstd.ZstdDecompressor()
   304         dctx = zstd.ZstdDecompressor()
   303 
   305 
   304         with dctx.stream_reader(b'foo') as reader:
   306         with dctx.stream_reader(b'foo') as reader:
   305             with self.assertRaises(NotImplementedError):
   307             with self.assertRaises(io.UnsupportedOperation):
   306                 reader.readline()
   308                 reader.readline()
   307 
   309 
   308             with self.assertRaises(NotImplementedError):
   310             with self.assertRaises(io.UnsupportedOperation):
   309                 reader.readlines()
   311                 reader.readlines()
   310 
   312 
   311             with self.assertRaises(NotImplementedError):
   313             with self.assertRaises(io.UnsupportedOperation):
   312                 reader.readall()
       
   313 
       
   314             with self.assertRaises(NotImplementedError):
       
   315                 iter(reader)
   314                 iter(reader)
   316 
   315 
   317             with self.assertRaises(NotImplementedError):
   316             with self.assertRaises(io.UnsupportedOperation):
   318                 next(reader)
   317                 next(reader)
   319 
   318 
   320             with self.assertRaises(io.UnsupportedOperation):
   319             with self.assertRaises(io.UnsupportedOperation):
   321                 reader.write(b'foo')
   320                 reader.write(b'foo')
   322 
   321 
   345             reader.close()
   344             reader.close()
   346             self.assertTrue(reader.closed)
   345             self.assertTrue(reader.closed)
   347             with self.assertRaisesRegexp(ValueError, 'stream is closed'):
   346             with self.assertRaisesRegexp(ValueError, 'stream is closed'):
   348                 reader.read(1)
   347                 reader.read(1)
   349 
   348 
   350     def test_bad_read_size(self):
   349     def test_read_sizes(self):
   351         dctx = zstd.ZstdDecompressor()
   350         cctx = zstd.ZstdCompressor()
   352 
   351         foo = cctx.compress(b'foo')
   353         with dctx.stream_reader(b'foo') as reader:
   352 
   354             with self.assertRaisesRegexp(ValueError, 'cannot read negative or size 0 amounts'):
   353         dctx = zstd.ZstdDecompressor()
   355                 reader.read(-1)
   354 
   356 
   355         with dctx.stream_reader(foo) as reader:
   357             with self.assertRaisesRegexp(ValueError, 'cannot read negative or size 0 amounts'):
   356             with self.assertRaisesRegexp(ValueError, 'cannot read negative amounts less than -1'):
   358                 reader.read(0)
   357                 reader.read(-2)
       
   358 
       
   359             self.assertEqual(reader.read(0), b'')
       
   360             self.assertEqual(reader.read(), b'foo')
   359 
   361 
   360     def test_read_buffer(self):
   362     def test_read_buffer(self):
   361         cctx = zstd.ZstdCompressor()
   363         cctx = zstd.ZstdCompressor()
   362 
   364 
   363         source = b''.join([b'foo' * 60, b'bar' * 60, b'baz' * 60])
   365         source = b''.join([b'foo' * 60, b'bar' * 60, b'baz' * 60])
   522         dctx = zstd.ZstdDecompressor()
   524         dctx = zstd.ZstdDecompressor()
   523 
   525 
   524         reader = dctx.stream_reader(source)
   526         reader = dctx.stream_reader(source)
   525 
   527 
   526         with reader:
   528         with reader:
   527             with self.assertRaises(TypeError):
   529             reader.read(0)
   528                 reader.read()
       
   529 
   530 
   530         with reader:
   531         with reader:
   531             with self.assertRaisesRegexp(ValueError, 'stream is closed'):
   532             with self.assertRaisesRegexp(ValueError, 'stream is closed'):
   532                 reader.read(100)
   533                 reader.read(100)
       
   534 
       
   535     def test_partial_read(self):
       
   536         # Inspired by https://github.com/indygreg/python-zstandard/issues/71.
       
   537         buffer = io.BytesIO()
       
   538         cctx = zstd.ZstdCompressor()
       
   539         writer = cctx.stream_writer(buffer)
       
   540         writer.write(bytearray(os.urandom(1000000)))
       
   541         writer.flush(zstd.FLUSH_FRAME)
       
   542         buffer.seek(0)
       
   543 
       
   544         dctx = zstd.ZstdDecompressor()
       
   545         reader = dctx.stream_reader(buffer)
       
   546 
       
   547         while True:
       
   548             chunk = reader.read(8192)
       
   549             if not chunk:
       
   550                 break
       
   551 
       
   552     def test_read_multiple_frames(self):
       
   553         cctx = zstd.ZstdCompressor()
       
   554         source = io.BytesIO()
       
   555         writer = cctx.stream_writer(source)
       
   556         writer.write(b'foo')
       
   557         writer.flush(zstd.FLUSH_FRAME)
       
   558         writer.write(b'bar')
       
   559         writer.flush(zstd.FLUSH_FRAME)
       
   560 
       
   561         dctx = zstd.ZstdDecompressor()
       
   562 
       
   563         reader = dctx.stream_reader(source.getvalue())
       
   564         self.assertEqual(reader.read(2), b'fo')
       
   565         self.assertEqual(reader.read(2), b'o')
       
   566         self.assertEqual(reader.read(2), b'ba')
       
   567         self.assertEqual(reader.read(2), b'r')
       
   568 
       
   569         source.seek(0)
       
   570         reader = dctx.stream_reader(source)
       
   571         self.assertEqual(reader.read(2), b'fo')
       
   572         self.assertEqual(reader.read(2), b'o')
       
   573         self.assertEqual(reader.read(2), b'ba')
       
   574         self.assertEqual(reader.read(2), b'r')
       
   575 
       
   576         reader = dctx.stream_reader(source.getvalue())
       
   577         self.assertEqual(reader.read(3), b'foo')
       
   578         self.assertEqual(reader.read(3), b'bar')
       
   579 
       
   580         source.seek(0)
       
   581         reader = dctx.stream_reader(source)
       
   582         self.assertEqual(reader.read(3), b'foo')
       
   583         self.assertEqual(reader.read(3), b'bar')
       
   584 
       
   585         reader = dctx.stream_reader(source.getvalue())
       
   586         self.assertEqual(reader.read(4), b'foo')
       
   587         self.assertEqual(reader.read(4), b'bar')
       
   588 
       
   589         source.seek(0)
       
   590         reader = dctx.stream_reader(source)
       
   591         self.assertEqual(reader.read(4), b'foo')
       
   592         self.assertEqual(reader.read(4), b'bar')
       
   593 
       
   594         reader = dctx.stream_reader(source.getvalue())
       
   595         self.assertEqual(reader.read(128), b'foo')
       
   596         self.assertEqual(reader.read(128), b'bar')
       
   597 
       
   598         source.seek(0)
       
   599         reader = dctx.stream_reader(source)
       
   600         self.assertEqual(reader.read(128), b'foo')
       
   601         self.assertEqual(reader.read(128), b'bar')
       
   602 
       
   603         # Now tests for reads spanning frames.
       
   604         reader = dctx.stream_reader(source.getvalue(), read_across_frames=True)
       
   605         self.assertEqual(reader.read(3), b'foo')
       
   606         self.assertEqual(reader.read(3), b'bar')
       
   607 
       
   608         source.seek(0)
       
   609         reader = dctx.stream_reader(source, read_across_frames=True)
       
   610         self.assertEqual(reader.read(3), b'foo')
       
   611         self.assertEqual(reader.read(3), b'bar')
       
   612 
       
   613         reader = dctx.stream_reader(source.getvalue(), read_across_frames=True)
       
   614         self.assertEqual(reader.read(6), b'foobar')
       
   615 
       
   616         source.seek(0)
       
   617         reader = dctx.stream_reader(source, read_across_frames=True)
       
   618         self.assertEqual(reader.read(6), b'foobar')
       
   619 
       
   620         reader = dctx.stream_reader(source.getvalue(), read_across_frames=True)
       
   621         self.assertEqual(reader.read(7), b'foobar')
       
   622 
       
   623         source.seek(0)
       
   624         reader = dctx.stream_reader(source, read_across_frames=True)
       
   625         self.assertEqual(reader.read(7), b'foobar')
       
   626 
       
   627         reader = dctx.stream_reader(source.getvalue(), read_across_frames=True)
       
   628         self.assertEqual(reader.read(128), b'foobar')
       
   629 
       
   630         source.seek(0)
       
   631         reader = dctx.stream_reader(source, read_across_frames=True)
       
   632         self.assertEqual(reader.read(128), b'foobar')
       
   633 
       
   634     def test_readinto(self):
       
   635         cctx = zstd.ZstdCompressor()
       
   636         foo = cctx.compress(b'foo')
       
   637 
       
   638         dctx = zstd.ZstdDecompressor()
       
   639 
       
   640         # Attempting to readinto() a non-writable buffer fails.
       
   641         # The exact exception varies based on the backend.
       
   642         reader = dctx.stream_reader(foo)
       
   643         with self.assertRaises(Exception):
       
   644             reader.readinto(b'foobar')
       
   645 
       
   646         # readinto() with sufficiently large destination.
       
   647         b = bytearray(1024)
       
   648         reader = dctx.stream_reader(foo)
       
   649         self.assertEqual(reader.readinto(b), 3)
       
   650         self.assertEqual(b[0:3], b'foo')
       
   651         self.assertEqual(reader.readinto(b), 0)
       
   652         self.assertEqual(b[0:3], b'foo')
       
   653 
       
   654         # readinto() with small reads.
       
   655         b = bytearray(1024)
       
   656         reader = dctx.stream_reader(foo, read_size=1)
       
   657         self.assertEqual(reader.readinto(b), 3)
       
   658         self.assertEqual(b[0:3], b'foo')
       
   659 
       
   660         # Too small destination buffer.
       
   661         b = bytearray(2)
       
   662         reader = dctx.stream_reader(foo)
       
   663         self.assertEqual(reader.readinto(b), 2)
       
   664         self.assertEqual(b[:], b'fo')
       
   665 
       
   666     def test_readinto1(self):
       
   667         cctx = zstd.ZstdCompressor()
       
   668         foo = cctx.compress(b'foo')
       
   669 
       
   670         dctx = zstd.ZstdDecompressor()
       
   671 
       
   672         reader = dctx.stream_reader(foo)
       
   673         with self.assertRaises(Exception):
       
   674             reader.readinto1(b'foobar')
       
   675 
       
   676         # Sufficiently large destination.
       
   677         b = bytearray(1024)
       
   678         reader = dctx.stream_reader(foo)
       
   679         self.assertEqual(reader.readinto1(b), 3)
       
   680         self.assertEqual(b[0:3], b'foo')
       
   681         self.assertEqual(reader.readinto1(b), 0)
       
   682         self.assertEqual(b[0:3], b'foo')
       
   683 
       
   684         # readinto() with small reads.
       
   685         b = bytearray(1024)
       
   686         reader = dctx.stream_reader(foo, read_size=1)
       
   687         self.assertEqual(reader.readinto1(b), 3)
       
   688         self.assertEqual(b[0:3], b'foo')
       
   689 
       
   690         # Too small destination buffer.
       
   691         b = bytearray(2)
       
   692         reader = dctx.stream_reader(foo)
       
   693         self.assertEqual(reader.readinto1(b), 2)
       
   694         self.assertEqual(b[:], b'fo')
       
   695 
       
   696     def test_readall(self):
       
   697         cctx = zstd.ZstdCompressor()
       
   698         foo = cctx.compress(b'foo')
       
   699 
       
   700         dctx = zstd.ZstdDecompressor()
       
   701         reader = dctx.stream_reader(foo)
       
   702 
       
   703         self.assertEqual(reader.readall(), b'foo')
       
   704 
       
   705     def test_read1(self):
       
   706         cctx = zstd.ZstdCompressor()
       
   707         foo = cctx.compress(b'foo')
       
   708 
       
   709         dctx = zstd.ZstdDecompressor()
       
   710 
       
   711         b = OpCountingBytesIO(foo)
       
   712         reader = dctx.stream_reader(b)
       
   713 
       
   714         self.assertEqual(reader.read1(), b'foo')
       
   715         self.assertEqual(b._read_count, 1)
       
   716 
       
   717         b = OpCountingBytesIO(foo)
       
   718         reader = dctx.stream_reader(b)
       
   719 
       
   720         self.assertEqual(reader.read1(0), b'')
       
   721         self.assertEqual(reader.read1(2), b'fo')
       
   722         self.assertEqual(b._read_count, 1)
       
   723         self.assertEqual(reader.read1(1), b'o')
       
   724         self.assertEqual(b._read_count, 1)
       
   725         self.assertEqual(reader.read1(1), b'')
       
   726         self.assertEqual(b._read_count, 2)
       
   727 
       
   728     def test_read_lines(self):
       
   729         cctx = zstd.ZstdCompressor()
       
   730         source = b'\n'.join(('line %d' % i).encode('ascii') for i in range(1024))
       
   731 
       
   732         frame = cctx.compress(source)
       
   733 
       
   734         dctx = zstd.ZstdDecompressor()
       
   735         reader = dctx.stream_reader(frame)
       
   736         tr = io.TextIOWrapper(reader, encoding='utf-8')
       
   737 
       
   738         lines = []
       
   739         for line in tr:
       
   740             lines.append(line.encode('utf-8'))
       
   741 
       
   742         self.assertEqual(len(lines), 1024)
       
   743         self.assertEqual(b''.join(lines), source)
       
   744 
       
   745         reader = dctx.stream_reader(frame)
       
   746         tr = io.TextIOWrapper(reader, encoding='utf-8')
       
   747 
       
   748         lines = tr.readlines()
       
   749         self.assertEqual(len(lines), 1024)
       
   750         self.assertEqual(''.join(lines).encode('utf-8'), source)
       
   751 
       
   752         reader = dctx.stream_reader(frame)
       
   753         tr = io.TextIOWrapper(reader, encoding='utf-8')
       
   754 
       
   755         lines = []
       
   756         while True:
       
   757             line = tr.readline()
       
   758             if not line:
       
   759                 break
       
   760 
       
   761             lines.append(line.encode('utf-8'))
       
   762 
       
   763         self.assertEqual(len(lines), 1024)
       
   764         self.assertEqual(b''.join(lines), source)
   533 
   765 
   534 
   766 
   535 @make_cffi
   767 @make_cffi
   536 class TestDecompressor_decompressobj(unittest.TestCase):
   768 class TestDecompressor_decompressobj(unittest.TestCase):
   537     def test_simple(self):
   769     def test_simple(self):
   538         data = zstd.ZstdCompressor(level=1).compress(b'foobar')
   770         data = zstd.ZstdCompressor(level=1).compress(b'foobar')
   539 
   771 
   540         dctx = zstd.ZstdDecompressor()
   772         dctx = zstd.ZstdDecompressor()
   541         dobj = dctx.decompressobj()
   773         dobj = dctx.decompressobj()
   542         self.assertEqual(dobj.decompress(data), b'foobar')
   774         self.assertEqual(dobj.decompress(data), b'foobar')
       
   775         self.assertIsNone(dobj.flush())
       
   776         self.assertIsNone(dobj.flush(10))
       
   777         self.assertIsNone(dobj.flush(length=100))
   543 
   778 
   544     def test_input_types(self):
   779     def test_input_types(self):
   545         compressed = zstd.ZstdCompressor(level=1).compress(b'foo')
   780         compressed = zstd.ZstdCompressor(level=1).compress(b'foo')
   546 
   781 
   547         dctx = zstd.ZstdDecompressor()
   782         dctx = zstd.ZstdDecompressor()
   555             mutable_array,
   790             mutable_array,
   556         ]
   791         ]
   557 
   792 
   558         for source in sources:
   793         for source in sources:
   559             dobj = dctx.decompressobj()
   794             dobj = dctx.decompressobj()
       
   795             self.assertIsNone(dobj.flush())
       
   796             self.assertIsNone(dobj.flush(10))
       
   797             self.assertIsNone(dobj.flush(length=100))
   560             self.assertEqual(dobj.decompress(source), b'foo')
   798             self.assertEqual(dobj.decompress(source), b'foo')
       
   799             self.assertIsNone(dobj.flush())
   561 
   800 
   562     def test_reuse(self):
   801     def test_reuse(self):
   563         data = zstd.ZstdCompressor(level=1).compress(b'foobar')
   802         data = zstd.ZstdCompressor(level=1).compress(b'foobar')
   564 
   803 
   565         dctx = zstd.ZstdDecompressor()
   804         dctx = zstd.ZstdDecompressor()
   566         dobj = dctx.decompressobj()
   805         dobj = dctx.decompressobj()
   567         dobj.decompress(data)
   806         dobj.decompress(data)
   568 
   807 
   569         with self.assertRaisesRegexp(zstd.ZstdError, 'cannot use a decompressobj'):
   808         with self.assertRaisesRegexp(zstd.ZstdError, 'cannot use a decompressobj'):
   570             dobj.decompress(data)
   809             dobj.decompress(data)
       
   810             self.assertIsNone(dobj.flush())
   571 
   811 
   572     def test_bad_write_size(self):
   812     def test_bad_write_size(self):
   573         dctx = zstd.ZstdDecompressor()
   813         dctx = zstd.ZstdDecompressor()
   574 
   814 
   575         with self.assertRaisesRegexp(ValueError, 'write_size must be positive'):
   815         with self.assertRaisesRegexp(ValueError, 'write_size must be positive'):
   583 
   823 
   584         for i in range(128):
   824         for i in range(128):
   585             dobj = dctx.decompressobj(write_size=i + 1)
   825             dobj = dctx.decompressobj(write_size=i + 1)
   586             self.assertEqual(dobj.decompress(data), source)
   826             self.assertEqual(dobj.decompress(data), source)
   587 
   827 
       
   828 
   588 def decompress_via_writer(data):
   829 def decompress_via_writer(data):
   589     buffer = io.BytesIO()
   830     buffer = io.BytesIO()
   590     dctx = zstd.ZstdDecompressor()
   831     dctx = zstd.ZstdDecompressor()
   591     with dctx.stream_writer(buffer) as decompressor:
   832     decompressor = dctx.stream_writer(buffer)
   592         decompressor.write(data)
   833     decompressor.write(data)
       
   834 
   593     return buffer.getvalue()
   835     return buffer.getvalue()
   594 
   836 
   595 
   837 
   596 @make_cffi
   838 @make_cffi
   597 class TestDecompressor_stream_writer(unittest.TestCase):
   839 class TestDecompressor_stream_writer(unittest.TestCase):
       
   840     def test_io_api(self):
       
   841         buffer = io.BytesIO()
       
   842         dctx = zstd.ZstdDecompressor()
       
   843         writer = dctx.stream_writer(buffer)
       
   844 
       
   845         self.assertFalse(writer.closed)
       
   846         self.assertFalse(writer.isatty())
       
   847         self.assertFalse(writer.readable())
       
   848 
       
   849         with self.assertRaises(io.UnsupportedOperation):
       
   850             writer.readline()
       
   851 
       
   852         with self.assertRaises(io.UnsupportedOperation):
       
   853             writer.readline(42)
       
   854 
       
   855         with self.assertRaises(io.UnsupportedOperation):
       
   856             writer.readline(size=42)
       
   857 
       
   858         with self.assertRaises(io.UnsupportedOperation):
       
   859             writer.readlines()
       
   860 
       
   861         with self.assertRaises(io.UnsupportedOperation):
       
   862             writer.readlines(42)
       
   863 
       
   864         with self.assertRaises(io.UnsupportedOperation):
       
   865             writer.readlines(hint=42)
       
   866 
       
   867         with self.assertRaises(io.UnsupportedOperation):
       
   868             writer.seek(0)
       
   869 
       
   870         with self.assertRaises(io.UnsupportedOperation):
       
   871             writer.seek(10, os.SEEK_SET)
       
   872 
       
   873         self.assertFalse(writer.seekable())
       
   874 
       
   875         with self.assertRaises(io.UnsupportedOperation):
       
   876             writer.tell()
       
   877 
       
   878         with self.assertRaises(io.UnsupportedOperation):
       
   879             writer.truncate()
       
   880 
       
   881         with self.assertRaises(io.UnsupportedOperation):
       
   882             writer.truncate(42)
       
   883 
       
   884         with self.assertRaises(io.UnsupportedOperation):
       
   885             writer.truncate(size=42)
       
   886 
       
   887         self.assertTrue(writer.writable())
       
   888 
       
   889         with self.assertRaises(io.UnsupportedOperation):
       
   890             writer.writelines([])
       
   891 
       
   892         with self.assertRaises(io.UnsupportedOperation):
       
   893             writer.read()
       
   894 
       
   895         with self.assertRaises(io.UnsupportedOperation):
       
   896             writer.read(42)
       
   897 
       
   898         with self.assertRaises(io.UnsupportedOperation):
       
   899             writer.read(size=42)
       
   900 
       
   901         with self.assertRaises(io.UnsupportedOperation):
       
   902             writer.readall()
       
   903 
       
   904         with self.assertRaises(io.UnsupportedOperation):
       
   905             writer.readinto(None)
       
   906 
       
   907         with self.assertRaises(io.UnsupportedOperation):
       
   908             writer.fileno()
       
   909 
       
   910     def test_fileno_file(self):
       
   911         with tempfile.TemporaryFile('wb') as tf:
       
   912             dctx = zstd.ZstdDecompressor()
       
   913             writer = dctx.stream_writer(tf)
       
   914 
       
   915             self.assertEqual(writer.fileno(), tf.fileno())
       
   916 
       
   917     def test_close(self):
       
   918         foo = zstd.ZstdCompressor().compress(b'foo')
       
   919 
       
   920         buffer = NonClosingBytesIO()
       
   921         dctx = zstd.ZstdDecompressor()
       
   922         writer = dctx.stream_writer(buffer)
       
   923 
       
   924         writer.write(foo)
       
   925         self.assertFalse(writer.closed)
       
   926         self.assertFalse(buffer.closed)
       
   927         writer.close()
       
   928         self.assertTrue(writer.closed)
       
   929         self.assertTrue(buffer.closed)
       
   930 
       
   931         with self.assertRaisesRegexp(ValueError, 'stream is closed'):
       
   932             writer.write(b'')
       
   933 
       
   934         with self.assertRaisesRegexp(ValueError, 'stream is closed'):
       
   935             writer.flush()
       
   936 
       
   937         with self.assertRaisesRegexp(ValueError, 'stream is closed'):
       
   938             with writer:
       
   939                 pass
       
   940 
       
   941         self.assertEqual(buffer.getvalue(), b'foo')
       
   942 
       
   943         # Context manager exit should close stream.
       
   944         buffer = NonClosingBytesIO()
       
   945         writer = dctx.stream_writer(buffer)
       
   946 
       
   947         with writer:
       
   948             writer.write(foo)
       
   949 
       
   950         self.assertTrue(writer.closed)
       
   951         self.assertEqual(buffer.getvalue(), b'foo')
       
   952 
       
   953     def test_flush(self):
       
   954         buffer = OpCountingBytesIO()
       
   955         dctx = zstd.ZstdDecompressor()
       
   956         writer = dctx.stream_writer(buffer)
       
   957 
       
   958         writer.flush()
       
   959         self.assertEqual(buffer._flush_count, 1)
       
   960         writer.flush()
       
   961         self.assertEqual(buffer._flush_count, 2)
       
   962 
   598     def test_empty_roundtrip(self):
   963     def test_empty_roundtrip(self):
   599         cctx = zstd.ZstdCompressor()
   964         cctx = zstd.ZstdCompressor()
   600         empty = cctx.compress(b'')
   965         empty = cctx.compress(b'')
   601         self.assertEqual(decompress_via_writer(empty), b'')
   966         self.assertEqual(decompress_via_writer(empty), b'')
   602 
   967 
   614         ]
   979         ]
   615 
   980 
   616         dctx = zstd.ZstdDecompressor()
   981         dctx = zstd.ZstdDecompressor()
   617         for source in sources:
   982         for source in sources:
   618             buffer = io.BytesIO()
   983             buffer = io.BytesIO()
       
   984 
       
   985             decompressor = dctx.stream_writer(buffer)
       
   986             decompressor.write(source)
       
   987             self.assertEqual(buffer.getvalue(), b'foo')
       
   988 
       
   989             buffer = NonClosingBytesIO()
       
   990 
   619             with dctx.stream_writer(buffer) as decompressor:
   991             with dctx.stream_writer(buffer) as decompressor:
   620                 decompressor.write(source)
   992                 self.assertEqual(decompressor.write(source), 3)
   621 
   993 
       
   994             self.assertEqual(buffer.getvalue(), b'foo')
       
   995 
       
   996             buffer = io.BytesIO()
       
   997             writer = dctx.stream_writer(buffer, write_return_read=True)
       
   998             self.assertEqual(writer.write(source), len(source))
   622             self.assertEqual(buffer.getvalue(), b'foo')
   999             self.assertEqual(buffer.getvalue(), b'foo')
   623 
  1000 
   624     def test_large_roundtrip(self):
  1001     def test_large_roundtrip(self):
   625         chunks = []
  1002         chunks = []
   626         for i in range(255):
  1003         for i in range(255):
   639 
  1016 
   640         orig = b''.join(chunks)
  1017         orig = b''.join(chunks)
   641         cctx = zstd.ZstdCompressor()
  1018         cctx = zstd.ZstdCompressor()
   642         compressed = cctx.compress(orig)
  1019         compressed = cctx.compress(orig)
   643 
  1020 
   644         buffer = io.BytesIO()
  1021         buffer = NonClosingBytesIO()
   645         dctx = zstd.ZstdDecompressor()
  1022         dctx = zstd.ZstdDecompressor()
   646         with dctx.stream_writer(buffer) as decompressor:
  1023         with dctx.stream_writer(buffer) as decompressor:
   647             pos = 0
  1024             pos = 0
   648             while pos < len(compressed):
  1025             while pos < len(compressed):
   649                 pos2 = pos + 8192
  1026                 pos2 = pos + 8192
   650                 decompressor.write(compressed[pos:pos2])
  1027                 decompressor.write(compressed[pos:pos2])
   651                 pos += 8192
  1028                 pos += 8192
   652         self.assertEqual(buffer.getvalue(), orig)
  1029         self.assertEqual(buffer.getvalue(), orig)
   653 
  1030 
       
  1031         # Again with write_return_read=True
       
  1032         buffer = io.BytesIO()
       
  1033         writer = dctx.stream_writer(buffer, write_return_read=True)
       
  1034         pos = 0
       
  1035         while pos < len(compressed):
       
  1036             pos2 = pos + 8192
       
  1037             chunk = compressed[pos:pos2]
       
  1038             self.assertEqual(writer.write(chunk), len(chunk))
       
  1039             pos += 8192
       
  1040         self.assertEqual(buffer.getvalue(), orig)
       
  1041 
   654     def test_dictionary(self):
  1042     def test_dictionary(self):
   655         samples = []
  1043         samples = []
   656         for i in range(128):
  1044         for i in range(128):
   657             samples.append(b'foo' * 64)
  1045             samples.append(b'foo' * 64)
   658             samples.append(b'bar' * 64)
  1046             samples.append(b'bar' * 64)
   659             samples.append(b'foobar' * 64)
  1047             samples.append(b'foobar' * 64)
   660 
  1048 
   661         d = zstd.train_dictionary(8192, samples)
  1049         d = zstd.train_dictionary(8192, samples)
   662 
  1050 
   663         orig = b'foobar' * 16384
  1051         orig = b'foobar' * 16384
   664         buffer = io.BytesIO()
  1052         buffer = NonClosingBytesIO()
   665         cctx = zstd.ZstdCompressor(dict_data=d)
  1053         cctx = zstd.ZstdCompressor(dict_data=d)
   666         with cctx.stream_writer(buffer) as compressor:
  1054         with cctx.stream_writer(buffer) as compressor:
   667             self.assertEqual(compressor.write(orig), 0)
  1055             self.assertEqual(compressor.write(orig), 0)
   668 
  1056 
   669         compressed = buffer.getvalue()
  1057         compressed = buffer.getvalue()
   670         buffer = io.BytesIO()
  1058         buffer = io.BytesIO()
   671 
  1059 
   672         dctx = zstd.ZstdDecompressor(dict_data=d)
  1060         dctx = zstd.ZstdDecompressor(dict_data=d)
       
  1061         decompressor = dctx.stream_writer(buffer)
       
  1062         self.assertEqual(decompressor.write(compressed), len(orig))
       
  1063         self.assertEqual(buffer.getvalue(), orig)
       
  1064 
       
  1065         buffer = NonClosingBytesIO()
       
  1066 
   673         with dctx.stream_writer(buffer) as decompressor:
  1067         with dctx.stream_writer(buffer) as decompressor:
   674             self.assertEqual(decompressor.write(compressed), len(orig))
  1068             self.assertEqual(decompressor.write(compressed), len(orig))
   675 
  1069 
   676         self.assertEqual(buffer.getvalue(), orig)
  1070         self.assertEqual(buffer.getvalue(), orig)
   677 
  1071 
   678     def test_memory_size(self):
  1072     def test_memory_size(self):
   679         dctx = zstd.ZstdDecompressor()
  1073         dctx = zstd.ZstdDecompressor()
   680         buffer = io.BytesIO()
  1074         buffer = io.BytesIO()
       
  1075 
       
  1076         decompressor = dctx.stream_writer(buffer)
       
  1077         size = decompressor.memory_size()
       
  1078         self.assertGreater(size, 100000)
       
  1079 
   681         with dctx.stream_writer(buffer) as decompressor:
  1080         with dctx.stream_writer(buffer) as decompressor:
   682             size = decompressor.memory_size()
  1081             size = decompressor.memory_size()
   683 
  1082 
   684         self.assertGreater(size, 100000)
  1083         self.assertGreater(size, 100000)
   685 
  1084 
   808         self.assertEqual(decompressed, source.getvalue())
  1207         self.assertEqual(decompressed, source.getvalue())
   809 
  1208 
   810     @unittest.skipUnless('ZSTD_SLOW_TESTS' in os.environ, 'ZSTD_SLOW_TESTS not set')
  1209     @unittest.skipUnless('ZSTD_SLOW_TESTS' in os.environ, 'ZSTD_SLOW_TESTS not set')
   811     def test_large_input(self):
  1210     def test_large_input(self):
   812         bytes = list(struct.Struct('>B').pack(i) for i in range(256))
  1211         bytes = list(struct.Struct('>B').pack(i) for i in range(256))
   813         compressed = io.BytesIO()
  1212         compressed = NonClosingBytesIO()
   814         input_size = 0
  1213         input_size = 0
   815         cctx = zstd.ZstdCompressor(level=1)
  1214         cctx = zstd.ZstdCompressor(level=1)
   816         with cctx.stream_writer(compressed) as compressor:
  1215         with cctx.stream_writer(compressed) as compressor:
   817             while True:
  1216             while True:
   818                 compressor.write(random.choice(bytes))
  1217                 compressor.write(random.choice(bytes))
   821                 have_compressed = len(compressed.getvalue()) > zstd.DECOMPRESSION_RECOMMENDED_INPUT_SIZE
  1220                 have_compressed = len(compressed.getvalue()) > zstd.DECOMPRESSION_RECOMMENDED_INPUT_SIZE
   822                 have_raw = input_size > zstd.DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE * 2
  1221                 have_raw = input_size > zstd.DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE * 2
   823                 if have_compressed and have_raw:
  1222                 if have_compressed and have_raw:
   824                     break
  1223                     break
   825 
  1224 
   826         compressed.seek(0)
  1225         compressed = io.BytesIO(compressed.getvalue())
   827         self.assertGreater(len(compressed.getvalue()),
  1226         self.assertGreater(len(compressed.getvalue()),
   828                            zstd.DECOMPRESSION_RECOMMENDED_INPUT_SIZE)
  1227                            zstd.DECOMPRESSION_RECOMMENDED_INPUT_SIZE)
   829 
  1228 
   830         dctx = zstd.ZstdDecompressor()
  1229         dctx = zstd.ZstdDecompressor()
   831         it = dctx.read_to_iter(compressed)
  1230         it = dctx.read_to_iter(compressed)
   859         # Found this edge case via fuzzing.
  1258         # Found this edge case via fuzzing.
   860         cctx = zstd.ZstdCompressor(level=1)
  1259         cctx = zstd.ZstdCompressor(level=1)
   861 
  1260 
   862         source = io.BytesIO()
  1261         source = io.BytesIO()
   863 
  1262 
   864         compressed = io.BytesIO()
  1263         compressed = NonClosingBytesIO()
   865         with cctx.stream_writer(compressed) as compressor:
  1264         with cctx.stream_writer(compressed) as compressor:
   866             for i in range(256):
  1265             for i in range(256):
   867                 chunk = b'\0' * 1024
  1266                 chunk = b'\0' * 1024
   868                 compressor.write(chunk)
  1267                 compressor.write(chunk)
   869                 source.write(chunk)
  1268                 source.write(chunk)
   872 
  1271 
   873         simple = dctx.decompress(compressed.getvalue(),
  1272         simple = dctx.decompress(compressed.getvalue(),
   874                                  max_output_size=len(source.getvalue()))
  1273                                  max_output_size=len(source.getvalue()))
   875         self.assertEqual(simple, source.getvalue())
  1274         self.assertEqual(simple, source.getvalue())
   876 
  1275 
   877         compressed.seek(0)
  1276         compressed = io.BytesIO(compressed.getvalue())
   878         streamed = b''.join(dctx.read_to_iter(compressed))
  1277         streamed = b''.join(dctx.read_to_iter(compressed))
   879         self.assertEqual(streamed, source.getvalue())
  1278         self.assertEqual(streamed, source.getvalue())
   880 
  1279 
   881     def test_read_write_size(self):
  1280     def test_read_write_size(self):
   882         source = OpCountingBytesIO(zstd.ZstdCompressor().compress(b'foobarfoobar'))
  1281         source = OpCountingBytesIO(zstd.ZstdCompressor().compress(b'foobarfoobar'))
   999 # TODO enable for CFFI
  1398 # TODO enable for CFFI
  1000 class TestDecompressor_multi_decompress_to_buffer(unittest.TestCase):
  1399 class TestDecompressor_multi_decompress_to_buffer(unittest.TestCase):
  1001     def test_invalid_inputs(self):
  1400     def test_invalid_inputs(self):
  1002         dctx = zstd.ZstdDecompressor()
  1401         dctx = zstd.ZstdDecompressor()
  1003 
  1402 
       
  1403         if not hasattr(dctx, 'multi_decompress_to_buffer'):
       
  1404             self.skipTest('multi_decompress_to_buffer not available')
       
  1405 
  1004         with self.assertRaises(TypeError):
  1406         with self.assertRaises(TypeError):
  1005             dctx.multi_decompress_to_buffer(True)
  1407             dctx.multi_decompress_to_buffer(True)
  1006 
  1408 
  1007         with self.assertRaises(TypeError):
  1409         with self.assertRaises(TypeError):
  1008             dctx.multi_decompress_to_buffer((1, 2))
  1410             dctx.multi_decompress_to_buffer((1, 2))
  1018 
  1420 
  1019         original = [b'foo' * 4, b'bar' * 6]
  1421         original = [b'foo' * 4, b'bar' * 6]
  1020         frames = [cctx.compress(d) for d in original]
  1422         frames = [cctx.compress(d) for d in original]
  1021 
  1423 
  1022         dctx = zstd.ZstdDecompressor()
  1424         dctx = zstd.ZstdDecompressor()
       
  1425 
       
  1426         if not hasattr(dctx, 'multi_decompress_to_buffer'):
       
  1427             self.skipTest('multi_decompress_to_buffer not available')
       
  1428 
  1023         result = dctx.multi_decompress_to_buffer(frames)
  1429         result = dctx.multi_decompress_to_buffer(frames)
  1024 
  1430 
  1025         self.assertEqual(len(result), len(frames))
  1431         self.assertEqual(len(result), len(frames))
  1026         self.assertEqual(result.size(), sum(map(len, original)))
  1432         self.assertEqual(result.size(), sum(map(len, original)))
  1027 
  1433 
  1039         original = [b'foo' * 4, b'bar' * 6, b'baz' * 8]
  1445         original = [b'foo' * 4, b'bar' * 6, b'baz' * 8]
  1040         frames = [cctx.compress(d) for d in original]
  1446         frames = [cctx.compress(d) for d in original]
  1041         sizes = struct.pack('=' + 'Q' * len(original), *map(len, original))
  1447         sizes = struct.pack('=' + 'Q' * len(original), *map(len, original))
  1042 
  1448 
  1043         dctx = zstd.ZstdDecompressor()
  1449         dctx = zstd.ZstdDecompressor()
       
  1450 
       
  1451         if not hasattr(dctx, 'multi_decompress_to_buffer'):
       
  1452             self.skipTest('multi_decompress_to_buffer not available')
       
  1453 
  1044         result = dctx.multi_decompress_to_buffer(frames, decompressed_sizes=sizes)
  1454         result = dctx.multi_decompress_to_buffer(frames, decompressed_sizes=sizes)
  1045 
  1455 
  1046         self.assertEqual(len(result), len(frames))
  1456         self.assertEqual(len(result), len(frames))
  1047         self.assertEqual(result.size(), sum(map(len, original)))
  1457         self.assertEqual(result.size(), sum(map(len, original)))
  1048 
  1458 
  1054 
  1464 
  1055         original = [b'foo' * 4, b'bar' * 6]
  1465         original = [b'foo' * 4, b'bar' * 6]
  1056         frames = [cctx.compress(d) for d in original]
  1466         frames = [cctx.compress(d) for d in original]
  1057 
  1467 
  1058         dctx = zstd.ZstdDecompressor()
  1468         dctx = zstd.ZstdDecompressor()
       
  1469 
       
  1470         if not hasattr(dctx, 'multi_decompress_to_buffer'):
       
  1471             self.skipTest('multi_decompress_to_buffer not available')
  1059 
  1472 
  1060         segments = struct.pack('=QQQQ', 0, len(frames[0]), len(frames[0]), len(frames[1]))
  1473         segments = struct.pack('=QQQQ', 0, len(frames[0]), len(frames[0]), len(frames[1]))
  1061         b = zstd.BufferWithSegments(b''.join(frames), segments)
  1474         b = zstd.BufferWithSegments(b''.join(frames), segments)
  1062 
  1475 
  1063         result = dctx.multi_decompress_to_buffer(b)
  1476         result = dctx.multi_decompress_to_buffer(b)
  1072         cctx = zstd.ZstdCompressor(write_content_size=False)
  1485         cctx = zstd.ZstdCompressor(write_content_size=False)
  1073         original = [b'foo' * 4, b'bar' * 6, b'baz' * 8]
  1486         original = [b'foo' * 4, b'bar' * 6, b'baz' * 8]
  1074         frames = [cctx.compress(d) for d in original]
  1487         frames = [cctx.compress(d) for d in original]
  1075         sizes = struct.pack('=' + 'Q' * len(original), *map(len, original))
  1488         sizes = struct.pack('=' + 'Q' * len(original), *map(len, original))
  1076 
  1489 
       
  1490         dctx = zstd.ZstdDecompressor()
       
  1491 
       
  1492         if not hasattr(dctx, 'multi_decompress_to_buffer'):
       
  1493             self.skipTest('multi_decompress_to_buffer not available')
       
  1494 
  1077         segments = struct.pack('=QQQQQQ', 0, len(frames[0]),
  1495         segments = struct.pack('=QQQQQQ', 0, len(frames[0]),
  1078                                len(frames[0]), len(frames[1]),
  1496                                len(frames[0]), len(frames[1]),
  1079                                len(frames[0]) + len(frames[1]), len(frames[2]))
  1497                                len(frames[0]) + len(frames[1]), len(frames[2]))
  1080         b = zstd.BufferWithSegments(b''.join(frames), segments)
  1498         b = zstd.BufferWithSegments(b''.join(frames), segments)
  1081 
  1499 
  1082         dctx = zstd.ZstdDecompressor()
       
  1083         result = dctx.multi_decompress_to_buffer(b, decompressed_sizes=sizes)
  1500         result = dctx.multi_decompress_to_buffer(b, decompressed_sizes=sizes)
  1084 
  1501 
  1085         self.assertEqual(len(result), len(frames))
  1502         self.assertEqual(len(result), len(frames))
  1086         self.assertEqual(result.size(), sum(map(len, original)))
  1503         self.assertEqual(result.size(), sum(map(len, original)))
  1087 
  1504 
  1097             b'foo2' * 4,
  1514             b'foo2' * 4,
  1098             b'foo3' * 5,
  1515             b'foo3' * 5,
  1099             b'foo4' * 6,
  1516             b'foo4' * 6,
  1100         ]
  1517         ]
  1101 
  1518 
       
  1519         if not hasattr(cctx, 'multi_compress_to_buffer'):
       
  1520             self.skipTest('multi_compress_to_buffer not available')
       
  1521 
  1102         frames = cctx.multi_compress_to_buffer(original)
  1522         frames = cctx.multi_compress_to_buffer(original)
  1103 
  1523 
  1104         # Check round trip.
  1524         # Check round trip.
  1105         dctx = zstd.ZstdDecompressor()
  1525         dctx = zstd.ZstdDecompressor()
       
  1526 
  1106         decompressed = dctx.multi_decompress_to_buffer(frames, threads=3)
  1527         decompressed = dctx.multi_decompress_to_buffer(frames, threads=3)
  1107 
  1528 
  1108         self.assertEqual(len(decompressed), len(original))
  1529         self.assertEqual(len(decompressed), len(original))
  1109 
  1530 
  1110         for i, data in enumerate(original):
  1531         for i, data in enumerate(original):
  1136 
  1557 
  1137         cctx = zstd.ZstdCompressor(dict_data=d, level=1)
  1558         cctx = zstd.ZstdCompressor(dict_data=d, level=1)
  1138         frames = [cctx.compress(s) for s in generate_samples()]
  1559         frames = [cctx.compress(s) for s in generate_samples()]
  1139 
  1560 
  1140         dctx = zstd.ZstdDecompressor(dict_data=d)
  1561         dctx = zstd.ZstdDecompressor(dict_data=d)
       
  1562 
       
  1563         if not hasattr(dctx, 'multi_decompress_to_buffer'):
       
  1564             self.skipTest('multi_decompress_to_buffer not available')
       
  1565 
  1141         result = dctx.multi_decompress_to_buffer(frames)
  1566         result = dctx.multi_decompress_to_buffer(frames)
       
  1567 
  1142         self.assertEqual([o.tobytes() for o in result], generate_samples())
  1568         self.assertEqual([o.tobytes() for o in result], generate_samples())
  1143 
  1569 
  1144     def test_multiple_threads(self):
  1570     def test_multiple_threads(self):
  1145         cctx = zstd.ZstdCompressor()
  1571         cctx = zstd.ZstdCompressor()
  1146 
  1572 
  1147         frames = []
  1573         frames = []
  1148         frames.extend(cctx.compress(b'x' * 64) for i in range(256))
  1574         frames.extend(cctx.compress(b'x' * 64) for i in range(256))
  1149         frames.extend(cctx.compress(b'y' * 64) for i in range(256))
  1575         frames.extend(cctx.compress(b'y' * 64) for i in range(256))
  1150 
  1576 
  1151         dctx = zstd.ZstdDecompressor()
  1577         dctx = zstd.ZstdDecompressor()
       
  1578 
       
  1579         if not hasattr(dctx, 'multi_decompress_to_buffer'):
       
  1580             self.skipTest('multi_decompress_to_buffer not available')
       
  1581 
  1152         result = dctx.multi_decompress_to_buffer(frames, threads=-1)
  1582         result = dctx.multi_decompress_to_buffer(frames, threads=-1)
  1153 
  1583 
  1154         self.assertEqual(len(result), len(frames))
  1584         self.assertEqual(len(result), len(frames))
  1155         self.assertEqual(result.size(), 2 * 64 * 256)
  1585         self.assertEqual(result.size(), 2 * 64 * 256)
  1156         self.assertEqual(result[0].tobytes(), b'x' * 64)
  1586         self.assertEqual(result[0].tobytes(), b'x' * 64)
  1161         frames = [cctx.compress(b'x' * 128), cctx.compress(b'y' * 128)]
  1591         frames = [cctx.compress(b'x' * 128), cctx.compress(b'y' * 128)]
  1162 
  1592 
  1163         frames[1] = frames[1][0:15] + b'extra' + frames[1][15:]
  1593         frames[1] = frames[1][0:15] + b'extra' + frames[1][15:]
  1164 
  1594 
  1165         dctx = zstd.ZstdDecompressor()
  1595         dctx = zstd.ZstdDecompressor()
       
  1596 
       
  1597         if not hasattr(dctx, 'multi_decompress_to_buffer'):
       
  1598             self.skipTest('multi_decompress_to_buffer not available')
  1166 
  1599 
  1167         with self.assertRaisesRegexp(zstd.ZstdError,
  1600         with self.assertRaisesRegexp(zstd.ZstdError,
  1168                                      'error decompressing item 1: ('
  1601                                      'error decompressing item 1: ('
  1169                                      'Corrupted block|'
  1602                                      'Corrupted block|'
  1170                                      'Destination buffer is too small)'):
  1603                                      'Destination buffer is too small)'):