comparison contrib/python-zstandard/zstd_cffi.py @ 40121:73fef626dae3

zstandard: vendor python-zstandard 0.10.1 This was just released. The upstream source distribution from PyPI was extracted. Unwanted files were removed. The clang-format ignore list was updated to reflect the new source of files. setup.py was updated to pass a new argument to python-zstandard's function for returning an Extension instance. Upstream had to change to use relative paths because Python 3.7's packaging doesn't seem to like absolute paths when defining sources, includes, etc. The default relative path calculation is relative to setup_zstd.py which is different from the directory of Mercurial's setup.py. The project contains a vendored copy of zstandard 1.3.6. The old version was 1.3.4. The API should be backwards compatible and nothing in core should need adjusted. However, there is a new "chunker" API that we may find useful in places where we want to emit compressed chunks of a fixed size. There are a pair of bug fixes in 0.10.0 with regards to compressobj() and decompressobj() when block flushing is used. I actually found these bugs when introducing these APIs in Mercurial! But existing Mercurial code is not affected because we don't perform block flushing. # no-check-commit because 3rd party code has different style guidelines Differential Revision: https://phab.mercurial-scm.org/D4911
author Gregory Szorc <gregory.szorc@gmail.com>
date Mon, 08 Oct 2018 16:27:40 -0700
parents b1fb341d8a61
children
comparison
equal deleted inserted replaced
40120:89742f1fa6cb 40121:73fef626dae3
38 'COMPRESSION_RECOMMENDED_INPUT_SIZE', 38 'COMPRESSION_RECOMMENDED_INPUT_SIZE',
39 'COMPRESSION_RECOMMENDED_OUTPUT_SIZE', 39 'COMPRESSION_RECOMMENDED_OUTPUT_SIZE',
40 'DECOMPRESSION_RECOMMENDED_INPUT_SIZE', 40 'DECOMPRESSION_RECOMMENDED_INPUT_SIZE',
41 'DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE', 41 'DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE',
42 'MAGIC_NUMBER', 42 'MAGIC_NUMBER',
43 'BLOCKSIZELOG_MAX',
44 'BLOCKSIZE_MAX',
43 'WINDOWLOG_MIN', 45 'WINDOWLOG_MIN',
44 'WINDOWLOG_MAX', 46 'WINDOWLOG_MAX',
45 'CHAINLOG_MIN', 47 'CHAINLOG_MIN',
46 'CHAINLOG_MAX', 48 'CHAINLOG_MAX',
47 'HASHLOG_MIN', 49 'HASHLOG_MIN',
50 'SEARCHLOG_MIN', 52 'SEARCHLOG_MIN',
51 'SEARCHLOG_MAX', 53 'SEARCHLOG_MAX',
52 'SEARCHLENGTH_MIN', 54 'SEARCHLENGTH_MIN',
53 'SEARCHLENGTH_MAX', 55 'SEARCHLENGTH_MAX',
54 'TARGETLENGTH_MIN', 56 'TARGETLENGTH_MIN',
57 'TARGETLENGTH_MAX',
55 'LDM_MINMATCH_MIN', 58 'LDM_MINMATCH_MIN',
56 'LDM_MINMATCH_MAX', 59 'LDM_MINMATCH_MAX',
57 'LDM_BUCKETSIZELOG_MAX', 60 'LDM_BUCKETSIZELOG_MAX',
58 'STRATEGY_FAST', 61 'STRATEGY_FAST',
59 'STRATEGY_DFAST', 62 'STRATEGY_DFAST',
100 FRAME_HEADER = b'\x28\xb5\x2f\xfd' 103 FRAME_HEADER = b'\x28\xb5\x2f\xfd'
101 CONTENTSIZE_UNKNOWN = lib.ZSTD_CONTENTSIZE_UNKNOWN 104 CONTENTSIZE_UNKNOWN = lib.ZSTD_CONTENTSIZE_UNKNOWN
102 CONTENTSIZE_ERROR = lib.ZSTD_CONTENTSIZE_ERROR 105 CONTENTSIZE_ERROR = lib.ZSTD_CONTENTSIZE_ERROR
103 ZSTD_VERSION = (lib.ZSTD_VERSION_MAJOR, lib.ZSTD_VERSION_MINOR, lib.ZSTD_VERSION_RELEASE) 106 ZSTD_VERSION = (lib.ZSTD_VERSION_MAJOR, lib.ZSTD_VERSION_MINOR, lib.ZSTD_VERSION_RELEASE)
104 107
108 BLOCKSIZELOG_MAX = lib.ZSTD_BLOCKSIZELOG_MAX
109 BLOCKSIZE_MAX = lib.ZSTD_BLOCKSIZE_MAX
105 WINDOWLOG_MIN = lib.ZSTD_WINDOWLOG_MIN 110 WINDOWLOG_MIN = lib.ZSTD_WINDOWLOG_MIN
106 WINDOWLOG_MAX = lib.ZSTD_WINDOWLOG_MAX 111 WINDOWLOG_MAX = lib.ZSTD_WINDOWLOG_MAX
107 CHAINLOG_MIN = lib.ZSTD_CHAINLOG_MIN 112 CHAINLOG_MIN = lib.ZSTD_CHAINLOG_MIN
108 CHAINLOG_MAX = lib.ZSTD_CHAINLOG_MAX 113 CHAINLOG_MAX = lib.ZSTD_CHAINLOG_MAX
109 HASHLOG_MIN = lib.ZSTD_HASHLOG_MIN 114 HASHLOG_MIN = lib.ZSTD_HASHLOG_MIN
112 SEARCHLOG_MIN = lib.ZSTD_SEARCHLOG_MIN 117 SEARCHLOG_MIN = lib.ZSTD_SEARCHLOG_MIN
113 SEARCHLOG_MAX = lib.ZSTD_SEARCHLOG_MAX 118 SEARCHLOG_MAX = lib.ZSTD_SEARCHLOG_MAX
114 SEARCHLENGTH_MIN = lib.ZSTD_SEARCHLENGTH_MIN 119 SEARCHLENGTH_MIN = lib.ZSTD_SEARCHLENGTH_MIN
115 SEARCHLENGTH_MAX = lib.ZSTD_SEARCHLENGTH_MAX 120 SEARCHLENGTH_MAX = lib.ZSTD_SEARCHLENGTH_MAX
116 TARGETLENGTH_MIN = lib.ZSTD_TARGETLENGTH_MIN 121 TARGETLENGTH_MIN = lib.ZSTD_TARGETLENGTH_MIN
122 TARGETLENGTH_MAX = lib.ZSTD_TARGETLENGTH_MAX
117 LDM_MINMATCH_MIN = lib.ZSTD_LDM_MINMATCH_MIN 123 LDM_MINMATCH_MIN = lib.ZSTD_LDM_MINMATCH_MIN
118 LDM_MINMATCH_MAX = lib.ZSTD_LDM_MINMATCH_MAX 124 LDM_MINMATCH_MAX = lib.ZSTD_LDM_MINMATCH_MAX
119 LDM_BUCKETSIZELOG_MAX = lib.ZSTD_LDM_BUCKETSIZELOG_MAX 125 LDM_BUCKETSIZELOG_MAX = lib.ZSTD_LDM_BUCKETSIZELOG_MAX
120 126
121 STRATEGY_FAST = lib.ZSTD_fast 127 STRATEGY_FAST = lib.ZSTD_fast
189 (lib.ZSTD_p_checksumFlag, params.write_checksum), 195 (lib.ZSTD_p_checksumFlag, params.write_checksum),
190 (lib.ZSTD_p_dictIDFlag, params.write_dict_id), 196 (lib.ZSTD_p_dictIDFlag, params.write_dict_id),
191 (lib.ZSTD_p_nbWorkers, params.threads), 197 (lib.ZSTD_p_nbWorkers, params.threads),
192 (lib.ZSTD_p_jobSize, params.job_size), 198 (lib.ZSTD_p_jobSize, params.job_size),
193 (lib.ZSTD_p_overlapSizeLog, params.overlap_size_log), 199 (lib.ZSTD_p_overlapSizeLog, params.overlap_size_log),
194 (lib.ZSTD_p_compressLiterals, params.compress_literals),
195 (lib.ZSTD_p_forceMaxWindow, params.force_max_window), 200 (lib.ZSTD_p_forceMaxWindow, params.force_max_window),
196 (lib.ZSTD_p_enableLongDistanceMatching, params.enable_ldm), 201 (lib.ZSTD_p_enableLongDistanceMatching, params.enable_ldm),
197 (lib.ZSTD_p_ldmHashLog, params.ldm_hash_log), 202 (lib.ZSTD_p_ldmHashLog, params.ldm_hash_log),
198 (lib.ZSTD_p_ldmMinMatch, params.ldm_min_match), 203 (lib.ZSTD_p_ldmMinMatch, params.ldm_min_match),
199 (lib.ZSTD_p_ldmBucketSizeLog, params.ldm_bucket_size_log), 204 (lib.ZSTD_p_ldmBucketSizeLog, params.ldm_bucket_size_log),
222 227
223 for arg, attr in args.items(): 228 for arg, attr in args.items():
224 if arg not in kwargs: 229 if arg not in kwargs:
225 kwargs[arg] = getattr(params, attr) 230 kwargs[arg] = getattr(params, attr)
226 231
227 if 'compress_literals' not in kwargs:
228 kwargs['compress_literals'] = 1 if level >= 0 else 0
229
230 return ZstdCompressionParameters(**kwargs) 232 return ZstdCompressionParameters(**kwargs)
231 233
232 def __init__(self, format=0, compression_level=0, window_log=0, hash_log=0, 234 def __init__(self, format=0, compression_level=0, window_log=0, hash_log=0,
233 chain_log=0, search_log=0, min_match=0, target_length=0, 235 chain_log=0, search_log=0, min_match=0, target_length=0,
234 compression_strategy=0, write_content_size=1, write_checksum=0, 236 compression_strategy=0, write_content_size=1, write_checksum=0,
235 write_dict_id=0, job_size=0, overlap_size_log=0, 237 write_dict_id=0, job_size=0, overlap_size_log=0,
236 force_max_window=0, enable_ldm=0, ldm_hash_log=0, 238 force_max_window=0, enable_ldm=0, ldm_hash_log=0,
237 ldm_min_match=0, ldm_bucket_size_log=0, ldm_hash_every_log=0, 239 ldm_min_match=0, ldm_bucket_size_log=0, ldm_hash_every_log=0,
238 threads=0, compress_literals=None): 240 threads=0):
239 241
240 if threads < 0: 242 if threads < 0:
241 threads = _cpu_count() 243 threads = _cpu_count()
242
243 if compress_literals is None:
244 compress_literals = compression_level >= 0
245 244
246 self.format = format 245 self.format = format
247 self.compression_level = compression_level 246 self.compression_level = compression_level
248 self.window_log = window_log 247 self.window_log = window_log
249 self.hash_log = hash_log 248 self.hash_log = hash_log
255 self.write_content_size = write_content_size 254 self.write_content_size = write_content_size
256 self.write_checksum = write_checksum 255 self.write_checksum = write_checksum
257 self.write_dict_id = write_dict_id 256 self.write_dict_id = write_dict_id
258 self.job_size = job_size 257 self.job_size = job_size
259 self.overlap_size_log = overlap_size_log 258 self.overlap_size_log = overlap_size_log
260 self.compress_literals = compress_literals
261 self.force_max_window = force_max_window 259 self.force_max_window = force_max_window
262 self.enable_ldm = enable_ldm 260 self.enable_ldm = enable_ldm
263 self.ldm_hash_log = ldm_hash_log 261 self.ldm_hash_log = ldm_hash_log
264 self.ldm_min_match = ldm_min_match 262 self.ldm_min_match = ldm_min_match
265 self.ldm_bucket_size_log = ldm_bucket_size_log 263 self.ldm_bucket_size_log = ldm_bucket_size_log
409 lib.ZSTD_e_flush) 407 lib.ZSTD_e_flush)
410 if lib.ZSTD_isError(zresult): 408 if lib.ZSTD_isError(zresult):
411 raise ZstdError('zstd compress error: %s' % 409 raise ZstdError('zstd compress error: %s' %
412 _zstd_error(zresult)) 410 _zstd_error(zresult))
413 411
414 if not out_buffer.pos: 412 if out_buffer.pos:
413 self._writer.write(ffi.buffer(out_buffer.dst, out_buffer.pos)[:])
414 total_write += out_buffer.pos
415 self._bytes_compressed += out_buffer.pos
416 out_buffer.pos = 0
417
418 if not zresult:
415 break 419 break
416
417 self._writer.write(ffi.buffer(out_buffer.dst, out_buffer.pos)[:])
418 total_write += out_buffer.pos
419 self._bytes_compressed += out_buffer.pos
420 out_buffer.pos = 0
421 420
422 return total_write 421 return total_write
423 422
424 def tell(self): 423 def tell(self):
425 return self._bytes_compressed 424 return self._bytes_compressed
458 raise ValueError('flush mode not recognized') 457 raise ValueError('flush mode not recognized')
459 458
460 if self._finished: 459 if self._finished:
461 raise ZstdError('compressor object already finished') 460 raise ZstdError('compressor object already finished')
462 461
462 if flush_mode == COMPRESSOBJ_FLUSH_BLOCK:
463 z_flush_mode = lib.ZSTD_e_flush
464 elif flush_mode == COMPRESSOBJ_FLUSH_FINISH:
465 z_flush_mode = lib.ZSTD_e_end
466 self._finished = True
467 else:
468 raise ZstdError('unhandled flush mode')
469
463 assert self._out.pos == 0 470 assert self._out.pos == 0
464 471
465 in_buffer = ffi.new('ZSTD_inBuffer *') 472 in_buffer = ffi.new('ZSTD_inBuffer *')
466 in_buffer.src = ffi.NULL 473 in_buffer.src = ffi.NULL
467 in_buffer.size = 0 474 in_buffer.size = 0
468 in_buffer.pos = 0 475 in_buffer.pos = 0
469 476
470 if flush_mode == COMPRESSOBJ_FLUSH_BLOCK:
471 zresult = lib.ZSTD_compress_generic(self._compressor._cctx,
472 self._out,
473 in_buffer,
474 lib.ZSTD_e_flush)
475 if lib.ZSTD_isError(zresult):
476 raise ZstdError('zstd compress error: %s' %
477 _zstd_error(zresult))
478
479 # Output buffer is guaranteed to hold full block.
480 assert zresult == 0
481
482 if self._out.pos:
483 result = ffi.buffer(self._out.dst, self._out.pos)[:]
484 self._out.pos = 0
485 return result
486 else:
487 return b''
488
489 assert flush_mode == COMPRESSOBJ_FLUSH_FINISH
490 self._finished = True
491
492 chunks = [] 477 chunks = []
493 478
494 while True: 479 while True:
495 zresult = lib.ZSTD_compress_generic(self._compressor._cctx, 480 zresult = lib.ZSTD_compress_generic(self._compressor._cctx,
496 self._out, 481 self._out,
497 in_buffer, 482 in_buffer,
498 lib.ZSTD_e_end) 483 z_flush_mode)
499 if lib.ZSTD_isError(zresult): 484 if lib.ZSTD_isError(zresult):
500 raise ZstdError('error ending compression stream: %s' % 485 raise ZstdError('error ending compression stream: %s' %
501 _zstd_error(zresult)) 486 _zstd_error(zresult))
502 487
503 if self._out.pos: 488 if self._out.pos:
508 break 493 break
509 494
510 return b''.join(chunks) 495 return b''.join(chunks)
511 496
512 497
498 class ZstdCompressionChunker(object):
499 def __init__(self, compressor, chunk_size):
500 self._compressor = compressor
501 self._out = ffi.new('ZSTD_outBuffer *')
502 self._dst_buffer = ffi.new('char[]', chunk_size)
503 self._out.dst = self._dst_buffer
504 self._out.size = chunk_size
505 self._out.pos = 0
506
507 self._in = ffi.new('ZSTD_inBuffer *')
508 self._in.src = ffi.NULL
509 self._in.size = 0
510 self._in.pos = 0
511 self._finished = False
512
513 def compress(self, data):
514 if self._finished:
515 raise ZstdError('cannot call compress() after compression finished')
516
517 if self._in.src != ffi.NULL:
518 raise ZstdError('cannot perform operation before consuming output '
519 'from previous operation')
520
521 data_buffer = ffi.from_buffer(data)
522
523 if not len(data_buffer):
524 return
525
526 self._in.src = data_buffer
527 self._in.size = len(data_buffer)
528 self._in.pos = 0
529
530 while self._in.pos < self._in.size:
531 zresult = lib.ZSTD_compress_generic(self._compressor._cctx,
532 self._out,
533 self._in,
534 lib.ZSTD_e_continue)
535
536 if self._in.pos == self._in.size:
537 self._in.src = ffi.NULL
538 self._in.size = 0
539 self._in.pos = 0
540
541 if lib.ZSTD_isError(zresult):
542 raise ZstdError('zstd compress error: %s' %
543 _zstd_error(zresult))
544
545 if self._out.pos == self._out.size:
546 yield ffi.buffer(self._out.dst, self._out.pos)[:]
547 self._out.pos = 0
548
549 def flush(self):
550 if self._finished:
551 raise ZstdError('cannot call flush() after compression finished')
552
553 if self._in.src != ffi.NULL:
554 raise ZstdError('cannot call flush() before consuming output from '
555 'previous operation')
556
557 while True:
558 zresult = lib.ZSTD_compress_generic(self._compressor._cctx,
559 self._out, self._in,
560 lib.ZSTD_e_flush)
561 if lib.ZSTD_isError(zresult):
562 raise ZstdError('zstd compress error: %s' % _zstd_error(zresult))
563
564 if self._out.pos:
565 yield ffi.buffer(self._out.dst, self._out.pos)[:]
566 self._out.pos = 0
567
568 if not zresult:
569 return
570
571 def finish(self):
572 if self._finished:
573 raise ZstdError('cannot call finish() after compression finished')
574
575 if self._in.src != ffi.NULL:
576 raise ZstdError('cannot call finish() before consuming output from '
577 'previous operation')
578
579 while True:
580 zresult = lib.ZSTD_compress_generic(self._compressor._cctx,
581 self._out, self._in,
582 lib.ZSTD_e_end)
583 if lib.ZSTD_isError(zresult):
584 raise ZstdError('zstd compress error: %s' % _zstd_error(zresult))
585
586 if self._out.pos:
587 yield ffi.buffer(self._out.dst, self._out.pos)[:]
588 self._out.pos = 0
589
590 if not zresult:
591 self._finished = True
592 return
593
594
513 class CompressionReader(object): 595 class CompressionReader(object):
514 def __init__(self, compressor, source, size, read_size): 596 def __init__(self, compressor, source, read_size):
515 self._compressor = compressor 597 self._compressor = compressor
516 self._source = source 598 self._source = source
517 self._source_size = size
518 self._read_size = read_size 599 self._read_size = read_size
519 self._entered = False 600 self._entered = False
520 self._closed = False 601 self._closed = False
521 self._bytes_compressed = 0 602 self._bytes_compressed = 0
522 self._finished_input = False 603 self._finished_input = False
528 609
529 def __enter__(self): 610 def __enter__(self):
530 if self._entered: 611 if self._entered:
531 raise ValueError('cannot __enter__ multiple times') 612 raise ValueError('cannot __enter__ multiple times')
532 613
533 zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._compressor._cctx,
534 self._source_size)
535 if lib.ZSTD_isError(zresult):
536 raise ZstdError('error setting source size: %s' %
537 _zstd_error(zresult))
538
539 self._entered = True 614 self._entered = True
540 return self 615 return self
541 616
542 def __exit__(self, exc_type, exc_value, exc_tb): 617 def __exit__(self, exc_type, exc_value, exc_tb):
543 self._entered = False 618 self._entered = False
576 651
577 def close(self): 652 def close(self):
578 self._closed = True 653 self._closed = True
579 return None 654 return None
580 655
656 @property
581 def closed(self): 657 def closed(self):
582 return self._closed 658 return self._closed
583 659
584 def tell(self): 660 def tell(self):
585 return self._bytes_compressed 661 return self._bytes_compressed
594 raise io.UnsupportedOperation() 670 raise io.UnsupportedOperation()
595 671
596 next = __next__ 672 next = __next__
597 673
598 def read(self, size=-1): 674 def read(self, size=-1):
599 if not self._entered:
600 raise ZstdError('read() must be called from an active context manager')
601
602 if self._closed: 675 if self._closed:
603 raise ValueError('stream is closed') 676 raise ValueError('stream is closed')
604 677
605 if self._finished_output: 678 if self._finished_output:
606 return b'' 679 return b''
757 830
758 self._cctx = cctx 831 self._cctx = cctx
759 self._dict_data = dict_data 832 self._dict_data = dict_data
760 833
761 # We defer setting up garbage collection until after calling 834 # We defer setting up garbage collection until after calling
762 # _ensure_cctx() to ensure the memory size estimate is more accurate. 835 # _setup_cctx() to ensure the memory size estimate is more accurate.
763 try: 836 try:
764 self._ensure_cctx() 837 self._setup_cctx()
765 finally: 838 finally:
766 self._cctx = ffi.gc(cctx, lib.ZSTD_freeCCtx, 839 self._cctx = ffi.gc(cctx, lib.ZSTD_freeCCtx,
767 size=lib.ZSTD_sizeof_CCtx(cctx)) 840 size=lib.ZSTD_sizeof_CCtx(cctx))
768 841
769 def _ensure_cctx(self): 842 def _setup_cctx(self):
770 lib.ZSTD_CCtx_reset(self._cctx)
771
772 zresult = lib.ZSTD_CCtx_setParametersUsingCCtxParams(self._cctx, 843 zresult = lib.ZSTD_CCtx_setParametersUsingCCtxParams(self._cctx,
773 self._params) 844 self._params)
774 if lib.ZSTD_isError(zresult): 845 if lib.ZSTD_isError(zresult):
775 raise ZstdError('could not set compression parameters: %s' % 846 raise ZstdError('could not set compression parameters: %s' %
776 _zstd_error(zresult)) 847 _zstd_error(zresult))
791 862
792 def memory_size(self): 863 def memory_size(self):
793 return lib.ZSTD_sizeof_CCtx(self._cctx) 864 return lib.ZSTD_sizeof_CCtx(self._cctx)
794 865
795 def compress(self, data): 866 def compress(self, data):
796 self._ensure_cctx() 867 lib.ZSTD_CCtx_reset(self._cctx)
797 868
798 data_buffer = ffi.from_buffer(data) 869 data_buffer = ffi.from_buffer(data)
799 870
800 dest_size = lib.ZSTD_compressBound(len(data_buffer)) 871 dest_size = lib.ZSTD_compressBound(len(data_buffer))
801 out = new_nonzero('char[]', dest_size) 872 out = new_nonzero('char[]', dest_size)
828 raise ZstdError('unexpected partial frame flush') 899 raise ZstdError('unexpected partial frame flush')
829 900
830 return ffi.buffer(out, out_buffer.pos)[:] 901 return ffi.buffer(out, out_buffer.pos)[:]
831 902
832 def compressobj(self, size=-1): 903 def compressobj(self, size=-1):
833 self._ensure_cctx() 904 lib.ZSTD_CCtx_reset(self._cctx)
834 905
835 if size < 0: 906 if size < 0:
836 size = lib.ZSTD_CONTENTSIZE_UNKNOWN 907 size = lib.ZSTD_CONTENTSIZE_UNKNOWN
837 908
838 zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._cctx, size) 909 zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._cctx, size)
849 cobj._compressor = self 920 cobj._compressor = self
850 cobj._finished = False 921 cobj._finished = False
851 922
852 return cobj 923 return cobj
853 924
925 def chunker(self, size=-1, chunk_size=COMPRESSION_RECOMMENDED_OUTPUT_SIZE):
926 lib.ZSTD_CCtx_reset(self._cctx)
927
928 if size < 0:
929 size = lib.ZSTD_CONTENTSIZE_UNKNOWN
930
931 zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._cctx, size)
932 if lib.ZSTD_isError(zresult):
933 raise ZstdError('error setting source size: %s' %
934 _zstd_error(zresult))
935
936 return ZstdCompressionChunker(self, chunk_size=chunk_size)
937
854 def copy_stream(self, ifh, ofh, size=-1, 938 def copy_stream(self, ifh, ofh, size=-1,
855 read_size=COMPRESSION_RECOMMENDED_INPUT_SIZE, 939 read_size=COMPRESSION_RECOMMENDED_INPUT_SIZE,
856 write_size=COMPRESSION_RECOMMENDED_OUTPUT_SIZE): 940 write_size=COMPRESSION_RECOMMENDED_OUTPUT_SIZE):
857 941
858 if not hasattr(ifh, 'read'): 942 if not hasattr(ifh, 'read'):
859 raise ValueError('first argument must have a read() method') 943 raise ValueError('first argument must have a read() method')
860 if not hasattr(ofh, 'write'): 944 if not hasattr(ofh, 'write'):
861 raise ValueError('second argument must have a write() method') 945 raise ValueError('second argument must have a write() method')
862 946
863 self._ensure_cctx() 947 lib.ZSTD_CCtx_reset(self._cctx)
864 948
865 if size < 0: 949 if size < 0:
866 size = lib.ZSTD_CONTENTSIZE_UNKNOWN 950 size = lib.ZSTD_CONTENTSIZE_UNKNOWN
867 951
868 zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._cctx, size) 952 zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._cctx, size)
925 1009
926 return total_read, total_write 1010 return total_read, total_write
927 1011
928 def stream_reader(self, source, size=-1, 1012 def stream_reader(self, source, size=-1,
929 read_size=COMPRESSION_RECOMMENDED_INPUT_SIZE): 1013 read_size=COMPRESSION_RECOMMENDED_INPUT_SIZE):
930 self._ensure_cctx() 1014 lib.ZSTD_CCtx_reset(self._cctx)
931 1015
932 try: 1016 try:
933 size = len(source) 1017 size = len(source)
934 except Exception: 1018 except Exception:
935 pass 1019 pass
936 1020
937 if size < 0: 1021 if size < 0:
938 size = lib.ZSTD_CONTENTSIZE_UNKNOWN 1022 size = lib.ZSTD_CONTENTSIZE_UNKNOWN
939 1023
940 return CompressionReader(self, source, size, read_size) 1024 zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._cctx, size)
1025 if lib.ZSTD_isError(zresult):
1026 raise ZstdError('error setting source size: %s' %
1027 _zstd_error(zresult))
1028
1029 return CompressionReader(self, source, read_size)
941 1030
942 def stream_writer(self, writer, size=-1, 1031 def stream_writer(self, writer, size=-1,
943 write_size=COMPRESSION_RECOMMENDED_OUTPUT_SIZE): 1032 write_size=COMPRESSION_RECOMMENDED_OUTPUT_SIZE):
944 1033
945 if not hasattr(writer, 'write'): 1034 if not hasattr(writer, 'write'):
946 raise ValueError('must pass an object with a write() method') 1035 raise ValueError('must pass an object with a write() method')
947 1036
948 self._ensure_cctx() 1037 lib.ZSTD_CCtx_reset(self._cctx)
949 1038
950 if size < 0: 1039 if size < 0:
951 size = lib.ZSTD_CONTENTSIZE_UNKNOWN 1040 size = lib.ZSTD_CONTENTSIZE_UNKNOWN
952 1041
953 return ZstdCompressionWriter(self, writer, size, write_size) 1042 return ZstdCompressionWriter(self, writer, size, write_size)
965 size = len(reader) 1054 size = len(reader)
966 else: 1055 else:
967 raise ValueError('must pass an object with a read() method or ' 1056 raise ValueError('must pass an object with a read() method or '
968 'conforms to buffer protocol') 1057 'conforms to buffer protocol')
969 1058
970 self._ensure_cctx() 1059 lib.ZSTD_CCtx_reset(self._cctx)
971 1060
972 if size < 0: 1061 if size < 0:
973 size = lib.ZSTD_CONTENTSIZE_UNKNOWN 1062 size = lib.ZSTD_CONTENTSIZE_UNKNOWN
974 1063
975 zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._cctx, size) 1064 zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._cctx, size)
1265 out_buffer.size = len(dst_buffer) 1354 out_buffer.size = len(dst_buffer)
1266 out_buffer.pos = 0 1355 out_buffer.pos = 0
1267 1356
1268 chunks = [] 1357 chunks = []
1269 1358
1270 while in_buffer.pos < in_buffer.size: 1359 while True:
1271 zresult = lib.ZSTD_decompress_generic(self._decompressor._dctx, 1360 zresult = lib.ZSTD_decompress_generic(self._decompressor._dctx,
1272 out_buffer, in_buffer) 1361 out_buffer, in_buffer)
1273 if lib.ZSTD_isError(zresult): 1362 if lib.ZSTD_isError(zresult):
1274 raise ZstdError('zstd decompressor error: %s' % 1363 raise ZstdError('zstd decompressor error: %s' %
1275 _zstd_error(zresult)) 1364 _zstd_error(zresult))
1278 self._finished = True 1367 self._finished = True
1279 self._decompressor = None 1368 self._decompressor = None
1280 1369
1281 if out_buffer.pos: 1370 if out_buffer.pos:
1282 chunks.append(ffi.buffer(out_buffer.dst, out_buffer.pos)[:]) 1371 chunks.append(ffi.buffer(out_buffer.dst, out_buffer.pos)[:])
1283 out_buffer.pos = 0 1372
1373 if (zresult == 0 or
1374 (in_buffer.pos == in_buffer.size and out_buffer.pos == 0)):
1375 break
1376
1377 out_buffer.pos = 0
1284 1378
1285 return b''.join(chunks) 1379 return b''.join(chunks)
1286 1380
1287 1381
1288 class DecompressionReader(object): 1382 class DecompressionReader(object):
1301 1395
1302 def __enter__(self): 1396 def __enter__(self):
1303 if self._entered: 1397 if self._entered:
1304 raise ValueError('cannot __enter__ multiple times') 1398 raise ValueError('cannot __enter__ multiple times')
1305 1399
1306 self._decompressor._ensure_dctx()
1307
1308 self._entered = True 1400 self._entered = True
1309 return self 1401 return self
1310 1402
1311 def __exit__(self, exc_type, exc_value, exc_tb): 1403 def __exit__(self, exc_type, exc_value, exc_tb):
1312 self._entered = False 1404 self._entered = False
1345 1437
1346 def close(self): 1438 def close(self):
1347 self._closed = True 1439 self._closed = True
1348 return None 1440 return None
1349 1441
1442 @property
1350 def closed(self): 1443 def closed(self):
1351 return self._closed 1444 return self._closed
1352 1445
1353 def tell(self): 1446 def tell(self):
1354 return self._bytes_decompressed 1447 return self._bytes_decompressed
1362 def __next__(self): 1455 def __next__(self):
1363 raise NotImplementedError() 1456 raise NotImplementedError()
1364 1457
1365 next = __next__ 1458 next = __next__
1366 1459
1367 def read(self, size=-1): 1460 def read(self, size):
1368 if not self._entered:
1369 raise ZstdError('read() must be called from an active context manager')
1370
1371 if self._closed: 1461 if self._closed:
1372 raise ValueError('stream is closed') 1462 raise ValueError('stream is closed')
1373 1463
1374 if self._finished_output: 1464 if self._finished_output:
1375 return b'' 1465 return b''
1440 1530
1441 self._bytes_decompressed += out_buffer.pos 1531 self._bytes_decompressed += out_buffer.pos
1442 return ffi.buffer(out_buffer.dst, out_buffer.pos)[:] 1532 return ffi.buffer(out_buffer.dst, out_buffer.pos)[:]
1443 1533
1444 def seek(self, pos, whence=os.SEEK_SET): 1534 def seek(self, pos, whence=os.SEEK_SET):
1445 if not self._entered:
1446 raise ZstdError('seek() must be called from an active context '
1447 'manager')
1448
1449 if self._closed: 1535 if self._closed:
1450 raise ValueError('stream is closed') 1536 raise ValueError('stream is closed')
1451 1537
1452 read_amount = 0 1538 read_amount = 0
1453 1539