comparison mercurial/wireprotoframing.py @ 52698:1a612f9ec2c4

typing: add type annotations for wireprotoframing encoder/decoder classes This was low hanging fruit, and will point out an apparent programming bug. The arguments are left untyped for now, because the internal compressors and decompressors generally take a `Buffer` object, and return bytes. I've had issues around using `typing_extensions.Buffer` with pytype in the past, but also the identity classes return the argument unmodified, so that will break the type annotation. We'd probably need an instance check, and a cast if it's not bytes. Having to roll up all of the encoder and decoder types into a union type is a bit annoying, but the protocol class can't be used in the map because it then assumes that type will be instantiated. I also couldn't get it to work with a `TypeVar` bound to indicate that it's actually a subclass either. Oh well.
author Matt Harbison <matt_harbison@yahoo.com>
date Mon, 13 Jan 2025 15:37:08 -0500
parents e627cc25b6f3
children 068398a8c9cb
comparison
equal deleted inserted replaced
52697:f3762eafed66 52698:1a612f9ec2c4
9 # protocol. For details about the protocol, see 9 # protocol. For details about the protocol, see
10 # `hg help internals.wireprotocol`. 10 # `hg help internals.wireprotocol`.
11 11
12 from __future__ import annotations 12 from __future__ import annotations
13 13
14 import abc
14 import collections 15 import collections
15 import struct 16 import struct
16 import typing 17 import typing
18
19 from typing import (
20 Protocol,
21 Type,
22 )
17 23
18 from .i18n import _ 24 from .i18n import _
19 from .thirdparty import attr 25 from .thirdparty import attr
20 26
21 # Force pytype to use the non-vendored package 27 # Force pytype to use the non-vendored package
703 709
704 # TODO consider defining encoders/decoders using the util.compressionengine 710 # TODO consider defining encoders/decoders using the util.compressionengine
705 # mechanism. 711 # mechanism.
706 712
707 713
708 class identityencoder: 714 class Encoder(Protocol):
715 """A protocol class for the various encoder implementations."""
716
717 @abc.abstractmethod
718 def encode(self, data) -> bytes:
719 raise NotImplementedError
720
721 @abc.abstractmethod
722 def flush(self) -> bytes:
723 raise NotImplementedError
724
725 @abc.abstractmethod
726 def finish(self) -> bytes:
727 raise NotImplementedError
728
729
730 class Decoder(Protocol):
731 """A protocol class for the various encoder implementations."""
732
733 @abc.abstractmethod
734 def decode(self, data) -> bytes:
735 raise NotImplementedError
736
737
738 class identityencoder(Encoder):
709 """Encoder for the "identity" stream encoding profile.""" 739 """Encoder for the "identity" stream encoding profile."""
710 740
711 def __init__(self, ui): 741 def __init__(self, ui) -> None:
712 pass 742 pass
713 743
714 def encode(self, data): 744 def encode(self, data) -> bytes:
715 return data 745 return data
716 746
717 def flush(self): 747 def flush(self) -> bytes:
718 return b'' 748 return b''
719 749
720 def finish(self): 750 def finish(self) -> bytes:
721 return b'' 751 return b''
722 752
723 753
724 class identitydecoder: 754 class identitydecoder(Decoder):
725 """Decoder for the "identity" stream encoding profile.""" 755 """Decoder for the "identity" stream encoding profile."""
726 756
727 def __init__(self, ui, extraobjs): 757 def __init__(self, ui, extraobjs) -> None:
728 if extraobjs: 758 if extraobjs:
729 raise error.Abort( 759 raise error.Abort(
730 _(b'identity decoder received unexpected additional values') 760 _(b'identity decoder received unexpected additional values')
731 ) 761 )
732 762
733 def decode(self, data): 763 def decode(self, data) -> bytes:
734 return data 764 return data
735 765
736 766
737 class zlibencoder: 767 class zlibencoder(Encoder):
738 def __init__(self, ui): 768 def __init__(self, ui) -> None:
739 import zlib 769 import zlib
740 770
741 self._zlib = zlib 771 self._zlib = zlib
742 self._compressor = zlib.compressobj() 772 self._compressor = zlib.compressobj()
743 773
744 def encode(self, data): 774 def encode(self, data) -> bytes:
745 return self._compressor.compress(data) 775 return self._compressor.compress(data)
746 776
747 def flush(self): 777 def flush(self) -> bytes:
748 # Z_SYNC_FLUSH doesn't reset compression context, which is 778 # Z_SYNC_FLUSH doesn't reset compression context, which is
749 # what we want. 779 # what we want.
750 return self._compressor.flush(self._zlib.Z_SYNC_FLUSH) 780 return self._compressor.flush(self._zlib.Z_SYNC_FLUSH)
751 781
752 def finish(self): 782 def finish(self) -> bytes:
753 res = self._compressor.flush(self._zlib.Z_FINISH) 783 res = self._compressor.flush(self._zlib.Z_FINISH)
754 self._compressor = None 784 self._compressor = None
755 return res 785 return res
756 786
757 787
758 class zlibdecoder: 788 class zlibdecoder(Decoder):
759 def __init__(self, ui, extraobjs): 789 def __init__(self, ui, extraobjs) -> None:
760 import zlib 790 import zlib
761 791
762 if extraobjs: 792 if extraobjs:
763 raise error.Abort( 793 raise error.Abort(
764 _(b'zlib decoder received unexpected additional values') 794 _(b'zlib decoder received unexpected additional values')
765 ) 795 )
766 796
767 self._decompressor = zlib.decompressobj() 797 self._decompressor = zlib.decompressobj()
768 798
769 def decode(self, data): 799 def decode(self, data) -> bytes:
770 return self._decompressor.decompress(data) 800 return self._decompressor.decompress(data)
771 801
772 802
773 class zstdbaseencoder: 803 class zstdbaseencoder(Encoder):
774 def __init__(self, level): 804 def __init__(self, level: int) -> None:
775 from . import zstd # pytype: disable=import-error 805 from . import zstd # pytype: disable=import-error
776 806
777 self._zstd = zstd 807 self._zstd = zstd
778 cctx = zstd.ZstdCompressor(level=level) 808 cctx = zstd.ZstdCompressor(level=level)
779 self._compressor = cctx.compressobj() 809 self._compressor = cctx.compressobj()
780 810
781 def encode(self, data): 811 def encode(self, data) -> bytes:
782 return self._compressor.compress(data) 812 return self._compressor.compress(data)
783 813
784 def flush(self): 814 def flush(self) -> bytes:
785 # COMPRESSOBJ_FLUSH_BLOCK flushes all data previously fed into the 815 # COMPRESSOBJ_FLUSH_BLOCK flushes all data previously fed into the
786 # compressor and allows a decompressor to access all encoded data 816 # compressor and allows a decompressor to access all encoded data
787 # up to this point. 817 # up to this point.
788 return self._compressor.flush(self._zstd.COMPRESSOBJ_FLUSH_BLOCK) 818 return self._compressor.flush(self._zstd.COMPRESSOBJ_FLUSH_BLOCK)
789 819
790 def finish(self): 820 def finish(self) -> bytes:
791 res = self._compressor.flush(self._zstd.COMPRESSOBJ_FLUSH_FINISH) 821 res = self._compressor.flush(self._zstd.COMPRESSOBJ_FLUSH_FINISH)
792 self._compressor = None 822 self._compressor = None
793 return res 823 return res
794 824
795 825
796 class zstd8mbencoder(zstdbaseencoder): 826 class zstd8mbencoder(zstdbaseencoder):
797 def __init__(self, ui): 827 def __init__(self, ui) -> None:
798 super().__init__(3) 828 super().__init__(3)
799 829
800 830
801 class zstdbasedecoder: 831 class zstdbasedecoder(Decoder):
802 def __init__(self, maxwindowsize): 832 def __init__(self, maxwindowsize: int) -> None:
803 from . import zstd # pytype: disable=import-error 833 from . import zstd # pytype: disable=import-error
804 834
805 dctx = zstd.ZstdDecompressor(max_window_size=maxwindowsize) 835 dctx = zstd.ZstdDecompressor(max_window_size=maxwindowsize)
806 self._decompressor = dctx.decompressobj() 836 self._decompressor = dctx.decompressobj()
807 837
808 def decode(self, data): 838 def decode(self, data) -> bytes:
809 return self._decompressor.decompress(data) 839 return self._decompressor.decompress(data)
810 840
811 841
812 class zstd8mbdecoder(zstdbasedecoder): 842 class zstd8mbdecoder(zstdbasedecoder):
813 def __init__(self, ui, extraobjs): 843 def __init__(self, ui, extraobjs) -> None:
814 if extraobjs: 844 if extraobjs:
815 raise error.Abort( 845 raise error.Abort(
816 _(b'zstd8mb decoder received unexpected additional values') 846 _(b'zstd8mb decoder received unexpected additional values')
817 ) 847 )
818 848
819 super().__init__(maxwindowsize=8 * 1048576) 849 super().__init__(maxwindowsize=8 * 1048576)
820 850
821 851
852 # TypeVar('EncoderT', bound=Encoder) was flagged as "not in scope" when used
853 # on the STREAM_ENCODERS dict below.
854 if typing.TYPE_CHECKING:
855 EncoderT = Type[identityencoder | zlibencoder | zstd8mbencoder]
856 DecoderT = Type[identitydecoder | zlibdecoder | zstd8mbdecoder]
857
822 # We lazily populate this to avoid excessive module imports when importing 858 # We lazily populate this to avoid excessive module imports when importing
823 # this module. 859 # this module.
824 STREAM_ENCODERS = {} 860 STREAM_ENCODERS: dict[bytes, tuple[EncoderT, DecoderT]] = {}
825 STREAM_ENCODERS_ORDER = [] 861 STREAM_ENCODERS_ORDER: list[bytes] = []
826 862
827 863
828 def populatestreamencoders(): 864 def populatestreamencoders():
829 if STREAM_ENCODERS: 865 if STREAM_ENCODERS:
830 return 866 return