--- a/mercurial/wireproto.py Fri Mar 02 18:50:49 2018 -0500
+++ b/mercurial/wireproto.py Fri Mar 02 09:47:37 2018 -0500
@@ -592,9 +592,10 @@
class commandentry(object):
"""Represents a declared wire protocol command."""
- def __init__(self, func, args=''):
+ def __init__(self, func, args='', transports=None):
self.func = func
self.args = args
+ self.transports = transports or set()
def _merge(self, func, args):
"""Merge this instance with an incoming 2-tuple.
@@ -604,7 +605,7 @@
data not captured by the 2-tuple and a new instance containing
the union of the two objects is returned.
"""
- return commandentry(func, args=args)
+ return commandentry(func, args=args, transports=set(self.transports))
# Old code treats instances as 2-tuples. So expose that interface.
def __iter__(self):
@@ -640,7 +641,9 @@
if k in self:
v = self[k]._merge(v[0], v[1])
else:
- v = commandentry(v[0], args=v[1])
+ # Use default values from @wireprotocommand.
+ v = commandentry(v[0], args=v[1],
+ transports=set(wireprototypes.TRANSPORTS))
else:
raise ValueError('command entries must be commandentry instances '
'or 2-tuples')
@@ -649,22 +652,52 @@
def commandavailable(self, command, proto):
"""Determine if a command is available for the requested protocol."""
- # For now, commands are available for all protocols. So do a simple
- # membership test.
- return command in self
+ assert proto.name in wireprototypes.TRANSPORTS
+
+ entry = self.get(command)
+
+ if not entry:
+ return False
+
+ if proto.name not in entry.transports:
+ return False
+
+ return True
+
+# Constants specifying which transports a wire protocol command should be
+# available on. For use with @wireprotocommand.
+POLICY_ALL = 'all'
+POLICY_V1_ONLY = 'v1-only'
+POLICY_V2_ONLY = 'v2-only'
commands = commanddict()
-def wireprotocommand(name, args=''):
+def wireprotocommand(name, args='', transportpolicy=POLICY_ALL):
"""Decorator to declare a wire protocol command.
``name`` is the name of the wire protocol command being provided.
``args`` is a space-delimited list of named arguments that the command
accepts. ``*`` is a special value that says to accept all arguments.
+
+ ``transportpolicy`` is a POLICY_* constant denoting which transports
+ this wire protocol command should be exposed to. By default, commands
+ are exposed to all wire protocol transports.
"""
+ if transportpolicy == POLICY_ALL:
+ transports = set(wireprototypes.TRANSPORTS)
+ elif transportpolicy == POLICY_V1_ONLY:
+ transports = {k for k, v in wireprototypes.TRANSPORTS.items()
+ if v['version'] == 1}
+ elif transportpolicy == POLICY_V2_ONLY:
+ transports = {k for k, v in wireprototypes.TRANSPORTS.items()
+ if v['version'] == 2}
+ else:
+ raise error.Abort(_('invalid transport policy value: %s') %
+ transportpolicy)
+
def register(func):
- commands[name] = commandentry(func, args=args)
+ commands[name] = commandentry(func, args=args, transports=transports)
return func
return register