mercurial/cmd_impls/graft.py
changeset 52329 5ab77b93567c
parent 52328 f2fc0a91faca
child 52330 8572e80f978c
--- a/mercurial/cmd_impls/graft.py	Tue Nov 19 15:46:12 2024 +0100
+++ b/mercurial/cmd_impls/graft.py	Tue Nov 19 15:54:20 2024 +0100
@@ -7,6 +7,27 @@
 
 def cmd_graft(ui, repo, *revs, **opts):
     """implement the graft command as defined in mercuria/commands.py"""
+    ret = _process_args(ui, repo, *revs, **opts)
+    if ret is None:
+        return -1
+    action, args = ret
+    if action == "ABORT":
+        return cmdutil.abortgraft(ui, repo, *args)
+    elif action == "STOP":
+        return _stopgraft(ui, repo, *args)
+    elif action == "GRAFT":
+        return _graft_revisions(ui, repo, *args)
+    else:
+        raise error.ProgrammingError(b'unknown action: %s' % action)
+    return 0
+
+
+def _process_args(ui, repo, *revs, **opts):
+    """process the graft command argument to figure out what to do
+
+    This also filter the selected revision to skip the one that cannot be graft
+    or were alredy grafted.
+    """
     if revs and opts.get('rev'):
         ui.warn(
             _(
@@ -52,7 +73,7 @@
                 'rev',
             ],
         )
-        return _stopgraft(ui, repo, graftstate)
+        return "STOP", [graftstate]
     elif opts.get('abort'):
         cmdutil.check_incompatible_arguments(
             opts,
@@ -67,7 +88,7 @@
                 'rev',
             ],
         )
-        return cmdutil.abortgraft(ui, repo, graftstate)
+        return "ABORT", [graftstate]
     elif opts.get('continue'):
         cont = True
         if revs:
@@ -89,9 +110,9 @@
             revs = [repo[node].rev() for node in nodes]
         else:
             cmdutil.wrongtooltocontinue(repo, _(b'graft'))
+    elif not revs:
+        raise error.InputError(_(b'no revisions specified'))
     else:
-        if not revs:
-            raise error.InputError(_(b'no revisions specified'))
         cmdutil.checkunfinished(repo)
         cmdutil.bailifchanged(repo)
         revs = logcmdutil.revrange(repo, revs)
@@ -107,7 +128,7 @@
             skipped.add(rev)
     revs = [r for r in revs if r not in skipped]
     if not revs:
-        return -1
+        return None
     if basectx is not None and len(revs) != 1:
         raise error.InputError(_(b'only one revision allowed with --base '))
 
@@ -126,7 +147,7 @@
         revs = [r for r in revs if r not in ancestors]
 
         if not revs:
-            return -1
+            return None
 
         # analyze revs for earlier grafts
         ids = {}
@@ -188,12 +209,28 @@
                     )
                     revs.remove(r)
         if not revs:
-            return -1
+            return None
 
     if opts.get('no_commit'):
         statedata[b'no_commit'] = True
     if opts.get('base'):
         statedata[b'base'] = opts['base']
+
+    return "GRAFT", [graftstate, statedata, revs, editor, basectx, cont, opts]
+
+
+def _graft_revisions(
+    ui,
+    repo,
+    graftstate,
+    statedata,
+    revs,
+    editor,
+    basectx,
+    cont=False,
+    opts,
+):
+    """actually graft some revisions"""
     for pos, ctx in enumerate(repo.set(b"%ld", revs)):
         desc = b'%d:%s "%s"' % (
             ctx.rev(),