diff -r a09435c0eb14 -r 6fc31e7bd5db mercurial/bundle2.py --- a/mercurial/bundle2.py Wed Jul 10 16:04:53 2024 -0400 +++ b/mercurial/bundle2.py Wed Jul 10 17:09:34 2024 -0400 @@ -153,6 +153,7 @@ import string import struct import sys +import typing from .i18n import _ from .node import ( @@ -181,6 +182,17 @@ ) from .interfaces import repository +if typing.TYPE_CHECKING: + from typing import ( + Dict, + List, + Optional, + Tuple, + Union, + ) + + Capabilities = Dict[bytes, Union[List[bytes], Tuple[bytes, ...]]] + urlerr = util.urlerr urlreq = util.urlreq @@ -602,7 +614,7 @@ ) -def decodecaps(blob): +def decodecaps(blob: bytes) -> "Capabilities": """decode a bundle2 caps bytes blob into a dictionary The blob is a list of capabilities (one per line) @@ -662,11 +674,14 @@ _magicstring = b'HG20' - def __init__(self, ui, capabilities=()): + def __init__(self, ui, capabilities: "Optional[Capabilities]" = None): + if capabilities is None: + capabilities = {} + self.ui = ui self._params = [] self._parts = [] - self.capabilities = dict(capabilities) + self.capabilities: "Capabilities" = dict(capabilities) self._compengine = util.compengines.forbundletype(b'UN') self._compopts = None # If compression is being handled by a consumer of the raw @@ -1612,7 +1627,7 @@ # These are only the static capabilities. # Check the 'getrepocaps' function for the rest. -capabilities = { +capabilities: "Capabilities" = { b'HG20': (), b'bookmarks': (), b'error': (b'abort', b'unsupportedcontent', b'pushraced', b'pushkey'), @@ -1626,7 +1641,8 @@ } -def getrepocaps(repo, allowpushback=False, role=None): +# TODO: drop the default value for 'role' +def getrepocaps(repo, allowpushback: bool = False, role=None) -> "Capabilities": """return the bundle2 capabilities for a given repo Exists to allow extensions (like evolution) to mutate the capabilities. @@ -1675,7 +1691,7 @@ return caps -def bundle2caps(remote): +def bundle2caps(remote) -> "Capabilities": """return the bundle capabilities of a peer as dict""" raw = remote.capable(b'bundle2') if not raw and raw != b'': @@ -1684,7 +1700,7 @@ return decodecaps(capsblob) -def obsmarkersversion(caps): +def obsmarkersversion(caps: "Capabilities"): """extract the list of supported obsmarkers versions from a bundle2caps dict""" obscaps = caps.get(b'obsmarkers', ()) return [int(c[1:]) for c in obscaps if c.startswith(b'V')] @@ -1725,7 +1741,7 @@ msg %= count raise error.ProgrammingError(msg) - caps = {} + caps: "Capabilities" = {} if opts.get(b'obsolescence', False): caps[b'obsmarkers'] = (b'V1',) stream_version = opts.get(b'stream', b"")