contrib/python-zstandard/tests/test_decompressor_fuzzing.py
changeset 42070 675775c33ab6
parent 37495 b1fb341d8a61
child 43994 de7838053207
equal deleted inserted replaced
42069:668eff08387f 42070:675775c33ab6
    10 
    10 
    11 import zstandard as zstd
    11 import zstandard as zstd
    12 
    12 
    13 from . common import (
    13 from . common import (
    14     make_cffi,
    14     make_cffi,
       
    15     NonClosingBytesIO,
    15     random_input_data,
    16     random_input_data,
    16 )
    17 )
    17 
    18 
    18 
    19 
    19 @unittest.skipUnless('ZSTD_SLOW_TESTS' in os.environ, 'ZSTD_SLOW_TESTS not set')
    20 @unittest.skipUnless('ZSTD_SLOW_TESTS' in os.environ, 'ZSTD_SLOW_TESTS not set')
    21 class TestDecompressor_stream_reader_fuzzing(unittest.TestCase):
    22 class TestDecompressor_stream_reader_fuzzing(unittest.TestCase):
    22     @hypothesis.settings(
    23     @hypothesis.settings(
    23         suppress_health_check=[hypothesis.HealthCheck.large_base_example])
    24         suppress_health_check=[hypothesis.HealthCheck.large_base_example])
    24     @hypothesis.given(original=strategies.sampled_from(random_input_data()),
    25     @hypothesis.given(original=strategies.sampled_from(random_input_data()),
    25                       level=strategies.integers(min_value=1, max_value=5),
    26                       level=strategies.integers(min_value=1, max_value=5),
    26                       source_read_size=strategies.integers(1, 16384),
    27                       streaming=strategies.booleans(),
       
    28                       source_read_size=strategies.integers(1, 1048576),
    27                       read_sizes=strategies.data())
    29                       read_sizes=strategies.data())
    28     def test_stream_source_read_variance(self, original, level, source_read_size,
    30     def test_stream_source_read_variance(self, original, level, streaming,
    29                                          read_sizes):
    31                                          source_read_size, read_sizes):
    30         cctx = zstd.ZstdCompressor(level=level)
    32         cctx = zstd.ZstdCompressor(level=level)
    31         frame = cctx.compress(original)
    33 
    32 
    34         if streaming:
    33         dctx = zstd.ZstdDecompressor()
    35             source = io.BytesIO()
    34         source = io.BytesIO(frame)
    36             writer = cctx.stream_writer(source)
       
    37             writer.write(original)
       
    38             writer.flush(zstd.FLUSH_FRAME)
       
    39             source.seek(0)
       
    40         else:
       
    41             frame = cctx.compress(original)
       
    42             source = io.BytesIO(frame)
       
    43 
       
    44         dctx = zstd.ZstdDecompressor()
    35 
    45 
    36         chunks = []
    46         chunks = []
    37         with dctx.stream_reader(source, read_size=source_read_size) as reader:
    47         with dctx.stream_reader(source, read_size=source_read_size) as reader:
    38             while True:
    48             while True:
    39                 read_size = read_sizes.draw(strategies.integers(1, 16384))
    49                 read_size = read_sizes.draw(strategies.integers(-1, 131072))
    40                 chunk = reader.read(read_size)
    50                 chunk = reader.read(read_size)
    41                 if not chunk:
    51                 if not chunk and read_size:
    42                     break
    52                     break
    43 
    53 
    44                 chunks.append(chunk)
    54                 chunks.append(chunk)
    45 
    55 
    46         self.assertEqual(b''.join(chunks), original)
    56         self.assertEqual(b''.join(chunks), original)
    47 
    57 
    48     @hypothesis.settings(
    58     # Similar to above except we have a constant read() size.
    49         suppress_health_check=[hypothesis.HealthCheck.large_base_example])
    59     @hypothesis.settings(
    50     @hypothesis.given(original=strategies.sampled_from(random_input_data()),
    60         suppress_health_check=[hypothesis.HealthCheck.large_base_example])
    51                       level=strategies.integers(min_value=1, max_value=5),
    61     @hypothesis.given(original=strategies.sampled_from(random_input_data()),
    52                       source_read_size=strategies.integers(1, 16384),
    62                       level=strategies.integers(min_value=1, max_value=5),
       
    63                       streaming=strategies.booleans(),
       
    64                       source_read_size=strategies.integers(1, 1048576),
       
    65                       read_size=strategies.integers(-1, 131072))
       
    66     def test_stream_source_read_size(self, original, level, streaming,
       
    67                                      source_read_size, read_size):
       
    68         if read_size == 0:
       
    69             read_size = 1
       
    70 
       
    71         cctx = zstd.ZstdCompressor(level=level)
       
    72 
       
    73         if streaming:
       
    74             source = io.BytesIO()
       
    75             writer = cctx.stream_writer(source)
       
    76             writer.write(original)
       
    77             writer.flush(zstd.FLUSH_FRAME)
       
    78             source.seek(0)
       
    79         else:
       
    80             frame = cctx.compress(original)
       
    81             source = io.BytesIO(frame)
       
    82 
       
    83         dctx = zstd.ZstdDecompressor()
       
    84 
       
    85         chunks = []
       
    86         reader = dctx.stream_reader(source, read_size=source_read_size)
       
    87         while True:
       
    88             chunk = reader.read(read_size)
       
    89             if not chunk and read_size:
       
    90                 break
       
    91 
       
    92             chunks.append(chunk)
       
    93 
       
    94         self.assertEqual(b''.join(chunks), original)
       
    95 
       
    96     @hypothesis.settings(
       
    97         suppress_health_check=[hypothesis.HealthCheck.large_base_example])
       
    98     @hypothesis.given(original=strategies.sampled_from(random_input_data()),
       
    99                       level=strategies.integers(min_value=1, max_value=5),
       
   100                       streaming=strategies.booleans(),
       
   101                       source_read_size=strategies.integers(1, 1048576),
    53                       read_sizes=strategies.data())
   102                       read_sizes=strategies.data())
    54     def test_buffer_source_read_variance(self, original, level, source_read_size,
   103     def test_buffer_source_read_variance(self, original, level, streaming,
    55                                          read_sizes):
   104                                          source_read_size, read_sizes):
    56         cctx = zstd.ZstdCompressor(level=level)
   105         cctx = zstd.ZstdCompressor(level=level)
    57         frame = cctx.compress(original)
   106 
       
   107         if streaming:
       
   108             source = io.BytesIO()
       
   109             writer = cctx.stream_writer(source)
       
   110             writer.write(original)
       
   111             writer.flush(zstd.FLUSH_FRAME)
       
   112             frame = source.getvalue()
       
   113         else:
       
   114             frame = cctx.compress(original)
    58 
   115 
    59         dctx = zstd.ZstdDecompressor()
   116         dctx = zstd.ZstdDecompressor()
    60         chunks = []
   117         chunks = []
    61 
   118 
    62         with dctx.stream_reader(frame, read_size=source_read_size) as reader:
   119         with dctx.stream_reader(frame, read_size=source_read_size) as reader:
    63             while True:
   120             while True:
    64                 read_size = read_sizes.draw(strategies.integers(1, 16384))
   121                 read_size = read_sizes.draw(strategies.integers(-1, 131072))
    65                 chunk = reader.read(read_size)
   122                 chunk = reader.read(read_size)
    66                 if not chunk:
   123                 if not chunk and read_size:
    67                     break
   124                     break
    68 
   125 
    69                 chunks.append(chunk)
   126                 chunks.append(chunk)
       
   127 
       
   128         self.assertEqual(b''.join(chunks), original)
       
   129 
       
   130     # Similar to above except we have a constant read() size.
       
   131     @hypothesis.settings(
       
   132         suppress_health_check=[hypothesis.HealthCheck.large_base_example])
       
   133     @hypothesis.given(original=strategies.sampled_from(random_input_data()),
       
   134                       level=strategies.integers(min_value=1, max_value=5),
       
   135                       streaming=strategies.booleans(),
       
   136                       source_read_size=strategies.integers(1, 1048576),
       
   137                       read_size=strategies.integers(-1, 131072))
       
   138     def test_buffer_source_constant_read_size(self, original, level, streaming,
       
   139                                               source_read_size, read_size):
       
   140         if read_size == 0:
       
   141             read_size = -1
       
   142 
       
   143         cctx = zstd.ZstdCompressor(level=level)
       
   144 
       
   145         if streaming:
       
   146             source = io.BytesIO()
       
   147             writer = cctx.stream_writer(source)
       
   148             writer.write(original)
       
   149             writer.flush(zstd.FLUSH_FRAME)
       
   150             frame = source.getvalue()
       
   151         else:
       
   152             frame = cctx.compress(original)
       
   153 
       
   154         dctx = zstd.ZstdDecompressor()
       
   155         chunks = []
       
   156 
       
   157         reader = dctx.stream_reader(frame, read_size=source_read_size)
       
   158         while True:
       
   159             chunk = reader.read(read_size)
       
   160             if not chunk and read_size:
       
   161                 break
       
   162 
       
   163             chunks.append(chunk)
       
   164 
       
   165         self.assertEqual(b''.join(chunks), original)
       
   166 
       
   167     @hypothesis.settings(
       
   168         suppress_health_check=[hypothesis.HealthCheck.large_base_example])
       
   169     @hypothesis.given(original=strategies.sampled_from(random_input_data()),
       
   170                       level=strategies.integers(min_value=1, max_value=5),
       
   171                       streaming=strategies.booleans(),
       
   172                       source_read_size=strategies.integers(1, 1048576))
       
   173     def test_stream_source_readall(self, original, level, streaming,
       
   174                                          source_read_size):
       
   175         cctx = zstd.ZstdCompressor(level=level)
       
   176 
       
   177         if streaming:
       
   178             source = io.BytesIO()
       
   179             writer = cctx.stream_writer(source)
       
   180             writer.write(original)
       
   181             writer.flush(zstd.FLUSH_FRAME)
       
   182             source.seek(0)
       
   183         else:
       
   184             frame = cctx.compress(original)
       
   185             source = io.BytesIO(frame)
       
   186 
       
   187         dctx = zstd.ZstdDecompressor()
       
   188 
       
   189         data = dctx.stream_reader(source, read_size=source_read_size).readall()
       
   190         self.assertEqual(data, original)
       
   191 
       
   192     @hypothesis.settings(
       
   193         suppress_health_check=[hypothesis.HealthCheck.large_base_example])
       
   194     @hypothesis.given(original=strategies.sampled_from(random_input_data()),
       
   195                       level=strategies.integers(min_value=1, max_value=5),
       
   196                       streaming=strategies.booleans(),
       
   197                       source_read_size=strategies.integers(1, 1048576),
       
   198                       read_sizes=strategies.data())
       
   199     def test_stream_source_read1_variance(self, original, level, streaming,
       
   200                                           source_read_size, read_sizes):
       
   201         cctx = zstd.ZstdCompressor(level=level)
       
   202 
       
   203         if streaming:
       
   204             source = io.BytesIO()
       
   205             writer = cctx.stream_writer(source)
       
   206             writer.write(original)
       
   207             writer.flush(zstd.FLUSH_FRAME)
       
   208             source.seek(0)
       
   209         else:
       
   210             frame = cctx.compress(original)
       
   211             source = io.BytesIO(frame)
       
   212 
       
   213         dctx = zstd.ZstdDecompressor()
       
   214 
       
   215         chunks = []
       
   216         with dctx.stream_reader(source, read_size=source_read_size) as reader:
       
   217             while True:
       
   218                 read_size = read_sizes.draw(strategies.integers(-1, 131072))
       
   219                 chunk = reader.read1(read_size)
       
   220                 if not chunk and read_size:
       
   221                     break
       
   222 
       
   223                 chunks.append(chunk)
       
   224 
       
   225         self.assertEqual(b''.join(chunks), original)
       
   226 
       
   227     @hypothesis.settings(
       
   228         suppress_health_check=[hypothesis.HealthCheck.large_base_example])
       
   229     @hypothesis.given(original=strategies.sampled_from(random_input_data()),
       
   230                       level=strategies.integers(min_value=1, max_value=5),
       
   231                       streaming=strategies.booleans(),
       
   232                       source_read_size=strategies.integers(1, 1048576),
       
   233                       read_sizes=strategies.data())
       
   234     def test_stream_source_readinto1_variance(self, original, level, streaming,
       
   235                                           source_read_size, read_sizes):
       
   236         cctx = zstd.ZstdCompressor(level=level)
       
   237 
       
   238         if streaming:
       
   239             source = io.BytesIO()
       
   240             writer = cctx.stream_writer(source)
       
   241             writer.write(original)
       
   242             writer.flush(zstd.FLUSH_FRAME)
       
   243             source.seek(0)
       
   244         else:
       
   245             frame = cctx.compress(original)
       
   246             source = io.BytesIO(frame)
       
   247 
       
   248         dctx = zstd.ZstdDecompressor()
       
   249 
       
   250         chunks = []
       
   251         with dctx.stream_reader(source, read_size=source_read_size) as reader:
       
   252             while True:
       
   253                 read_size = read_sizes.draw(strategies.integers(1, 131072))
       
   254                 b = bytearray(read_size)
       
   255                 count = reader.readinto1(b)
       
   256 
       
   257                 if not count:
       
   258                     break
       
   259 
       
   260                 chunks.append(bytes(b[0:count]))
    70 
   261 
    71         self.assertEqual(b''.join(chunks), original)
   262         self.assertEqual(b''.join(chunks), original)
    72 
   263 
    73     @hypothesis.settings(
   264     @hypothesis.settings(
    74         suppress_health_check=[hypothesis.HealthCheck.large_base_example])
   265         suppress_health_check=[hypothesis.HealthCheck.large_base_example])
    75     @hypothesis.given(
   266     @hypothesis.given(
    76         original=strategies.sampled_from(random_input_data()),
   267         original=strategies.sampled_from(random_input_data()),
    77         level=strategies.integers(min_value=1, max_value=5),
   268         level=strategies.integers(min_value=1, max_value=5),
    78         source_read_size=strategies.integers(1, 16384),
   269         source_read_size=strategies.integers(1, 1048576),
    79         seek_amounts=strategies.data(),
   270         seek_amounts=strategies.data(),
    80         read_sizes=strategies.data())
   271         read_sizes=strategies.data())
    81     def test_relative_seeks(self, original, level, source_read_size, seek_amounts,
   272     def test_relative_seeks(self, original, level, source_read_size, seek_amounts,
    82                             read_sizes):
   273                             read_sizes):
    83         cctx = zstd.ZstdCompressor(level=level)
   274         cctx = zstd.ZstdCompressor(level=level)
    97                 if not chunk:
   288                 if not chunk:
    98                     break
   289                     break
    99 
   290 
   100                 self.assertEqual(original[offset:offset + len(chunk)], chunk)
   291                 self.assertEqual(original[offset:offset + len(chunk)], chunk)
   101 
   292 
       
   293     @hypothesis.settings(
       
   294         suppress_health_check=[hypothesis.HealthCheck.large_base_example])
       
   295     @hypothesis.given(
       
   296         originals=strategies.data(),
       
   297         frame_count=strategies.integers(min_value=2, max_value=10),
       
   298         level=strategies.integers(min_value=1, max_value=5),
       
   299         source_read_size=strategies.integers(1, 1048576),
       
   300         read_sizes=strategies.data())
       
   301     def test_multiple_frames(self, originals, frame_count, level,
       
   302                              source_read_size, read_sizes):
       
   303 
       
   304         cctx = zstd.ZstdCompressor(level=level)
       
   305         source = io.BytesIO()
       
   306         buffer = io.BytesIO()
       
   307         writer = cctx.stream_writer(buffer)
       
   308 
       
   309         for i in range(frame_count):
       
   310             data = originals.draw(strategies.sampled_from(random_input_data()))
       
   311             source.write(data)
       
   312             writer.write(data)
       
   313             writer.flush(zstd.FLUSH_FRAME)
       
   314 
       
   315         dctx = zstd.ZstdDecompressor()
       
   316         buffer.seek(0)
       
   317         reader = dctx.stream_reader(buffer, read_size=source_read_size,
       
   318                                     read_across_frames=True)
       
   319 
       
   320         chunks = []
       
   321 
       
   322         while True:
       
   323             read_amount = read_sizes.draw(strategies.integers(-1, 16384))
       
   324             chunk = reader.read(read_amount)
       
   325 
       
   326             if not chunk and read_amount:
       
   327                 break
       
   328 
       
   329             chunks.append(chunk)
       
   330 
       
   331         self.assertEqual(source.getvalue(), b''.join(chunks))
       
   332 
   102 
   333 
   103 @unittest.skipUnless('ZSTD_SLOW_TESTS' in os.environ, 'ZSTD_SLOW_TESTS not set')
   334 @unittest.skipUnless('ZSTD_SLOW_TESTS' in os.environ, 'ZSTD_SLOW_TESTS not set')
   104 @make_cffi
   335 @make_cffi
   105 class TestDecompressor_stream_writer_fuzzing(unittest.TestCase):
   336 class TestDecompressor_stream_writer_fuzzing(unittest.TestCase):
   106     @hypothesis.given(original=strategies.sampled_from(random_input_data()),
   337     @hypothesis.given(original=strategies.sampled_from(random_input_data()),
   111         cctx = zstd.ZstdCompressor(level=level)
   342         cctx = zstd.ZstdCompressor(level=level)
   112         frame = cctx.compress(original)
   343         frame = cctx.compress(original)
   113 
   344 
   114         dctx = zstd.ZstdDecompressor()
   345         dctx = zstd.ZstdDecompressor()
   115         source = io.BytesIO(frame)
   346         source = io.BytesIO(frame)
   116         dest = io.BytesIO()
   347         dest = NonClosingBytesIO()
   117 
   348 
   118         with dctx.stream_writer(dest, write_size=write_size) as decompressor:
   349         with dctx.stream_writer(dest, write_size=write_size) as decompressor:
   119             while True:
   350             while True:
   120                 input_size = input_sizes.draw(strategies.integers(1, 4096))
   351                 input_size = input_sizes.draw(strategies.integers(1, 4096))
   121                 chunk = source.read(input_size)
   352                 chunk = source.read(input_size)
   232         cctx = zstd.ZstdCompressor(level=1,
   463         cctx = zstd.ZstdCompressor(level=1,
   233                                    write_content_size=True,
   464                                    write_content_size=True,
   234                                    write_checksum=True,
   465                                    write_checksum=True,
   235                                    **kwargs)
   466                                    **kwargs)
   236 
   467 
       
   468         if not hasattr(cctx, 'multi_compress_to_buffer'):
       
   469             self.skipTest('multi_compress_to_buffer not available')
       
   470 
   237         frames_buffer = cctx.multi_compress_to_buffer(original, threads=-1)
   471         frames_buffer = cctx.multi_compress_to_buffer(original, threads=-1)
   238 
   472 
   239         dctx = zstd.ZstdDecompressor(**kwargs)
   473         dctx = zstd.ZstdDecompressor(**kwargs)
   240 
       
   241         result = dctx.multi_decompress_to_buffer(frames_buffer)
   474         result = dctx.multi_decompress_to_buffer(frames_buffer)
   242 
   475 
   243         self.assertEqual(len(result), len(original))
   476         self.assertEqual(len(result), len(original))
   244         for i, frame in enumerate(result):
   477         for i, frame in enumerate(result):
   245             self.assertEqual(frame.tobytes(), original[i])
   478             self.assertEqual(frame.tobytes(), original[i])