mercurial/wireprotoframing.py
changeset 52698 1a612f9ec2c4
parent 52644 e627cc25b6f3
child 52699 068398a8c9cb
--- a/mercurial/wireprotoframing.py	Mon Jan 13 12:24:33 2025 -0500
+++ b/mercurial/wireprotoframing.py	Mon Jan 13 15:37:08 2025 -0500
@@ -11,10 +11,16 @@
 
 from __future__ import annotations
 
+import abc
 import collections
 import struct
 import typing
 
+from typing import (
+    Protocol,
+    Type,
+)
+
 from .i18n import _
 from .thirdparty import attr
 
@@ -705,58 +711,82 @@
 # mechanism.
 
 
-class identityencoder:
+class Encoder(Protocol):
+    """A protocol class for the various encoder implementations."""
+
+    @abc.abstractmethod
+    def encode(self, data) -> bytes:
+        raise NotImplementedError
+
+    @abc.abstractmethod
+    def flush(self) -> bytes:
+        raise NotImplementedError
+
+    @abc.abstractmethod
+    def finish(self) -> bytes:
+        raise NotImplementedError
+
+
+class Decoder(Protocol):
+    """A protocol class for the various encoder implementations."""
+
+    @abc.abstractmethod
+    def decode(self, data) -> bytes:
+        raise NotImplementedError
+
+
+class identityencoder(Encoder):
     """Encoder for the "identity" stream encoding profile."""
 
-    def __init__(self, ui):
+    def __init__(self, ui) -> None:
         pass
 
-    def encode(self, data):
+    def encode(self, data) -> bytes:
         return data
 
-    def flush(self):
+    def flush(self) -> bytes:
         return b''
 
-    def finish(self):
+    def finish(self) -> bytes:
         return b''
 
 
-class identitydecoder:
+class identitydecoder(Decoder):
     """Decoder for the "identity" stream encoding profile."""
 
-    def __init__(self, ui, extraobjs):
+    def __init__(self, ui, extraobjs) -> None:
         if extraobjs:
             raise error.Abort(
                 _(b'identity decoder received unexpected additional values')
             )
 
-    def decode(self, data):
+    def decode(self, data) -> bytes:
         return data
 
 
-class zlibencoder:
-    def __init__(self, ui):
+class zlibencoder(Encoder):
+    def __init__(self, ui) -> None:
         import zlib
 
         self._zlib = zlib
         self._compressor = zlib.compressobj()
 
-    def encode(self, data):
+    def encode(self, data) -> bytes:
         return self._compressor.compress(data)
 
-    def flush(self):
+    def flush(self) -> bytes:
         # Z_SYNC_FLUSH doesn't reset compression context, which is
         # what we want.
         return self._compressor.flush(self._zlib.Z_SYNC_FLUSH)
 
-    def finish(self):
+    def finish(self) -> bytes:
         res = self._compressor.flush(self._zlib.Z_FINISH)
         self._compressor = None
         return res
 
 
-class zlibdecoder:
-    def __init__(self, ui, extraobjs):
+class zlibdecoder(Decoder):
+    def __init__(self, ui, extraobjs) -> None:
         import zlib
 
         if extraobjs:
@@ -766,51 +796,51 @@
 
         self._decompressor = zlib.decompressobj()
 
-    def decode(self, data):
+    def decode(self, data) -> bytes:
         return self._decompressor.decompress(data)
 
 
-class zstdbaseencoder:
-    def __init__(self, level):
+class zstdbaseencoder(Encoder):
+    def __init__(self, level: int) -> None:
         from . import zstd  # pytype: disable=import-error
 
         self._zstd = zstd
         cctx = zstd.ZstdCompressor(level=level)
         self._compressor = cctx.compressobj()
 
-    def encode(self, data):
+    def encode(self, data) -> bytes:
         return self._compressor.compress(data)
 
-    def flush(self):
+    def flush(self) -> bytes:
         # COMPRESSOBJ_FLUSH_BLOCK flushes all data previously fed into the
         # compressor and allows a decompressor to access all encoded data
         # up to this point.
         return self._compressor.flush(self._zstd.COMPRESSOBJ_FLUSH_BLOCK)
 
-    def finish(self):
+    def finish(self) -> bytes:
         res = self._compressor.flush(self._zstd.COMPRESSOBJ_FLUSH_FINISH)
         self._compressor = None
         return res
 
 
 class zstd8mbencoder(zstdbaseencoder):
-    def __init__(self, ui):
+    def __init__(self, ui) -> None:
         super().__init__(3)
 
 
-class zstdbasedecoder:
-    def __init__(self, maxwindowsize):
+class zstdbasedecoder(Decoder):
+    def __init__(self, maxwindowsize: int) -> None:
         from . import zstd  # pytype: disable=import-error
 
         dctx = zstd.ZstdDecompressor(max_window_size=maxwindowsize)
         self._decompressor = dctx.decompressobj()
 
-    def decode(self, data):
+    def decode(self, data) -> bytes:
         return self._decompressor.decompress(data)
 
 
 class zstd8mbdecoder(zstdbasedecoder):
-    def __init__(self, ui, extraobjs):
+    def __init__(self, ui, extraobjs) -> None:
         if extraobjs:
             raise error.Abort(
                 _(b'zstd8mb decoder received unexpected additional values')
@@ -819,10 +849,16 @@
         super().__init__(maxwindowsize=8 * 1048576)
 
 
+# TypeVar('EncoderT', bound=Encoder) was flagged as "not in scope" when used
+# on the STREAM_ENCODERS dict below.
+if typing.TYPE_CHECKING:
+    EncoderT = Type[identityencoder | zlibencoder | zstd8mbencoder]
+    DecoderT = Type[identitydecoder | zlibdecoder | zstd8mbdecoder]
+
 # We lazily populate this to avoid excessive module imports when importing
 # this module.
-STREAM_ENCODERS = {}
-STREAM_ENCODERS_ORDER = []
+STREAM_ENCODERS: dict[bytes, tuple[EncoderT, DecoderT]] = {}
+STREAM_ENCODERS_ORDER: list[bytes] = []
 
 
 def populatestreamencoders():