diff -r 1151c731686e -r abc3b9801563 mercurial/wireproto.py --- a/mercurial/wireproto.py Fri Mar 02 18:50:49 2018 -0500 +++ b/mercurial/wireproto.py Fri Mar 02 09:47:37 2018 -0500 @@ -592,9 +592,10 @@ class commandentry(object): """Represents a declared wire protocol command.""" - def __init__(self, func, args=''): + def __init__(self, func, args='', transports=None): self.func = func self.args = args + self.transports = transports or set() def _merge(self, func, args): """Merge this instance with an incoming 2-tuple. @@ -604,7 +605,7 @@ data not captured by the 2-tuple and a new instance containing the union of the two objects is returned. """ - return commandentry(func, args=args) + return commandentry(func, args=args, transports=set(self.transports)) # Old code treats instances as 2-tuples. So expose that interface. def __iter__(self): @@ -640,7 +641,9 @@ if k in self: v = self[k]._merge(v[0], v[1]) else: - v = commandentry(v[0], args=v[1]) + # Use default values from @wireprotocommand. + v = commandentry(v[0], args=v[1], + transports=set(wireprototypes.TRANSPORTS)) else: raise ValueError('command entries must be commandentry instances ' 'or 2-tuples') @@ -649,22 +652,52 @@ def commandavailable(self, command, proto): """Determine if a command is available for the requested protocol.""" - # For now, commands are available for all protocols. So do a simple - # membership test. - return command in self + assert proto.name in wireprototypes.TRANSPORTS + + entry = self.get(command) + + if not entry: + return False + + if proto.name not in entry.transports: + return False + + return True + +# Constants specifying which transports a wire protocol command should be +# available on. For use with @wireprotocommand. +POLICY_ALL = 'all' +POLICY_V1_ONLY = 'v1-only' +POLICY_V2_ONLY = 'v2-only' commands = commanddict() -def wireprotocommand(name, args=''): +def wireprotocommand(name, args='', transportpolicy=POLICY_ALL): """Decorator to declare a wire protocol command. ``name`` is the name of the wire protocol command being provided. ``args`` is a space-delimited list of named arguments that the command accepts. ``*`` is a special value that says to accept all arguments. + + ``transportpolicy`` is a POLICY_* constant denoting which transports + this wire protocol command should be exposed to. By default, commands + are exposed to all wire protocol transports. """ + if transportpolicy == POLICY_ALL: + transports = set(wireprototypes.TRANSPORTS) + elif transportpolicy == POLICY_V1_ONLY: + transports = {k for k, v in wireprototypes.TRANSPORTS.items() + if v['version'] == 1} + elif transportpolicy == POLICY_V2_ONLY: + transports = {k for k, v in wireprototypes.TRANSPORTS.items() + if v['version'] == 2} + else: + raise error.Abort(_('invalid transport policy value: %s') % + transportpolicy) + def register(func): - commands[name] = commandentry(func, args=args) + commands[name] = commandentry(func, args=args, transports=transports) return func return register