--- a/mercurial/wireprotoframing.py Mon Jan 13 12:24:33 2025 -0500
+++ b/mercurial/wireprotoframing.py Mon Jan 13 15:37:08 2025 -0500
@@ -11,10 +11,16 @@
from __future__ import annotations
+import abc
import collections
import struct
import typing
+from typing import (
+ Protocol,
+ Type,
+)
+
from .i18n import _
from .thirdparty import attr
@@ -705,58 +711,82 @@
# mechanism.
-class identityencoder:
+class Encoder(Protocol):
+ """A protocol class for the various encoder implementations."""
+
+ @abc.abstractmethod
+ def encode(self, data) -> bytes:
+ raise NotImplementedError
+
+ @abc.abstractmethod
+ def flush(self) -> bytes:
+ raise NotImplementedError
+
+ @abc.abstractmethod
+ def finish(self) -> bytes:
+ raise NotImplementedError
+
+
+class Decoder(Protocol):
+ """A protocol class for the various encoder implementations."""
+
+ @abc.abstractmethod
+ def decode(self, data) -> bytes:
+ raise NotImplementedError
+
+
+class identityencoder(Encoder):
"""Encoder for the "identity" stream encoding profile."""
- def __init__(self, ui):
+ def __init__(self, ui) -> None:
pass
- def encode(self, data):
+ def encode(self, data) -> bytes:
return data
- def flush(self):
+ def flush(self) -> bytes:
return b''
- def finish(self):
+ def finish(self) -> bytes:
return b''
-class identitydecoder:
+class identitydecoder(Decoder):
"""Decoder for the "identity" stream encoding profile."""
- def __init__(self, ui, extraobjs):
+ def __init__(self, ui, extraobjs) -> None:
if extraobjs:
raise error.Abort(
_(b'identity decoder received unexpected additional values')
)
- def decode(self, data):
+ def decode(self, data) -> bytes:
return data
-class zlibencoder:
- def __init__(self, ui):
+class zlibencoder(Encoder):
+ def __init__(self, ui) -> None:
import zlib
self._zlib = zlib
self._compressor = zlib.compressobj()
- def encode(self, data):
+ def encode(self, data) -> bytes:
return self._compressor.compress(data)
- def flush(self):
+ def flush(self) -> bytes:
# Z_SYNC_FLUSH doesn't reset compression context, which is
# what we want.
return self._compressor.flush(self._zlib.Z_SYNC_FLUSH)
- def finish(self):
+ def finish(self) -> bytes:
res = self._compressor.flush(self._zlib.Z_FINISH)
self._compressor = None
return res
-class zlibdecoder:
- def __init__(self, ui, extraobjs):
+class zlibdecoder(Decoder):
+ def __init__(self, ui, extraobjs) -> None:
import zlib
if extraobjs:
@@ -766,51 +796,51 @@
self._decompressor = zlib.decompressobj()
- def decode(self, data):
+ def decode(self, data) -> bytes:
return self._decompressor.decompress(data)
-class zstdbaseencoder:
- def __init__(self, level):
+class zstdbaseencoder(Encoder):
+ def __init__(self, level: int) -> None:
from . import zstd # pytype: disable=import-error
self._zstd = zstd
cctx = zstd.ZstdCompressor(level=level)
self._compressor = cctx.compressobj()
- def encode(self, data):
+ def encode(self, data) -> bytes:
return self._compressor.compress(data)
- def flush(self):
+ def flush(self) -> bytes:
# COMPRESSOBJ_FLUSH_BLOCK flushes all data previously fed into the
# compressor and allows a decompressor to access all encoded data
# up to this point.
return self._compressor.flush(self._zstd.COMPRESSOBJ_FLUSH_BLOCK)
- def finish(self):
+ def finish(self) -> bytes:
res = self._compressor.flush(self._zstd.COMPRESSOBJ_FLUSH_FINISH)
self._compressor = None
return res
class zstd8mbencoder(zstdbaseencoder):
- def __init__(self, ui):
+ def __init__(self, ui) -> None:
super().__init__(3)
-class zstdbasedecoder:
- def __init__(self, maxwindowsize):
+class zstdbasedecoder(Decoder):
+ def __init__(self, maxwindowsize: int) -> None:
from . import zstd # pytype: disable=import-error
dctx = zstd.ZstdDecompressor(max_window_size=maxwindowsize)
self._decompressor = dctx.decompressobj()
- def decode(self, data):
+ def decode(self, data) -> bytes:
return self._decompressor.decompress(data)
class zstd8mbdecoder(zstdbasedecoder):
- def __init__(self, ui, extraobjs):
+ def __init__(self, ui, extraobjs) -> None:
if extraobjs:
raise error.Abort(
_(b'zstd8mb decoder received unexpected additional values')
@@ -819,10 +849,16 @@
super().__init__(maxwindowsize=8 * 1048576)
+# TypeVar('EncoderT', bound=Encoder) was flagged as "not in scope" when used
+# on the STREAM_ENCODERS dict below.
+if typing.TYPE_CHECKING:
+ EncoderT = Type[identityencoder | zlibencoder | zstd8mbencoder]
+ DecoderT = Type[identitydecoder | zlibdecoder | zstd8mbdecoder]
+
# We lazily populate this to avoid excessive module imports when importing
# this module.
-STREAM_ENCODERS = {}
-STREAM_ENCODERS_ORDER = []
+STREAM_ENCODERS: dict[bytes, tuple[EncoderT, DecoderT]] = {}
+STREAM_ENCODERS_ORDER: list[bytes] = []
def populatestreamencoders():