mercurial/wireprotoframing.py
changeset 52699 068398a8c9cb
parent 52698 1a612f9ec2c4
--- a/mercurial/wireprotoframing.py	Mon Jan 13 15:37:08 2025 -0500
+++ b/mercurial/wireprotoframing.py	Mon Jan 06 20:32:31 2025 -0500
@@ -41,6 +41,13 @@
     stringutil,
 )
 
+if typing.TYPE_CHECKING:
+    from typing import (
+        Iterator,
+    )
+
+    HandleSendFramesReturnT = tuple[bytes, dict[bytes, Iterator[bytearray]]]
+
 FRAME_HEADER_SIZE = 8
 DEFAULT_MAX_FRAME_SIZE = 32768
 
@@ -48,7 +55,7 @@
 STREAM_FLAG_END_STREAM = 0x02
 STREAM_FLAG_ENCODING_APPLIED = 0x04
 
-STREAM_FLAGS = {
+STREAM_FLAGS: dict[bytes, int] = {
     b'stream-begin': STREAM_FLAG_BEGIN_STREAM,
     b'stream-end': STREAM_FLAG_END_STREAM,
     b'encoded': STREAM_FLAG_ENCODING_APPLIED,
@@ -63,7 +70,7 @@
 FRAME_TYPE_SENDER_PROTOCOL_SETTINGS = 0x08
 FRAME_TYPE_STREAM_SETTINGS = 0x09
 
-FRAME_TYPES = {
+FRAME_TYPES: dict[bytes, int] = {
     b'command-request': FRAME_TYPE_COMMAND_REQUEST,
     b'command-data': FRAME_TYPE_COMMAND_DATA,
     b'command-response': FRAME_TYPE_COMMAND_RESPONSE,
@@ -79,7 +86,7 @@
 FLAG_COMMAND_REQUEST_MORE_FRAMES = 0x04
 FLAG_COMMAND_REQUEST_EXPECT_DATA = 0x08
 
-FLAGS_COMMAND_REQUEST = {
+FLAGS_COMMAND_REQUEST: dict[bytes, int] = {
     b'new': FLAG_COMMAND_REQUEST_NEW,
     b'continuation': FLAG_COMMAND_REQUEST_CONTINUATION,
     b'more': FLAG_COMMAND_REQUEST_MORE_FRAMES,
@@ -89,7 +96,7 @@
 FLAG_COMMAND_DATA_CONTINUATION = 0x01
 FLAG_COMMAND_DATA_EOS = 0x02
 
-FLAGS_COMMAND_DATA = {
+FLAGS_COMMAND_DATA: dict[bytes, int] = {
     b'continuation': FLAG_COMMAND_DATA_CONTINUATION,
     b'eos': FLAG_COMMAND_DATA_EOS,
 }
@@ -97,7 +104,7 @@
 FLAG_COMMAND_RESPONSE_CONTINUATION = 0x01
 FLAG_COMMAND_RESPONSE_EOS = 0x02
 
-FLAGS_COMMAND_RESPONSE = {
+FLAGS_COMMAND_RESPONSE: dict[bytes, int] = {
     b'continuation': FLAG_COMMAND_RESPONSE_CONTINUATION,
     b'eos': FLAG_COMMAND_RESPONSE_EOS,
 }
@@ -105,7 +112,7 @@
 FLAG_SENDER_PROTOCOL_SETTINGS_CONTINUATION = 0x01
 FLAG_SENDER_PROTOCOL_SETTINGS_EOS = 0x02
 
-FLAGS_SENDER_PROTOCOL_SETTINGS = {
+FLAGS_SENDER_PROTOCOL_SETTINGS: dict[bytes, int] = {
     b'continuation': FLAG_SENDER_PROTOCOL_SETTINGS_CONTINUATION,
     b'eos': FLAG_SENDER_PROTOCOL_SETTINGS_EOS,
 }
@@ -113,13 +120,13 @@
 FLAG_STREAM_ENCODING_SETTINGS_CONTINUATION = 0x01
 FLAG_STREAM_ENCODING_SETTINGS_EOS = 0x02
 
-FLAGS_STREAM_ENCODING_SETTINGS = {
+FLAGS_STREAM_ENCODING_SETTINGS: dict[bytes, int] = {
     b'continuation': FLAG_STREAM_ENCODING_SETTINGS_CONTINUATION,
     b'eos': FLAG_STREAM_ENCODING_SETTINGS_EOS,
 }
 
 # Maps frame types to their available flags.
-FRAME_TYPE_FLAGS = {
+FRAME_TYPE_FLAGS: dict[int, dict[bytes, int]] = {
     FRAME_TYPE_COMMAND_REQUEST: FLAGS_COMMAND_REQUEST,
     FRAME_TYPE_COMMAND_DATA: FLAGS_COMMAND_DATA,
     FRAME_TYPE_COMMAND_RESPONSE: FLAGS_COMMAND_RESPONSE,
@@ -133,7 +140,7 @@
 ARGUMENT_RECORD_HEADER = struct.Struct('<HH')
 
 
-def humanflags(mapping, value):
+def humanflags(mapping: dict[bytes, int], value: int) -> bytes:
     """Convert a numeric flags value to a human value, using a mapping table."""
     namemap = {v: k for k, v in mapping.items()}
     flags = []
@@ -150,24 +157,24 @@
 class frameheader:
     """Represents the data in a frame header."""
 
-    length = attr.ib()
-    requestid = attr.ib()
-    streamid = attr.ib()
-    streamflags = attr.ib()
-    typeid = attr.ib()
-    flags = attr.ib()
+    length = attr.ib(type=int)
+    requestid = attr.ib(type=int)
+    streamid = attr.ib(type=int)
+    streamflags = attr.ib(type=int)
+    typeid = attr.ib(type=int)
+    flags = attr.ib(type=int)
 
 
 @attr.s(slots=True, repr=False)
 class frame:
     """Represents a parsed frame."""
 
-    requestid = attr.ib()
-    streamid = attr.ib()
-    streamflags = attr.ib()
-    typeid = attr.ib()
-    flags = attr.ib()
-    payload = attr.ib()
+    requestid = attr.ib(type=int)
+    streamid = attr.ib(type=int)
+    streamflags = attr.ib(type=int)
+    typeid = attr.ib(type=int)
+    flags = attr.ib(type=int)
+    payload = attr.ib(type=bytes)
 
     @encoding.strmethod
     def __repr__(self):
@@ -191,7 +198,14 @@
         )
 
 
-def makeframe(requestid, streamid, streamflags, typeid, flags, payload):
+def makeframe(
+    requestid: int,
+    streamid: int,
+    streamflags: int,
+    typeid: int,
+    flags: int,
+    payload: bytes,
+) -> bytearray:
     """Assemble a frame into a byte array."""
     # TODO assert size of payload.
     frame = bytearray(FRAME_HEADER_SIZE + len(payload))
@@ -212,7 +226,7 @@
     return frame
 
 
-def makeframefromhumanstring(s):
+def makeframefromhumanstring(s: bytes) -> bytearray:
     """Create a frame from a human readable string
 
     Strings have the form:
@@ -278,7 +292,7 @@
     )
 
 
-def parseheader(data):
+def parseheader(data: bytes) -> frameheader:
     """Parse a unified framing protocol frame header from a buffer.
 
     The header is expected to be in the buffer at offset 0 and the
@@ -303,7 +317,7 @@
     )
 
 
-def readframe(fh):
+def readframe(fh) -> frame | None:
     """Read a unified framing protocol frame from a file object.
 
     Returns a 3-tuple of (type, flags, payload) for the decoded frame or
@@ -338,14 +352,14 @@
 
 
 def createcommandframes(
-    stream,
-    requestid,
+    stream: stream,
+    requestid: int,
     cmd,
     args,
     datafh=None,
-    maxframesize=DEFAULT_MAX_FRAME_SIZE,
+    maxframesize: int = DEFAULT_MAX_FRAME_SIZE,
     redirect=None,
-):
+) -> Iterator[bytearray]:
     """Create frames necessary to transmit a request to run a command.
 
     This is a generator of bytearrays. Each item represents a frame
@@ -414,7 +428,9 @@
                 break
 
 
-def createcommandresponseokframe(stream, requestid):
+def createcommandresponseokframe(
+    stream: outputstream, requestid: int
+) -> bytearray | None:
     overall = b''.join(cborutil.streamencode({b'status': b'ok'}))
 
     if stream.streamsettingssent:
@@ -436,8 +452,10 @@
 
 
 def createcommandresponseeosframes(
-    stream, requestid, maxframesize=DEFAULT_MAX_FRAME_SIZE
-):
+    stream: outputstream,
+    requestid: int,
+    maxframesize: int = DEFAULT_MAX_FRAME_SIZE,
+) -> Iterator[bytearray]:
     """Create an empty payload frame representing command end-of-stream."""
     payload = stream.flush()
 
@@ -465,7 +483,9 @@
             break
 
 
-def createalternatelocationresponseframe(stream, requestid, location):
+def createalternatelocationresponseframe(
+    stream: outputstream, requestid: int, location
+) -> bytearray:
     data = {
         b'status': b'redirect',
         b'location': {
@@ -504,7 +524,9 @@
     )
 
 
-def createcommanderrorresponse(stream, requestid, message, args=None):
+def createcommanderrorresponse(
+    stream: stream, requestid: int, message: bytes, args=None
+) -> Iterator[bytearray]:
     # TODO should this be using a list of {'msg': ..., 'args': {}} so atom
     # formatting works consistently?
     m = {
@@ -527,7 +549,9 @@
     )
 
 
-def createerrorframe(stream, requestid, msg, errtype):
+def createerrorframe(
+    stream: stream, requestid: int, msg: bytes, errtype: bytes
+) -> Iterator[bytearray]:
     # TODO properly handle frame size limits.
     assert len(msg) <= DEFAULT_MAX_FRAME_SIZE
 
@@ -549,8 +573,11 @@
 
 
 def createtextoutputframe(
-    stream, requestid, atoms, maxframesize=DEFAULT_MAX_FRAME_SIZE
-):
+    stream: stream,
+    requestid: int,
+    atoms,
+    maxframesize: int = DEFAULT_MAX_FRAME_SIZE,
+) -> Iterator[bytearray]:
     """Create a text output frame to render text to people.
 
     ``atoms`` is a 3-tuple of (formatting string, args, labels).
@@ -616,7 +643,9 @@
     level.
     """
 
-    def __init__(self, stream, requestid, maxframesize=DEFAULT_MAX_FRAME_SIZE):
+    def __init__(
+        self, stream, requestid: int, maxframesize: int = DEFAULT_MAX_FRAME_SIZE
+    ) -> None:
         self._stream = stream
         self._requestid = requestid
         self._maxsize = maxframesize
@@ -861,7 +890,7 @@
 STREAM_ENCODERS_ORDER: list[bytes] = []
 
 
-def populatestreamencoders():
+def populatestreamencoders() -> None:
     if STREAM_ENCODERS:
         return
 
@@ -887,11 +916,16 @@
 class stream:
     """Represents a logical unidirectional series of frames."""
 
-    def __init__(self, streamid, active=False):
+    streamid: int
+    _active: bool
+
+    def __init__(self, streamid: int, active: bool = False) -> None:
         self.streamid = streamid
         self._active = active
 
-    def makeframe(self, requestid, typeid, flags, payload):
+    def makeframe(
+        self, requestid: int, typeid: int, flags: int, payload: bytes
+    ) -> bytearray:
         """Create a frame to be sent out over this stream.
 
         Only returns the frame instance. Does not actually send it.
@@ -909,11 +943,13 @@
 class inputstream(stream):
     """Represents a stream used for receiving data."""
 
-    def __init__(self, streamid, active=False):
+    _decoder: Decoder | None
+
+    def __init__(self, streamid: int, active: bool = False) -> None:
         super().__init__(streamid, active=active)
         self._decoder = None
 
-    def setdecoder(self, ui, name, extraobjs):
+    def setdecoder(self, ui, name: bytes, extraobjs) -> None:
         """Set the decoder for this stream.
 
         Receives the stream profile name and any additional CBOR objects
@@ -924,7 +960,7 @@
 
         self._decoder = STREAM_ENCODERS[name][1](ui, extraobjs)
 
-    def decode(self, data):
+    def decode(self, data) -> bytes:
         # Default is identity decoder. We don't bother instantiating one
         # because it is trivial.
         if not self._decoder:
@@ -932,23 +968,29 @@
 
         return self._decoder.decode(data)
 
-    def flush(self):
+    def flush(self) -> bytes:
         if not self._decoder:
             return b''
 
+        # TODO: this looks like a bug- no decoder class defines flush(), so
+        #  either no decoders are used, or no inputstream is flushed.
         return self._decoder.flush()
 
 
 class outputstream(stream):
     """Represents a stream used for sending data."""
 
-    def __init__(self, streamid, active=False):
+    streamsettingssent: bool
+    _encoder: Encoder | None
+    _encodername: bytes | None
+
+    def __init__(self, streamid: int, active: bool = False) -> None:
         super().__init__(streamid, active=active)
         self.streamsettingssent = False
         self._encoder = None
         self._encodername = None
 
-    def setencoder(self, ui, name):
+    def setencoder(self, ui, name: bytes) -> None:
         """Set the encoder for this stream.
 
         Receives the stream profile name.
@@ -959,25 +1001,33 @@
         self._encoder = STREAM_ENCODERS[name][0](ui)
         self._encodername = name
 
-    def encode(self, data):
+    def encode(self, data) -> bytes:
         if not self._encoder:
             return data
 
         return self._encoder.encode(data)
 
-    def flush(self):
+    def flush(self) -> bytes:
         if not self._encoder:
             return b''
 
         return self._encoder.flush()
 
-    def finish(self):
+    # TODO: was this supposed to return the result of finish()?
+    def finish(self):  # -> bytes:
         if not self._encoder:
             return b''
 
         self._encoder.finish()
 
-    def makeframe(self, requestid, typeid, flags, payload, encoded=False):
+    def makeframe(
+        self,
+        requestid: int,
+        typeid: int,
+        flags: int,
+        payload: bytes,
+        encoded: bool = False,
+    ) -> bytearray:
         """Create a frame to be sent out over this stream.
 
         Only returns the frame instance. Does not actually send it.
@@ -1006,7 +1056,7 @@
             requestid, self.streamid, streamflags, typeid, flags, payload
         )
 
-    def makestreamsettingsframe(self, requestid):
+    def makestreamsettingsframe(self, requestid: int) -> bytearray | None:
         """Create a stream settings frame for this stream.
 
         Returns frame data or None if no stream settings frame is needed or has
@@ -1024,7 +1074,7 @@
         )
 
 
-def ensureserverstream(stream):
+def ensureserverstream(stream: stream) -> None:
     if stream.streamid % 2:
         raise error.ProgrammingError(
             b'server should only write to even '
@@ -1032,7 +1082,7 @@
         )
 
 
-DEFAULT_PROTOCOL_SETTINGS = {
+DEFAULT_PROTOCOL_SETTINGS: dict[bytes, list[bytes]] = {
     b'contentencodings': [b'identity'],
 }
 
@@ -1102,7 +1152,9 @@
     between who responds to what.
     """
 
-    def __init__(self, ui, deferoutput=False):
+    _bufferedframegens: list[Iterator[bytearray]]
+
+    def __init__(self, ui, deferoutput: bool = False) -> None:
         """Construct a new server reactor.
 
         ``deferoutput`` can be used to indicate that no output frames should be
@@ -1134,7 +1186,7 @@
 
         populatestreamencoders()
 
-    def onframerecv(self, frame):
+    def onframerecv(self, frame: frame):
         """Process a frame that has been received off the wire.
 
         Returns a dict with an ``action`` key that details what action,
@@ -1183,7 +1235,9 @@
 
         return meth(frame)
 
-    def oncommandresponsereadyobjects(self, stream, requestid, objs):
+    def oncommandresponsereadyobjects(
+        self, stream, requestid: int, objs
+    ) -> HandleSendFramesReturnT:
         """Signal that objects are ready to be sent to the client.
 
         ``objs`` is an iterable of objects (typically a generator) that will
@@ -1322,7 +1376,7 @@
 
         return self._handlesendframes(sendframes())
 
-    def oninputeof(self):
+    def oninputeof(self) -> tuple[bytes, dict[bytes, Iterator[bytearray]]]:
         """Signals that end of input has been received.
 
         No more frames will be received. All pending activity should be
@@ -1342,7 +1396,9 @@
             b'framegen': makegen(),
         }
 
-    def _handlesendframes(self, framegen):
+    def _handlesendframes(
+        self, framegen: Iterator[bytearray]
+    ) -> HandleSendFramesReturnT:
         if self._deferoutput:
             self._bufferedframegens.append(framegen)
             return b'noop', {}
@@ -1351,7 +1407,9 @@
                 b'framegen': framegen,
             }
 
-    def onservererror(self, stream, requestid, msg):
+    def onservererror(
+        self, stream: stream, requestid: int, msg: bytes
+    ) -> HandleSendFramesReturnT:
         ensureserverstream(stream)
 
         def sendframes():
@@ -1363,7 +1421,9 @@
 
         return self._handlesendframes(sendframes())
 
-    def oncommanderror(self, stream, requestid, message, args=None):
+    def oncommanderror(
+        self, stream: stream, requestid: int, message: bytes, args=None
+    ) -> HandleSendFramesReturnT:
         """Called when a command encountered an error before sending output."""
         ensureserverstream(stream)
 
@@ -1376,7 +1436,7 @@
 
         return self._handlesendframes(sendframes())
 
-    def makeoutputstream(self):
+    def makeoutputstream(self) -> outputstream:
         """Create a stream to be used for sending data to the client.
 
         If this is called before protocol settings frames are received, we
@@ -1398,12 +1458,12 @@
 
         return s
 
-    def _makeerrorresult(self, msg):
+    def _makeerrorresult(self, msg: bytes) -> tuple[bytes, dict[bytes, bytes]]:
         return b'error', {
             b'message': msg,
         }
 
-    def _makeruncommandresult(self, requestid):
+    def _makeruncommandresult(self, requestid: int):
         entry = self._receivingcommands[requestid]
 
         if not entry[b'requestdone']:
@@ -1446,12 +1506,12 @@
             },
         )
 
-    def _makewantframeresult(self):
+    def _makewantframeresult(self) -> tuple[bytes, dict[bytes, bytes]]:
         return b'wantframe', {
             b'state': self._state,
         }
 
-    def _validatecommandrequestframe(self, frame):
+    def _validatecommandrequestframe(self, frame: frame):
         new = frame.flags & FLAG_COMMAND_REQUEST_NEW
         continuation = frame.flags & FLAG_COMMAND_REQUEST_CONTINUATION
 
@@ -1473,7 +1533,7 @@
                 )
             )
 
-    def _onframeinitial(self, frame):
+    def _onframeinitial(self, frame: frame):
         # Called when we receive a frame when in the "initial" state.
         if frame.typeid == FRAME_TYPE_SENDER_PROTOCOL_SETTINGS:
             self._state = b'protocol-settings-receiving'
@@ -1494,7 +1554,7 @@
                 % frame.typeid
             )
 
-    def _onframeprotocolsettings(self, frame):
+    def _onframeprotocolsettings(self, frame: frame):
         assert self._state == b'protocol-settings-receiving'
         assert self._protocolsettingsdecoder is not None
 
@@ -1570,7 +1630,7 @@
 
         return self._makewantframeresult()
 
-    def _onframeidle(self, frame):
+    def _onframeidle(self, frame: frame):
         # The only frame type that should be received in this state is a
         # command request.
         if frame.typeid != FRAME_TYPE_COMMAND_REQUEST:
@@ -1623,7 +1683,7 @@
         self._state = b'command-receiving'
         return self._makewantframeresult()
 
-    def _onframecommandreceiving(self, frame):
+    def _onframecommandreceiving(self, frame: frame):
         if frame.typeid == FRAME_TYPE_COMMAND_REQUEST:
             # Process new command requests as such.
             if frame.flags & FLAG_COMMAND_REQUEST_NEW:
@@ -1701,7 +1761,7 @@
                 _(b'received unexpected frame type: %d') % frame.typeid
             )
 
-    def _handlecommanddataframe(self, frame, entry):
+    def _handlecommanddataframe(self, frame: frame, entry):
         assert frame.typeid == FRAME_TYPE_COMMAND_DATA
 
         # TODO support streaming data instead of buffering it.
@@ -1716,14 +1776,18 @@
             self._state = b'errored'
             return self._makeerrorresult(_(b'command data frame without flags'))
 
-    def _onframeerrored(self, frame):
+    def _onframeerrored(self, frame: frame):
         return self._makeerrorresult(_(b'server already errored'))
 
 
 class commandrequest:
     """Represents a request to run a command."""
 
-    def __init__(self, requestid, name, args, datafh=None, redirect=None):
+    state: bytes
+
+    def __init__(
+        self, requestid: int, name, args, datafh=None, redirect=None
+    ) -> None:
         self.requestid = requestid
         self.name = name
         self.args = args
@@ -1779,13 +1843,16 @@
        respectively.
     """
 
+    _hasmultiplesend: bool
+    _buffersends: bool
+
     def __init__(
         self,
         ui,
-        hasmultiplesend=False,
-        buffersends=True,
+        hasmultiplesend: bool = False,
+        buffersends: bool = True,
         clientcontentencoders=None,
-    ):
+    ) -> None:
         """Create a new instance.
 
         ``hasmultiplesend`` indicates whether multiple sends are supported
@@ -1859,7 +1926,7 @@
                 },
             )
 
-    def flushcommands(self):
+    def flushcommands(self) -> tuple[bytes, dict[bytes, Iterator[bytearray]]]:
         """Request that all queued commands be sent.
 
         If any commands are buffered, this will instruct the caller to send
@@ -1892,7 +1959,9 @@
             b'framegen': makeframes(),
         }
 
-    def _makecommandframes(self, request):
+    def _makecommandframes(
+        self, request: commandrequest
+    ) -> Iterator[bytearray]:
         """Emit frames to issue a command request.
 
         As a side-effect, update request accounting to reflect its changed
@@ -1932,7 +2001,7 @@
 
         request.state = b'sent'
 
-    def onframerecv(self, frame):
+    def onframerecv(self, frame: frame):
         """Process a frame that has been received off the wire.
 
         Returns a 2-tuple of (action, meta) describing further action the
@@ -2004,7 +2073,7 @@
 
         return meth(request, frame)
 
-    def _onstreamsettingsframe(self, frame):
+    def _onstreamsettingsframe(self, frame: frame):
         assert frame.typeid == FRAME_TYPE_STREAM_SETTINGS
 
         more = frame.flags & FLAG_STREAM_ENCODING_SETTINGS_CONTINUATION
@@ -2092,7 +2161,7 @@
 
         return b'noop', {}
 
-    def _oncommandresponseframe(self, request, frame):
+    def _oncommandresponseframe(self, request: commandrequest, frame: frame):
         if frame.flags & FLAG_COMMAND_RESPONSE_EOS:
             request.state = b'received'
             del self._activerequests[request.requestid]
@@ -2107,7 +2176,7 @@
             },
         )
 
-    def _onerrorresponseframe(self, request, frame):
+    def _onerrorresponseframe(self, request: commandrequest, frame: frame):
         request.state = b'errored'
         del self._activerequests[request.requestid]