Mercurial > public > mercurial-scm > hg
comparison contrib/python-zstandard/tests/test_decompressor.py @ 37495:b1fb341d8a61
zstandard: vendor python-zstandard 0.9.0
This was just released. It features a number of goodies. More info at
https://gregoryszorc.com/blog/2018/04/09/release-of-python-zstandard-0.9/.
The clang-format ignore list was updated to reflect the new source
of files.
The project contains a vendored copy of zstandard 1.3.4. The old
version was 1.1.3. One of the changes between those versions is that
zstandard is now dual licensed BSD + GPLv2 and the patent rights grant
has been removed. Good riddance.
The API should be backwards compatible. So no changes in core
should be needed. However, there were a number of changes in the
library that we'll want to adapt to. Those will be addressed in
subsequent commits.
Differential Revision: https://phab.mercurial-scm.org/D3198
author | Gregory Szorc <gregory.szorc@gmail.com> |
---|---|
date | Mon, 09 Apr 2018 10:13:29 -0700 |
parents | e0dc40530c5a |
children | 73fef626dae3 |
comparison
equal
deleted
inserted
replaced
37494:1ce7a55b09d1 | 37495:b1fb341d8a61 |
---|---|
1 import io | 1 import io |
2 import os | |
2 import random | 3 import random |
3 import struct | 4 import struct |
4 import sys | 5 import sys |
5 | 6 import unittest |
6 try: | 7 |
7 import unittest2 as unittest | 8 import zstandard as zstd |
8 except ImportError: | |
9 import unittest | |
10 | |
11 import zstd | |
12 | 9 |
13 from .common import ( | 10 from .common import ( |
11 generate_samples, | |
14 make_cffi, | 12 make_cffi, |
15 OpCountingBytesIO, | 13 OpCountingBytesIO, |
16 ) | 14 ) |
17 | 15 |
18 | 16 |
21 else: | 19 else: |
22 next = lambda it: it.next() | 20 next = lambda it: it.next() |
23 | 21 |
24 | 22 |
25 @make_cffi | 23 @make_cffi |
24 class TestFrameHeaderSize(unittest.TestCase): | |
25 def test_empty(self): | |
26 with self.assertRaisesRegexp( | |
27 zstd.ZstdError, 'could not determine frame header size: Src size ' | |
28 'is incorrect'): | |
29 zstd.frame_header_size(b'') | |
30 | |
31 def test_too_small(self): | |
32 with self.assertRaisesRegexp( | |
33 zstd.ZstdError, 'could not determine frame header size: Src size ' | |
34 'is incorrect'): | |
35 zstd.frame_header_size(b'foob') | |
36 | |
37 def test_basic(self): | |
38 # It doesn't matter that it isn't a valid frame. | |
39 self.assertEqual(zstd.frame_header_size(b'long enough but no magic'), 6) | |
40 | |
41 | |
42 @make_cffi | |
43 class TestFrameContentSize(unittest.TestCase): | |
44 def test_empty(self): | |
45 with self.assertRaisesRegexp(zstd.ZstdError, | |
46 'error when determining content size'): | |
47 zstd.frame_content_size(b'') | |
48 | |
49 def test_too_small(self): | |
50 with self.assertRaisesRegexp(zstd.ZstdError, | |
51 'error when determining content size'): | |
52 zstd.frame_content_size(b'foob') | |
53 | |
54 def test_bad_frame(self): | |
55 with self.assertRaisesRegexp(zstd.ZstdError, | |
56 'error when determining content size'): | |
57 zstd.frame_content_size(b'invalid frame header') | |
58 | |
59 def test_unknown(self): | |
60 cctx = zstd.ZstdCompressor(write_content_size=False) | |
61 frame = cctx.compress(b'foobar') | |
62 | |
63 self.assertEqual(zstd.frame_content_size(frame), -1) | |
64 | |
65 def test_empty(self): | |
66 cctx = zstd.ZstdCompressor() | |
67 frame = cctx.compress(b'') | |
68 | |
69 self.assertEqual(zstd.frame_content_size(frame), 0) | |
70 | |
71 def test_basic(self): | |
72 cctx = zstd.ZstdCompressor() | |
73 frame = cctx.compress(b'foobar') | |
74 | |
75 self.assertEqual(zstd.frame_content_size(frame), 6) | |
76 | |
77 | |
78 @make_cffi | |
79 class TestDecompressor(unittest.TestCase): | |
80 def test_memory_size(self): | |
81 dctx = zstd.ZstdDecompressor() | |
82 | |
83 self.assertGreater(dctx.memory_size(), 100) | |
84 | |
85 | |
86 @make_cffi | |
26 class TestDecompressor_decompress(unittest.TestCase): | 87 class TestDecompressor_decompress(unittest.TestCase): |
27 def test_empty_input(self): | 88 def test_empty_input(self): |
28 dctx = zstd.ZstdDecompressor() | 89 dctx = zstd.ZstdDecompressor() |
29 | 90 |
30 with self.assertRaisesRegexp(zstd.ZstdError, 'input data invalid'): | 91 with self.assertRaisesRegexp(zstd.ZstdError, 'error determining content size from frame header'): |
31 dctx.decompress(b'') | 92 dctx.decompress(b'') |
32 | 93 |
33 def test_invalid_input(self): | 94 def test_invalid_input(self): |
34 dctx = zstd.ZstdDecompressor() | 95 dctx = zstd.ZstdDecompressor() |
35 | 96 |
36 with self.assertRaisesRegexp(zstd.ZstdError, 'input data invalid'): | 97 with self.assertRaisesRegexp(zstd.ZstdError, 'error determining content size from frame header'): |
37 dctx.decompress(b'foobar') | 98 dctx.decompress(b'foobar') |
99 | |
100 def test_input_types(self): | |
101 cctx = zstd.ZstdCompressor(level=1) | |
102 compressed = cctx.compress(b'foo') | |
103 | |
104 mutable_array = bytearray(len(compressed)) | |
105 mutable_array[:] = compressed | |
106 | |
107 sources = [ | |
108 memoryview(compressed), | |
109 bytearray(compressed), | |
110 mutable_array, | |
111 ] | |
112 | |
113 dctx = zstd.ZstdDecompressor() | |
114 for source in sources: | |
115 self.assertEqual(dctx.decompress(source), b'foo') | |
38 | 116 |
39 def test_no_content_size_in_frame(self): | 117 def test_no_content_size_in_frame(self): |
40 cctx = zstd.ZstdCompressor(write_content_size=False) | 118 cctx = zstd.ZstdCompressor(write_content_size=False) |
41 compressed = cctx.compress(b'foobar') | 119 compressed = cctx.compress(b'foobar') |
42 | 120 |
43 dctx = zstd.ZstdDecompressor() | 121 dctx = zstd.ZstdDecompressor() |
44 with self.assertRaisesRegexp(zstd.ZstdError, 'input data invalid'): | 122 with self.assertRaisesRegexp(zstd.ZstdError, 'could not determine content size in frame header'): |
45 dctx.decompress(compressed) | 123 dctx.decompress(compressed) |
46 | 124 |
47 def test_content_size_present(self): | 125 def test_content_size_present(self): |
48 cctx = zstd.ZstdCompressor(write_content_size=True) | 126 cctx = zstd.ZstdCompressor() |
49 compressed = cctx.compress(b'foobar') | 127 compressed = cctx.compress(b'foobar') |
50 | 128 |
51 dctx = zstd.ZstdDecompressor() | 129 dctx = zstd.ZstdDecompressor() |
52 decompressed = dctx.decompress(compressed) | 130 decompressed = dctx.decompress(compressed) |
53 self.assertEqual(decompressed, b'foobar') | 131 self.assertEqual(decompressed, b'foobar') |
132 | |
133 def test_empty_roundtrip(self): | |
134 cctx = zstd.ZstdCompressor() | |
135 compressed = cctx.compress(b'') | |
136 | |
137 dctx = zstd.ZstdDecompressor() | |
138 decompressed = dctx.decompress(compressed) | |
139 | |
140 self.assertEqual(decompressed, b'') | |
54 | 141 |
55 def test_max_output_size(self): | 142 def test_max_output_size(self): |
56 cctx = zstd.ZstdCompressor(write_content_size=False) | 143 cctx = zstd.ZstdCompressor(write_content_size=False) |
57 source = b'foobar' * 256 | 144 source = b'foobar' * 256 |
58 compressed = cctx.compress(source) | 145 compressed = cctx.compress(source) |
61 # Will fit into buffer exactly the size of input. | 148 # Will fit into buffer exactly the size of input. |
62 decompressed = dctx.decompress(compressed, max_output_size=len(source)) | 149 decompressed = dctx.decompress(compressed, max_output_size=len(source)) |
63 self.assertEqual(decompressed, source) | 150 self.assertEqual(decompressed, source) |
64 | 151 |
65 # Input size - 1 fails | 152 # Input size - 1 fails |
66 with self.assertRaisesRegexp(zstd.ZstdError, 'Destination buffer is too small'): | 153 with self.assertRaisesRegexp(zstd.ZstdError, |
154 'decompression error: did not decompress full frame'): | |
67 dctx.decompress(compressed, max_output_size=len(source) - 1) | 155 dctx.decompress(compressed, max_output_size=len(source) - 1) |
68 | 156 |
69 # Input size + 1 works | 157 # Input size + 1 works |
70 decompressed = dctx.decompress(compressed, max_output_size=len(source) + 1) | 158 decompressed = dctx.decompress(compressed, max_output_size=len(source) + 1) |
71 self.assertEqual(decompressed, source) | 159 self.assertEqual(decompressed, source) |
92 samples.append(b'foobar' * 64) | 180 samples.append(b'foobar' * 64) |
93 | 181 |
94 d = zstd.train_dictionary(8192, samples) | 182 d = zstd.train_dictionary(8192, samples) |
95 | 183 |
96 orig = b'foobar' * 16384 | 184 orig = b'foobar' * 16384 |
97 cctx = zstd.ZstdCompressor(level=1, dict_data=d, write_content_size=True) | 185 cctx = zstd.ZstdCompressor(level=1, dict_data=d) |
98 compressed = cctx.compress(orig) | 186 compressed = cctx.compress(orig) |
99 | 187 |
100 dctx = zstd.ZstdDecompressor(dict_data=d) | 188 dctx = zstd.ZstdDecompressor(dict_data=d) |
101 decompressed = dctx.decompress(compressed) | 189 decompressed = dctx.decompress(compressed) |
102 | 190 |
111 | 199 |
112 d = zstd.train_dictionary(8192, samples) | 200 d = zstd.train_dictionary(8192, samples) |
113 | 201 |
114 sources = (b'foobar' * 8192, b'foo' * 8192, b'bar' * 8192) | 202 sources = (b'foobar' * 8192, b'foo' * 8192, b'bar' * 8192) |
115 compressed = [] | 203 compressed = [] |
116 cctx = zstd.ZstdCompressor(level=1, dict_data=d, write_content_size=True) | 204 cctx = zstd.ZstdCompressor(level=1, dict_data=d) |
117 for source in sources: | 205 for source in sources: |
118 compressed.append(cctx.compress(source)) | 206 compressed.append(cctx.compress(source)) |
119 | 207 |
120 dctx = zstd.ZstdDecompressor(dict_data=d) | 208 dctx = zstd.ZstdDecompressor(dict_data=d) |
121 for i in range(len(sources)): | 209 for i in range(len(sources)): |
122 decompressed = dctx.decompress(compressed[i]) | 210 decompressed = dctx.decompress(compressed[i]) |
123 self.assertEqual(decompressed, sources[i]) | 211 self.assertEqual(decompressed, sources[i]) |
212 | |
213 def test_max_window_size(self): | |
214 with open(__file__, 'rb') as fh: | |
215 source = fh.read() | |
216 | |
217 # If we write a content size, the decompressor engages single pass | |
218 # mode and the window size doesn't come into play. | |
219 cctx = zstd.ZstdCompressor(write_content_size=False) | |
220 frame = cctx.compress(source) | |
221 | |
222 dctx = zstd.ZstdDecompressor(max_window_size=1) | |
223 | |
224 with self.assertRaisesRegexp( | |
225 zstd.ZstdError, 'decompression error: Frame requires too much memory'): | |
226 dctx.decompress(frame, max_output_size=len(source)) | |
124 | 227 |
125 | 228 |
126 @make_cffi | 229 @make_cffi |
127 class TestDecompressor_copy_stream(unittest.TestCase): | 230 class TestDecompressor_copy_stream(unittest.TestCase): |
128 def test_no_read(self): | 231 def test_no_read(self): |
184 self.assertEqual(source._read_count, len(source.getvalue()) + 1) | 287 self.assertEqual(source._read_count, len(source.getvalue()) + 1) |
185 self.assertEqual(dest._write_count, len(dest.getvalue())) | 288 self.assertEqual(dest._write_count, len(dest.getvalue())) |
186 | 289 |
187 | 290 |
188 @make_cffi | 291 @make_cffi |
292 class TestDecompressor_stream_reader(unittest.TestCase): | |
293 def test_context_manager(self): | |
294 dctx = zstd.ZstdDecompressor() | |
295 | |
296 reader = dctx.stream_reader(b'foo') | |
297 with self.assertRaisesRegexp(zstd.ZstdError, 'read\(\) must be called from an active'): | |
298 reader.read(1) | |
299 | |
300 with dctx.stream_reader(b'foo') as reader: | |
301 with self.assertRaisesRegexp(ValueError, 'cannot __enter__ multiple times'): | |
302 with reader as reader2: | |
303 pass | |
304 | |
305 def test_not_implemented(self): | |
306 dctx = zstd.ZstdDecompressor() | |
307 | |
308 with dctx.stream_reader(b'foo') as reader: | |
309 with self.assertRaises(NotImplementedError): | |
310 reader.readline() | |
311 | |
312 with self.assertRaises(NotImplementedError): | |
313 reader.readlines() | |
314 | |
315 with self.assertRaises(NotImplementedError): | |
316 reader.readall() | |
317 | |
318 with self.assertRaises(NotImplementedError): | |
319 iter(reader) | |
320 | |
321 with self.assertRaises(NotImplementedError): | |
322 next(reader) | |
323 | |
324 with self.assertRaises(io.UnsupportedOperation): | |
325 reader.write(b'foo') | |
326 | |
327 with self.assertRaises(io.UnsupportedOperation): | |
328 reader.writelines([]) | |
329 | |
330 def test_constant_methods(self): | |
331 dctx = zstd.ZstdDecompressor() | |
332 | |
333 with dctx.stream_reader(b'foo') as reader: | |
334 self.assertTrue(reader.readable()) | |
335 self.assertFalse(reader.writable()) | |
336 self.assertTrue(reader.seekable()) | |
337 self.assertFalse(reader.isatty()) | |
338 self.assertIsNone(reader.flush()) | |
339 | |
340 def test_read_closed(self): | |
341 dctx = zstd.ZstdDecompressor() | |
342 | |
343 with dctx.stream_reader(b'foo') as reader: | |
344 reader.close() | |
345 with self.assertRaisesRegexp(ValueError, 'stream is closed'): | |
346 reader.read(1) | |
347 | |
348 def test_bad_read_size(self): | |
349 dctx = zstd.ZstdDecompressor() | |
350 | |
351 with dctx.stream_reader(b'foo') as reader: | |
352 with self.assertRaisesRegexp(ValueError, 'cannot read negative or size 0 amounts'): | |
353 reader.read(-1) | |
354 | |
355 with self.assertRaisesRegexp(ValueError, 'cannot read negative or size 0 amounts'): | |
356 reader.read(0) | |
357 | |
358 def test_read_buffer(self): | |
359 cctx = zstd.ZstdCompressor() | |
360 | |
361 source = b''.join([b'foo' * 60, b'bar' * 60, b'baz' * 60]) | |
362 frame = cctx.compress(source) | |
363 | |
364 dctx = zstd.ZstdDecompressor() | |
365 | |
366 with dctx.stream_reader(frame) as reader: | |
367 self.assertEqual(reader.tell(), 0) | |
368 | |
369 # We should get entire frame in one read. | |
370 result = reader.read(8192) | |
371 self.assertEqual(result, source) | |
372 self.assertEqual(reader.tell(), len(source)) | |
373 | |
374 # Read after EOF should return empty bytes. | |
375 self.assertEqual(reader.read(), b'') | |
376 self.assertEqual(reader.tell(), len(result)) | |
377 | |
378 self.assertTrue(reader.closed()) | |
379 | |
380 def test_read_buffer_small_chunks(self): | |
381 cctx = zstd.ZstdCompressor() | |
382 source = b''.join([b'foo' * 60, b'bar' * 60, b'baz' * 60]) | |
383 frame = cctx.compress(source) | |
384 | |
385 dctx = zstd.ZstdDecompressor() | |
386 chunks = [] | |
387 | |
388 with dctx.stream_reader(frame, read_size=1) as reader: | |
389 while True: | |
390 chunk = reader.read(1) | |
391 if not chunk: | |
392 break | |
393 | |
394 chunks.append(chunk) | |
395 self.assertEqual(reader.tell(), sum(map(len, chunks))) | |
396 | |
397 self.assertEqual(b''.join(chunks), source) | |
398 | |
399 def test_read_stream(self): | |
400 cctx = zstd.ZstdCompressor() | |
401 source = b''.join([b'foo' * 60, b'bar' * 60, b'baz' * 60]) | |
402 frame = cctx.compress(source) | |
403 | |
404 dctx = zstd.ZstdDecompressor() | |
405 with dctx.stream_reader(io.BytesIO(frame)) as reader: | |
406 self.assertEqual(reader.tell(), 0) | |
407 | |
408 chunk = reader.read(8192) | |
409 self.assertEqual(chunk, source) | |
410 self.assertEqual(reader.tell(), len(source)) | |
411 self.assertEqual(reader.read(), b'') | |
412 self.assertEqual(reader.tell(), len(source)) | |
413 | |
414 def test_read_stream_small_chunks(self): | |
415 cctx = zstd.ZstdCompressor() | |
416 source = b''.join([b'foo' * 60, b'bar' * 60, b'baz' * 60]) | |
417 frame = cctx.compress(source) | |
418 | |
419 dctx = zstd.ZstdDecompressor() | |
420 chunks = [] | |
421 | |
422 with dctx.stream_reader(io.BytesIO(frame), read_size=1) as reader: | |
423 while True: | |
424 chunk = reader.read(1) | |
425 if not chunk: | |
426 break | |
427 | |
428 chunks.append(chunk) | |
429 self.assertEqual(reader.tell(), sum(map(len, chunks))) | |
430 | |
431 self.assertEqual(b''.join(chunks), source) | |
432 | |
433 def test_read_after_exit(self): | |
434 cctx = zstd.ZstdCompressor() | |
435 frame = cctx.compress(b'foo' * 60) | |
436 | |
437 dctx = zstd.ZstdDecompressor() | |
438 | |
439 with dctx.stream_reader(frame) as reader: | |
440 while reader.read(16): | |
441 pass | |
442 | |
443 with self.assertRaisesRegexp(zstd.ZstdError, 'read\(\) must be called from an active'): | |
444 reader.read(10) | |
445 | |
446 def test_illegal_seeks(self): | |
447 cctx = zstd.ZstdCompressor() | |
448 frame = cctx.compress(b'foo' * 60) | |
449 | |
450 dctx = zstd.ZstdDecompressor() | |
451 | |
452 with dctx.stream_reader(frame) as reader: | |
453 with self.assertRaisesRegexp(ValueError, | |
454 'cannot seek to negative position'): | |
455 reader.seek(-1, os.SEEK_SET) | |
456 | |
457 reader.read(1) | |
458 | |
459 with self.assertRaisesRegexp( | |
460 ValueError, 'cannot seek zstd decompression stream backwards'): | |
461 reader.seek(0, os.SEEK_SET) | |
462 | |
463 with self.assertRaisesRegexp( | |
464 ValueError, 'cannot seek zstd decompression stream backwards'): | |
465 reader.seek(-1, os.SEEK_CUR) | |
466 | |
467 with self.assertRaisesRegexp( | |
468 ValueError, | |
469 'zstd decompression streams cannot be seeked with SEEK_END'): | |
470 reader.seek(0, os.SEEK_END) | |
471 | |
472 reader.close() | |
473 | |
474 with self.assertRaisesRegexp(ValueError, 'stream is closed'): | |
475 reader.seek(4, os.SEEK_SET) | |
476 | |
477 with self.assertRaisesRegexp( | |
478 zstd.ZstdError, 'seek\(\) must be called from an active context'): | |
479 reader.seek(0) | |
480 | |
481 def test_seek(self): | |
482 source = b'foobar' * 60 | |
483 cctx = zstd.ZstdCompressor() | |
484 frame = cctx.compress(source) | |
485 | |
486 dctx = zstd.ZstdDecompressor() | |
487 | |
488 with dctx.stream_reader(frame) as reader: | |
489 reader.seek(3) | |
490 self.assertEqual(reader.read(3), b'bar') | |
491 | |
492 reader.seek(4, os.SEEK_CUR) | |
493 self.assertEqual(reader.read(2), b'ar') | |
494 | |
495 | |
496 @make_cffi | |
189 class TestDecompressor_decompressobj(unittest.TestCase): | 497 class TestDecompressor_decompressobj(unittest.TestCase): |
190 def test_simple(self): | 498 def test_simple(self): |
191 data = zstd.ZstdCompressor(level=1).compress(b'foobar') | 499 data = zstd.ZstdCompressor(level=1).compress(b'foobar') |
192 | 500 |
193 dctx = zstd.ZstdDecompressor() | 501 dctx = zstd.ZstdDecompressor() |
194 dobj = dctx.decompressobj() | 502 dobj = dctx.decompressobj() |
195 self.assertEqual(dobj.decompress(data), b'foobar') | 503 self.assertEqual(dobj.decompress(data), b'foobar') |
196 | 504 |
505 def test_input_types(self): | |
506 compressed = zstd.ZstdCompressor(level=1).compress(b'foo') | |
507 | |
508 dctx = zstd.ZstdDecompressor() | |
509 | |
510 mutable_array = bytearray(len(compressed)) | |
511 mutable_array[:] = compressed | |
512 | |
513 sources = [ | |
514 memoryview(compressed), | |
515 bytearray(compressed), | |
516 mutable_array, | |
517 ] | |
518 | |
519 for source in sources: | |
520 dobj = dctx.decompressobj() | |
521 self.assertEqual(dobj.decompress(source), b'foo') | |
522 | |
197 def test_reuse(self): | 523 def test_reuse(self): |
198 data = zstd.ZstdCompressor(level=1).compress(b'foobar') | 524 data = zstd.ZstdCompressor(level=1).compress(b'foobar') |
199 | 525 |
200 dctx = zstd.ZstdDecompressor() | 526 dctx = zstd.ZstdDecompressor() |
201 dobj = dctx.decompressobj() | 527 dobj = dctx.decompressobj() |
202 dobj.decompress(data) | 528 dobj.decompress(data) |
203 | 529 |
204 with self.assertRaisesRegexp(zstd.ZstdError, 'cannot use a decompressobj'): | 530 with self.assertRaisesRegexp(zstd.ZstdError, 'cannot use a decompressobj'): |
205 dobj.decompress(data) | 531 dobj.decompress(data) |
206 | 532 |
533 def test_bad_write_size(self): | |
534 dctx = zstd.ZstdDecompressor() | |
535 | |
536 with self.assertRaisesRegexp(ValueError, 'write_size must be positive'): | |
537 dctx.decompressobj(write_size=0) | |
538 | |
539 def test_write_size(self): | |
540 source = b'foo' * 64 + b'bar' * 128 | |
541 data = zstd.ZstdCompressor(level=1).compress(source) | |
542 | |
543 dctx = zstd.ZstdDecompressor() | |
544 | |
545 for i in range(128): | |
546 dobj = dctx.decompressobj(write_size=i + 1) | |
547 self.assertEqual(dobj.decompress(data), source) | |
207 | 548 |
208 def decompress_via_writer(data): | 549 def decompress_via_writer(data): |
209 buffer = io.BytesIO() | 550 buffer = io.BytesIO() |
210 dctx = zstd.ZstdDecompressor() | 551 dctx = zstd.ZstdDecompressor() |
211 with dctx.write_to(buffer) as decompressor: | 552 with dctx.stream_writer(buffer) as decompressor: |
212 decompressor.write(data) | 553 decompressor.write(data) |
213 return buffer.getvalue() | 554 return buffer.getvalue() |
214 | 555 |
215 | 556 |
216 @make_cffi | 557 @make_cffi |
217 class TestDecompressor_write_to(unittest.TestCase): | 558 class TestDecompressor_stream_writer(unittest.TestCase): |
218 def test_empty_roundtrip(self): | 559 def test_empty_roundtrip(self): |
219 cctx = zstd.ZstdCompressor() | 560 cctx = zstd.ZstdCompressor() |
220 empty = cctx.compress(b'') | 561 empty = cctx.compress(b'') |
221 self.assertEqual(decompress_via_writer(empty), b'') | 562 self.assertEqual(decompress_via_writer(empty), b'') |
563 | |
564 def test_input_types(self): | |
565 cctx = zstd.ZstdCompressor(level=1) | |
566 compressed = cctx.compress(b'foo') | |
567 | |
568 mutable_array = bytearray(len(compressed)) | |
569 mutable_array[:] = compressed | |
570 | |
571 sources = [ | |
572 memoryview(compressed), | |
573 bytearray(compressed), | |
574 mutable_array, | |
575 ] | |
576 | |
577 dctx = zstd.ZstdDecompressor() | |
578 for source in sources: | |
579 buffer = io.BytesIO() | |
580 with dctx.stream_writer(buffer) as decompressor: | |
581 decompressor.write(source) | |
582 | |
583 self.assertEqual(buffer.getvalue(), b'foo') | |
222 | 584 |
223 def test_large_roundtrip(self): | 585 def test_large_roundtrip(self): |
224 chunks = [] | 586 chunks = [] |
225 for i in range(255): | 587 for i in range(255): |
226 chunks.append(struct.Struct('>B').pack(i) * 16384) | 588 chunks.append(struct.Struct('>B').pack(i) * 16384) |
240 cctx = zstd.ZstdCompressor() | 602 cctx = zstd.ZstdCompressor() |
241 compressed = cctx.compress(orig) | 603 compressed = cctx.compress(orig) |
242 | 604 |
243 buffer = io.BytesIO() | 605 buffer = io.BytesIO() |
244 dctx = zstd.ZstdDecompressor() | 606 dctx = zstd.ZstdDecompressor() |
245 with dctx.write_to(buffer) as decompressor: | 607 with dctx.stream_writer(buffer) as decompressor: |
246 pos = 0 | 608 pos = 0 |
247 while pos < len(compressed): | 609 while pos < len(compressed): |
248 pos2 = pos + 8192 | 610 pos2 = pos + 8192 |
249 decompressor.write(compressed[pos:pos2]) | 611 decompressor.write(compressed[pos:pos2]) |
250 pos += 8192 | 612 pos += 8192 |
260 d = zstd.train_dictionary(8192, samples) | 622 d = zstd.train_dictionary(8192, samples) |
261 | 623 |
262 orig = b'foobar' * 16384 | 624 orig = b'foobar' * 16384 |
263 buffer = io.BytesIO() | 625 buffer = io.BytesIO() |
264 cctx = zstd.ZstdCompressor(dict_data=d) | 626 cctx = zstd.ZstdCompressor(dict_data=d) |
265 with cctx.write_to(buffer) as compressor: | 627 with cctx.stream_writer(buffer) as compressor: |
266 self.assertEqual(compressor.write(orig), 1544) | 628 self.assertEqual(compressor.write(orig), 0) |
267 | 629 |
268 compressed = buffer.getvalue() | 630 compressed = buffer.getvalue() |
269 buffer = io.BytesIO() | 631 buffer = io.BytesIO() |
270 | 632 |
271 dctx = zstd.ZstdDecompressor(dict_data=d) | 633 dctx = zstd.ZstdDecompressor(dict_data=d) |
272 with dctx.write_to(buffer) as decompressor: | 634 with dctx.stream_writer(buffer) as decompressor: |
273 self.assertEqual(decompressor.write(compressed), len(orig)) | 635 self.assertEqual(decompressor.write(compressed), len(orig)) |
274 | 636 |
275 self.assertEqual(buffer.getvalue(), orig) | 637 self.assertEqual(buffer.getvalue(), orig) |
276 | 638 |
277 def test_memory_size(self): | 639 def test_memory_size(self): |
278 dctx = zstd.ZstdDecompressor() | 640 dctx = zstd.ZstdDecompressor() |
279 buffer = io.BytesIO() | 641 buffer = io.BytesIO() |
280 with dctx.write_to(buffer) as decompressor: | 642 with dctx.stream_writer(buffer) as decompressor: |
281 size = decompressor.memory_size() | 643 size = decompressor.memory_size() |
282 | 644 |
283 self.assertGreater(size, 100000) | 645 self.assertGreater(size, 100000) |
284 | 646 |
285 def test_write_size(self): | 647 def test_write_size(self): |
286 source = zstd.ZstdCompressor().compress(b'foobarfoobar') | 648 source = zstd.ZstdCompressor().compress(b'foobarfoobar') |
287 dest = OpCountingBytesIO() | 649 dest = OpCountingBytesIO() |
288 dctx = zstd.ZstdDecompressor() | 650 dctx = zstd.ZstdDecompressor() |
289 with dctx.write_to(dest, write_size=1) as decompressor: | 651 with dctx.stream_writer(dest, write_size=1) as decompressor: |
290 s = struct.Struct('>B') | 652 s = struct.Struct('>B') |
291 for c in source: | 653 for c in source: |
292 if not isinstance(c, str): | 654 if not isinstance(c, str): |
293 c = s.pack(c) | 655 c = s.pack(c) |
294 decompressor.write(c) | 656 decompressor.write(c) |
296 self.assertEqual(dest.getvalue(), b'foobarfoobar') | 658 self.assertEqual(dest.getvalue(), b'foobarfoobar') |
297 self.assertEqual(dest._write_count, len(dest.getvalue())) | 659 self.assertEqual(dest._write_count, len(dest.getvalue())) |
298 | 660 |
299 | 661 |
300 @make_cffi | 662 @make_cffi |
301 class TestDecompressor_read_from(unittest.TestCase): | 663 class TestDecompressor_read_to_iter(unittest.TestCase): |
302 def test_type_validation(self): | 664 def test_type_validation(self): |
303 dctx = zstd.ZstdDecompressor() | 665 dctx = zstd.ZstdDecompressor() |
304 | 666 |
305 # Object with read() works. | 667 # Object with read() works. |
306 dctx.read_from(io.BytesIO()) | 668 dctx.read_to_iter(io.BytesIO()) |
307 | 669 |
308 # Buffer protocol works. | 670 # Buffer protocol works. |
309 dctx.read_from(b'foobar') | 671 dctx.read_to_iter(b'foobar') |
310 | 672 |
311 with self.assertRaisesRegexp(ValueError, 'must pass an object with a read'): | 673 with self.assertRaisesRegexp(ValueError, 'must pass an object with a read'): |
312 b''.join(dctx.read_from(True)) | 674 b''.join(dctx.read_to_iter(True)) |
313 | 675 |
314 def test_empty_input(self): | 676 def test_empty_input(self): |
315 dctx = zstd.ZstdDecompressor() | 677 dctx = zstd.ZstdDecompressor() |
316 | 678 |
317 source = io.BytesIO() | 679 source = io.BytesIO() |
318 it = dctx.read_from(source) | 680 it = dctx.read_to_iter(source) |
319 # TODO this is arguably wrong. Should get an error about missing frame foo. | 681 # TODO this is arguably wrong. Should get an error about missing frame foo. |
320 with self.assertRaises(StopIteration): | 682 with self.assertRaises(StopIteration): |
321 next(it) | 683 next(it) |
322 | 684 |
323 it = dctx.read_from(b'') | 685 it = dctx.read_to_iter(b'') |
324 with self.assertRaises(StopIteration): | 686 with self.assertRaises(StopIteration): |
325 next(it) | 687 next(it) |
326 | 688 |
327 def test_invalid_input(self): | 689 def test_invalid_input(self): |
328 dctx = zstd.ZstdDecompressor() | 690 dctx = zstd.ZstdDecompressor() |
329 | 691 |
330 source = io.BytesIO(b'foobar') | 692 source = io.BytesIO(b'foobar') |
331 it = dctx.read_from(source) | 693 it = dctx.read_to_iter(source) |
332 with self.assertRaisesRegexp(zstd.ZstdError, 'Unknown frame descriptor'): | 694 with self.assertRaisesRegexp(zstd.ZstdError, 'Unknown frame descriptor'): |
333 next(it) | 695 next(it) |
334 | 696 |
335 it = dctx.read_from(b'foobar') | 697 it = dctx.read_to_iter(b'foobar') |
336 with self.assertRaisesRegexp(zstd.ZstdError, 'Unknown frame descriptor'): | 698 with self.assertRaisesRegexp(zstd.ZstdError, 'Unknown frame descriptor'): |
337 next(it) | 699 next(it) |
338 | 700 |
339 def test_empty_roundtrip(self): | 701 def test_empty_roundtrip(self): |
340 cctx = zstd.ZstdCompressor(level=1, write_content_size=False) | 702 cctx = zstd.ZstdCompressor(level=1, write_content_size=False) |
342 | 704 |
343 source = io.BytesIO(empty) | 705 source = io.BytesIO(empty) |
344 source.seek(0) | 706 source.seek(0) |
345 | 707 |
346 dctx = zstd.ZstdDecompressor() | 708 dctx = zstd.ZstdDecompressor() |
347 it = dctx.read_from(source) | 709 it = dctx.read_to_iter(source) |
348 | 710 |
349 # No chunks should be emitted since there is no data. | 711 # No chunks should be emitted since there is no data. |
350 with self.assertRaises(StopIteration): | 712 with self.assertRaises(StopIteration): |
351 next(it) | 713 next(it) |
352 | 714 |
356 | 718 |
357 def test_skip_bytes_too_large(self): | 719 def test_skip_bytes_too_large(self): |
358 dctx = zstd.ZstdDecompressor() | 720 dctx = zstd.ZstdDecompressor() |
359 | 721 |
360 with self.assertRaisesRegexp(ValueError, 'skip_bytes must be smaller than read_size'): | 722 with self.assertRaisesRegexp(ValueError, 'skip_bytes must be smaller than read_size'): |
361 b''.join(dctx.read_from(b'', skip_bytes=1, read_size=1)) | 723 b''.join(dctx.read_to_iter(b'', skip_bytes=1, read_size=1)) |
362 | 724 |
363 with self.assertRaisesRegexp(ValueError, 'skip_bytes larger than first input chunk'): | 725 with self.assertRaisesRegexp(ValueError, 'skip_bytes larger than first input chunk'): |
364 b''.join(dctx.read_from(b'foobar', skip_bytes=10)) | 726 b''.join(dctx.read_to_iter(b'foobar', skip_bytes=10)) |
365 | 727 |
366 def test_skip_bytes(self): | 728 def test_skip_bytes(self): |
367 cctx = zstd.ZstdCompressor(write_content_size=False) | 729 cctx = zstd.ZstdCompressor(write_content_size=False) |
368 compressed = cctx.compress(b'foobar') | 730 compressed = cctx.compress(b'foobar') |
369 | 731 |
370 dctx = zstd.ZstdDecompressor() | 732 dctx = zstd.ZstdDecompressor() |
371 output = b''.join(dctx.read_from(b'hdr' + compressed, skip_bytes=3)) | 733 output = b''.join(dctx.read_to_iter(b'hdr' + compressed, skip_bytes=3)) |
372 self.assertEqual(output, b'foobar') | 734 self.assertEqual(output, b'foobar') |
373 | 735 |
374 def test_large_output(self): | 736 def test_large_output(self): |
375 source = io.BytesIO() | 737 source = io.BytesIO() |
376 source.write(b'f' * zstd.DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE) | 738 source.write(b'f' * zstd.DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE) |
380 cctx = zstd.ZstdCompressor(level=1) | 742 cctx = zstd.ZstdCompressor(level=1) |
381 compressed = io.BytesIO(cctx.compress(source.getvalue())) | 743 compressed = io.BytesIO(cctx.compress(source.getvalue())) |
382 compressed.seek(0) | 744 compressed.seek(0) |
383 | 745 |
384 dctx = zstd.ZstdDecompressor() | 746 dctx = zstd.ZstdDecompressor() |
385 it = dctx.read_from(compressed) | 747 it = dctx.read_to_iter(compressed) |
386 | 748 |
387 chunks = [] | 749 chunks = [] |
388 chunks.append(next(it)) | 750 chunks.append(next(it)) |
389 chunks.append(next(it)) | 751 chunks.append(next(it)) |
390 | 752 |
393 | 755 |
394 decompressed = b''.join(chunks) | 756 decompressed = b''.join(chunks) |
395 self.assertEqual(decompressed, source.getvalue()) | 757 self.assertEqual(decompressed, source.getvalue()) |
396 | 758 |
397 # And again with buffer protocol. | 759 # And again with buffer protocol. |
398 it = dctx.read_from(compressed.getvalue()) | 760 it = dctx.read_to_iter(compressed.getvalue()) |
399 chunks = [] | 761 chunks = [] |
400 chunks.append(next(it)) | 762 chunks.append(next(it)) |
401 chunks.append(next(it)) | 763 chunks.append(next(it)) |
402 | 764 |
403 with self.assertRaises(StopIteration): | 765 with self.assertRaises(StopIteration): |
404 next(it) | 766 next(it) |
405 | 767 |
406 decompressed = b''.join(chunks) | 768 decompressed = b''.join(chunks) |
407 self.assertEqual(decompressed, source.getvalue()) | 769 self.assertEqual(decompressed, source.getvalue()) |
408 | 770 |
771 @unittest.skipUnless('ZSTD_SLOW_TESTS' in os.environ, 'ZSTD_SLOW_TESTS not set') | |
409 def test_large_input(self): | 772 def test_large_input(self): |
410 bytes = list(struct.Struct('>B').pack(i) for i in range(256)) | 773 bytes = list(struct.Struct('>B').pack(i) for i in range(256)) |
411 compressed = io.BytesIO() | 774 compressed = io.BytesIO() |
412 input_size = 0 | 775 input_size = 0 |
413 cctx = zstd.ZstdCompressor(level=1) | 776 cctx = zstd.ZstdCompressor(level=1) |
414 with cctx.write_to(compressed) as compressor: | 777 with cctx.stream_writer(compressed) as compressor: |
415 while True: | 778 while True: |
416 compressor.write(random.choice(bytes)) | 779 compressor.write(random.choice(bytes)) |
417 input_size += 1 | 780 input_size += 1 |
418 | 781 |
419 have_compressed = len(compressed.getvalue()) > zstd.DECOMPRESSION_RECOMMENDED_INPUT_SIZE | 782 have_compressed = len(compressed.getvalue()) > zstd.DECOMPRESSION_RECOMMENDED_INPUT_SIZE |
424 compressed.seek(0) | 787 compressed.seek(0) |
425 self.assertGreater(len(compressed.getvalue()), | 788 self.assertGreater(len(compressed.getvalue()), |
426 zstd.DECOMPRESSION_RECOMMENDED_INPUT_SIZE) | 789 zstd.DECOMPRESSION_RECOMMENDED_INPUT_SIZE) |
427 | 790 |
428 dctx = zstd.ZstdDecompressor() | 791 dctx = zstd.ZstdDecompressor() |
429 it = dctx.read_from(compressed) | 792 it = dctx.read_to_iter(compressed) |
430 | 793 |
431 chunks = [] | 794 chunks = [] |
432 chunks.append(next(it)) | 795 chunks.append(next(it)) |
433 chunks.append(next(it)) | 796 chunks.append(next(it)) |
434 chunks.append(next(it)) | 797 chunks.append(next(it)) |
438 | 801 |
439 decompressed = b''.join(chunks) | 802 decompressed = b''.join(chunks) |
440 self.assertEqual(len(decompressed), input_size) | 803 self.assertEqual(len(decompressed), input_size) |
441 | 804 |
442 # And again with buffer protocol. | 805 # And again with buffer protocol. |
443 it = dctx.read_from(compressed.getvalue()) | 806 it = dctx.read_to_iter(compressed.getvalue()) |
444 | 807 |
445 chunks = [] | 808 chunks = [] |
446 chunks.append(next(it)) | 809 chunks.append(next(it)) |
447 chunks.append(next(it)) | 810 chunks.append(next(it)) |
448 chunks.append(next(it)) | 811 chunks.append(next(it)) |
458 cctx = zstd.ZstdCompressor(level=1) | 821 cctx = zstd.ZstdCompressor(level=1) |
459 | 822 |
460 source = io.BytesIO() | 823 source = io.BytesIO() |
461 | 824 |
462 compressed = io.BytesIO() | 825 compressed = io.BytesIO() |
463 with cctx.write_to(compressed) as compressor: | 826 with cctx.stream_writer(compressed) as compressor: |
464 for i in range(256): | 827 for i in range(256): |
465 chunk = b'\0' * 1024 | 828 chunk = b'\0' * 1024 |
466 compressor.write(chunk) | 829 compressor.write(chunk) |
467 source.write(chunk) | 830 source.write(chunk) |
468 | 831 |
471 simple = dctx.decompress(compressed.getvalue(), | 834 simple = dctx.decompress(compressed.getvalue(), |
472 max_output_size=len(source.getvalue())) | 835 max_output_size=len(source.getvalue())) |
473 self.assertEqual(simple, source.getvalue()) | 836 self.assertEqual(simple, source.getvalue()) |
474 | 837 |
475 compressed.seek(0) | 838 compressed.seek(0) |
476 streamed = b''.join(dctx.read_from(compressed)) | 839 streamed = b''.join(dctx.read_to_iter(compressed)) |
477 self.assertEqual(streamed, source.getvalue()) | 840 self.assertEqual(streamed, source.getvalue()) |
478 | 841 |
479 def test_read_write_size(self): | 842 def test_read_write_size(self): |
480 source = OpCountingBytesIO(zstd.ZstdCompressor().compress(b'foobarfoobar')) | 843 source = OpCountingBytesIO(zstd.ZstdCompressor().compress(b'foobarfoobar')) |
481 dctx = zstd.ZstdDecompressor() | 844 dctx = zstd.ZstdDecompressor() |
482 for chunk in dctx.read_from(source, read_size=1, write_size=1): | 845 for chunk in dctx.read_to_iter(source, read_size=1, write_size=1): |
483 self.assertEqual(len(chunk), 1) | 846 self.assertEqual(len(chunk), 1) |
484 | 847 |
485 self.assertEqual(source._read_count, len(source.getvalue())) | 848 self.assertEqual(source._read_count, len(source.getvalue())) |
849 | |
850 def test_magic_less(self): | |
851 params = zstd.CompressionParameters.from_level( | |
852 1, format=zstd.FORMAT_ZSTD1_MAGICLESS) | |
853 cctx = zstd.ZstdCompressor(compression_params=params) | |
854 frame = cctx.compress(b'foobar') | |
855 | |
856 self.assertNotEqual(frame[0:4], b'\x28\xb5\x2f\xfd') | |
857 | |
858 dctx = zstd.ZstdDecompressor() | |
859 with self.assertRaisesRegexp( | |
860 zstd.ZstdError, 'error determining content size from frame header'): | |
861 dctx.decompress(frame) | |
862 | |
863 dctx = zstd.ZstdDecompressor(format=zstd.FORMAT_ZSTD1_MAGICLESS) | |
864 res = b''.join(dctx.read_to_iter(frame)) | |
865 self.assertEqual(res, b'foobar') | |
486 | 866 |
487 | 867 |
488 @make_cffi | 868 @make_cffi |
489 class TestDecompressor_content_dict_chain(unittest.TestCase): | 869 class TestDecompressor_content_dict_chain(unittest.TestCase): |
490 def test_bad_inputs_simple(self): | 870 def test_bad_inputs_simple(self): |
509 dctx.decompress_content_dict_chain([zstd.FRAME_HEADER]) | 889 dctx.decompress_content_dict_chain([zstd.FRAME_HEADER]) |
510 | 890 |
511 with self.assertRaisesRegexp(ValueError, 'chunk 0 is not a valid zstd frame'): | 891 with self.assertRaisesRegexp(ValueError, 'chunk 0 is not a valid zstd frame'): |
512 dctx.decompress_content_dict_chain([b'foo' * 8]) | 892 dctx.decompress_content_dict_chain([b'foo' * 8]) |
513 | 893 |
514 no_size = zstd.ZstdCompressor().compress(b'foo' * 64) | 894 no_size = zstd.ZstdCompressor(write_content_size=False).compress(b'foo' * 64) |
515 | 895 |
516 with self.assertRaisesRegexp(ValueError, 'chunk 0 missing content size in frame'): | 896 with self.assertRaisesRegexp(ValueError, 'chunk 0 missing content size in frame'): |
517 dctx.decompress_content_dict_chain([no_size]) | 897 dctx.decompress_content_dict_chain([no_size]) |
518 | 898 |
519 # Corrupt first frame. | 899 # Corrupt first frame. |
520 frame = zstd.ZstdCompressor(write_content_size=True).compress(b'foo' * 64) | 900 frame = zstd.ZstdCompressor().compress(b'foo' * 64) |
521 frame = frame[0:12] + frame[15:] | 901 frame = frame[0:12] + frame[15:] |
522 with self.assertRaisesRegexp(zstd.ZstdError, 'could not decompress chunk 0'): | 902 with self.assertRaisesRegexp(zstd.ZstdError, |
903 'chunk 0 did not decompress full frame'): | |
523 dctx.decompress_content_dict_chain([frame]) | 904 dctx.decompress_content_dict_chain([frame]) |
524 | 905 |
525 def test_bad_subsequent_input(self): | 906 def test_bad_subsequent_input(self): |
526 initial = zstd.ZstdCompressor(write_content_size=True).compress(b'foo' * 64) | 907 initial = zstd.ZstdCompressor().compress(b'foo' * 64) |
527 | 908 |
528 dctx = zstd.ZstdDecompressor() | 909 dctx = zstd.ZstdDecompressor() |
529 | 910 |
530 with self.assertRaisesRegexp(ValueError, 'chunk 1 must be bytes'): | 911 with self.assertRaisesRegexp(ValueError, 'chunk 1 must be bytes'): |
531 dctx.decompress_content_dict_chain([initial, u'foo']) | 912 dctx.decompress_content_dict_chain([initial, u'foo']) |
537 dctx.decompress_content_dict_chain([initial, zstd.FRAME_HEADER]) | 918 dctx.decompress_content_dict_chain([initial, zstd.FRAME_HEADER]) |
538 | 919 |
539 with self.assertRaisesRegexp(ValueError, 'chunk 1 is not a valid zstd frame'): | 920 with self.assertRaisesRegexp(ValueError, 'chunk 1 is not a valid zstd frame'): |
540 dctx.decompress_content_dict_chain([initial, b'foo' * 8]) | 921 dctx.decompress_content_dict_chain([initial, b'foo' * 8]) |
541 | 922 |
542 no_size = zstd.ZstdCompressor().compress(b'foo' * 64) | 923 no_size = zstd.ZstdCompressor(write_content_size=False).compress(b'foo' * 64) |
543 | 924 |
544 with self.assertRaisesRegexp(ValueError, 'chunk 1 missing content size in frame'): | 925 with self.assertRaisesRegexp(ValueError, 'chunk 1 missing content size in frame'): |
545 dctx.decompress_content_dict_chain([initial, no_size]) | 926 dctx.decompress_content_dict_chain([initial, no_size]) |
546 | 927 |
547 # Corrupt second frame. | 928 # Corrupt second frame. |
548 cctx = zstd.ZstdCompressor(write_content_size=True, dict_data=zstd.ZstdCompressionDict(b'foo' * 64)) | 929 cctx = zstd.ZstdCompressor(dict_data=zstd.ZstdCompressionDict(b'foo' * 64)) |
549 frame = cctx.compress(b'bar' * 64) | 930 frame = cctx.compress(b'bar' * 64) |
550 frame = frame[0:12] + frame[15:] | 931 frame = frame[0:12] + frame[15:] |
551 | 932 |
552 with self.assertRaisesRegexp(zstd.ZstdError, 'could not decompress chunk 1'): | 933 with self.assertRaisesRegexp(zstd.ZstdError, 'chunk 1 did not decompress full frame'): |
553 dctx.decompress_content_dict_chain([initial, frame]) | 934 dctx.decompress_content_dict_chain([initial, frame]) |
554 | 935 |
555 def test_simple(self): | 936 def test_simple(self): |
556 original = [ | 937 original = [ |
557 b'foo' * 64, | 938 b'foo' * 64, |
560 b'foobaz' * 64, | 941 b'foobaz' * 64, |
561 b'foobarbaz' * 64, | 942 b'foobarbaz' * 64, |
562 ] | 943 ] |
563 | 944 |
564 chunks = [] | 945 chunks = [] |
565 chunks.append(zstd.ZstdCompressor(write_content_size=True).compress(original[0])) | 946 chunks.append(zstd.ZstdCompressor().compress(original[0])) |
566 for i, chunk in enumerate(original[1:]): | 947 for i, chunk in enumerate(original[1:]): |
567 d = zstd.ZstdCompressionDict(original[i]) | 948 d = zstd.ZstdCompressionDict(original[i]) |
568 cctx = zstd.ZstdCompressor(dict_data=d, write_content_size=True) | 949 cctx = zstd.ZstdCompressor(dict_data=d) |
569 chunks.append(cctx.compress(chunk)) | 950 chunks.append(cctx.compress(chunk)) |
570 | 951 |
571 for i in range(1, len(original)): | 952 for i in range(1, len(original)): |
572 chain = chunks[0:i] | 953 chain = chunks[0:i] |
573 expected = original[i - 1] | 954 expected = original[i - 1] |
592 | 973 |
593 with self.assertRaisesRegexp(ValueError, 'could not determine decompressed size of item 0'): | 974 with self.assertRaisesRegexp(ValueError, 'could not determine decompressed size of item 0'): |
594 dctx.multi_decompress_to_buffer([b'foobarbaz']) | 975 dctx.multi_decompress_to_buffer([b'foobarbaz']) |
595 | 976 |
596 def test_list_input(self): | 977 def test_list_input(self): |
597 cctx = zstd.ZstdCompressor(write_content_size=True) | 978 cctx = zstd.ZstdCompressor() |
598 | 979 |
599 original = [b'foo' * 4, b'bar' * 6] | 980 original = [b'foo' * 4, b'bar' * 6] |
600 frames = [cctx.compress(d) for d in original] | 981 frames = [cctx.compress(d) for d in original] |
601 | 982 |
602 dctx = zstd.ZstdDecompressor() | 983 dctx = zstd.ZstdDecompressor() |
612 self.assertEqual(len(result[0]), 12) | 993 self.assertEqual(len(result[0]), 12) |
613 self.assertEqual(result[1].offset, 12) | 994 self.assertEqual(result[1].offset, 12) |
614 self.assertEqual(len(result[1]), 18) | 995 self.assertEqual(len(result[1]), 18) |
615 | 996 |
616 def test_list_input_frame_sizes(self): | 997 def test_list_input_frame_sizes(self): |
617 cctx = zstd.ZstdCompressor(write_content_size=False) | 998 cctx = zstd.ZstdCompressor() |
618 | 999 |
619 original = [b'foo' * 4, b'bar' * 6, b'baz' * 8] | 1000 original = [b'foo' * 4, b'bar' * 6, b'baz' * 8] |
620 frames = [cctx.compress(d) for d in original] | 1001 frames = [cctx.compress(d) for d in original] |
621 sizes = struct.pack('=' + 'Q' * len(original), *map(len, original)) | 1002 sizes = struct.pack('=' + 'Q' * len(original), *map(len, original)) |
622 | 1003 |
628 | 1009 |
629 for i, data in enumerate(original): | 1010 for i, data in enumerate(original): |
630 self.assertEqual(result[i].tobytes(), data) | 1011 self.assertEqual(result[i].tobytes(), data) |
631 | 1012 |
632 def test_buffer_with_segments_input(self): | 1013 def test_buffer_with_segments_input(self): |
633 cctx = zstd.ZstdCompressor(write_content_size=True) | 1014 cctx = zstd.ZstdCompressor() |
634 | 1015 |
635 original = [b'foo' * 4, b'bar' * 6] | 1016 original = [b'foo' * 4, b'bar' * 6] |
636 frames = [cctx.compress(d) for d in original] | 1017 frames = [cctx.compress(d) for d in original] |
637 | 1018 |
638 dctx = zstd.ZstdDecompressor() | 1019 dctx = zstd.ZstdDecompressor() |
667 | 1048 |
668 for i, data in enumerate(original): | 1049 for i, data in enumerate(original): |
669 self.assertEqual(result[i].tobytes(), data) | 1050 self.assertEqual(result[i].tobytes(), data) |
670 | 1051 |
671 def test_buffer_with_segments_collection_input(self): | 1052 def test_buffer_with_segments_collection_input(self): |
672 cctx = zstd.ZstdCompressor(write_content_size=True) | 1053 cctx = zstd.ZstdCompressor() |
673 | 1054 |
674 original = [ | 1055 original = [ |
675 b'foo0' * 2, | 1056 b'foo0' * 2, |
676 b'foo1' * 3, | 1057 b'foo1' * 3, |
677 b'foo2' * 4, | 1058 b'foo2' * 4, |
709 | 1090 |
710 self.assertEqual(len(decompressed), 5) | 1091 self.assertEqual(len(decompressed), 5) |
711 for i in range(5): | 1092 for i in range(5): |
712 self.assertEqual(decompressed[i].tobytes(), original[i]) | 1093 self.assertEqual(decompressed[i].tobytes(), original[i]) |
713 | 1094 |
1095 def test_dict(self): | |
1096 d = zstd.train_dictionary(16384, generate_samples(), k=64, d=16) | |
1097 | |
1098 cctx = zstd.ZstdCompressor(dict_data=d, level=1) | |
1099 frames = [cctx.compress(s) for s in generate_samples()] | |
1100 | |
1101 dctx = zstd.ZstdDecompressor(dict_data=d) | |
1102 result = dctx.multi_decompress_to_buffer(frames) | |
1103 self.assertEqual([o.tobytes() for o in result], generate_samples()) | |
1104 | |
714 def test_multiple_threads(self): | 1105 def test_multiple_threads(self): |
715 cctx = zstd.ZstdCompressor(write_content_size=True) | 1106 cctx = zstd.ZstdCompressor() |
716 | 1107 |
717 frames = [] | 1108 frames = [] |
718 frames.extend(cctx.compress(b'x' * 64) for i in range(256)) | 1109 frames.extend(cctx.compress(b'x' * 64) for i in range(256)) |
719 frames.extend(cctx.compress(b'y' * 64) for i in range(256)) | 1110 frames.extend(cctx.compress(b'y' * 64) for i in range(256)) |
720 | 1111 |
725 self.assertEqual(result.size(), 2 * 64 * 256) | 1116 self.assertEqual(result.size(), 2 * 64 * 256) |
726 self.assertEqual(result[0].tobytes(), b'x' * 64) | 1117 self.assertEqual(result[0].tobytes(), b'x' * 64) |
727 self.assertEqual(result[256].tobytes(), b'y' * 64) | 1118 self.assertEqual(result[256].tobytes(), b'y' * 64) |
728 | 1119 |
729 def test_item_failure(self): | 1120 def test_item_failure(self): |
730 cctx = zstd.ZstdCompressor(write_content_size=True) | 1121 cctx = zstd.ZstdCompressor() |
731 frames = [cctx.compress(b'x' * 128), cctx.compress(b'y' * 128)] | 1122 frames = [cctx.compress(b'x' * 128), cctx.compress(b'y' * 128)] |
732 | 1123 |
733 frames[1] = frames[1] + b'extra' | 1124 frames[1] = frames[1][0:15] + b'extra' + frames[1][15:] |
734 | 1125 |
735 dctx = zstd.ZstdDecompressor() | 1126 dctx = zstd.ZstdDecompressor() |
736 | 1127 |
737 with self.assertRaisesRegexp(zstd.ZstdError, 'error decompressing item 1: Src size incorrect'): | 1128 with self.assertRaisesRegexp(zstd.ZstdError, |
1129 'error decompressing item 1: (' | |
1130 'Corrupted block|' | |
1131 'Destination buffer is too small)'): | |
738 dctx.multi_decompress_to_buffer(frames) | 1132 dctx.multi_decompress_to_buffer(frames) |
739 | 1133 |
740 with self.assertRaisesRegexp(zstd.ZstdError, 'error decompressing item 1: Src size incorrect'): | 1134 with self.assertRaisesRegexp(zstd.ZstdError, |
1135 'error decompressing item 1: (' | |
1136 'Corrupted block|' | |
1137 'Destination buffer is too small)'): | |
741 dctx.multi_decompress_to_buffer(frames, threads=2) | 1138 dctx.multi_decompress_to_buffer(frames, threads=2) |
1139 |