changeset 52698:1a612f9ec2c4

typing: add type annotations for wireprotoframing encoder/decoder classes This was low hanging fruit, and will point out an apparent programming bug. The arguments are left untyped for now, because the internal compressors and decompressors generally take a `Buffer` object, and return bytes. I've had issues around using `typing_extensions.Buffer` with pytype in the past, but also the identity classes return the argument unmodified, so that will break the type annotation. We'd probably need an instance check, and a cast if it's not bytes. Having to roll up all of the encoder and decoder types into a union type is a bit annoying, but the protocol class can't be used in the map because it then assumes that type will be instantiated. I also couldn't get it to work with a `TypeVar` bound to indicate that it's actually a subclass either. Oh well.
author Matt Harbison <matt_harbison@yahoo.com>
date Mon, 13 Jan 2025 15:37:08 -0500
parents f3762eafed66
children 068398a8c9cb
files mercurial/wireprotoframing.py
diffstat 1 files changed, 64 insertions(+), 28 deletions(-) [+]
line wrap: on
line diff
--- a/mercurial/wireprotoframing.py	Mon Jan 13 12:24:33 2025 -0500
+++ b/mercurial/wireprotoframing.py	Mon Jan 13 15:37:08 2025 -0500
@@ -11,10 +11,16 @@
 
 from __future__ import annotations
 
+import abc
 import collections
 import struct
 import typing
 
+from typing import (
+    Protocol,
+    Type,
+)
+
 from .i18n import _
 from .thirdparty import attr
 
@@ -705,58 +711,82 @@
 # mechanism.
 
 
-class identityencoder:
+class Encoder(Protocol):
+    """A protocol class for the various encoder implementations."""
+
+    @abc.abstractmethod
+    def encode(self, data) -> bytes:
+        raise NotImplementedError
+
+    @abc.abstractmethod
+    def flush(self) -> bytes:
+        raise NotImplementedError
+
+    @abc.abstractmethod
+    def finish(self) -> bytes:
+        raise NotImplementedError
+
+
+class Decoder(Protocol):
+    """A protocol class for the various encoder implementations."""
+
+    @abc.abstractmethod
+    def decode(self, data) -> bytes:
+        raise NotImplementedError
+
+
+class identityencoder(Encoder):
     """Encoder for the "identity" stream encoding profile."""
 
-    def __init__(self, ui):
+    def __init__(self, ui) -> None:
         pass
 
-    def encode(self, data):
+    def encode(self, data) -> bytes:
         return data
 
-    def flush(self):
+    def flush(self) -> bytes:
         return b''
 
-    def finish(self):
+    def finish(self) -> bytes:
         return b''
 
 
-class identitydecoder:
+class identitydecoder(Decoder):
     """Decoder for the "identity" stream encoding profile."""
 
-    def __init__(self, ui, extraobjs):
+    def __init__(self, ui, extraobjs) -> None:
         if extraobjs:
             raise error.Abort(
                 _(b'identity decoder received unexpected additional values')
             )
 
-    def decode(self, data):
+    def decode(self, data) -> bytes:
         return data
 
 
-class zlibencoder:
-    def __init__(self, ui):
+class zlibencoder(Encoder):
+    def __init__(self, ui) -> None:
         import zlib
 
         self._zlib = zlib
         self._compressor = zlib.compressobj()
 
-    def encode(self, data):
+    def encode(self, data) -> bytes:
         return self._compressor.compress(data)
 
-    def flush(self):
+    def flush(self) -> bytes:
         # Z_SYNC_FLUSH doesn't reset compression context, which is
         # what we want.
         return self._compressor.flush(self._zlib.Z_SYNC_FLUSH)
 
-    def finish(self):
+    def finish(self) -> bytes:
         res = self._compressor.flush(self._zlib.Z_FINISH)
         self._compressor = None
         return res
 
 
-class zlibdecoder:
-    def __init__(self, ui, extraobjs):
+class zlibdecoder(Decoder):
+    def __init__(self, ui, extraobjs) -> None:
         import zlib
 
         if extraobjs:
@@ -766,51 +796,51 @@
 
         self._decompressor = zlib.decompressobj()
 
-    def decode(self, data):
+    def decode(self, data) -> bytes:
         return self._decompressor.decompress(data)
 
 
-class zstdbaseencoder:
-    def __init__(self, level):
+class zstdbaseencoder(Encoder):
+    def __init__(self, level: int) -> None:
         from . import zstd  # pytype: disable=import-error
 
         self._zstd = zstd
         cctx = zstd.ZstdCompressor(level=level)
         self._compressor = cctx.compressobj()
 
-    def encode(self, data):
+    def encode(self, data) -> bytes:
         return self._compressor.compress(data)
 
-    def flush(self):
+    def flush(self) -> bytes:
         # COMPRESSOBJ_FLUSH_BLOCK flushes all data previously fed into the
         # compressor and allows a decompressor to access all encoded data
         # up to this point.
         return self._compressor.flush(self._zstd.COMPRESSOBJ_FLUSH_BLOCK)
 
-    def finish(self):
+    def finish(self) -> bytes:
         res = self._compressor.flush(self._zstd.COMPRESSOBJ_FLUSH_FINISH)
         self._compressor = None
         return res
 
 
 class zstd8mbencoder(zstdbaseencoder):
-    def __init__(self, ui):
+    def __init__(self, ui) -> None:
         super().__init__(3)
 
 
-class zstdbasedecoder:
-    def __init__(self, maxwindowsize):
+class zstdbasedecoder(Decoder):
+    def __init__(self, maxwindowsize: int) -> None:
         from . import zstd  # pytype: disable=import-error
 
         dctx = zstd.ZstdDecompressor(max_window_size=maxwindowsize)
         self._decompressor = dctx.decompressobj()
 
-    def decode(self, data):
+    def decode(self, data) -> bytes:
         return self._decompressor.decompress(data)
 
 
 class zstd8mbdecoder(zstdbasedecoder):
-    def __init__(self, ui, extraobjs):
+    def __init__(self, ui, extraobjs) -> None:
         if extraobjs:
             raise error.Abort(
                 _(b'zstd8mb decoder received unexpected additional values')
@@ -819,10 +849,16 @@
         super().__init__(maxwindowsize=8 * 1048576)
 
 
+# TypeVar('EncoderT', bound=Encoder) was flagged as "not in scope" when used
+# on the STREAM_ENCODERS dict below.
+if typing.TYPE_CHECKING:
+    EncoderT = Type[identityencoder | zlibencoder | zstd8mbencoder]
+    DecoderT = Type[identitydecoder | zlibdecoder | zstd8mbdecoder]
+
 # We lazily populate this to avoid excessive module imports when importing
 # this module.
-STREAM_ENCODERS = {}
-STREAM_ENCODERS_ORDER = []
+STREAM_ENCODERS: dict[bytes, tuple[EncoderT, DecoderT]] = {}
+STREAM_ENCODERS_ORDER: list[bytes] = []
 
 
 def populatestreamencoders():