diff mercurial/bundlerepo.py @ 52857:dabec69bd6fc

refactor: split `bundlerepo.getremotechanges` into three pieces Split the function in three and introduce an explicit class that tracks state. The goal here is to improve error handling, in particular so it's easy to wrap things in a try/with without having to grow an already-significant indentation level. The code can probably be cleaned up further, but I don't want to bite off a piece too large to chew.
author Arseniy Alekseyev <aalekseyev@janestreet.com>
date Mon, 03 Feb 2025 18:26:26 +0000
parents 5cc8deb96b48
children b25208655467
line wrap: on
line diff
--- a/mercurial/bundlerepo.py	Tue Jan 07 12:51:52 2025 +0100
+++ b/mercurial/bundlerepo.py	Mon Feb 03 18:26:26 2025 +0000
@@ -617,6 +617,36 @@
         raise NotImplementedError
 
 
+class getremotechanges_state_tracker:
+    def __init__(self, peer, incoming, common, rheads):
+        # bundle file to be deleted
+        self.bundle = None
+        # bundle repo to be closed
+        self.bundlerepo = None
+        # remote peer connection to be closed
+        self.peer = peer
+        # if peer is remote, `localrepo` will be equal to
+        # `bundlerepo` when bundle is created.
+        self.localrepo = peer.local()
+
+        # `incoming` operation parameters:
+        # (these get mutated by _create_bundle)
+        self.incoming = incoming
+        self.common = common
+        self.rheads = rheads
+
+    def cleanup(self):
+        try:
+            if self.bundlerepo:
+                self.bundlerepo.close()
+        finally:
+            try:
+                if self.bundle:
+                    os.unlink(self.bundle)
+            finally:
+                self.peer.close()
+
+
 def getremotechanges(
     ui, repo, peer, onlyheads=None, bundlename=None, force=False
 ):
@@ -652,10 +682,16 @@
     commonset = set(common)
     rheads = [x for x in rheads if x not in commonset]
 
-    bundle = None
-    bundlerepo = None
-    localrepo = peer.local()
-    if bundlename or not localrepo:
+    state = getremotechanges_state_tracker(peer, incoming, common, rheads)
+    csets = _getremotechanges_slowpath(
+        state, ui, repo, bundlename=bundlename, onlyheads=onlyheads
+    )
+
+    return (state.localrepo, csets, state.cleanup)
+
+
+def _create_bundle(state, ui, repo, bundlename, onlyheads):
+    if True:
         # create a bundle (uncompressed if peer repo is not local)
 
         # developer config: devel.legacy.exchange
@@ -663,17 +699,17 @@
         forcebundle1 = b'bundle2' not in legexc and b'bundle1' in legexc
         canbundle2 = (
             not forcebundle1
-            and peer.capable(b'getbundle')
-            and peer.capable(b'bundle2')
+            and state.peer.capable(b'getbundle')
+            and state.peer.capable(b'bundle2')
         )
         if canbundle2:
-            with peer.commandexecutor() as e:
+            with state.peer.commandexecutor() as e:
                 b2 = e.callcommand(
                     b'getbundle',
                     {
                         b'source': b'incoming',
-                        b'common': common,
-                        b'heads': rheads,
+                        b'common': state.common,
+                        b'heads': state.rheads,
                         b'bundlecaps': exchange.caps20to10(
                             repo, role=b'client'
                         ),
@@ -681,72 +717,93 @@
                     },
                 ).result()
 
-                fname = bundle = changegroup.writechunks(
+                fname = state.bundle = changegroup.writechunks(
                     ui, b2._forwardchunks(), bundlename
                 )
         else:
-            if peer.capable(b'getbundle'):
-                with peer.commandexecutor() as e:
+            if state.peer.capable(b'getbundle'):
+                with state.peer.commandexecutor() as e:
                     cg = e.callcommand(
                         b'getbundle',
                         {
                             b'source': b'incoming',
-                            b'common': common,
-                            b'heads': rheads,
+                            b'common': state.common,
+                            b'heads': state.rheads,
                         },
                     ).result()
-            elif onlyheads is None and not peer.capable(b'changegroupsubset'):
+            elif onlyheads is None and not state.peer.capable(
+                b'changegroupsubset'
+            ):
                 # compat with older servers when pulling all remote heads
 
-                with peer.commandexecutor() as e:
+                with state.peer.commandexecutor() as e:
                     cg = e.callcommand(
                         b'changegroup',
                         {
-                            b'nodes': incoming,
+                            b'nodes': state.incoming,
+                            b'source': b'incoming',
+                        },
+                    ).result()
+
+                state.rheads = None
+            else:
+                with state.peer.commandexecutor() as e:
+                    cg = e.callcommand(
+                        b'changegroupsubset',
+                        {
+                            b'bases': state.incoming,
+                            b'heads': state.rheads,
                             b'source': b'incoming',
                         },
                     ).result()
 
-                rheads = None
-            else:
-                with peer.commandexecutor() as e:
-                    cg = e.callcommand(
-                        b'changegroupsubset',
-                        {
-                            b'bases': incoming,
-                            b'heads': rheads,
-                            b'source': b'incoming',
-                        },
-                    ).result()
-
-            if localrepo:
+            if state.localrepo:
                 bundletype = b"HG10BZ"
             else:
                 bundletype = b"HG10UN"
-            fname = bundle = bundle2.writebundle(ui, cg, bundlename, bundletype)
+            fname = state.bundle = bundle2.writebundle(
+                ui, cg, bundlename, bundletype
+            )
         # keep written bundle?
         if bundlename:
-            bundle = None
-        if not localrepo:
+            state.bundle = None
+
+        return fname
+
+
+def _getremotechanges_slowpath(
+    state, ui, repo, bundlename=None, onlyheads=None
+):
+    if bundlename or not state.localrepo:
+        fname = _create_bundle(
+            state,
+            ui,
+            repo,
+            bundlename=bundlename,
+            onlyheads=onlyheads,
+        )
+        if not state.localrepo:
             # use the created uncompressed bundlerepo
-            localrepo = bundlerepo = makebundlerepository(
+            state.localrepo = state.bundlerepo = makebundlerepository(
                 repo.baseui, repo.root, fname
             )
 
             # this repo contains local and peer now, so filter out local again
-            common = repo.heads()
-    if localrepo:
+            state.common = repo.heads()
+
+    if state.localrepo:
         # Part of common may be remotely filtered
         # So use an unfiltered version
         # The discovery process probably need cleanup to avoid that
-        localrepo = localrepo.unfiltered()
+        state.localrepo = state.localrepo.unfiltered()
 
-    csets = localrepo.changelog.findmissing(common, rheads)
+    csets = state.localrepo.changelog.findmissing(state.common, state.rheads)
 
-    if bundlerepo:
+    if state.bundlerepo:
+        bundlerepo = state.bundlerepo
         reponodes = [ctx.node() for ctx in bundlerepo[bundlerepo.firstnewrev :]]
 
-        with peer.commandexecutor() as e:
+        with state.peer.commandexecutor() as e:
             remotephases = e.callcommand(
                 b'listkeys',
                 {
@@ -755,16 +812,9 @@
             ).result()
 
         pullop = exchange.pulloperation(
-            bundlerepo, peer, path=None, heads=reponodes
+            bundlerepo, state.peer, path=None, heads=reponodes
         )
         pullop.trmanager = bundletransactionmanager()
         exchange._pullapplyphases(pullop, remotephases)
 
-    def cleanup():
-        if bundlerepo:
-            bundlerepo.close()
-        if bundle:
-            os.unlink(bundle)
-        peer.close()
-
-    return (localrepo, csets, cleanup)
+    return csets