diff -r c87bd1fe3da2 -r 5322e738be0f mercurial/state.py --- a/mercurial/state.py Fri Jul 17 20:24:42 2020 +0200 +++ b/mercurial/state.py Tue Jul 14 13:36:57 2020 -0700 @@ -19,6 +19,8 @@ from __future__ import absolute_import +import contextlib + from .i18n import _ from . import ( @@ -119,6 +121,7 @@ reportonly, continueflag, stopflag, + childopnames, cmdmsg, cmdhint, statushint, @@ -132,6 +135,8 @@ self._reportonly = reportonly self._continueflag = continueflag self._stopflag = stopflag + self._childopnames = childopnames + self._delegating = False self._cmdmsg = cmdmsg self._cmdhint = cmdhint self._statushint = statushint @@ -181,12 +186,15 @@ """ if self._opname == b'merge': return len(repo[None].parents()) > 1 + elif self._delegating: + return False else: return repo.vfs.exists(self._fname) # A list of statecheck objects for multistep operations like graft. _unfinishedstates = [] +_unfinishedstatesbyname = {} def addunfinished( @@ -197,6 +205,7 @@ reportonly=False, continueflag=False, stopflag=False, + childopnames=None, cmdmsg=b"", cmdhint=b"", statushint=b"", @@ -218,6 +227,8 @@ `--continue` option or not. stopflag is a boolean that determines whether or not a command supports --stop flag + childopnames is a list of other opnames this op uses as sub-steps of its + own execution. They must already be added. cmdmsg is used to pass a different status message in case standard message of the format "abort: cmdname in progress" is not desired. cmdhint is used to pass a different hint message in case standard @@ -230,6 +241,7 @@ continuefunc stores the function required to finish an interrupted operation. """ + childopnames = childopnames or [] statecheckobj = _statecheck( opname, fname, @@ -238,17 +250,98 @@ reportonly, continueflag, stopflag, + childopnames, cmdmsg, cmdhint, statushint, abortfunc, continuefunc, ) + if opname == b'merge': _unfinishedstates.append(statecheckobj) else: + # This check enforces that for any op 'foo' which depends on op 'bar', + # 'foo' comes before 'bar' in _unfinishedstates. This ensures that + # getrepostate() always returns the most specific applicable answer. + for childopname in childopnames: + if childopname not in _unfinishedstatesbyname: + raise error.ProgrammingError( + _(b'op %s depends on unknown op %s') % (opname, childopname) + ) + _unfinishedstates.insert(0, statecheckobj) + if opname in _unfinishedstatesbyname: + raise error.ProgrammingError(_(b'op %s registered twice') % opname) + _unfinishedstatesbyname[opname] = statecheckobj + + +def _getparentandchild(opname, childopname): + p = _unfinishedstatesbyname.get(opname, None) + if not p: + raise error.ProgrammingError(_(b'unknown op %s') % opname) + if childopname not in p._childopnames: + raise error.ProgrammingError( + _(b'op %s does not delegate to %s') % (opname, childopname) + ) + c = _unfinishedstatesbyname[childopname] + return p, c + + +@contextlib.contextmanager +def delegating(repo, opname, childopname): + """context wrapper for delegations from opname to childopname. + + requires that childopname was specified when opname was registered. + + Usage: + def my_command_foo_that_uses_rebase(...): + ... + with state.delegating(repo, 'foo', 'rebase'): + _run_rebase(...) + ... + """ + + p, c = _getparentandchild(opname, childopname) + if p._delegating: + raise error.ProgrammingError( + _(b'cannot delegate from op %s recursively') % opname + ) + p._delegating = True + try: + yield + except error.ConflictResolutionRequired as e: + # Rewrite conflict resolution advice for the parent opname. + if e.opname == childopname: + raise error.ConflictResolutionRequired(opname) + raise e + finally: + p._delegating = False + + +def ischildunfinished(repo, opname, childopname): + """Returns true if both opname and childopname are unfinished.""" + + p, c = _getparentandchild(opname, childopname) + return (p._delegating or p.isunfinished(repo)) and c.isunfinished(repo) + + +def continuechild(ui, repo, opname, childopname): + """Checks that childopname is in progress, and continues it.""" + + p, c = _getparentandchild(opname, childopname) + if not ischildunfinished(repo, opname, childopname): + raise error.ProgrammingError( + _(b'child op %s of parent %s is not unfinished') + % (childopname, opname) + ) + if not c.continuefunc: + raise error.ProgrammingError( + _(b'op %s has no continue function') % childopname + ) + return c.continuefunc(ui, repo) + addunfinished( b'update',