--- a/mercurial/state.py Fri Jul 17 20:24:42 2020 +0200
+++ b/mercurial/state.py Tue Jul 14 13:36:57 2020 -0700
@@ -19,6 +19,8 @@
from __future__ import absolute_import
+import contextlib
+
from .i18n import _
from . import (
@@ -119,6 +121,7 @@
reportonly,
continueflag,
stopflag,
+ childopnames,
cmdmsg,
cmdhint,
statushint,
@@ -132,6 +135,8 @@
self._reportonly = reportonly
self._continueflag = continueflag
self._stopflag = stopflag
+ self._childopnames = childopnames
+ self._delegating = False
self._cmdmsg = cmdmsg
self._cmdhint = cmdhint
self._statushint = statushint
@@ -181,12 +186,15 @@
"""
if self._opname == b'merge':
return len(repo[None].parents()) > 1
+ elif self._delegating:
+ return False
else:
return repo.vfs.exists(self._fname)
# A list of statecheck objects for multistep operations like graft.
_unfinishedstates = []
+_unfinishedstatesbyname = {}
def addunfinished(
@@ -197,6 +205,7 @@
reportonly=False,
continueflag=False,
stopflag=False,
+ childopnames=None,
cmdmsg=b"",
cmdhint=b"",
statushint=b"",
@@ -218,6 +227,8 @@
`--continue` option or not.
stopflag is a boolean that determines whether or not a command supports
--stop flag
+ childopnames is a list of other opnames this op uses as sub-steps of its
+ own execution. They must already be added.
cmdmsg is used to pass a different status message in case standard
message of the format "abort: cmdname in progress" is not desired.
cmdhint is used to pass a different hint message in case standard
@@ -230,6 +241,7 @@
continuefunc stores the function required to finish an interrupted
operation.
"""
+ childopnames = childopnames or []
statecheckobj = _statecheck(
opname,
fname,
@@ -238,17 +250,98 @@
reportonly,
continueflag,
stopflag,
+ childopnames,
cmdmsg,
cmdhint,
statushint,
abortfunc,
continuefunc,
)
+
if opname == b'merge':
_unfinishedstates.append(statecheckobj)
else:
+ # This check enforces that for any op 'foo' which depends on op 'bar',
+ # 'foo' comes before 'bar' in _unfinishedstates. This ensures that
+ # getrepostate() always returns the most specific applicable answer.
+ for childopname in childopnames:
+ if childopname not in _unfinishedstatesbyname:
+ raise error.ProgrammingError(
+ _(b'op %s depends on unknown op %s') % (opname, childopname)
+ )
+
_unfinishedstates.insert(0, statecheckobj)
+ if opname in _unfinishedstatesbyname:
+ raise error.ProgrammingError(_(b'op %s registered twice') % opname)
+ _unfinishedstatesbyname[opname] = statecheckobj
+
+
+def _getparentandchild(opname, childopname):
+ p = _unfinishedstatesbyname.get(opname, None)
+ if not p:
+ raise error.ProgrammingError(_(b'unknown op %s') % opname)
+ if childopname not in p._childopnames:
+ raise error.ProgrammingError(
+ _(b'op %s does not delegate to %s') % (opname, childopname)
+ )
+ c = _unfinishedstatesbyname[childopname]
+ return p, c
+
+
+@contextlib.contextmanager
+def delegating(repo, opname, childopname):
+ """context wrapper for delegations from opname to childopname.
+
+ requires that childopname was specified when opname was registered.
+
+ Usage:
+ def my_command_foo_that_uses_rebase(...):
+ ...
+ with state.delegating(repo, 'foo', 'rebase'):
+ _run_rebase(...)
+ ...
+ """
+
+ p, c = _getparentandchild(opname, childopname)
+ if p._delegating:
+ raise error.ProgrammingError(
+ _(b'cannot delegate from op %s recursively') % opname
+ )
+ p._delegating = True
+ try:
+ yield
+ except error.ConflictResolutionRequired as e:
+ # Rewrite conflict resolution advice for the parent opname.
+ if e.opname == childopname:
+ raise error.ConflictResolutionRequired(opname)
+ raise e
+ finally:
+ p._delegating = False
+
+
+def ischildunfinished(repo, opname, childopname):
+ """Returns true if both opname and childopname are unfinished."""
+
+ p, c = _getparentandchild(opname, childopname)
+ return (p._delegating or p.isunfinished(repo)) and c.isunfinished(repo)
+
+
+def continuechild(ui, repo, opname, childopname):
+ """Checks that childopname is in progress, and continues it."""
+
+ p, c = _getparentandchild(opname, childopname)
+ if not ischildunfinished(repo, opname, childopname):
+ raise error.ProgrammingError(
+ _(b'child op %s of parent %s is not unfinished')
+ % (childopname, opname)
+ )
+ if not c.continuefunc:
+ raise error.ProgrammingError(
+ _(b'op %s has no continue function') % childopname
+ )
+ return c.continuefunc(ui, repo)
+
addunfinished(
b'update',