Mercurial > public > mercurial-scm > hg-stable
diff contrib/python-zstandard/zstd_cffi.py @ 37495:b1fb341d8a61
zstandard: vendor python-zstandard 0.9.0
This was just released. It features a number of goodies. More info at
https://gregoryszorc.com/blog/2018/04/09/release-of-python-zstandard-0.9/.
The clang-format ignore list was updated to reflect the new source
of files.
The project contains a vendored copy of zstandard 1.3.4. The old
version was 1.1.3. One of the changes between those versions is that
zstandard is now dual licensed BSD + GPLv2 and the patent rights grant
has been removed. Good riddance.
The API should be backwards compatible. So no changes in core
should be needed. However, there were a number of changes in the
library that we'll want to adapt to. Those will be addressed in
subsequent commits.
Differential Revision: https://phab.mercurial-scm.org/D3198
author | Gregory Szorc <gregory.szorc@gmail.com> |
---|---|
date | Mon, 09 Apr 2018 10:13:29 -0700 |
parents | e0dc40530c5a |
children | 73fef626dae3 |
line wrap: on
line diff
--- a/contrib/python-zstandard/zstd_cffi.py Sun Apr 08 01:08:43 2018 +0200 +++ b/contrib/python-zstandard/zstd_cffi.py Mon Apr 09 10:13:29 2018 -0700 @@ -8,6 +8,69 @@ from __future__ import absolute_import, unicode_literals +# This should match what the C extension exports. +__all__ = [ + #'BufferSegment', + #'BufferSegments', + #'BufferWithSegments', + #'BufferWithSegmentsCollection', + 'CompressionParameters', + 'ZstdCompressionDict', + 'ZstdCompressionParameters', + 'ZstdCompressor', + 'ZstdError', + 'ZstdDecompressor', + 'FrameParameters', + 'estimate_decompression_context_size', + 'frame_content_size', + 'frame_header_size', + 'get_frame_parameters', + 'train_dictionary', + + # Constants. + 'COMPRESSOBJ_FLUSH_FINISH', + 'COMPRESSOBJ_FLUSH_BLOCK', + 'ZSTD_VERSION', + 'FRAME_HEADER', + 'CONTENTSIZE_UNKNOWN', + 'CONTENTSIZE_ERROR', + 'MAX_COMPRESSION_LEVEL', + 'COMPRESSION_RECOMMENDED_INPUT_SIZE', + 'COMPRESSION_RECOMMENDED_OUTPUT_SIZE', + 'DECOMPRESSION_RECOMMENDED_INPUT_SIZE', + 'DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE', + 'MAGIC_NUMBER', + 'WINDOWLOG_MIN', + 'WINDOWLOG_MAX', + 'CHAINLOG_MIN', + 'CHAINLOG_MAX', + 'HASHLOG_MIN', + 'HASHLOG_MAX', + 'HASHLOG3_MAX', + 'SEARCHLOG_MIN', + 'SEARCHLOG_MAX', + 'SEARCHLENGTH_MIN', + 'SEARCHLENGTH_MAX', + 'TARGETLENGTH_MIN', + 'LDM_MINMATCH_MIN', + 'LDM_MINMATCH_MAX', + 'LDM_BUCKETSIZELOG_MAX', + 'STRATEGY_FAST', + 'STRATEGY_DFAST', + 'STRATEGY_GREEDY', + 'STRATEGY_LAZY', + 'STRATEGY_LAZY2', + 'STRATEGY_BTLAZY2', + 'STRATEGY_BTOPT', + 'STRATEGY_BTULTRA', + 'DICT_TYPE_AUTO', + 'DICT_TYPE_RAWCONTENT', + 'DICT_TYPE_FULLDICT', + 'FORMAT_ZSTD1', + 'FORMAT_ZSTD1_MAGICLESS', +] + +import io import os import sys @@ -35,6 +98,8 @@ MAX_COMPRESSION_LEVEL = lib.ZSTD_maxCLevel() MAGIC_NUMBER = lib.ZSTD_MAGICNUMBER FRAME_HEADER = b'\x28\xb5\x2f\xfd' +CONTENTSIZE_UNKNOWN = lib.ZSTD_CONTENTSIZE_UNKNOWN +CONTENTSIZE_ERROR = lib.ZSTD_CONTENTSIZE_ERROR ZSTD_VERSION = (lib.ZSTD_VERSION_MAJOR, lib.ZSTD_VERSION_MINOR, lib.ZSTD_VERSION_RELEASE) WINDOWLOG_MIN = lib.ZSTD_WINDOWLOG_MIN @@ -49,7 +114,9 @@ SEARCHLENGTH_MIN = lib.ZSTD_SEARCHLENGTH_MIN SEARCHLENGTH_MAX = lib.ZSTD_SEARCHLENGTH_MAX TARGETLENGTH_MIN = lib.ZSTD_TARGETLENGTH_MIN -TARGETLENGTH_MAX = lib.ZSTD_TARGETLENGTH_MAX +LDM_MINMATCH_MIN = lib.ZSTD_LDM_MINMATCH_MIN +LDM_MINMATCH_MAX = lib.ZSTD_LDM_MINMATCH_MAX +LDM_BUCKETSIZELOG_MAX = lib.ZSTD_LDM_BUCKETSIZELOG_MAX STRATEGY_FAST = lib.ZSTD_fast STRATEGY_DFAST = lib.ZSTD_dfast @@ -58,6 +125,14 @@ STRATEGY_LAZY2 = lib.ZSTD_lazy2 STRATEGY_BTLAZY2 = lib.ZSTD_btlazy2 STRATEGY_BTOPT = lib.ZSTD_btopt +STRATEGY_BTULTRA = lib.ZSTD_btultra + +DICT_TYPE_AUTO = lib.ZSTD_dct_auto +DICT_TYPE_RAWCONTENT = lib.ZSTD_dct_rawContent +DICT_TYPE_FULLDICT = lib.ZSTD_dct_fullDict + +FORMAT_ZSTD1 = lib.ZSTD_f_zstd1 +FORMAT_ZSTD1_MAGICLESS = lib.ZSTD_f_zstd1_magicless COMPRESSOBJ_FLUSH_FINISH = 0 COMPRESSOBJ_FLUSH_BLOCK = 1 @@ -87,81 +162,128 @@ pass -class CompressionParameters(object): - def __init__(self, window_log, chain_log, hash_log, search_log, - search_length, target_length, strategy): - if window_log < WINDOWLOG_MIN or window_log > WINDOWLOG_MAX: - raise ValueError('invalid window log value') +def _zstd_error(zresult): + # Resolves to bytes on Python 2 and 3. We use the string for formatting + # into error messages, which will be literal unicode. So convert it to + # unicode. + return ffi.string(lib.ZSTD_getErrorName(zresult)).decode('utf-8') - if chain_log < CHAINLOG_MIN or chain_log > CHAINLOG_MAX: - raise ValueError('invalid chain log value') +def _make_cctx_params(params): + res = lib.ZSTD_createCCtxParams() + if res == ffi.NULL: + raise MemoryError() + + res = ffi.gc(res, lib.ZSTD_freeCCtxParams) - if hash_log < HASHLOG_MIN or hash_log > HASHLOG_MAX: - raise ValueError('invalid hash log value') + attrs = [ + (lib.ZSTD_p_format, params.format), + (lib.ZSTD_p_compressionLevel, params.compression_level), + (lib.ZSTD_p_windowLog, params.window_log), + (lib.ZSTD_p_hashLog, params.hash_log), + (lib.ZSTD_p_chainLog, params.chain_log), + (lib.ZSTD_p_searchLog, params.search_log), + (lib.ZSTD_p_minMatch, params.min_match), + (lib.ZSTD_p_targetLength, params.target_length), + (lib.ZSTD_p_compressionStrategy, params.compression_strategy), + (lib.ZSTD_p_contentSizeFlag, params.write_content_size), + (lib.ZSTD_p_checksumFlag, params.write_checksum), + (lib.ZSTD_p_dictIDFlag, params.write_dict_id), + (lib.ZSTD_p_nbWorkers, params.threads), + (lib.ZSTD_p_jobSize, params.job_size), + (lib.ZSTD_p_overlapSizeLog, params.overlap_size_log), + (lib.ZSTD_p_compressLiterals, params.compress_literals), + (lib.ZSTD_p_forceMaxWindow, params.force_max_window), + (lib.ZSTD_p_enableLongDistanceMatching, params.enable_ldm), + (lib.ZSTD_p_ldmHashLog, params.ldm_hash_log), + (lib.ZSTD_p_ldmMinMatch, params.ldm_min_match), + (lib.ZSTD_p_ldmBucketSizeLog, params.ldm_bucket_size_log), + (lib.ZSTD_p_ldmHashEveryLog, params.ldm_hash_every_log), + ] - if search_log < SEARCHLOG_MIN or search_log > SEARCHLOG_MAX: - raise ValueError('invalid search log value') + for param, value in attrs: + _set_compression_parameter(res, param, value) + + return res - if search_length < SEARCHLENGTH_MIN or search_length > SEARCHLENGTH_MAX: - raise ValueError('invalid search length value') +class ZstdCompressionParameters(object): + @staticmethod + def from_level(level, source_size=0, dict_size=0, **kwargs): + params = lib.ZSTD_getCParams(level, source_size, dict_size) - if target_length < TARGETLENGTH_MIN or target_length > TARGETLENGTH_MAX: - raise ValueError('invalid target length value') + args = { + 'window_log': 'windowLog', + 'chain_log': 'chainLog', + 'hash_log': 'hashLog', + 'search_log': 'searchLog', + 'min_match': 'searchLength', + 'target_length': 'targetLength', + 'compression_strategy': 'strategy', + } + + for arg, attr in args.items(): + if arg not in kwargs: + kwargs[arg] = getattr(params, attr) + + if 'compress_literals' not in kwargs: + kwargs['compress_literals'] = 1 if level >= 0 else 0 - if strategy < STRATEGY_FAST or strategy > STRATEGY_BTOPT: - raise ValueError('invalid strategy value') + return ZstdCompressionParameters(**kwargs) + + def __init__(self, format=0, compression_level=0, window_log=0, hash_log=0, + chain_log=0, search_log=0, min_match=0, target_length=0, + compression_strategy=0, write_content_size=1, write_checksum=0, + write_dict_id=0, job_size=0, overlap_size_log=0, + force_max_window=0, enable_ldm=0, ldm_hash_log=0, + ldm_min_match=0, ldm_bucket_size_log=0, ldm_hash_every_log=0, + threads=0, compress_literals=None): + if threads < 0: + threads = _cpu_count() + + if compress_literals is None: + compress_literals = compression_level >= 0 + + self.format = format + self.compression_level = compression_level self.window_log = window_log - self.chain_log = chain_log self.hash_log = hash_log + self.chain_log = chain_log self.search_log = search_log - self.search_length = search_length + self.min_match = min_match self.target_length = target_length - self.strategy = strategy + self.compression_strategy = compression_strategy + self.write_content_size = write_content_size + self.write_checksum = write_checksum + self.write_dict_id = write_dict_id + self.job_size = job_size + self.overlap_size_log = overlap_size_log + self.compress_literals = compress_literals + self.force_max_window = force_max_window + self.enable_ldm = enable_ldm + self.ldm_hash_log = ldm_hash_log + self.ldm_min_match = ldm_min_match + self.ldm_bucket_size_log = ldm_bucket_size_log + self.ldm_hash_every_log = ldm_hash_every_log + self.threads = threads - zresult = lib.ZSTD_checkCParams(self.as_compression_parameters()) - if lib.ZSTD_isError(zresult): - raise ValueError('invalid compression parameters: %s', - ffi.string(lib.ZSTD_getErrorName(zresult))) + self.params = _make_cctx_params(self) def estimated_compression_context_size(self): - return lib.ZSTD_estimateCCtxSize(self.as_compression_parameters()) - - def as_compression_parameters(self): - p = ffi.new('ZSTD_compressionParameters *')[0] - p.windowLog = self.window_log - p.chainLog = self.chain_log - p.hashLog = self.hash_log - p.searchLog = self.search_log - p.searchLength = self.search_length - p.targetLength = self.target_length - p.strategy = self.strategy - - return p + return lib.ZSTD_estimateCCtxSize_usingCCtxParams(self.params) -def get_compression_parameters(level, source_size=0, dict_size=0): - params = lib.ZSTD_getCParams(level, source_size, dict_size) - return CompressionParameters(window_log=params.windowLog, - chain_log=params.chainLog, - hash_log=params.hashLog, - search_log=params.searchLog, - search_length=params.searchLength, - target_length=params.targetLength, - strategy=params.strategy) - - -def estimate_compression_context_size(params): - if not isinstance(params, CompressionParameters): - raise ValueError('argument must be a CompressionParameters') - - cparams = params.as_compression_parameters() - return lib.ZSTD_estimateCCtxSize(cparams) - +CompressionParameters = ZstdCompressionParameters def estimate_decompression_context_size(): return lib.ZSTD_estimateDCtxSize() +def _set_compression_parameter(params, param, value): + zresult = lib.ZSTD_CCtxParam_setParameter(params, param, + ffi.cast('unsigned', value)) + if lib.ZSTD_isError(zresult): + raise ZstdError('unable to set compression context parameter: %s' % + _zstd_error(zresult)) + class ZstdCompressionWriter(object): def __init__(self, compressor, writer, source_size, write_size): self._compressor = compressor @@ -169,16 +291,18 @@ self._source_size = source_size self._write_size = write_size self._entered = False - self._mtcctx = compressor._cctx if compressor._multithreaded else None + self._bytes_compressed = 0 def __enter__(self): if self._entered: raise ZstdError('cannot __enter__ multiple times') - if self._mtcctx: - self._compressor._init_mtcstream(self._source_size) - else: - self._compressor._ensure_cstream(self._source_size) + zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._compressor._cctx, + self._source_size) + if lib.ZSTD_isError(zresult): + raise ZstdError('error setting source size: %s' % + _zstd_error(zresult)) + self._entered = True return self @@ -186,20 +310,27 @@ self._entered = False if not exc_type and not exc_value and not exc_tb: + dst_buffer = ffi.new('char[]', self._write_size) + out_buffer = ffi.new('ZSTD_outBuffer *') - dst_buffer = ffi.new('char[]', self._write_size) + in_buffer = ffi.new('ZSTD_inBuffer *') + out_buffer.dst = dst_buffer - out_buffer.size = self._write_size + out_buffer.size = len(dst_buffer) out_buffer.pos = 0 + in_buffer.src = ffi.NULL + in_buffer.size = 0 + in_buffer.pos = 0 + while True: - if self._mtcctx: - zresult = lib.ZSTDMT_endStream(self._mtcctx, out_buffer) - else: - zresult = lib.ZSTD_endStream(self._compressor._cstream, out_buffer) + zresult = lib.ZSTD_compress_generic(self._compressor._cctx, + out_buffer, in_buffer, + lib.ZSTD_e_end) + if lib.ZSTD_isError(zresult): raise ZstdError('error ending compression stream: %s' % - ffi.string(lib.ZSTD_getErrorName(zresult))) + _zstd_error(zresult)) if out_buffer.pos: self._writer.write(ffi.buffer(out_buffer.dst, out_buffer.pos)[:]) @@ -217,7 +348,7 @@ raise ZstdError('cannot determine size of an inactive compressor; ' 'call when a context manager is active') - return lib.ZSTD_sizeof_CStream(self._compressor._cstream) + return lib.ZSTD_sizeof_CCtx(self._compressor._cctx) def write(self, data): if not self._entered: @@ -240,19 +371,17 @@ out_buffer.pos = 0 while in_buffer.pos < in_buffer.size: - if self._mtcctx: - zresult = lib.ZSTDMT_compressStream(self._mtcctx, out_buffer, - in_buffer) - else: - zresult = lib.ZSTD_compressStream(self._compressor._cstream, out_buffer, - in_buffer) + zresult = lib.ZSTD_compress_generic(self._compressor._cctx, + out_buffer, in_buffer, + lib.ZSTD_e_continue) if lib.ZSTD_isError(zresult): raise ZstdError('zstd compress error: %s' % - ffi.string(lib.ZSTD_getErrorName(zresult))) + _zstd_error(zresult)) if out_buffer.pos: self._writer.write(ffi.buffer(out_buffer.dst, out_buffer.pos)[:]) total_write += out_buffer.pos + self._bytes_compressed += out_buffer.pos out_buffer.pos = 0 return total_write @@ -269,24 +398,32 @@ out_buffer.size = self._write_size out_buffer.pos = 0 + in_buffer = ffi.new('ZSTD_inBuffer *') + in_buffer.src = ffi.NULL + in_buffer.size = 0 + in_buffer.pos = 0 + while True: - if self._mtcctx: - zresult = lib.ZSTDMT_flushStream(self._mtcctx, out_buffer) - else: - zresult = lib.ZSTD_flushStream(self._compressor._cstream, out_buffer) + zresult = lib.ZSTD_compress_generic(self._compressor._cctx, + out_buffer, in_buffer, + lib.ZSTD_e_flush) if lib.ZSTD_isError(zresult): raise ZstdError('zstd compress error: %s' % - ffi.string(lib.ZSTD_getErrorName(zresult))) + _zstd_error(zresult)) if not out_buffer.pos: break self._writer.write(ffi.buffer(out_buffer.dst, out_buffer.pos)[:]) total_write += out_buffer.pos + self._bytes_compressed += out_buffer.pos out_buffer.pos = 0 return total_write + def tell(self): + return self._bytes_compressed + class ZstdCompressionObj(object): def compress(self, data): @@ -302,15 +439,13 @@ chunks = [] while source.pos < len(data): - if self._mtcctx: - zresult = lib.ZSTDMT_compressStream(self._mtcctx, - self._out, source) - else: - zresult = lib.ZSTD_compressStream(self._compressor._cstream, self._out, - source) + zresult = lib.ZSTD_compress_generic(self._compressor._cctx, + self._out, + source, + lib.ZSTD_e_continue) if lib.ZSTD_isError(zresult): raise ZstdError('zstd compress error: %s' % - ffi.string(lib.ZSTD_getErrorName(zresult))) + _zstd_error(zresult)) if self._out.pos: chunks.append(ffi.buffer(self._out.dst, self._out.pos)[:]) @@ -327,14 +462,19 @@ assert self._out.pos == 0 + in_buffer = ffi.new('ZSTD_inBuffer *') + in_buffer.src = ffi.NULL + in_buffer.size = 0 + in_buffer.pos = 0 + if flush_mode == COMPRESSOBJ_FLUSH_BLOCK: - if self._mtcctx: - zresult = lib.ZSTDMT_flushStream(self._mtcctx, self._out) - else: - zresult = lib.ZSTD_flushStream(self._compressor._cstream, self._out) + zresult = lib.ZSTD_compress_generic(self._compressor._cctx, + self._out, + in_buffer, + lib.ZSTD_e_flush) if lib.ZSTD_isError(zresult): raise ZstdError('zstd compress error: %s' % - ffi.string(lib.ZSTD_getErrorName(zresult))) + _zstd_error(zresult)) # Output buffer is guaranteed to hold full block. assert zresult == 0 @@ -352,13 +492,13 @@ chunks = [] while True: - if self._mtcctx: - zresult = lib.ZSTDMT_endStream(self._mtcctx, self._out) - else: - zresult = lib.ZSTD_endStream(self._compressor._cstream, self._out) + zresult = lib.ZSTD_compress_generic(self._compressor._cctx, + self._out, + in_buffer, + lib.ZSTD_e_end) if lib.ZSTD_isError(zresult): raise ZstdError('error ending compression stream: %s' % - ffi.string(lib.ZSTD_getErroName(zresult))) + _zstd_error(zresult)) if self._out.pos: chunks.append(ffi.buffer(self._out.dst, self._out.pos)[:]) @@ -370,95 +510,335 @@ return b''.join(chunks) +class CompressionReader(object): + def __init__(self, compressor, source, size, read_size): + self._compressor = compressor + self._source = source + self._source_size = size + self._read_size = read_size + self._entered = False + self._closed = False + self._bytes_compressed = 0 + self._finished_input = False + self._finished_output = False + + self._in_buffer = ffi.new('ZSTD_inBuffer *') + # Holds a ref so backing bytes in self._in_buffer stay alive. + self._source_buffer = None + + def __enter__(self): + if self._entered: + raise ValueError('cannot __enter__ multiple times') + + zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._compressor._cctx, + self._source_size) + if lib.ZSTD_isError(zresult): + raise ZstdError('error setting source size: %s' % + _zstd_error(zresult)) + + self._entered = True + return self + + def __exit__(self, exc_type, exc_value, exc_tb): + self._entered = False + self._closed = True + self._source = None + self._compressor = None + + return False + + def readable(self): + return True + + def writable(self): + return False + + def seekable(self): + return False + + def readline(self): + raise io.UnsupportedOperation() + + def readlines(self): + raise io.UnsupportedOperation() + + def write(self, data): + raise OSError('stream is not writable') + + def writelines(self, ignored): + raise OSError('stream is not writable') + + def isatty(self): + return False + + def flush(self): + return None + + def close(self): + self._closed = True + return None + + def closed(self): + return self._closed + + def tell(self): + return self._bytes_compressed + + def readall(self): + raise NotImplementedError() + + def __iter__(self): + raise io.UnsupportedOperation() + + def __next__(self): + raise io.UnsupportedOperation() + + next = __next__ + + def read(self, size=-1): + if not self._entered: + raise ZstdError('read() must be called from an active context manager') + + if self._closed: + raise ValueError('stream is closed') + + if self._finished_output: + return b'' + + if size < 1: + raise ValueError('cannot read negative or size 0 amounts') + + # Need a dedicated ref to dest buffer otherwise it gets collected. + dst_buffer = ffi.new('char[]', size) + out_buffer = ffi.new('ZSTD_outBuffer *') + out_buffer.dst = dst_buffer + out_buffer.size = size + out_buffer.pos = 0 + + def compress_input(): + if self._in_buffer.pos >= self._in_buffer.size: + return + + old_pos = out_buffer.pos + + zresult = lib.ZSTD_compress_generic(self._compressor._cctx, + out_buffer, self._in_buffer, + lib.ZSTD_e_continue) + + self._bytes_compressed += out_buffer.pos - old_pos + + if self._in_buffer.pos == self._in_buffer.size: + self._in_buffer.src = ffi.NULL + self._in_buffer.pos = 0 + self._in_buffer.size = 0 + self._source_buffer = None + + if not hasattr(self._source, 'read'): + self._finished_input = True + + if lib.ZSTD_isError(zresult): + raise ZstdError('zstd compress error: %s', + _zstd_error(zresult)) + + if out_buffer.pos and out_buffer.pos == out_buffer.size: + return ffi.buffer(out_buffer.dst, out_buffer.pos)[:] + + def get_input(): + if self._finished_input: + return + + if hasattr(self._source, 'read'): + data = self._source.read(self._read_size) + + if not data: + self._finished_input = True + return + + self._source_buffer = ffi.from_buffer(data) + self._in_buffer.src = self._source_buffer + self._in_buffer.size = len(self._source_buffer) + self._in_buffer.pos = 0 + else: + self._source_buffer = ffi.from_buffer(self._source) + self._in_buffer.src = self._source_buffer + self._in_buffer.size = len(self._source_buffer) + self._in_buffer.pos = 0 + + result = compress_input() + if result: + return result + + while not self._finished_input: + get_input() + result = compress_input() + if result: + return result + + # EOF + old_pos = out_buffer.pos + + zresult = lib.ZSTD_compress_generic(self._compressor._cctx, + out_buffer, self._in_buffer, + lib.ZSTD_e_end) + + self._bytes_compressed += out_buffer.pos - old_pos + + if lib.ZSTD_isError(zresult): + raise ZstdError('error ending compression stream: %s', + _zstd_error(zresult)) + + if zresult == 0: + self._finished_output = True + + return ffi.buffer(out_buffer.dst, out_buffer.pos)[:] + class ZstdCompressor(object): def __init__(self, level=3, dict_data=None, compression_params=None, - write_checksum=False, write_content_size=False, - write_dict_id=True, threads=0): - if level < 1: - raise ValueError('level must be greater than 0') - elif level > lib.ZSTD_maxCLevel(): + write_checksum=None, write_content_size=None, + write_dict_id=None, threads=0): + if level > lib.ZSTD_maxCLevel(): raise ValueError('level must be less than %d' % lib.ZSTD_maxCLevel()) if threads < 0: threads = _cpu_count() - self._compression_level = level - self._dict_data = dict_data - self._cparams = compression_params - self._fparams = ffi.new('ZSTD_frameParameters *')[0] - self._fparams.checksumFlag = write_checksum - self._fparams.contentSizeFlag = write_content_size - self._fparams.noDictIDFlag = not write_dict_id + if compression_params and write_checksum is not None: + raise ValueError('cannot define compression_params and ' + 'write_checksum') + + if compression_params and write_content_size is not None: + raise ValueError('cannot define compression_params and ' + 'write_content_size') + + if compression_params and write_dict_id is not None: + raise ValueError('cannot define compression_params and ' + 'write_dict_id') - if threads: - cctx = lib.ZSTDMT_createCCtx(threads) - if cctx == ffi.NULL: - raise MemoryError() + if compression_params and threads: + raise ValueError('cannot define compression_params and threads') - self._cctx = ffi.gc(cctx, lib.ZSTDMT_freeCCtx) - self._multithreaded = True + if compression_params: + self._params = _make_cctx_params(compression_params) else: - cctx = lib.ZSTD_createCCtx() - if cctx == ffi.NULL: + if write_dict_id is None: + write_dict_id = True + + params = lib.ZSTD_createCCtxParams() + if params == ffi.NULL: raise MemoryError() - self._cctx = ffi.gc(cctx, lib.ZSTD_freeCCtx) - self._multithreaded = False + self._params = ffi.gc(params, lib.ZSTD_freeCCtxParams) + + _set_compression_parameter(self._params, + lib.ZSTD_p_compressionLevel, + level) - self._cstream = None + _set_compression_parameter( + self._params, + lib.ZSTD_p_contentSizeFlag, + write_content_size if write_content_size is not None else 1) + + _set_compression_parameter(self._params, + lib.ZSTD_p_checksumFlag, + 1 if write_checksum else 0) - def compress(self, data, allow_empty=False): - if len(data) == 0 and self._fparams.contentSizeFlag and not allow_empty: - raise ValueError('cannot write empty inputs when writing content sizes') + _set_compression_parameter(self._params, + lib.ZSTD_p_dictIDFlag, + 1 if write_dict_id else 0) - if self._multithreaded and self._dict_data: - raise ZstdError('compress() cannot be used with both dictionaries and multi-threaded compression') + if threads: + _set_compression_parameter(self._params, + lib.ZSTD_p_nbWorkers, + threads) - if self._multithreaded and self._cparams: - raise ZstdError('compress() cannot be used with both compression parameters and multi-threaded compression') + cctx = lib.ZSTD_createCCtx() + if cctx == ffi.NULL: + raise MemoryError() + + self._cctx = cctx + self._dict_data = dict_data - # TODO use a CDict for performance. - dict_data = ffi.NULL - dict_size = 0 + # We defer setting up garbage collection until after calling + # _ensure_cctx() to ensure the memory size estimate is more accurate. + try: + self._ensure_cctx() + finally: + self._cctx = ffi.gc(cctx, lib.ZSTD_freeCCtx, + size=lib.ZSTD_sizeof_CCtx(cctx)) - if self._dict_data: - dict_data = self._dict_data.as_bytes() - dict_size = len(self._dict_data) + def _ensure_cctx(self): + lib.ZSTD_CCtx_reset(self._cctx) + + zresult = lib.ZSTD_CCtx_setParametersUsingCCtxParams(self._cctx, + self._params) + if lib.ZSTD_isError(zresult): + raise ZstdError('could not set compression parameters: %s' % + _zstd_error(zresult)) + + dict_data = self._dict_data - params = ffi.new('ZSTD_parameters *')[0] - if self._cparams: - params.cParams = self._cparams.as_compression_parameters() - else: - params.cParams = lib.ZSTD_getCParams(self._compression_level, len(data), - dict_size) - params.fParams = self._fparams + if dict_data: + if dict_data._cdict: + zresult = lib.ZSTD_CCtx_refCDict(self._cctx, dict_data._cdict) + else: + zresult = lib.ZSTD_CCtx_loadDictionary_advanced( + self._cctx, dict_data.as_bytes(), len(dict_data), + lib.ZSTD_dlm_byRef, dict_data._dict_type) - dest_size = lib.ZSTD_compressBound(len(data)) + if lib.ZSTD_isError(zresult): + raise ZstdError('could not load compression dictionary: %s' % + _zstd_error(zresult)) + + def memory_size(self): + return lib.ZSTD_sizeof_CCtx(self._cctx) + + def compress(self, data): + self._ensure_cctx() + + data_buffer = ffi.from_buffer(data) + + dest_size = lib.ZSTD_compressBound(len(data_buffer)) out = new_nonzero('char[]', dest_size) - if self._multithreaded: - zresult = lib.ZSTDMT_compressCCtx(self._cctx, - ffi.addressof(out), dest_size, - data, len(data), - self._compression_level) - else: - zresult = lib.ZSTD_compress_advanced(self._cctx, - ffi.addressof(out), dest_size, - data, len(data), - dict_data, dict_size, - params) + zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._cctx, len(data_buffer)) + if lib.ZSTD_isError(zresult): + raise ZstdError('error setting source size: %s' % + _zstd_error(zresult)) + + out_buffer = ffi.new('ZSTD_outBuffer *') + in_buffer = ffi.new('ZSTD_inBuffer *') + + out_buffer.dst = out + out_buffer.size = dest_size + out_buffer.pos = 0 + + in_buffer.src = data_buffer + in_buffer.size = len(data_buffer) + in_buffer.pos = 0 + + zresult = lib.ZSTD_compress_generic(self._cctx, + out_buffer, + in_buffer, + lib.ZSTD_e_end) if lib.ZSTD_isError(zresult): raise ZstdError('cannot compress: %s' % - ffi.string(lib.ZSTD_getErrorName(zresult))) + _zstd_error(zresult)) + elif zresult: + raise ZstdError('unexpected partial frame flush') - return ffi.buffer(out, zresult)[:] + return ffi.buffer(out, out_buffer.pos)[:] - def compressobj(self, size=0): - if self._multithreaded: - self._init_mtcstream(size) - else: - self._ensure_cstream(size) + def compressobj(self, size=-1): + self._ensure_cctx() + + if size < 0: + size = lib.ZSTD_CONTENTSIZE_UNKNOWN + + zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._cctx, size) + if lib.ZSTD_isError(zresult): + raise ZstdError('error setting source size: %s' % + _zstd_error(zresult)) cobj = ZstdCompressionObj() cobj._out = ffi.new('ZSTD_outBuffer *') @@ -469,14 +849,9 @@ cobj._compressor = self cobj._finished = False - if self._multithreaded: - cobj._mtcctx = self._cctx - else: - cobj._mtcctx = None - return cobj - def copy_stream(self, ifh, ofh, size=0, + def copy_stream(self, ifh, ofh, size=-1, read_size=COMPRESSION_RECOMMENDED_INPUT_SIZE, write_size=COMPRESSION_RECOMMENDED_OUTPUT_SIZE): @@ -485,11 +860,15 @@ if not hasattr(ofh, 'write'): raise ValueError('second argument must have a write() method') - mt = self._multithreaded - if mt: - self._init_mtcstream(size) - else: - self._ensure_cstream(size) + self._ensure_cctx() + + if size < 0: + size = lib.ZSTD_CONTENTSIZE_UNKNOWN + + zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._cctx, size) + if lib.ZSTD_isError(zresult): + raise ZstdError('error setting source size: %s' % + _zstd_error(zresult)) in_buffer = ffi.new('ZSTD_inBuffer *') out_buffer = ffi.new('ZSTD_outBuffer *') @@ -513,14 +892,13 @@ in_buffer.pos = 0 while in_buffer.pos < in_buffer.size: - if mt: - zresult = lib.ZSTDMT_compressStream(self._cctx, out_buffer, in_buffer) - else: - zresult = lib.ZSTD_compressStream(self._cstream, - out_buffer, in_buffer) + zresult = lib.ZSTD_compress_generic(self._cctx, + out_buffer, + in_buffer, + lib.ZSTD_e_continue) if lib.ZSTD_isError(zresult): raise ZstdError('zstd compress error: %s' % - ffi.string(lib.ZSTD_getErrorName(zresult))) + _zstd_error(zresult)) if out_buffer.pos: ofh.write(ffi.buffer(out_buffer.dst, out_buffer.pos)) @@ -529,13 +907,13 @@ # We've finished reading. Flush the compressor. while True: - if mt: - zresult = lib.ZSTDMT_endStream(self._cctx, out_buffer) - else: - zresult = lib.ZSTD_endStream(self._cstream, out_buffer) + zresult = lib.ZSTD_compress_generic(self._cctx, + out_buffer, + in_buffer, + lib.ZSTD_e_end) if lib.ZSTD_isError(zresult): raise ZstdError('error ending compression stream: %s' % - ffi.string(lib.ZSTD_getErrorName(zresult))) + _zstd_error(zresult)) if out_buffer.pos: ofh.write(ffi.buffer(out_buffer.dst, out_buffer.pos)) @@ -547,17 +925,38 @@ return total_read, total_write - def write_to(self, writer, size=0, + def stream_reader(self, source, size=-1, + read_size=COMPRESSION_RECOMMENDED_INPUT_SIZE): + self._ensure_cctx() + + try: + size = len(source) + except Exception: + pass + + if size < 0: + size = lib.ZSTD_CONTENTSIZE_UNKNOWN + + return CompressionReader(self, source, size, read_size) + + def stream_writer(self, writer, size=-1, write_size=COMPRESSION_RECOMMENDED_OUTPUT_SIZE): if not hasattr(writer, 'write'): raise ValueError('must pass an object with a write() method') + self._ensure_cctx() + + if size < 0: + size = lib.ZSTD_CONTENTSIZE_UNKNOWN + return ZstdCompressionWriter(self, writer, size, write_size) - def read_from(self, reader, size=0, - read_size=COMPRESSION_RECOMMENDED_INPUT_SIZE, - write_size=COMPRESSION_RECOMMENDED_OUTPUT_SIZE): + write_to = stream_writer + + def read_to_iter(self, reader, size=-1, + read_size=COMPRESSION_RECOMMENDED_INPUT_SIZE, + write_size=COMPRESSION_RECOMMENDED_OUTPUT_SIZE): if hasattr(reader, 'read'): have_read = True elif hasattr(reader, '__getitem__'): @@ -568,10 +967,15 @@ raise ValueError('must pass an object with a read() method or ' 'conforms to buffer protocol') - if self._multithreaded: - self._init_mtcstream(size) - else: - self._ensure_cstream(size) + self._ensure_cctx() + + if size < 0: + size = lib.ZSTD_CONTENTSIZE_UNKNOWN + + zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._cctx, size) + if lib.ZSTD_isError(zresult): + raise ZstdError('error setting source size: %s' % + _zstd_error(zresult)) in_buffer = ffi.new('ZSTD_inBuffer *') out_buffer = ffi.new('ZSTD_outBuffer *') @@ -611,13 +1015,11 @@ in_buffer.pos = 0 while in_buffer.pos < in_buffer.size: - if self._multithreaded: - zresult = lib.ZSTDMT_compressStream(self._cctx, out_buffer, in_buffer) - else: - zresult = lib.ZSTD_compressStream(self._cstream, out_buffer, in_buffer) + zresult = lib.ZSTD_compress_generic(self._cctx, out_buffer, in_buffer, + lib.ZSTD_e_continue) if lib.ZSTD_isError(zresult): raise ZstdError('zstd compress error: %s' % - ffi.string(lib.ZSTD_getErrorName(zresult))) + _zstd_error(zresult)) if out_buffer.pos: data = ffi.buffer(out_buffer.dst, out_buffer.pos)[:] @@ -633,13 +1035,13 @@ # remains. while True: assert out_buffer.pos == 0 - if self._multithreaded: - zresult = lib.ZSTDMT_endStream(self._cctx, out_buffer) - else: - zresult = lib.ZSTD_endStream(self._cstream, out_buffer) + zresult = lib.ZSTD_compress_generic(self._cctx, + out_buffer, + in_buffer, + lib.ZSTD_e_end) if lib.ZSTD_isError(zresult): raise ZstdError('error ending compression stream: %s' % - ffi.string(lib.ZSTD_getErrorName(zresult))) + _zstd_error(zresult)) if out_buffer.pos: data = ffi.buffer(out_buffer.dst, out_buffer.pos)[:] @@ -649,67 +1051,12 @@ if zresult == 0: break - def _ensure_cstream(self, size): - if self._cstream: - zresult = lib.ZSTD_resetCStream(self._cstream, size) - if lib.ZSTD_isError(zresult): - raise ZstdError('could not reset CStream: %s' % - ffi.string(lib.ZSTD_getErrorName(zresult))) - - return - - cstream = lib.ZSTD_createCStream() - if cstream == ffi.NULL: - raise MemoryError() - - cstream = ffi.gc(cstream, lib.ZSTD_freeCStream) - - dict_data = ffi.NULL - dict_size = 0 - if self._dict_data: - dict_data = self._dict_data.as_bytes() - dict_size = len(self._dict_data) - - zparams = ffi.new('ZSTD_parameters *')[0] - if self._cparams: - zparams.cParams = self._cparams.as_compression_parameters() - else: - zparams.cParams = lib.ZSTD_getCParams(self._compression_level, - size, dict_size) - zparams.fParams = self._fparams + read_from = read_to_iter - zresult = lib.ZSTD_initCStream_advanced(cstream, dict_data, dict_size, - zparams, size) - if lib.ZSTD_isError(zresult): - raise Exception('cannot init CStream: %s' % - ffi.string(lib.ZSTD_getErrorName(zresult))) - - self._cstream = cstream - - def _init_mtcstream(self, size): - assert self._multithreaded + def frame_progression(self): + progression = lib.ZSTD_getFrameProgression(self._cctx) - dict_data = ffi.NULL - dict_size = 0 - if self._dict_data: - dict_data = self._dict_data.as_bytes() - dict_size = len(self._dict_data) - - zparams = ffi.new('ZSTD_parameters *')[0] - if self._cparams: - zparams.cParams = self._cparams.as_compression_parameters() - else: - zparams.cParams = lib.ZSTD_getCParams(self._compression_level, - size, dict_size) - - zparams.fParams = self._fparams - - zresult = lib.ZSTDMT_initCStream_advanced(self._cctx, dict_data, dict_size, - zparams, size) - - if lib.ZSTD_isError(zresult): - raise ZstdError('cannot init CStream: %s' % - ffi.string(lib.ZSTD_getErrorName(zresult))) + return progression.ingested, progression.consumed, progression.produced class FrameParameters(object): @@ -720,16 +1067,38 @@ self.has_checksum = bool(fparams.checksumFlag) -def get_frame_parameters(data): - if not isinstance(data, bytes_type): - raise TypeError('argument must be bytes') +def frame_content_size(data): + data_buffer = ffi.from_buffer(data) + + size = lib.ZSTD_getFrameContentSize(data_buffer, len(data_buffer)) + + if size == lib.ZSTD_CONTENTSIZE_ERROR: + raise ZstdError('error when determining content size') + elif size == lib.ZSTD_CONTENTSIZE_UNKNOWN: + return -1 + else: + return size + - params = ffi.new('ZSTD_frameParams *') +def frame_header_size(data): + data_buffer = ffi.from_buffer(data) + + zresult = lib.ZSTD_frameHeaderSize(data_buffer, len(data_buffer)) + if lib.ZSTD_isError(zresult): + raise ZstdError('could not determine frame header size: %s' % + _zstd_error(zresult)) - zresult = lib.ZSTD_getFrameParams(params, data, len(data)) + return zresult + + +def get_frame_parameters(data): + params = ffi.new('ZSTD_frameHeader *') + + data_buffer = ffi.from_buffer(data) + zresult = lib.ZSTD_getFrameHeader(params, data_buffer, len(data_buffer)) if lib.ZSTD_isError(zresult): raise ZstdError('cannot get frame parameters: %s' % - ffi.string(lib.ZSTD_getErrorName(zresult))) + _zstd_error(zresult)) if zresult: raise ZstdError('not enough data for frame parameters; need %d bytes' % @@ -739,12 +1108,20 @@ class ZstdCompressionDict(object): - def __init__(self, data, k=0, d=0): + def __init__(self, data, dict_type=DICT_TYPE_AUTO, k=0, d=0): assert isinstance(data, bytes_type) self._data = data self.k = k self.d = d + if dict_type not in (DICT_TYPE_AUTO, DICT_TYPE_RAWCONTENT, + DICT_TYPE_FULLDICT): + raise ValueError('invalid dictionary load mode: %d; must use ' + 'DICT_TYPE_* constants') + + self._dict_type = dict_type + self._cdict = None + def __len__(self): return len(self._data) @@ -754,51 +1131,55 @@ def as_bytes(self): return self._data + def precompute_compress(self, level=0, compression_params=None): + if level and compression_params: + raise ValueError('must only specify one of level or ' + 'compression_params') -def train_dictionary(dict_size, samples, selectivity=0, level=0, - notifications=0, dict_id=0): - if not isinstance(samples, list): - raise TypeError('samples must be a list') - - total_size = sum(map(len, samples)) - - samples_buffer = new_nonzero('char[]', total_size) - sample_sizes = new_nonzero('size_t[]', len(samples)) + if not level and not compression_params: + raise ValueError('must specify one of level or compression_params') - offset = 0 - for i, sample in enumerate(samples): - if not isinstance(sample, bytes_type): - raise ValueError('samples must be bytes') - - l = len(sample) - ffi.memmove(samples_buffer + offset, sample, l) - offset += l - sample_sizes[i] = l - - dict_data = new_nonzero('char[]', dict_size) + if level: + cparams = lib.ZSTD_getCParams(level, 0, len(self._data)) + else: + cparams = ffi.new('ZSTD_compressionParameters') + cparams.chainLog = compression_params.chain_log + cparams.hashLog = compression_params.hash_log + cparams.searchLength = compression_params.min_match + cparams.searchLog = compression_params.search_log + cparams.strategy = compression_params.compression_strategy + cparams.targetLength = compression_params.target_length + cparams.windowLog = compression_params.window_log - dparams = ffi.new('ZDICT_params_t *')[0] - dparams.selectivityLevel = selectivity - dparams.compressionLevel = level - dparams.notificationLevel = notifications - dparams.dictID = dict_id + cdict = lib.ZSTD_createCDict_advanced(self._data, len(self._data), + lib.ZSTD_dlm_byRef, + self._dict_type, + cparams, + lib.ZSTD_defaultCMem) + if cdict == ffi.NULL: + raise ZstdError('unable to precompute dictionary') + + self._cdict = ffi.gc(cdict, lib.ZSTD_freeCDict, + size=lib.ZSTD_sizeof_CDict(cdict)) - zresult = lib.ZDICT_trainFromBuffer_advanced( - ffi.addressof(dict_data), dict_size, - ffi.addressof(samples_buffer), - ffi.addressof(sample_sizes, 0), len(samples), - dparams) + @property + def _ddict(self): + ddict = lib.ZSTD_createDDict_advanced(self._data, len(self._data), + lib.ZSTD_dlm_byRef, + self._dict_type, + lib.ZSTD_defaultCMem) - if lib.ZDICT_isError(zresult): - raise ZstdError('Cannot train dict: %s' % - ffi.string(lib.ZDICT_getErrorName(zresult))) + if ddict == ffi.NULL: + raise ZstdError('could not create decompression dict') - return ZstdCompressionDict(ffi.buffer(dict_data, zresult)[:]) - + ddict = ffi.gc(ddict, lib.ZSTD_freeDDict, + size=lib.ZSTD_sizeof_DDict(ddict)) + self.__dict__['_ddict'] = ddict -def train_cover_dictionary(dict_size, samples, k=0, d=0, - notifications=0, dict_id=0, level=0, optimize=False, - steps=0, threads=0): + return ddict + +def train_dictionary(dict_size, samples, k=0, d=0, notifications=0, dict_id=0, + level=0, steps=0, threads=0): if not isinstance(samples, list): raise TypeError('samples must be a list') @@ -822,47 +1203,55 @@ dict_data = new_nonzero('char[]', dict_size) - dparams = ffi.new('COVER_params_t *')[0] + dparams = ffi.new('ZDICT_cover_params_t *')[0] dparams.k = k dparams.d = d dparams.steps = steps dparams.nbThreads = threads - dparams.notificationLevel = notifications - dparams.dictID = dict_id - dparams.compressionLevel = level + dparams.zParams.notificationLevel = notifications + dparams.zParams.dictID = dict_id + dparams.zParams.compressionLevel = level - if optimize: - zresult = lib.COVER_optimizeTrainFromBuffer( + if (not dparams.k and not dparams.d and not dparams.steps + and not dparams.nbThreads and not dparams.zParams.notificationLevel + and not dparams.zParams.dictID + and not dparams.zParams.compressionLevel): + zresult = lib.ZDICT_trainFromBuffer( + ffi.addressof(dict_data), dict_size, + ffi.addressof(samples_buffer), + ffi.addressof(sample_sizes, 0), len(samples)) + elif dparams.steps or dparams.nbThreads: + zresult = lib.ZDICT_optimizeTrainFromBuffer_cover( ffi.addressof(dict_data), dict_size, ffi.addressof(samples_buffer), ffi.addressof(sample_sizes, 0), len(samples), ffi.addressof(dparams)) else: - zresult = lib.COVER_trainFromBuffer( + zresult = lib.ZDICT_trainFromBuffer_cover( ffi.addressof(dict_data), dict_size, ffi.addressof(samples_buffer), ffi.addressof(sample_sizes, 0), len(samples), dparams) if lib.ZDICT_isError(zresult): - raise ZstdError('cannot train dict: %s' % - ffi.string(lib.ZDICT_getErrorName(zresult))) + msg = ffi.string(lib.ZDICT_getErrorName(zresult)).decode('utf-8') + raise ZstdError('cannot train dict: %s' % msg) return ZstdCompressionDict(ffi.buffer(dict_data, zresult)[:], + dict_type=DICT_TYPE_FULLDICT, k=dparams.k, d=dparams.d) class ZstdDecompressionObj(object): - def __init__(self, decompressor): + def __init__(self, decompressor, write_size): self._decompressor = decompressor + self._write_size = write_size self._finished = False def decompress(self, data): if self._finished: raise ZstdError('cannot use a decompressobj multiple times') - assert(self._decompressor._dstream) - in_buffer = ffi.new('ZSTD_inBuffer *') out_buffer = ffi.new('ZSTD_outBuffer *') @@ -871,7 +1260,7 @@ in_buffer.size = len(data_buffer) in_buffer.pos = 0 - dst_buffer = ffi.new('char[]', DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE) + dst_buffer = ffi.new('char[]', self._write_size) out_buffer.dst = dst_buffer out_buffer.size = len(dst_buffer) out_buffer.pos = 0 @@ -879,11 +1268,11 @@ chunks = [] while in_buffer.pos < in_buffer.size: - zresult = lib.ZSTD_decompressStream(self._decompressor._dstream, - out_buffer, in_buffer) + zresult = lib.ZSTD_decompress_generic(self._decompressor._dctx, + out_buffer, in_buffer) if lib.ZSTD_isError(zresult): raise ZstdError('zstd decompressor error: %s' % - ffi.string(lib.ZSTD_getErrorName(zresult))) + _zstd_error(zresult)) if zresult == 0: self._finished = True @@ -896,6 +1285,203 @@ return b''.join(chunks) +class DecompressionReader(object): + def __init__(self, decompressor, source, read_size): + self._decompressor = decompressor + self._source = source + self._read_size = read_size + self._entered = False + self._closed = False + self._bytes_decompressed = 0 + self._finished_input = False + self._finished_output = False + self._in_buffer = ffi.new('ZSTD_inBuffer *') + # Holds a ref to self._in_buffer.src. + self._source_buffer = None + + def __enter__(self): + if self._entered: + raise ValueError('cannot __enter__ multiple times') + + self._decompressor._ensure_dctx() + + self._entered = True + return self + + def __exit__(self, exc_type, exc_value, exc_tb): + self._entered = False + self._closed = True + self._source = None + self._decompressor = None + + return False + + def readable(self): + return True + + def writable(self): + return False + + def seekable(self): + return True + + def readline(self): + raise NotImplementedError() + + def readlines(self): + raise NotImplementedError() + + def write(self, data): + raise io.UnsupportedOperation() + + def writelines(self, lines): + raise io.UnsupportedOperation() + + def isatty(self): + return False + + def flush(self): + return None + + def close(self): + self._closed = True + return None + + def closed(self): + return self._closed + + def tell(self): + return self._bytes_decompressed + + def readall(self): + raise NotImplementedError() + + def __iter__(self): + raise NotImplementedError() + + def __next__(self): + raise NotImplementedError() + + next = __next__ + + def read(self, size=-1): + if not self._entered: + raise ZstdError('read() must be called from an active context manager') + + if self._closed: + raise ValueError('stream is closed') + + if self._finished_output: + return b'' + + if size < 1: + raise ValueError('cannot read negative or size 0 amounts') + + dst_buffer = ffi.new('char[]', size) + out_buffer = ffi.new('ZSTD_outBuffer *') + out_buffer.dst = dst_buffer + out_buffer.size = size + out_buffer.pos = 0 + + def decompress(): + zresult = lib.ZSTD_decompress_generic(self._decompressor._dctx, + out_buffer, self._in_buffer) + + if self._in_buffer.pos == self._in_buffer.size: + self._in_buffer.src = ffi.NULL + self._in_buffer.pos = 0 + self._in_buffer.size = 0 + self._source_buffer = None + + if not hasattr(self._source, 'read'): + self._finished_input = True + + if lib.ZSTD_isError(zresult): + raise ZstdError('zstd decompress error: %s', + _zstd_error(zresult)) + elif zresult == 0: + self._finished_output = True + + if out_buffer.pos and out_buffer.pos == out_buffer.size: + self._bytes_decompressed += out_buffer.size + return ffi.buffer(out_buffer.dst, out_buffer.pos)[:] + + def get_input(): + if self._finished_input: + return + + if hasattr(self._source, 'read'): + data = self._source.read(self._read_size) + + if not data: + self._finished_input = True + return + + self._source_buffer = ffi.from_buffer(data) + self._in_buffer.src = self._source_buffer + self._in_buffer.size = len(self._source_buffer) + self._in_buffer.pos = 0 + else: + self._source_buffer = ffi.from_buffer(self._source) + self._in_buffer.src = self._source_buffer + self._in_buffer.size = len(self._source_buffer) + self._in_buffer.pos = 0 + + get_input() + result = decompress() + if result: + return result + + while not self._finished_input: + get_input() + result = decompress() + if result: + return result + + self._bytes_decompressed += out_buffer.pos + return ffi.buffer(out_buffer.dst, out_buffer.pos)[:] + + def seek(self, pos, whence=os.SEEK_SET): + if not self._entered: + raise ZstdError('seek() must be called from an active context ' + 'manager') + + if self._closed: + raise ValueError('stream is closed') + + read_amount = 0 + + if whence == os.SEEK_SET: + if pos < 0: + raise ValueError('cannot seek to negative position with SEEK_SET') + + if pos < self._bytes_decompressed: + raise ValueError('cannot seek zstd decompression stream ' + 'backwards') + + read_amount = pos - self._bytes_decompressed + + elif whence == os.SEEK_CUR: + if pos < 0: + raise ValueError('cannot seek zstd decompression stream ' + 'backwards') + + read_amount = pos + elif whence == os.SEEK_END: + raise ValueError('zstd decompression streams cannot be seeked ' + 'with SEEK_END') + + while read_amount: + result = self.read(min(read_amount, + DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE)) + + if not result: + break + + read_amount -= len(result) + + return self._bytes_decompressed + class ZstdDecompressionWriter(object): def __init__(self, decompressor, writer, write_size): self._decompressor = decompressor @@ -907,7 +1493,7 @@ if self._entered: raise ZstdError('cannot __enter__ multiple times') - self._decompressor._ensure_dstream() + self._decompressor._ensure_dctx() self._entered = True return self @@ -916,11 +1502,11 @@ self._entered = False def memory_size(self): - if not self._decompressor._dstream: + if not self._decompressor._dctx: raise ZstdError('cannot determine size of inactive decompressor ' 'call when context manager is active') - return lib.ZSTD_sizeof_DStream(self._decompressor._dstream) + return lib.ZSTD_sizeof_DCtx(self._decompressor._dctx) def write(self, data): if not self._entered: @@ -941,13 +1527,13 @@ out_buffer.size = len(dst_buffer) out_buffer.pos = 0 - dstream = self._decompressor._dstream + dctx = self._decompressor._dctx while in_buffer.pos < in_buffer.size: - zresult = lib.ZSTD_decompressStream(dstream, out_buffer, in_buffer) + zresult = lib.ZSTD_decompress_generic(dctx, out_buffer, in_buffer) if lib.ZSTD_isError(zresult): raise ZstdError('zstd decompress error: %s' % - ffi.string(lib.ZSTD_getErrorName(zresult))) + _zstd_error(zresult)) if out_buffer.pos: self._writer.write(ffi.buffer(out_buffer.dst, out_buffer.pos)[:]) @@ -958,77 +1544,86 @@ class ZstdDecompressor(object): - def __init__(self, dict_data=None): + def __init__(self, dict_data=None, max_window_size=0, format=FORMAT_ZSTD1): self._dict_data = dict_data + self._max_window_size = max_window_size + self._format = format dctx = lib.ZSTD_createDCtx() if dctx == ffi.NULL: raise MemoryError() - self._refdctx = ffi.gc(dctx, lib.ZSTD_freeDCtx) - self._dstream = None - - @property - def _ddict(self): - if self._dict_data: - dict_data = self._dict_data.as_bytes() - dict_size = len(self._dict_data) + self._dctx = dctx - ddict = lib.ZSTD_createDDict(dict_data, dict_size) - if ddict == ffi.NULL: - raise ZstdError('could not create decompression dict') - else: - ddict = None + # Defer setting up garbage collection until full state is loaded so + # the memory size is more accurate. + try: + self._ensure_dctx() + finally: + self._dctx = ffi.gc(dctx, lib.ZSTD_freeDCtx, + size=lib.ZSTD_sizeof_DCtx(dctx)) - self.__dict__['_ddict'] = ddict - return ddict + def memory_size(self): + return lib.ZSTD_sizeof_DCtx(self._dctx) def decompress(self, data, max_output_size=0): + self._ensure_dctx() + data_buffer = ffi.from_buffer(data) - orig_dctx = new_nonzero('char[]', lib.ZSTD_sizeof_DCtx(self._refdctx)) - dctx = ffi.cast('ZSTD_DCtx *', orig_dctx) - lib.ZSTD_copyDCtx(dctx, self._refdctx) - - ddict = self._ddict + output_size = lib.ZSTD_getFrameContentSize(data_buffer, len(data_buffer)) - output_size = lib.ZSTD_getDecompressedSize(data_buffer, len(data_buffer)) - if output_size: - result_buffer = ffi.new('char[]', output_size) - result_size = output_size - else: + if output_size == lib.ZSTD_CONTENTSIZE_ERROR: + raise ZstdError('error determining content size from frame header') + elif output_size == 0: + return b'' + elif output_size == lib.ZSTD_CONTENTSIZE_UNKNOWN: if not max_output_size: - raise ZstdError('input data invalid or missing content size ' - 'in frame header') + raise ZstdError('could not determine content size in frame header') result_buffer = ffi.new('char[]', max_output_size) result_size = max_output_size + output_size = 0 + else: + result_buffer = ffi.new('char[]', output_size) + result_size = output_size - if ddict: - zresult = lib.ZSTD_decompress_usingDDict(dctx, - result_buffer, result_size, - data_buffer, len(data_buffer), - ddict) - else: - zresult = lib.ZSTD_decompressDCtx(dctx, - result_buffer, result_size, - data_buffer, len(data_buffer)) + out_buffer = ffi.new('ZSTD_outBuffer *') + out_buffer.dst = result_buffer + out_buffer.size = result_size + out_buffer.pos = 0 + + in_buffer = ffi.new('ZSTD_inBuffer *') + in_buffer.src = data_buffer + in_buffer.size = len(data_buffer) + in_buffer.pos = 0 + + zresult = lib.ZSTD_decompress_generic(self._dctx, out_buffer, in_buffer) if lib.ZSTD_isError(zresult): raise ZstdError('decompression error: %s' % - ffi.string(lib.ZSTD_getErrorName(zresult))) - elif output_size and zresult != output_size: + _zstd_error(zresult)) + elif zresult: + raise ZstdError('decompression error: did not decompress full frame') + elif output_size and out_buffer.pos != output_size: raise ZstdError('decompression error: decompressed %d bytes; expected %d' % (zresult, output_size)) - return ffi.buffer(result_buffer, zresult)[:] + return ffi.buffer(result_buffer, out_buffer.pos)[:] + + def stream_reader(self, source, read_size=DECOMPRESSION_RECOMMENDED_INPUT_SIZE): + self._ensure_dctx() + return DecompressionReader(self, source, read_size) - def decompressobj(self): - self._ensure_dstream() - return ZstdDecompressionObj(self) + def decompressobj(self, write_size=DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE): + if write_size < 1: + raise ValueError('write_size must be positive') - def read_from(self, reader, read_size=DECOMPRESSION_RECOMMENDED_INPUT_SIZE, - write_size=DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE, - skip_bytes=0): + self._ensure_dctx() + return ZstdDecompressionObj(self, write_size=write_size) + + def read_to_iter(self, reader, read_size=DECOMPRESSION_RECOMMENDED_INPUT_SIZE, + write_size=DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE, + skip_bytes=0): if skip_bytes >= read_size: raise ValueError('skip_bytes must be smaller than read_size') @@ -1051,7 +1646,7 @@ buffer_offset = skip_bytes - self._ensure_dstream() + self._ensure_dctx() in_buffer = ffi.new('ZSTD_inBuffer *') out_buffer = ffi.new('ZSTD_outBuffer *') @@ -1086,10 +1681,10 @@ while in_buffer.pos < in_buffer.size: assert out_buffer.pos == 0 - zresult = lib.ZSTD_decompressStream(self._dstream, out_buffer, in_buffer) + zresult = lib.ZSTD_decompress_generic(self._dctx, out_buffer, in_buffer) if lib.ZSTD_isError(zresult): raise ZstdError('zstd decompress error: %s' % - ffi.string(lib.ZSTD_getErrorName(zresult))) + _zstd_error(zresult)) if out_buffer.pos: data = ffi.buffer(out_buffer.dst, out_buffer.pos)[:] @@ -1104,12 +1699,16 @@ # If we get here, input is exhausted. - def write_to(self, writer, write_size=DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE): + read_from = read_to_iter + + def stream_writer(self, writer, write_size=DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE): if not hasattr(writer, 'write'): raise ValueError('must pass an object with a write() method') return ZstdDecompressionWriter(self, writer, write_size) + write_to = stream_writer + def copy_stream(self, ifh, ofh, read_size=DECOMPRESSION_RECOMMENDED_INPUT_SIZE, write_size=DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE): @@ -1118,7 +1717,7 @@ if not hasattr(ofh, 'write'): raise ValueError('second argument must have a write() method') - self._ensure_dstream() + self._ensure_dctx() in_buffer = ffi.new('ZSTD_inBuffer *') out_buffer = ffi.new('ZSTD_outBuffer *') @@ -1144,10 +1743,10 @@ # Flush all read data to output. while in_buffer.pos < in_buffer.size: - zresult = lib.ZSTD_decompressStream(self._dstream, out_buffer, in_buffer) + zresult = lib.ZSTD_decompress_generic(self._dctx, out_buffer, in_buffer) if lib.ZSTD_isError(zresult): raise ZstdError('zstd decompressor error: %s' % - ffi.string(lib.ZSTD_getErrorName(zresult))) + _zstd_error(zresult)) if out_buffer.pos: ofh.write(ffi.buffer(out_buffer.dst, out_buffer.pos)) @@ -1172,29 +1771,36 @@ # All chunks should be zstd frames and should have content size set. chunk_buffer = ffi.from_buffer(chunk) - params = ffi.new('ZSTD_frameParams *') - zresult = lib.ZSTD_getFrameParams(params, chunk_buffer, len(chunk_buffer)) + params = ffi.new('ZSTD_frameHeader *') + zresult = lib.ZSTD_getFrameHeader(params, chunk_buffer, len(chunk_buffer)) if lib.ZSTD_isError(zresult): raise ValueError('chunk 0 is not a valid zstd frame') elif zresult: raise ValueError('chunk 0 is too small to contain a zstd frame') - if not params.frameContentSize: + if params.frameContentSize == lib.ZSTD_CONTENTSIZE_UNKNOWN: raise ValueError('chunk 0 missing content size in frame') - dctx = lib.ZSTD_createDCtx() - if dctx == ffi.NULL: - raise MemoryError() - - dctx = ffi.gc(dctx, lib.ZSTD_freeDCtx) + self._ensure_dctx(load_dict=False) last_buffer = ffi.new('char[]', params.frameContentSize) - zresult = lib.ZSTD_decompressDCtx(dctx, last_buffer, len(last_buffer), - chunk_buffer, len(chunk_buffer)) + out_buffer = ffi.new('ZSTD_outBuffer *') + out_buffer.dst = last_buffer + out_buffer.size = len(last_buffer) + out_buffer.pos = 0 + + in_buffer = ffi.new('ZSTD_inBuffer *') + in_buffer.src = chunk_buffer + in_buffer.size = len(chunk_buffer) + in_buffer.pos = 0 + + zresult = lib.ZSTD_decompress_generic(self._dctx, out_buffer, in_buffer) if lib.ZSTD_isError(zresult): raise ZstdError('could not decompress chunk 0: %s' % - ffi.string(lib.ZSTD_getErrorName(zresult))) + _zstd_error(zresult)) + elif zresult: + raise ZstdError('chunk 0 did not decompress full frame') # Special case of chain length of 1 if len(frames) == 1: @@ -1207,51 +1813,54 @@ raise ValueError('chunk %d must be bytes' % i) chunk_buffer = ffi.from_buffer(chunk) - zresult = lib.ZSTD_getFrameParams(params, chunk_buffer, len(chunk_buffer)) + zresult = lib.ZSTD_getFrameHeader(params, chunk_buffer, len(chunk_buffer)) if lib.ZSTD_isError(zresult): raise ValueError('chunk %d is not a valid zstd frame' % i) elif zresult: raise ValueError('chunk %d is too small to contain a zstd frame' % i) - if not params.frameContentSize: + if params.frameContentSize == lib.ZSTD_CONTENTSIZE_UNKNOWN: raise ValueError('chunk %d missing content size in frame' % i) dest_buffer = ffi.new('char[]', params.frameContentSize) - zresult = lib.ZSTD_decompress_usingDict(dctx, dest_buffer, len(dest_buffer), - chunk_buffer, len(chunk_buffer), - last_buffer, len(last_buffer)) + out_buffer.dst = dest_buffer + out_buffer.size = len(dest_buffer) + out_buffer.pos = 0 + + in_buffer.src = chunk_buffer + in_buffer.size = len(chunk_buffer) + in_buffer.pos = 0 + + zresult = lib.ZSTD_decompress_generic(self._dctx, out_buffer, in_buffer) if lib.ZSTD_isError(zresult): - raise ZstdError('could not decompress chunk %d' % i) + raise ZstdError('could not decompress chunk %d: %s' % + _zstd_error(zresult)) + elif zresult: + raise ZstdError('chunk %d did not decompress full frame' % i) last_buffer = dest_buffer i += 1 return ffi.buffer(last_buffer, len(last_buffer))[:] - def _ensure_dstream(self): - if self._dstream: - zresult = lib.ZSTD_resetDStream(self._dstream) + def _ensure_dctx(self, load_dict=True): + lib.ZSTD_DCtx_reset(self._dctx) + + if self._max_window_size: + zresult = lib.ZSTD_DCtx_setMaxWindowSize(self._dctx, + self._max_window_size) if lib.ZSTD_isError(zresult): - raise ZstdError('could not reset DStream: %s' % - ffi.string(lib.ZSTD_getErrorName(zresult))) - - return - - self._dstream = lib.ZSTD_createDStream() - if self._dstream == ffi.NULL: - raise MemoryError() + raise ZstdError('unable to set max window size: %s' % + _zstd_error(zresult)) - self._dstream = ffi.gc(self._dstream, lib.ZSTD_freeDStream) + zresult = lib.ZSTD_DCtx_setFormat(self._dctx, self._format) + if lib.ZSTD_isError(zresult): + raise ZstdError('unable to set decoding format: %s' % + _zstd_error(zresult)) - if self._dict_data: - zresult = lib.ZSTD_initDStream_usingDict(self._dstream, - self._dict_data.as_bytes(), - len(self._dict_data)) - else: - zresult = lib.ZSTD_initDStream(self._dstream) - - if lib.ZSTD_isError(zresult): - self._dstream = None - raise ZstdError('could not initialize DStream: %s' % - ffi.string(lib.ZSTD_getErrorName(zresult))) + if self._dict_data and load_dict: + zresult = lib.ZSTD_DCtx_refDDict(self._dctx, self._dict_data._ddict) + if lib.ZSTD_isError(zresult): + raise ZstdError('unable to reference prepared dictionary: %s' % + _zstd_error(zresult))