mercurial/revlog.py
changeset 52163 7346f93be7a4
parent 52162 13815c9decd4
child 52172 72bc29f01570
--- a/mercurial/revlog.py	Wed Jun 19 17:03:13 2024 +0200
+++ b/mercurial/revlog.py	Wed Jun 19 19:10:49 2024 +0200
@@ -17,7 +17,6 @@
 import binascii
 import collections
 import contextlib
-import functools
 import io
 import os
 import struct
@@ -83,6 +82,7 @@
 if typing.TYPE_CHECKING:
     # noinspection PyPackageRequirements
     import attr
+    from .pure.parsers import BaseIndexObject
 
 from . import (
     ancestor,
@@ -381,7 +381,7 @@
         default_compression_header,
     ):
         self.opener = opener
-        self.index = index
+        self.index: BaseIndexObject = index
 
         self.index_file = index_file
         self.data_file = data_file
@@ -528,7 +528,9 @@
         generaldelta = self.delta_config.general_delta
         # Try C implementation.
         try:
-            return self.index.deltachain(rev, stoprev, generaldelta)
+            return self.index.deltachain(
+                rev, stoprev, generaldelta
+            )  # pytype: disable=attribute-error
         except AttributeError:
             pass
 
@@ -1246,6 +1248,71 @@
         return self.canonical_index_file
 
 
+if typing.TYPE_CHECKING:
+    # Tell Pytype what kind of object we expect
+    ProxyBase = BaseIndexObject
+else:
+    ProxyBase = object
+
+
+class RustIndexProxy(ProxyBase):
+    """Wrapper around the Rust index to fake having direct access to the index.
+
+    Rust enforces xor mutability (one mutable reference XOR 1..n non-mutable),
+    so we can't expose the index from Rust directly, since the `InnerRevlog`
+    already has ownership of the index. This object redirects all calls to the
+    index through the Rust-backed `InnerRevlog` glue which defines all
+    necessary forwarding methods.
+    """
+
+    def __init__(self, inner):
+        # Do not rename as it's being used to access the index from Rust
+        self.inner = inner
+
+    # TODO possibly write all index methods manually to save on overhead?
+    def __getattr__(self, name):
+        return getattr(self.inner, f"_index_{name}")
+
+    # Magic methods need to be defined explicitely
+    def __len__(self):
+        return self.inner._index___len__()
+
+    def __getitem__(self, key):
+        return self.inner._index___getitem__(key)
+
+    def __contains__(self, key):
+        return self.inner._index___contains__(key)
+
+    def __delitem__(self, key):
+        return self.inner._index___delitem__(key)
+
+
+class RustVFSWrapper:
+    """Used to wrap a Python VFS to pass it to Rust to lower the overhead of
+    calling back multiple times into Python.
+    """
+
+    def __init__(self, inner):
+        self.inner = inner
+
+    def __call__(
+        self,
+        path: bytes,
+        mode: bytes = b"rb",
+        atomictemp=False,
+        checkambig=False,
+    ):
+        fd = self.inner.__call__(
+            path=path, mode=mode, atomictemp=atomictemp, checkambig=checkambig
+        )
+        # Information that Rust needs to get ownership of the file that's
+        # being opened.
+        return (os.dup(fd.fileno()), fd._tempname if atomictemp else None)
+
+    def __getattr__(self, name):
+        return getattr(self.inner, name)
+
+
 class revlog:
     """
     the underlying revision storage object
@@ -1358,6 +1425,7 @@
         self._trypending = trypending
         self._try_split = try_split
         self._may_inline = may_inline
+        self.uses_rust = False
         self.opener = opener
         if persistentnodemap:
             self._nodemap_file = nodemaputil.get_nodemap_file(self)
@@ -1392,7 +1460,7 @@
         # Maps rev to chain base rev.
         self._chainbasecache = util.lrucachedict(100)
 
-        self.index = None
+        self.index: Optional[BaseIndexObject] = None
         self._docket = None
         self._nodemap_docket = None
         # Mapping of partial identifiers to full nodes.
@@ -1406,8 +1474,8 @@
         # prevent nesting of addgroup
         self._adding_group = None
 
-        chunk_cache = self._loadindex()
-        self._load_inner(chunk_cache)
+        index, chunk_cache = self._loadindex()
+        self._load_inner(index, chunk_cache)
         self._concurrencychecker = concurrencychecker
 
     def _init_opts(self):
@@ -1707,7 +1775,12 @@
         )
 
         use_rust_index = False
-        if rustrevlog is not None and self._nodemap_file is not None:
+        rust_applicable = self._nodemap_file is not None
+        rust_applicable = rust_applicable or self.target[0] == KIND_FILELOG
+        rust_applicable = rust_applicable and getattr(
+            self.opener, "rust_compatible", True
+        )
+        if rustrevlog is not None and rust_applicable:
             # we would like to use the rust_index in all case, especially
             # because it is necessary for AncestorsIterator and LazyAncestors
             # since the 6.7 cycle.
@@ -1717,6 +1790,9 @@
             # repository.
             use_rust_index = True
 
+            if self._format_version != REVLOGV1:
+                use_rust_index = False
+
         self._parse_index = parse_index_v1
         if self._format_version == REVLOGV0:
             self._parse_index = revlogv0.parse_index_v0
@@ -1726,58 +1802,84 @@
             self._parse_index = parse_index_cl_v2
         elif devel_nodemap:
             self._parse_index = parse_index_v1_nodemap
-        elif use_rust_index:
-            self._parse_index = functools.partial(
-                parse_index_v1_rust, default_header=new_header
-            )
-        try:
-            d = self._parse_index(index_data, self._inline)
-            index, chunkcache = d
-            use_nodemap = (
-                not self._inline
-                and self._nodemap_file is not None
-                and hasattr(index, 'update_nodemap_data')
-            )
-            if use_nodemap:
-                nodemap_data = nodemaputil.persisted_data(self)
-                if nodemap_data is not None:
-                    docket = nodemap_data[0]
-                    if (
-                        len(d[0]) > docket.tip_rev
-                        and d[0][docket.tip_rev][7] == docket.tip_node
-                    ):
-                        # no changelog tampering
-                        self._nodemap_docket = docket
-                        index.update_nodemap_data(*nodemap_data)
-        except (ValueError, IndexError):
-            raise error.RevlogError(
-                _(b"index %s is corrupted") % self.display_id
-            )
-        self.index = index
+
+        if use_rust_index:
+            # Let the Rust code parse its own index
+            index, chunkcache = (index_data, None)
+            self.uses_rust = True
+        else:
+            try:
+                d = self._parse_index(index_data, self._inline)
+                index, chunkcache = d
+                self._register_nodemap_info(index)
+            except (ValueError, IndexError):
+                raise error.RevlogError(
+                    _(b"index %s is corrupted") % self.display_id
+                )
         # revnum -> (chain-length, sum-delta-length)
         self._chaininfocache = util.lrucachedict(500)
 
-        return chunkcache
-
-    def _load_inner(self, chunk_cache):
+        return index, chunkcache
+
+    def _load_inner(self, index, chunk_cache):
         if self._docket is None:
             default_compression_header = None
         else:
             default_compression_header = self._docket.default_compression_header
 
-        self._inner = _InnerRevlog(
-            opener=self.opener,
-            index=self.index,
-            index_file=self._indexfile,
-            data_file=self._datafile,
-            sidedata_file=self._sidedatafile,
-            inline=self._inline,
-            data_config=self.data_config,
-            delta_config=self.delta_config,
-            feature_config=self.feature_config,
-            chunk_cache=chunk_cache,
-            default_compression_header=default_compression_header,
+        if self.uses_rust:
+            self._inner = rustrevlog.InnerRevlog(
+                opener=RustVFSWrapper(self.opener),
+                index_data=index,
+                index_file=self._indexfile,
+                data_file=self._datafile,
+                sidedata_file=self._sidedatafile,
+                inline=self._inline,
+                data_config=self.data_config,
+                delta_config=self.delta_config,
+                feature_config=self.feature_config,
+                chunk_cache=chunk_cache,
+                default_compression_header=default_compression_header,
+                revlog_type=self.target[0],
+            )
+            self.index = RustIndexProxy(self._inner)
+            self._register_nodemap_info(self.index)
+            self.uses_rust = True
+        else:
+            self._inner = _InnerRevlog(
+                opener=self.opener,
+                index=index,
+                index_file=self._indexfile,
+                data_file=self._datafile,
+                sidedata_file=self._sidedatafile,
+                inline=self._inline,
+                data_config=self.data_config,
+                delta_config=self.delta_config,
+                feature_config=self.feature_config,
+                chunk_cache=chunk_cache,
+                default_compression_header=default_compression_header,
+            )
+            self.index = self._inner.index
+
+    def _register_nodemap_info(self, index):
+        use_nodemap = (
+            not self._inline
+            and self._nodemap_file is not None
+            and hasattr(index, 'update_nodemap_data')
         )
+        if use_nodemap:
+            nodemap_data = nodemaputil.persisted_data(self)
+            if nodemap_data is not None:
+                docket = nodemap_data[0]
+                if (
+                    len(index) > docket.tip_rev
+                    and index[docket.tip_rev][7] == docket.tip_node
+                ):
+                    # no changelog tampering
+                    self._nodemap_docket = docket
+                    index.update_nodemap_data(
+                        *nodemap_data
+                    )  # pytype: disable=attribute-error
 
     def get_revlog(self):
         """simple function to mirror API of other not-really-revlog API"""
@@ -1869,7 +1971,9 @@
             nodemap_data = nodemaputil.persisted_data(self)
             if nodemap_data is not None:
                 self._nodemap_docket = nodemap_data[0]
-                self.index.update_nodemap_data(*nodemap_data)
+                self.index.update_nodemap_data(
+                    *nodemap_data
+                )  # pytype: disable=attribute-error
 
     def rev(self, node):
         """return the revision number associated with a <nodeid>"""
@@ -2368,23 +2472,26 @@
     def headrevs(self, revs=None, stop_rev=None):
         if revs is None:
             return self.index.headrevs(None, stop_rev)
-        assert stop_rev is None
         if rustdagop is not None and self.index.rust_ext_compat:
             return rustdagop.headrevs(self.index, revs)
         return dagop.headrevs(revs, self._uncheckedparentrevs)
 
     def headrevsdiff(self, start, stop):
         try:
-            return self.index.headrevsdiff(start, stop)
+            return self.index.headrevsdiff(
+                start, stop
+            )  # pytype: disable=attribute-error
         except AttributeError:
             return dagop.headrevsdiff(self._uncheckedparentrevs, start, stop)
 
     def computephases(self, roots):
-        return self.index.computephasesmapsets(roots)
+        return self.index.computephasesmapsets(
+            roots
+        )  # pytype: disable=attribute-error
 
     def _head_node_ids(self):
         try:
-            return self.index.head_node_ids()
+            return self.index.head_node_ids()  # pytype: disable=attribute-error
         except AttributeError:
             return [self.node(r) for r in self.headrevs()]
 
@@ -2442,7 +2549,9 @@
     def _commonancestorsheads(self, *revs):
         """calculate all the heads of the common ancestors of revs"""
         try:
-            ancs = self.index.commonancestorsheads(*revs)
+            ancs = self.index.commonancestorsheads(
+                *revs
+            )  # pytype: disable=attribute-error
         except (AttributeError, OverflowError):  # C implementation failed
             ancs = ancestor.commonancestorsheads(self.parentrevs, *revs)
         return ancs
@@ -2476,7 +2585,7 @@
         try:
             return self.index.reachableroots2(
                 minroot, heads, roots, includepath
-            )
+            )  # pytype: disable=attribute-error
         except AttributeError:
             return dagop._reachablerootspure(
                 self.parentrevs, minroot, roots, heads, includepath
@@ -2487,7 +2596,7 @@
 
         a, b = self.rev(a), self.rev(b)
         try:
-            ancs = self.index.ancestors(a, b)
+            ancs = self.index.ancestors(a, b)  # pytype: disable=attribute-error
         except (AttributeError, OverflowError):
             ancs = ancestor.ancestors(self.parentrevs, a, b)
         if ancs:
@@ -2534,7 +2643,9 @@
         maybewdir = self.nodeconstants.wdirhex.startswith(id)
         ambiguous = False
         try:
-            partial = self.index.partialmatch(id)
+            partial = self.index.partialmatch(
+                id
+            )  # pytype: disable=attribute-error
             if partial and self.hasnode(partial):
                 if maybewdir:
                     # single 'ff...' match in radix tree, ambiguous with wdir
@@ -2636,7 +2747,10 @@
 
         if not getattr(self, 'filteredrevs', None):
             try:
-                length = max(self.index.shortest(node), minlength)
+                shortest = self.index.shortest(
+                    node
+                )  # pytype: disable=attribute-error
+                length = max(shortest, minlength)
                 return disambiguate(hexnode, length)
             except error.RevlogError:
                 if node != self.nodeconstants.wdirid:
@@ -4089,7 +4203,9 @@
             ifh.seek(startrev * self.index.entry_size)
             for i, e in enumerate(new_entries):
                 rev = startrev + i
-                self.index.replace_sidedata_info(rev, *e)
+                self.index.replace_sidedata_info(
+                    rev, *e
+                )  # pytype: disable=attribute-error
                 packed = self.index.entry_binary(rev)
                 if rev == 0 and self._docket is None:
                     header = self._format_flags | self._format_version