mercurial/sshpeer.py
changeset 35976 48a3a9283f09
parent 35940 556218e08e25
child 35977 625038cb4b1d
--- a/mercurial/sshpeer.py	Tue Feb 06 10:51:15 2018 -0800
+++ b/mercurial/sshpeer.py	Tue Feb 06 11:08:36 2018 -0800
@@ -8,6 +8,7 @@
 from __future__ import absolute_import
 
 import re
+import uuid
 
 from .i18n import _
 from . import (
@@ -15,6 +16,7 @@
     pycompat,
     util,
     wireproto,
+    wireprotoserver,
 )
 
 def _serverquote(s):
@@ -162,15 +164,24 @@
         hint = ui.config('ui', 'ssherrorhint')
         raise error.RepoError(msg, hint=hint)
 
-    # The handshake consists of sending 2 wire protocol commands:
-    # ``hello`` and ``between``.
+    # The handshake consists of sending wire protocol commands in reverse
+    # order of protocol implementation and then sniffing for a response
+    # to one of them.
+    #
+    # Those commands (from oldest to newest) are:
     #
-    # The ``hello`` command (which was introduced in Mercurial 0.9.1)
-    # instructs the server to advertise its capabilities.
+    # ``between``
+    #   Asks for the set of revisions between a pair of revisions. Command
+    #   present in all Mercurial server implementations.
     #
-    # The ``between`` command (which has existed in all Mercurial servers
-    # for as long as SSH support has existed), asks for the set of revisions
-    # between a pair of revisions.
+    # ``hello``
+    #   Instructs the server to advertise its capabilities. Introduced in
+    #   Mercurial 0.9.1.
+    #
+    # ``upgrade``
+    #   Requests upgrade from default transport protocol version 1 to
+    #   a newer version. Introduced in Mercurial 4.6 as an experimental
+    #   feature.
     #
     # The ``between`` command is issued with a request for the null
     # range. If the remote is a Mercurial server, this request will
@@ -186,6 +197,18 @@
     # RFC 822 like lines. Of these, the ``capabilities:`` line contains
     # the capabilities of the server.
     #
+    # The ``upgrade`` command isn't really a command in the traditional
+    # sense of version 1 of the transport because it isn't using the
+    # proper mechanism for formatting insteads: instead, it just encodes
+    # arguments on the line, delimited by spaces.
+    #
+    # The ``upgrade`` line looks like ``upgrade <token> <capabilities>``.
+    # If the server doesn't support protocol upgrades, it will reply to
+    # this line with ``0\n``. Otherwise, it emits an
+    # ``upgraded <token> <protocol>`` line to both stdout and stderr.
+    # Content immediately following this line describes additional
+    # protocol and server state.
+    #
     # In addition to the responses to our command requests, the server
     # may emit "banner" output on stdout. SSH servers are allowed to
     # print messages to stdout on login. Issuing commands on connection
@@ -195,6 +218,14 @@
 
     requestlog = ui.configbool('devel', 'debug.peer-request')
 
+    # Generate a random token to help identify responses to version 2
+    # upgrade request.
+    token = bytes(uuid.uuid4())
+    upgradecaps = [
+        ('proto', wireprotoserver.SSHV2),
+    ]
+    upgradecaps = util.urlreq.urlencode(upgradecaps)
+
     try:
         pairsarg = '%s-%s' % ('0' * 40, '0' * 40)
         handshake = [
@@ -204,6 +235,11 @@
             pairsarg,
         ]
 
+        # Request upgrade to version 2 if configured.
+        if ui.configbool('experimental', 'sshpeer.advertise-v2'):
+            ui.debug('sending upgrade request: %s %s\n' % (token, upgradecaps))
+            handshake.insert(0, 'upgrade %s %s\n' % (token, upgradecaps))
+
         if requestlog:
             ui.debug('devel-peer-request: hello\n')
         ui.debug('sending hello command\n')
@@ -217,12 +253,31 @@
     except IOError:
         badresponse()
 
+    # Assume version 1 of wire protocol by default.
+    protoname = wireprotoserver.SSHV1
+    reupgraded = re.compile(b'^upgraded %s (.*)$' % re.escape(token))
+
     lines = ['', 'dummy']
     max_noise = 500
     while lines[-1] and max_noise:
         try:
             l = stdout.readline()
             _forwardoutput(ui, stderr)
+
+            # Look for reply to protocol upgrade request. It has a token
+            # in it, so there should be no false positives.
+            m = reupgraded.match(l)
+            if m:
+                protoname = m.group(1)
+                ui.debug('protocol upgraded to %s\n' % protoname)
+                # If an upgrade was handled, the ``hello`` and ``between``
+                # requests are ignored. The next output belongs to the
+                # protocol, so stop scanning lines.
+                break
+
+            # Otherwise it could be a banner, ``0\n`` response if server
+            # doesn't support upgrade.
+
             if lines[-1] == '1\n' and l == '\n':
                 break
             if l:
@@ -235,20 +290,39 @@
         badresponse()
 
     caps = set()
-    for l in reversed(lines):
-        # Look for response to ``hello`` command. Scan from the back so
-        # we don't misinterpret banner output as the command reply.
-        if l.startswith('capabilities:'):
-            caps.update(l[:-1].split(':')[1].split())
-            break
 
-    # Error if we couldn't find a response to ``hello``. This could
-    # mean:
+    # For version 1, we should see a ``capabilities`` line in response to the
+    # ``hello`` command.
+    if protoname == wireprotoserver.SSHV1:
+        for l in reversed(lines):
+            # Look for response to ``hello`` command. Scan from the back so
+            # we don't misinterpret banner output as the command reply.
+            if l.startswith('capabilities:'):
+                caps.update(l[:-1].split(':')[1].split())
+                break
+    elif protoname == wireprotoserver.SSHV2:
+        # We see a line with number of bytes to follow and then a value
+        # looking like ``capabilities: *``.
+        line = stdout.readline()
+        try:
+            valuelen = int(line)
+        except ValueError:
+            badresponse()
+
+        capsline = stdout.read(valuelen)
+        if not capsline.startswith('capabilities: '):
+            badresponse()
+
+        caps.update(capsline.split(':')[1].split())
+        # Trailing newline.
+        stdout.read(1)
+
+    # Error if we couldn't find capabilities, this means:
     #
     # 1. Remote isn't a Mercurial server
     # 2. Remote is a <0.9.1 Mercurial server
     # 3. Remote is a future Mercurial server that dropped ``hello``
-    #    support.
+    #    and other attempted handshake mechanisms.
     if not caps:
         badresponse()