diff mercurial/wireproto.py @ 37535:69e46c1834ac

wireproto: define and expose types of wire command arguments Exposing the set of argument names is cool. But with wire protocol version 2, we're using CBOR to transport arguments and this allows us to have typing for arguments. Typed arguments are much nicer because they will cut down on transfer overhead and processing overhead for decoding values. This commit teaches @wireprotocommand to accept a dictionary for arguments. The arguments registered for version 2 transports are canonically stored as dictionaries rather than a space-delimited string. It is an error to defined arguments with a dictionary for commands using version 1 transports. This reinforces my intent to fully decouple command handlers for version 2 transports. Differential Revision: https://phab.mercurial-scm.org/D3202
author Gregory Szorc <gregory.szorc@gmail.com>
date Fri, 06 Apr 2018 17:14:06 -0700
parents 465187fec06f
children 2003da12f49b
line wrap: on
line diff
--- a/mercurial/wireproto.py	Fri Apr 06 16:49:57 2018 -0700
+++ b/mercurial/wireproto.py	Fri Apr 06 17:14:06 2018 -0700
@@ -713,8 +713,11 @@
 
     ``name`` is the name of the wire protocol command being provided.
 
-    ``args`` is a space-delimited list of named arguments that the command
-    accepts. ``*`` is a special value that says to accept all arguments.
+    ``args`` defines the named arguments accepted by the command. It is
+    ideally a dict mapping argument names to their types. For backwards
+    compatibility, it can be a space-delimited list of argument names. For
+    version 1 transports, ``*`` denotes a special value that says to accept
+    all named arguments.
 
     ``transportpolicy`` is a POLICY_* constant denoting which transports
     this wire protocol command should be exposed to. By default, commands
@@ -752,6 +755,17 @@
                                      'got %s; expected "push" or "pull"' %
                                      permission)
 
+    if 1 in transportversions and not isinstance(args, bytes):
+        raise error.ProgrammingError('arguments for version 1 commands must '
+                                     'be declared as bytes')
+
+    if isinstance(args, bytes):
+        dictargs = {arg: b'legacy' for arg in args.split()}
+    elif isinstance(args, dict):
+        dictargs = args
+    else:
+        raise ValueError('args must be bytes or a dict')
+
     def register(func):
         if 1 in transportversions:
             if name in commands:
@@ -764,7 +778,8 @@
             if name in commandsv2:
                 raise error.ProgrammingError('%s command already registered '
                                              'for version 2' % name)
-            commandsv2[name] = commandentry(func, args=args,
+
+            commandsv2[name] = commandentry(func, args=dictargs,
                                             transports=transports,
                                             permission=permission)
 
@@ -1304,7 +1319,7 @@
 
     for command, entry in commandsv2.items():
         caps['commands'][command] = {
-            'args': sorted(entry.args.split()) if entry.args else [],
+            'args': entry.args,
             'permissions': [entry.permission],
         }
 
@@ -1325,7 +1340,11 @@
 
     return wireprototypes.cborresponse(caps)
 
-@wireprotocommand('heads', args='publiconly', permission='pull',
+@wireprotocommand('heads',
+                  args={
+                      'publiconly': False,
+                  },
+                  permission='pull',
                   transportpolicy=POLICY_V2_ONLY)
 def headsv2(repo, proto, publiconly=False):
     if publiconly:
@@ -1333,14 +1352,22 @@
 
     return wireprototypes.cborresponse(repo.heads())
 
-@wireprotocommand('known', 'nodes', permission='pull',
+@wireprotocommand('known',
+                  args={
+                      'nodes': [b'deadbeef'],
+                  },
+                  permission='pull',
                   transportpolicy=POLICY_V2_ONLY)
 def knownv2(repo, proto, nodes=None):
     nodes = nodes or []
     result = b''.join(b'1' if n else b'0' for n in repo.known(nodes))
     return wireprototypes.cborresponse(result)
 
-@wireprotocommand('listkeys', 'namespace', permission='pull',
+@wireprotocommand('listkeys',
+                  args={
+                      'namespace': b'ns',
+                  },
+                  permission='pull',
                   transportpolicy=POLICY_V2_ONLY)
 def listkeysv2(repo, proto, namespace=None):
     keys = repo.listkeys(encoding.tolocal(namespace))