diff -r 465187fec06f -r 69e46c1834ac mercurial/wireproto.py --- a/mercurial/wireproto.py Fri Apr 06 16:49:57 2018 -0700 +++ b/mercurial/wireproto.py Fri Apr 06 17:14:06 2018 -0700 @@ -713,8 +713,11 @@ ``name`` is the name of the wire protocol command being provided. - ``args`` is a space-delimited list of named arguments that the command - accepts. ``*`` is a special value that says to accept all arguments. + ``args`` defines the named arguments accepted by the command. It is + ideally a dict mapping argument names to their types. For backwards + compatibility, it can be a space-delimited list of argument names. For + version 1 transports, ``*`` denotes a special value that says to accept + all named arguments. ``transportpolicy`` is a POLICY_* constant denoting which transports this wire protocol command should be exposed to. By default, commands @@ -752,6 +755,17 @@ 'got %s; expected "push" or "pull"' % permission) + if 1 in transportversions and not isinstance(args, bytes): + raise error.ProgrammingError('arguments for version 1 commands must ' + 'be declared as bytes') + + if isinstance(args, bytes): + dictargs = {arg: b'legacy' for arg in args.split()} + elif isinstance(args, dict): + dictargs = args + else: + raise ValueError('args must be bytes or a dict') + def register(func): if 1 in transportversions: if name in commands: @@ -764,7 +778,8 @@ if name in commandsv2: raise error.ProgrammingError('%s command already registered ' 'for version 2' % name) - commandsv2[name] = commandentry(func, args=args, + + commandsv2[name] = commandentry(func, args=dictargs, transports=transports, permission=permission) @@ -1304,7 +1319,7 @@ for command, entry in commandsv2.items(): caps['commands'][command] = { - 'args': sorted(entry.args.split()) if entry.args else [], + 'args': entry.args, 'permissions': [entry.permission], } @@ -1325,7 +1340,11 @@ return wireprototypes.cborresponse(caps) -@wireprotocommand('heads', args='publiconly', permission='pull', +@wireprotocommand('heads', + args={ + 'publiconly': False, + }, + permission='pull', transportpolicy=POLICY_V2_ONLY) def headsv2(repo, proto, publiconly=False): if publiconly: @@ -1333,14 +1352,22 @@ return wireprototypes.cborresponse(repo.heads()) -@wireprotocommand('known', 'nodes', permission='pull', +@wireprotocommand('known', + args={ + 'nodes': [b'deadbeef'], + }, + permission='pull', transportpolicy=POLICY_V2_ONLY) def knownv2(repo, proto, nodes=None): nodes = nodes or [] result = b''.join(b'1' if n else b'0' for n in repo.known(nodes)) return wireprototypes.cborresponse(result) -@wireprotocommand('listkeys', 'namespace', permission='pull', +@wireprotocommand('listkeys', + args={ + 'namespace': b'ns', + }, + permission='pull', transportpolicy=POLICY_V2_ONLY) def listkeysv2(repo, proto, namespace=None): keys = repo.listkeys(encoding.tolocal(namespace))