Mercurial > public > mercurial-scm > hg
changeset 52698:1a612f9ec2c4
typing: add type annotations for wireprotoframing encoder/decoder classes
This was low hanging fruit, and will point out an apparent programming bug.
The arguments are left untyped for now, because the internal compressors and
decompressors generally take a `Buffer` object, and return bytes. I've had
issues around using `typing_extensions.Buffer` with pytype in the past, but also
the identity classes return the argument unmodified, so that will break the type
annotation. We'd probably need an instance check, and a cast if it's not bytes.
Having to roll up all of the encoder and decoder types into a union type is a
bit annoying, but the protocol class can't be used in the map because it then
assumes that type will be instantiated. I also couldn't get it to work with a
`TypeVar` bound to indicate that it's actually a subclass either. Oh well.
author | Matt Harbison <matt_harbison@yahoo.com> |
---|---|
date | Mon, 13 Jan 2025 15:37:08 -0500 |
parents | f3762eafed66 |
children | 068398a8c9cb |
files | mercurial/wireprotoframing.py |
diffstat | 1 files changed, 64 insertions(+), 28 deletions(-) [+] |
line wrap: on
line diff
--- a/mercurial/wireprotoframing.py Mon Jan 13 12:24:33 2025 -0500 +++ b/mercurial/wireprotoframing.py Mon Jan 13 15:37:08 2025 -0500 @@ -11,10 +11,16 @@ from __future__ import annotations +import abc import collections import struct import typing +from typing import ( + Protocol, + Type, +) + from .i18n import _ from .thirdparty import attr @@ -705,58 +711,82 @@ # mechanism. -class identityencoder: +class Encoder(Protocol): + """A protocol class for the various encoder implementations.""" + + @abc.abstractmethod + def encode(self, data) -> bytes: + raise NotImplementedError + + @abc.abstractmethod + def flush(self) -> bytes: + raise NotImplementedError + + @abc.abstractmethod + def finish(self) -> bytes: + raise NotImplementedError + + +class Decoder(Protocol): + """A protocol class for the various encoder implementations.""" + + @abc.abstractmethod + def decode(self, data) -> bytes: + raise NotImplementedError + + +class identityencoder(Encoder): """Encoder for the "identity" stream encoding profile.""" - def __init__(self, ui): + def __init__(self, ui) -> None: pass - def encode(self, data): + def encode(self, data) -> bytes: return data - def flush(self): + def flush(self) -> bytes: return b'' - def finish(self): + def finish(self) -> bytes: return b'' -class identitydecoder: +class identitydecoder(Decoder): """Decoder for the "identity" stream encoding profile.""" - def __init__(self, ui, extraobjs): + def __init__(self, ui, extraobjs) -> None: if extraobjs: raise error.Abort( _(b'identity decoder received unexpected additional values') ) - def decode(self, data): + def decode(self, data) -> bytes: return data -class zlibencoder: - def __init__(self, ui): +class zlibencoder(Encoder): + def __init__(self, ui) -> None: import zlib self._zlib = zlib self._compressor = zlib.compressobj() - def encode(self, data): + def encode(self, data) -> bytes: return self._compressor.compress(data) - def flush(self): + def flush(self) -> bytes: # Z_SYNC_FLUSH doesn't reset compression context, which is # what we want. return self._compressor.flush(self._zlib.Z_SYNC_FLUSH) - def finish(self): + def finish(self) -> bytes: res = self._compressor.flush(self._zlib.Z_FINISH) self._compressor = None return res -class zlibdecoder: - def __init__(self, ui, extraobjs): +class zlibdecoder(Decoder): + def __init__(self, ui, extraobjs) -> None: import zlib if extraobjs: @@ -766,51 +796,51 @@ self._decompressor = zlib.decompressobj() - def decode(self, data): + def decode(self, data) -> bytes: return self._decompressor.decompress(data) -class zstdbaseencoder: - def __init__(self, level): +class zstdbaseencoder(Encoder): + def __init__(self, level: int) -> None: from . import zstd # pytype: disable=import-error self._zstd = zstd cctx = zstd.ZstdCompressor(level=level) self._compressor = cctx.compressobj() - def encode(self, data): + def encode(self, data) -> bytes: return self._compressor.compress(data) - def flush(self): + def flush(self) -> bytes: # COMPRESSOBJ_FLUSH_BLOCK flushes all data previously fed into the # compressor and allows a decompressor to access all encoded data # up to this point. return self._compressor.flush(self._zstd.COMPRESSOBJ_FLUSH_BLOCK) - def finish(self): + def finish(self) -> bytes: res = self._compressor.flush(self._zstd.COMPRESSOBJ_FLUSH_FINISH) self._compressor = None return res class zstd8mbencoder(zstdbaseencoder): - def __init__(self, ui): + def __init__(self, ui) -> None: super().__init__(3) -class zstdbasedecoder: - def __init__(self, maxwindowsize): +class zstdbasedecoder(Decoder): + def __init__(self, maxwindowsize: int) -> None: from . import zstd # pytype: disable=import-error dctx = zstd.ZstdDecompressor(max_window_size=maxwindowsize) self._decompressor = dctx.decompressobj() - def decode(self, data): + def decode(self, data) -> bytes: return self._decompressor.decompress(data) class zstd8mbdecoder(zstdbasedecoder): - def __init__(self, ui, extraobjs): + def __init__(self, ui, extraobjs) -> None: if extraobjs: raise error.Abort( _(b'zstd8mb decoder received unexpected additional values') @@ -819,10 +849,16 @@ super().__init__(maxwindowsize=8 * 1048576) +# TypeVar('EncoderT', bound=Encoder) was flagged as "not in scope" when used +# on the STREAM_ENCODERS dict below. +if typing.TYPE_CHECKING: + EncoderT = Type[identityencoder | zlibencoder | zstd8mbencoder] + DecoderT = Type[identitydecoder | zlibdecoder | zstd8mbdecoder] + # We lazily populate this to avoid excessive module imports when importing # this module. -STREAM_ENCODERS = {} -STREAM_ENCODERS_ORDER = [] +STREAM_ENCODERS: dict[bytes, tuple[EncoderT, DecoderT]] = {} +STREAM_ENCODERS_ORDER: list[bytes] = [] def populatestreamencoders():