mercurial/wireprotoframing.py
changeset 40132 e67522413ca8
parent 40130 5d44c4d1d516
child 40133 762ef19a07e3
--- a/mercurial/wireprotoframing.py	Thu Oct 04 17:39:16 2018 -0700
+++ b/mercurial/wireprotoframing.py	Mon Oct 08 17:10:59 2018 -0700
@@ -648,6 +648,140 @@
             flags=FLAG_COMMAND_RESPONSE_CONTINUATION,
             payload=payload)
 
+# TODO consider defining encoders/decoders using the util.compressionengine
+# mechanism.
+
+class identityencoder(object):
+    """Encoder for the "identity" stream encoding profile."""
+    def __init__(self, ui):
+        pass
+
+    def encode(self, data):
+        return data
+
+    def flush(self):
+        return b''
+
+    def finish(self):
+        return b''
+
+class identitydecoder(object):
+    """Decoder for the "identity" stream encoding profile."""
+
+    def __init__(self, ui, extraobjs):
+        if extraobjs:
+            raise error.Abort(_('identity decoder received unexpected '
+                                'additional values'))
+
+    def decode(self, data):
+        return data
+
+class zlibencoder(object):
+    def __init__(self, ui):
+        import zlib
+        self._zlib = zlib
+        self._compressor = zlib.compressobj()
+
+    def encode(self, data):
+        return self._compressor.compress(data)
+
+    def flush(self):
+        # Z_SYNC_FLUSH doesn't reset compression context, which is
+        # what we want.
+        return self._compressor.flush(self._zlib.Z_SYNC_FLUSH)
+
+    def finish(self):
+        res = self._compressor.flush(self._zlib.Z_FINISH)
+        self._compressor = None
+        return res
+
+class zlibdecoder(object):
+    def __init__(self, ui, extraobjs):
+        import zlib
+
+        if extraobjs:
+            raise error.Abort(_('zlib decoder received unexpected '
+                                'additional values'))
+
+        self._decompressor = zlib.decompressobj()
+
+    def decode(self, data):
+        # Python 2's zlib module doesn't use the buffer protocol and can't
+        # handle all bytes-like types.
+        if not pycompat.ispy3 and isinstance(data, bytearray):
+            data = bytes(data)
+
+        return self._decompressor.decompress(data)
+
+class zstdbaseencoder(object):
+    def __init__(self, level):
+        from . import zstd
+
+        self._zstd = zstd
+        cctx = zstd.ZstdCompressor(level=level)
+        self._compressor = cctx.compressobj()
+
+    def encode(self, data):
+        return self._compressor.compress(data)
+
+    def flush(self):
+        # COMPRESSOBJ_FLUSH_BLOCK flushes all data previously fed into the
+        # compressor and allows a decompressor to access all encoded data
+        # up to this point.
+        return self._compressor.flush(self._zstd.COMPRESSOBJ_FLUSH_BLOCK)
+
+    def finish(self):
+        res = self._compressor.flush(self._zstd.COMPRESSOBJ_FLUSH_FINISH)
+        self._compressor = None
+        return res
+
+class zstd8mbencoder(zstdbaseencoder):
+    def __init__(self, ui):
+        super(zstd8mbencoder, self).__init__(3)
+
+class zstdbasedecoder(object):
+    def __init__(self, maxwindowsize):
+        from . import zstd
+        dctx = zstd.ZstdDecompressor(max_window_size=maxwindowsize)
+        self._decompressor = dctx.decompressobj()
+
+    def decode(self, data):
+        return self._decompressor.decompress(data)
+
+class zstd8mbdecoder(zstdbasedecoder):
+    def __init__(self, ui, extraobjs):
+        if extraobjs:
+            raise error.Abort(_('zstd8mb decoder received unexpected '
+                                'additional values'))
+
+        super(zstd8mbdecoder, self).__init__(maxwindowsize=8 * 1048576)
+
+# We lazily populate this to avoid excessive module imports when importing
+# this module.
+STREAM_ENCODERS = {}
+STREAM_ENCODERS_ORDER = []
+
+def populatestreamencoders():
+    if STREAM_ENCODERS:
+        return
+
+    try:
+        from . import zstd
+        zstd.__version__
+    except ImportError:
+        zstd = None
+
+    # zstandard is fastest and is preferred.
+    if zstd:
+        STREAM_ENCODERS[b'zstd-8mb'] = (zstd8mbencoder, zstd8mbdecoder)
+        STREAM_ENCODERS_ORDER.append(b'zstd-8mb')
+
+    STREAM_ENCODERS[b'zlib'] = (zlibencoder, zlibdecoder)
+    STREAM_ENCODERS_ORDER.append(b'zlib')
+
+    STREAM_ENCODERS[b'identity'] = (identityencoder, identitydecoder)
+    STREAM_ENCODERS_ORDER.append(b'identity')
+
 class stream(object):
     """Represents a logical unidirectional series of frames."""
 
@@ -671,16 +805,70 @@
 class inputstream(stream):
     """Represents a stream used for receiving data."""
 
-    def setdecoder(self, name, extraobjs):
+    def __init__(self, streamid, active=False):
+        super(inputstream, self).__init__(streamid, active=active)
+        self._decoder = None
+
+    def setdecoder(self, ui, name, extraobjs):
         """Set the decoder for this stream.
 
         Receives the stream profile name and any additional CBOR objects
         decoded from the stream encoding settings frame payloads.
         """
+        if name not in STREAM_ENCODERS:
+            raise error.Abort(_('unknown stream decoder: %s') % name)
+
+        self._decoder = STREAM_ENCODERS[name][1](ui, extraobjs)
+
+    def decode(self, data):
+        # Default is identity decoder. We don't bother instantiating one
+        # because it is trivial.
+        if not self._decoder:
+            return data
+
+        return self._decoder.decode(data)
+
+    def flush(self):
+        if not self._decoder:
+            return b''
+
+        return self._decoder.flush()
 
 class outputstream(stream):
     """Represents a stream used for sending data."""
 
+    def __init__(self, streamid, active=False):
+        super(outputstream, self).__init__(streamid, active=active)
+        self._encoder = None
+
+    def setencoder(self, ui, name):
+        """Set the encoder for this stream.
+
+        Receives the stream profile name.
+        """
+        if name not in STREAM_ENCODERS:
+            raise error.Abort(_('unknown stream encoder: %s') % name)
+
+        self._encoder = STREAM_ENCODERS[name][0](ui)
+
+    def encode(self, data):
+        if not self._encoder:
+            return data
+
+        return self._encoder.encode(data)
+
+    def flush(self):
+        if not self._encoder:
+            return b''
+
+        return self._encoder.flush()
+
+    def finish(self):
+        if not self._encoder:
+            return b''
+
+        self._encoder.finish()
+
 def ensureserverstream(stream):
     if stream.streamid % 2:
         raise error.ProgrammingError('server should only write to even '
@@ -786,6 +974,8 @@
         # Sender protocol settings are optional. Set implied default values.
         self._sendersettings = dict(DEFAULT_PROTOCOL_SETTINGS)
 
+        populatestreamencoders()
+
     def onframerecv(self, frame):
         """Process a frame that has been received off the wire.
 
@@ -1384,6 +1574,8 @@
         self._incomingstreams = {}
         self._streamsettingsdecoders = {}
 
+        populatestreamencoders()
+
     def callcommand(self, name, args, datafh=None, redirect=None):
         """Request that a command be executed.
 
@@ -1494,9 +1686,13 @@
             self._incomingstreams[frame.streamid] = inputstream(
                 frame.streamid)
 
+        stream = self._incomingstreams[frame.streamid]
+
+        # If the payload is encoded, ask the stream to decode it. We
+        # merely substitute the decoded result into the frame payload as
+        # if it had been transferred all along.
         if frame.streamflags & STREAM_FLAG_ENCODING_APPLIED:
-            raise error.ProgrammingError('support for decoding stream '
-                                         'payloads not yet implemneted')
+            frame.payload = stream.decode(frame.payload)
 
         if frame.streamflags & STREAM_FLAG_END_STREAM:
             del self._incomingstreams[frame.streamid]
@@ -1573,7 +1769,8 @@
             }
 
         try:
-            self._incomingstreams[frame.streamid].setdecoder(decoded[0],
+            self._incomingstreams[frame.streamid].setdecoder(self._ui,
+                                                             decoded[0],
                                                              decoded[1:])
         except Exception as e:
             return 'error', {