--- a/mercurial/wireproto.py Sun Mar 04 21:16:36 2018 -0500
+++ b/mercurial/wireproto.py Tue Mar 06 14:32:14 2018 -0800
@@ -672,6 +672,11 @@
commands = commanddict()
+# Maps wire protocol name to operation type. This is used for permissions
+# checking. All defined @wireiprotocommand should have an entry in this
+# dict.
+permissions = {}
+
def wireprotocommand(name, args='', transportpolicy=POLICY_ALL):
"""Decorator to declare a wire protocol command.
@@ -701,6 +706,8 @@
return func
return register
+# TODO define a more appropriate permissions type to use for this.
+permissions['batch'] = 'pull'
@wireprotocommand('batch', 'cmds *')
def batch(repo, proto, cmds, others):
repo = repo.filtered("served")
@@ -713,6 +720,17 @@
n, v = a.split('=')
vals[unescapearg(n)] = unescapearg(v)
func, spec = commands[op]
+
+ # If the protocol supports permissions checking, perform that
+ # checking on each batched command.
+ # TODO formalize permission checking as part of protocol interface.
+ if util.safehasattr(proto, 'checkperm'):
+ # Assume commands with no defined permissions are writes / for
+ # pushes. This is the safest from a security perspective because
+ # it doesn't allow commands with undefined semantics from
+ # bypassing permissions checks.
+ proto.checkperm(permissions.get(op, 'push'))
+
if spec:
keys = spec.split()
data = {}
@@ -740,6 +758,7 @@
return bytesresponse(';'.join(res))
+permissions['between'] = 'pull'
@wireprotocommand('between', 'pairs', transportpolicy=POLICY_V1_ONLY)
def between(repo, proto, pairs):
pairs = [decodelist(p, '-') for p in pairs.split(" ")]
@@ -749,6 +768,7 @@
return bytesresponse(''.join(r))
+permissions['branchmap'] = 'pull'
@wireprotocommand('branchmap')
def branchmap(repo, proto):
branchmap = repo.branchmap()
@@ -760,6 +780,7 @@
return bytesresponse('\n'.join(heads))
+permissions['branches'] = 'pull'
@wireprotocommand('branches', 'nodes', transportpolicy=POLICY_V1_ONLY)
def branches(repo, proto, nodes):
nodes = decodelist(nodes)
@@ -769,6 +790,7 @@
return bytesresponse(''.join(r))
+permissions['clonebundles'] = 'pull'
@wireprotocommand('clonebundles', '')
def clonebundles(repo, proto):
"""Server command for returning info for available bundles to seed clones.
@@ -821,10 +843,12 @@
# If you are writing an extension and consider wrapping this function. Wrap
# `_capabilities` instead.
+permissions['capabilities'] = 'pull'
@wireprotocommand('capabilities')
def capabilities(repo, proto):
return bytesresponse(' '.join(_capabilities(repo, proto)))
+permissions['changegroup'] = 'pull'
@wireprotocommand('changegroup', 'roots', transportpolicy=POLICY_V1_ONLY)
def changegroup(repo, proto, roots):
nodes = decodelist(roots)
@@ -834,6 +858,7 @@
gen = iter(lambda: cg.read(32768), '')
return streamres(gen=gen)
+permissions['changegroupsubset'] = 'pull'
@wireprotocommand('changegroupsubset', 'bases heads',
transportpolicy=POLICY_V1_ONLY)
def changegroupsubset(repo, proto, bases, heads):
@@ -845,6 +870,7 @@
gen = iter(lambda: cg.read(32768), '')
return streamres(gen=gen)
+permissions['debugwireargs'] = 'pull'
@wireprotocommand('debugwireargs', 'one two *')
def debugwireargs(repo, proto, one, two, others):
# only accept optional args from the known set
@@ -852,6 +878,7 @@
return bytesresponse(repo.debugwireargs(one, two,
**pycompat.strkwargs(opts)))
+permissions['getbundle'] = 'pull'
@wireprotocommand('getbundle', '*')
def getbundle(repo, proto, others):
opts = options('getbundle', gboptsmap.keys(), others)
@@ -918,11 +945,13 @@
return streamres(gen=chunks, prefer_uncompressed=not prefercompressed)
+permissions['heads'] = 'pull'
@wireprotocommand('heads')
def heads(repo, proto):
h = repo.heads()
return bytesresponse(encodelist(h) + '\n')
+permissions['hello'] = 'pull'
@wireprotocommand('hello')
def hello(repo, proto):
"""Called as part of SSH handshake to obtain server info.
@@ -938,11 +967,13 @@
caps = capabilities(repo, proto).data
return bytesresponse('capabilities: %s\n' % caps)
+permissions['listkeys'] = 'pull'
@wireprotocommand('listkeys', 'namespace')
def listkeys(repo, proto, namespace):
d = sorted(repo.listkeys(encoding.tolocal(namespace)).items())
return bytesresponse(pushkeymod.encodekeys(d))
+permissions['lookup'] = 'pull'
@wireprotocommand('lookup', 'key')
def lookup(repo, proto, key):
try:
@@ -955,11 +986,13 @@
success = 0
return bytesresponse('%d %s\n' % (success, r))
+permissions['known'] = 'pull'
@wireprotocommand('known', 'nodes *')
def known(repo, proto, nodes, others):
v = ''.join(b and '1' or '0' for b in repo.known(decodelist(nodes)))
return bytesresponse(v)
+permissions['pushkey'] = 'push'
@wireprotocommand('pushkey', 'namespace key old new')
def pushkey(repo, proto, namespace, key, old, new):
# compatibility with pre-1.8 clients which were accidentally
@@ -981,6 +1014,7 @@
output = output.getvalue() if output else ''
return bytesresponse('%d\n%s' % (int(r), output))
+permissions['stream_out'] = 'pull'
@wireprotocommand('stream_out')
def stream(repo, proto):
'''If the server supports streaming clone, it advertises the "stream"
@@ -989,6 +1023,7 @@
'''
return streamres_legacy(streamclone.generatev1wireproto(repo))
+permissions['unbundle'] = 'push'
@wireprotocommand('unbundle', 'heads')
def unbundle(repo, proto, heads):
their_heads = decodelist(heads)