mercurial/bundlerepo.py
changeset 52844 dabec69bd6fc
parent 52643 5cc8deb96b48
child 52845 b25208655467
--- a/mercurial/bundlerepo.py	Tue Jan 07 12:51:52 2025 +0100
+++ b/mercurial/bundlerepo.py	Mon Feb 03 18:26:26 2025 +0000
@@ -617,6 +617,36 @@
         raise NotImplementedError
 
 
+class getremotechanges_state_tracker:
+    def __init__(self, peer, incoming, common, rheads):
+        # bundle file to be deleted
+        self.bundle = None
+        # bundle repo to be closed
+        self.bundlerepo = None
+        # remote peer connection to be closed
+        self.peer = peer
+        # if peer is remote, `localrepo` will be equal to
+        # `bundlerepo` when bundle is created.
+        self.localrepo = peer.local()
+
+        # `incoming` operation parameters:
+        # (these get mutated by _create_bundle)
+        self.incoming = incoming
+        self.common = common
+        self.rheads = rheads
+
+    def cleanup(self):
+        try:
+            if self.bundlerepo:
+                self.bundlerepo.close()
+        finally:
+            try:
+                if self.bundle:
+                    os.unlink(self.bundle)
+            finally:
+                self.peer.close()
+
+
 def getremotechanges(
     ui, repo, peer, onlyheads=None, bundlename=None, force=False
 ):
@@ -652,10 +682,16 @@
     commonset = set(common)
     rheads = [x for x in rheads if x not in commonset]
 
-    bundle = None
-    bundlerepo = None
-    localrepo = peer.local()
-    if bundlename or not localrepo:
+    state = getremotechanges_state_tracker(peer, incoming, common, rheads)
+    csets = _getremotechanges_slowpath(
+        state, ui, repo, bundlename=bundlename, onlyheads=onlyheads
+    )
+
+    return (state.localrepo, csets, state.cleanup)
+
+
+def _create_bundle(state, ui, repo, bundlename, onlyheads):
+    if True:
         # create a bundle (uncompressed if peer repo is not local)
 
         # developer config: devel.legacy.exchange
@@ -663,17 +699,17 @@
         forcebundle1 = b'bundle2' not in legexc and b'bundle1' in legexc
         canbundle2 = (
             not forcebundle1
-            and peer.capable(b'getbundle')
-            and peer.capable(b'bundle2')
+            and state.peer.capable(b'getbundle')
+            and state.peer.capable(b'bundle2')
         )
         if canbundle2:
-            with peer.commandexecutor() as e:
+            with state.peer.commandexecutor() as e:
                 b2 = e.callcommand(
                     b'getbundle',
                     {
                         b'source': b'incoming',
-                        b'common': common,
-                        b'heads': rheads,
+                        b'common': state.common,
+                        b'heads': state.rheads,
                         b'bundlecaps': exchange.caps20to10(
                             repo, role=b'client'
                         ),
@@ -681,72 +717,93 @@
                     },
                 ).result()
 
-                fname = bundle = changegroup.writechunks(
+                fname = state.bundle = changegroup.writechunks(
                     ui, b2._forwardchunks(), bundlename
                 )
         else:
-            if peer.capable(b'getbundle'):
-                with peer.commandexecutor() as e:
+            if state.peer.capable(b'getbundle'):
+                with state.peer.commandexecutor() as e:
                     cg = e.callcommand(
                         b'getbundle',
                         {
                             b'source': b'incoming',
-                            b'common': common,
-                            b'heads': rheads,
+                            b'common': state.common,
+                            b'heads': state.rheads,
                         },
                     ).result()
-            elif onlyheads is None and not peer.capable(b'changegroupsubset'):
+            elif onlyheads is None and not state.peer.capable(
+                b'changegroupsubset'
+            ):
                 # compat with older servers when pulling all remote heads
 
-                with peer.commandexecutor() as e:
+                with state.peer.commandexecutor() as e:
                     cg = e.callcommand(
                         b'changegroup',
                         {
-                            b'nodes': incoming,
+                            b'nodes': state.incoming,
+                            b'source': b'incoming',
+                        },
+                    ).result()
+
+                state.rheads = None
+            else:
+                with state.peer.commandexecutor() as e:
+                    cg = e.callcommand(
+                        b'changegroupsubset',
+                        {
+                            b'bases': state.incoming,
+                            b'heads': state.rheads,
                             b'source': b'incoming',
                         },
                     ).result()
 
-                rheads = None
-            else:
-                with peer.commandexecutor() as e:
-                    cg = e.callcommand(
-                        b'changegroupsubset',
-                        {
-                            b'bases': incoming,
-                            b'heads': rheads,
-                            b'source': b'incoming',
-                        },
-                    ).result()
-
-            if localrepo:
+            if state.localrepo:
                 bundletype = b"HG10BZ"
             else:
                 bundletype = b"HG10UN"
-            fname = bundle = bundle2.writebundle(ui, cg, bundlename, bundletype)
+            fname = state.bundle = bundle2.writebundle(
+                ui, cg, bundlename, bundletype
+            )
         # keep written bundle?
         if bundlename:
-            bundle = None
-        if not localrepo:
+            state.bundle = None
+
+        return fname
+
+
+def _getremotechanges_slowpath(
+    state, ui, repo, bundlename=None, onlyheads=None
+):
+    if bundlename or not state.localrepo:
+        fname = _create_bundle(
+            state,
+            ui,
+            repo,
+            bundlename=bundlename,
+            onlyheads=onlyheads,
+        )
+        if not state.localrepo:
             # use the created uncompressed bundlerepo
-            localrepo = bundlerepo = makebundlerepository(
+            state.localrepo = state.bundlerepo = makebundlerepository(
                 repo.baseui, repo.root, fname
             )
 
             # this repo contains local and peer now, so filter out local again
-            common = repo.heads()
-    if localrepo:
+            state.common = repo.heads()
+
+    if state.localrepo:
         # Part of common may be remotely filtered
         # So use an unfiltered version
         # The discovery process probably need cleanup to avoid that
-        localrepo = localrepo.unfiltered()
+        state.localrepo = state.localrepo.unfiltered()
 
-    csets = localrepo.changelog.findmissing(common, rheads)
+    csets = state.localrepo.changelog.findmissing(state.common, state.rheads)
 
-    if bundlerepo:
+    if state.bundlerepo:
+        bundlerepo = state.bundlerepo
         reponodes = [ctx.node() for ctx in bundlerepo[bundlerepo.firstnewrev :]]
 
-        with peer.commandexecutor() as e:
+        with state.peer.commandexecutor() as e:
             remotephases = e.callcommand(
                 b'listkeys',
                 {
@@ -755,16 +812,9 @@
             ).result()
 
         pullop = exchange.pulloperation(
-            bundlerepo, peer, path=None, heads=reponodes
+            bundlerepo, state.peer, path=None, heads=reponodes
         )
         pullop.trmanager = bundletransactionmanager()
         exchange._pullapplyphases(pullop, remotephases)
 
-    def cleanup():
-        if bundlerepo:
-            bundlerepo.close()
-        if bundle:
-            os.unlink(bundle)
-        peer.close()
-
-    return (localrepo, csets, cleanup)
+    return csets