Mercurial > public > mercurial-scm > hg
comparison mercurial/wireprotoframing.py @ 52698:1a612f9ec2c4
typing: add type annotations for wireprotoframing encoder/decoder classes
This was low hanging fruit, and will point out an apparent programming bug.
The arguments are left untyped for now, because the internal compressors and
decompressors generally take a `Buffer` object, and return bytes. I've had
issues around using `typing_extensions.Buffer` with pytype in the past, but also
the identity classes return the argument unmodified, so that will break the type
annotation. We'd probably need an instance check, and a cast if it's not bytes.
Having to roll up all of the encoder and decoder types into a union type is a
bit annoying, but the protocol class can't be used in the map because it then
assumes that type will be instantiated. I also couldn't get it to work with a
`TypeVar` bound to indicate that it's actually a subclass either. Oh well.
author | Matt Harbison <matt_harbison@yahoo.com> |
---|---|
date | Mon, 13 Jan 2025 15:37:08 -0500 |
parents | e627cc25b6f3 |
children | 068398a8c9cb |
comparison
equal
deleted
inserted
replaced
52697:f3762eafed66 | 52698:1a612f9ec2c4 |
---|---|
9 # protocol. For details about the protocol, see | 9 # protocol. For details about the protocol, see |
10 # `hg help internals.wireprotocol`. | 10 # `hg help internals.wireprotocol`. |
11 | 11 |
12 from __future__ import annotations | 12 from __future__ import annotations |
13 | 13 |
14 import abc | |
14 import collections | 15 import collections |
15 import struct | 16 import struct |
16 import typing | 17 import typing |
18 | |
19 from typing import ( | |
20 Protocol, | |
21 Type, | |
22 ) | |
17 | 23 |
18 from .i18n import _ | 24 from .i18n import _ |
19 from .thirdparty import attr | 25 from .thirdparty import attr |
20 | 26 |
21 # Force pytype to use the non-vendored package | 27 # Force pytype to use the non-vendored package |
703 | 709 |
704 # TODO consider defining encoders/decoders using the util.compressionengine | 710 # TODO consider defining encoders/decoders using the util.compressionengine |
705 # mechanism. | 711 # mechanism. |
706 | 712 |
707 | 713 |
708 class identityencoder: | 714 class Encoder(Protocol): |
715 """A protocol class for the various encoder implementations.""" | |
716 | |
717 @abc.abstractmethod | |
718 def encode(self, data) -> bytes: | |
719 raise NotImplementedError | |
720 | |
721 @abc.abstractmethod | |
722 def flush(self) -> bytes: | |
723 raise NotImplementedError | |
724 | |
725 @abc.abstractmethod | |
726 def finish(self) -> bytes: | |
727 raise NotImplementedError | |
728 | |
729 | |
730 class Decoder(Protocol): | |
731 """A protocol class for the various encoder implementations.""" | |
732 | |
733 @abc.abstractmethod | |
734 def decode(self, data) -> bytes: | |
735 raise NotImplementedError | |
736 | |
737 | |
738 class identityencoder(Encoder): | |
709 """Encoder for the "identity" stream encoding profile.""" | 739 """Encoder for the "identity" stream encoding profile.""" |
710 | 740 |
711 def __init__(self, ui): | 741 def __init__(self, ui) -> None: |
712 pass | 742 pass |
713 | 743 |
714 def encode(self, data): | 744 def encode(self, data) -> bytes: |
715 return data | 745 return data |
716 | 746 |
717 def flush(self): | 747 def flush(self) -> bytes: |
718 return b'' | 748 return b'' |
719 | 749 |
720 def finish(self): | 750 def finish(self) -> bytes: |
721 return b'' | 751 return b'' |
722 | 752 |
723 | 753 |
724 class identitydecoder: | 754 class identitydecoder(Decoder): |
725 """Decoder for the "identity" stream encoding profile.""" | 755 """Decoder for the "identity" stream encoding profile.""" |
726 | 756 |
727 def __init__(self, ui, extraobjs): | 757 def __init__(self, ui, extraobjs) -> None: |
728 if extraobjs: | 758 if extraobjs: |
729 raise error.Abort( | 759 raise error.Abort( |
730 _(b'identity decoder received unexpected additional values') | 760 _(b'identity decoder received unexpected additional values') |
731 ) | 761 ) |
732 | 762 |
733 def decode(self, data): | 763 def decode(self, data) -> bytes: |
734 return data | 764 return data |
735 | 765 |
736 | 766 |
737 class zlibencoder: | 767 class zlibencoder(Encoder): |
738 def __init__(self, ui): | 768 def __init__(self, ui) -> None: |
739 import zlib | 769 import zlib |
740 | 770 |
741 self._zlib = zlib | 771 self._zlib = zlib |
742 self._compressor = zlib.compressobj() | 772 self._compressor = zlib.compressobj() |
743 | 773 |
744 def encode(self, data): | 774 def encode(self, data) -> bytes: |
745 return self._compressor.compress(data) | 775 return self._compressor.compress(data) |
746 | 776 |
747 def flush(self): | 777 def flush(self) -> bytes: |
748 # Z_SYNC_FLUSH doesn't reset compression context, which is | 778 # Z_SYNC_FLUSH doesn't reset compression context, which is |
749 # what we want. | 779 # what we want. |
750 return self._compressor.flush(self._zlib.Z_SYNC_FLUSH) | 780 return self._compressor.flush(self._zlib.Z_SYNC_FLUSH) |
751 | 781 |
752 def finish(self): | 782 def finish(self) -> bytes: |
753 res = self._compressor.flush(self._zlib.Z_FINISH) | 783 res = self._compressor.flush(self._zlib.Z_FINISH) |
754 self._compressor = None | 784 self._compressor = None |
755 return res | 785 return res |
756 | 786 |
757 | 787 |
758 class zlibdecoder: | 788 class zlibdecoder(Decoder): |
759 def __init__(self, ui, extraobjs): | 789 def __init__(self, ui, extraobjs) -> None: |
760 import zlib | 790 import zlib |
761 | 791 |
762 if extraobjs: | 792 if extraobjs: |
763 raise error.Abort( | 793 raise error.Abort( |
764 _(b'zlib decoder received unexpected additional values') | 794 _(b'zlib decoder received unexpected additional values') |
765 ) | 795 ) |
766 | 796 |
767 self._decompressor = zlib.decompressobj() | 797 self._decompressor = zlib.decompressobj() |
768 | 798 |
769 def decode(self, data): | 799 def decode(self, data) -> bytes: |
770 return self._decompressor.decompress(data) | 800 return self._decompressor.decompress(data) |
771 | 801 |
772 | 802 |
773 class zstdbaseencoder: | 803 class zstdbaseencoder(Encoder): |
774 def __init__(self, level): | 804 def __init__(self, level: int) -> None: |
775 from . import zstd # pytype: disable=import-error | 805 from . import zstd # pytype: disable=import-error |
776 | 806 |
777 self._zstd = zstd | 807 self._zstd = zstd |
778 cctx = zstd.ZstdCompressor(level=level) | 808 cctx = zstd.ZstdCompressor(level=level) |
779 self._compressor = cctx.compressobj() | 809 self._compressor = cctx.compressobj() |
780 | 810 |
781 def encode(self, data): | 811 def encode(self, data) -> bytes: |
782 return self._compressor.compress(data) | 812 return self._compressor.compress(data) |
783 | 813 |
784 def flush(self): | 814 def flush(self) -> bytes: |
785 # COMPRESSOBJ_FLUSH_BLOCK flushes all data previously fed into the | 815 # COMPRESSOBJ_FLUSH_BLOCK flushes all data previously fed into the |
786 # compressor and allows a decompressor to access all encoded data | 816 # compressor and allows a decompressor to access all encoded data |
787 # up to this point. | 817 # up to this point. |
788 return self._compressor.flush(self._zstd.COMPRESSOBJ_FLUSH_BLOCK) | 818 return self._compressor.flush(self._zstd.COMPRESSOBJ_FLUSH_BLOCK) |
789 | 819 |
790 def finish(self): | 820 def finish(self) -> bytes: |
791 res = self._compressor.flush(self._zstd.COMPRESSOBJ_FLUSH_FINISH) | 821 res = self._compressor.flush(self._zstd.COMPRESSOBJ_FLUSH_FINISH) |
792 self._compressor = None | 822 self._compressor = None |
793 return res | 823 return res |
794 | 824 |
795 | 825 |
796 class zstd8mbencoder(zstdbaseencoder): | 826 class zstd8mbencoder(zstdbaseencoder): |
797 def __init__(self, ui): | 827 def __init__(self, ui) -> None: |
798 super().__init__(3) | 828 super().__init__(3) |
799 | 829 |
800 | 830 |
801 class zstdbasedecoder: | 831 class zstdbasedecoder(Decoder): |
802 def __init__(self, maxwindowsize): | 832 def __init__(self, maxwindowsize: int) -> None: |
803 from . import zstd # pytype: disable=import-error | 833 from . import zstd # pytype: disable=import-error |
804 | 834 |
805 dctx = zstd.ZstdDecompressor(max_window_size=maxwindowsize) | 835 dctx = zstd.ZstdDecompressor(max_window_size=maxwindowsize) |
806 self._decompressor = dctx.decompressobj() | 836 self._decompressor = dctx.decompressobj() |
807 | 837 |
808 def decode(self, data): | 838 def decode(self, data) -> bytes: |
809 return self._decompressor.decompress(data) | 839 return self._decompressor.decompress(data) |
810 | 840 |
811 | 841 |
812 class zstd8mbdecoder(zstdbasedecoder): | 842 class zstd8mbdecoder(zstdbasedecoder): |
813 def __init__(self, ui, extraobjs): | 843 def __init__(self, ui, extraobjs) -> None: |
814 if extraobjs: | 844 if extraobjs: |
815 raise error.Abort( | 845 raise error.Abort( |
816 _(b'zstd8mb decoder received unexpected additional values') | 846 _(b'zstd8mb decoder received unexpected additional values') |
817 ) | 847 ) |
818 | 848 |
819 super().__init__(maxwindowsize=8 * 1048576) | 849 super().__init__(maxwindowsize=8 * 1048576) |
820 | 850 |
821 | 851 |
852 # TypeVar('EncoderT', bound=Encoder) was flagged as "not in scope" when used | |
853 # on the STREAM_ENCODERS dict below. | |
854 if typing.TYPE_CHECKING: | |
855 EncoderT = Type[identityencoder | zlibencoder | zstd8mbencoder] | |
856 DecoderT = Type[identitydecoder | zlibdecoder | zstd8mbdecoder] | |
857 | |
822 # We lazily populate this to avoid excessive module imports when importing | 858 # We lazily populate this to avoid excessive module imports when importing |
823 # this module. | 859 # this module. |
824 STREAM_ENCODERS = {} | 860 STREAM_ENCODERS: dict[bytes, tuple[EncoderT, DecoderT]] = {} |
825 STREAM_ENCODERS_ORDER = [] | 861 STREAM_ENCODERS_ORDER: list[bytes] = [] |
826 | 862 |
827 | 863 |
828 def populatestreamencoders(): | 864 def populatestreamencoders(): |
829 if STREAM_ENCODERS: | 865 if STREAM_ENCODERS: |
830 return | 866 return |