changeset 52805:acae91fad6be

rust-pyo3-revlog: standalone NodeTree class This is the actual first usage of `PyShareable`, but perhaps it could be not so much necessary in this case (we could just reference the `InnerRevlog` python object, and we do not need to keep additional state).
author Georges Racinet <georges.racinet@cloudcrane.io>
date Sun, 22 Dec 2024 17:02:09 +0100
parents 0ac956db7ea7
children 6a70e4931773
files rust/hg-pyo3/src/exceptions.rs rust/hg-pyo3/src/revision.rs rust/hg-pyo3/src/revlog/index.rs rust/hg-pyo3/src/revlog/mod.rs tests/test-rust-revlog.py
diffstat 5 files changed, 166 insertions(+), 33 deletions(-) [+]
line wrap: on
line diff
--- a/rust/hg-pyo3/src/exceptions.rs	Sun Dec 22 21:37:29 2024 +0100
+++ b/rust/hg-pyo3/src/exceptions.rs	Sun Dec 22 17:02:09 2024 +0100
@@ -59,7 +59,6 @@
     mercurial_py_errors::RevlogError::new_err((None::<String>,))
 }
 
-#[allow(dead_code)]
 pub fn rev_not_in_index(rev: UncheckedRevision) -> PyErr {
     PyValueError::new_err(format!("revlog index out of range: {}", rev))
 }
--- a/rust/hg-pyo3/src/revision.rs	Sun Dec 22 21:37:29 2024 +0100
+++ b/rust/hg-pyo3/src/revision.rs	Sun Dec 22 17:02:09 2024 +0100
@@ -41,7 +41,6 @@
     }
 }
 
-#[allow(dead_code)]
 pub fn check_revision(
     index: &impl RevlogIndex,
     rev: impl Into<UncheckedRevision>,
--- a/rust/hg-pyo3/src/revlog/index.rs	Sun Dec 22 21:37:29 2024 +0100
+++ b/rust/hg-pyo3/src/revlog/index.rs	Sun Dec 22 17:02:09 2024 +0100
@@ -10,7 +10,38 @@
 use pyo3::prelude::*;
 use pyo3::types::{PyBytes, PyTuple};
 
-use hg::revlog::index::RevisionDataParams;
+use hg::revlog::{
+    index::{Index, RevisionDataParams},
+    Node, Revision, RevlogIndex,
+};
+
+#[derive(derive_more::From)]
+pub struct PySharedIndex {
+    /// The underlying hg-core index
+    inner: &'static Index,
+}
+
+impl PySharedIndex {
+    /// Return a reference to the inner index, bound by `self`
+    pub fn inner(&self) -> &Index {
+        self.inner
+    }
+
+    /// Return an unsafe "faked" `'static` reference to the inner index, for
+    /// the purposes of Python <-> Rust memory sharing.
+    pub unsafe fn static_inner(&self) -> &'static Index {
+        self.inner
+    }
+}
+
+impl RevlogIndex for PySharedIndex {
+    fn len(&self) -> usize {
+        self.inner.len()
+    }
+    fn node(&self, rev: Revision) -> Option<&Node> {
+        self.inner.node(rev)
+    }
+}
 
 pub fn py_tuple_to_revision_data_params(
     tuple: &Bound<'_, PyTuple>,
--- a/rust/hg-pyo3/src/revlog/mod.rs	Sun Dec 22 21:37:29 2024 +0100
+++ b/rust/hg-pyo3/src/revlog/mod.rs	Sun Dec 22 17:02:09 2024 +0100
@@ -12,7 +12,7 @@
 use pyo3::exceptions::PyIndexError;
 use pyo3::prelude::*;
 use pyo3::types::{PyBytes, PyBytesMethods, PyList, PyTuple};
-use pyo3_sharedref::PyShareable;
+use pyo3_sharedref::{PyShareable, SharedByPyObject};
 
 use std::sync::{
     atomic::{AtomicUsize, Ordering},
@@ -34,11 +34,12 @@
 
 use crate::{
     exceptions::{
-        map_lock_error, map_try_lock_error, nodemap_error, revlog_error_bare,
+        map_lock_error, map_try_lock_error,
+        nodemap_error, rev_not_in_index, revlog_error_bare,
         revlog_error_from_msg,
     },
     node::{node_from_py_bytes, node_prefix_from_py_bytes, py_node_for_rev},
-    revision::PyRevision,
+    revision::{check_revision, PyRevision},
     store::PyFnCache,
     util::{new_submodule, take_buffer_with_slice},
 };
@@ -48,6 +49,7 @@
 mod index;
 use index::{
     py_tuple_to_revision_data_params, revision_data_params_to_py_tuple,
+    PySharedIndex,
 };
 
 #[pyclass]
@@ -468,11 +470,110 @@
     }
 }
 
+#[pyclass]
+struct NodeTree {
+    nt: RwLock<CoreNodeTree>,
+    index: SharedByPyObject<PySharedIndex>,
+}
+
+#[pymethods]
+impl NodeTree {
+    #[new]
+    // The share/mapping should be set apart to become the PyO3 homolog of
+    // `py_rust_index_to_graph`
+    fn new(index_proxy: &Bound<'_, PyAny>) -> PyResult<Self> {
+        let py_irl = index_proxy.getattr("inner")?;
+        let py_irl_ref = py_irl.downcast::<InnerRevlog>()?.borrow();
+        let shareable_irl = &py_irl_ref.irl;
+
+        // Safety: the owner is the actual one and we do not leak any
+        // internal reference.
+        let index = unsafe {
+            shareable_irl.share_map(&py_irl, |irl| (&irl.index).into())
+        };
+        let nt = CoreNodeTree::default(); // in-RAM, fully mutable
+
+        Ok(Self {
+            nt: nt.into(),
+            index,
+        })
+    }
+
+    /// Tell whether the NodeTree is still valid
+    ///
+    /// In case of mutation of the index, the given results are not
+    /// guaranteed to be correct, and in fact, the methods borrowing
+    /// the inner index would fail because of `PySharedRef` poisoning
+    /// (generation-based guard), same as iterating on a `dict` that has
+    /// been meanwhile mutated.
+    fn is_invalidated(&self, py: Python<'_>) -> PyResult<bool> {
+        // Safety: we don't leak any reference derived from self.index, as
+        // we only check errors
+        let result = unsafe { self.index.try_borrow(py) };
+        // two cases for result to be an error:
+        // - the index has previously been mutably borrowed
+        // - there is currently a mutable borrow
+        // in both cases this amounts for previous results related to
+        // the index to still be valid.
+        Ok(result.is_err())
+    }
+
+    fn insert(&self, py: Python<'_>, rev: PyRevision) -> PyResult<()> {
+        // Safety: we don't leak any reference derived from self.index,
+        // as `nt.insert` does not store direct references
+        let idx = &*unsafe { self.index.try_borrow(py)? };
+
+        let rev = check_revision(idx, rev)?;
+        if rev == NULL_REVISION {
+            return Err(rev_not_in_index(rev.into()));
+        }
+
+        let entry = idx.inner().get_entry(rev).expect("entry should exist");
+        let mut nt = self.nt.write().map_err(map_lock_error)?;
+        nt.insert(idx, entry.hash(), rev).map_err(nodemap_error)
+    }
+
+    fn shortest(
+        &self,
+        py: Python<'_>,
+        node: &Bound<'_, PyBytes>,
+    ) -> PyResult<usize> {
+        let nt = self.nt.read().map_err(map_lock_error)?;
+        // Safety: we don't leak any reference derived from self.index
+        // as returned type is Copy
+        let idx = &*unsafe { self.index.try_borrow(py)? };
+        nt.unique_prefix_len_node(idx, &node_from_py_bytes(node)?)
+            .map_err(nodemap_error)?
+            .ok_or_else(revlog_error_bare)
+    }
+
+    /// Lookup by node hex prefix in the NodeTree, returning revision number.
+    ///
+    /// This is not part of the classical NodeTree API, but is good enough
+    /// for unit testing, as in `test-rust-revlog.py`.
+    fn prefix_rev_lookup(
+        &self,
+        py: Python<'_>,
+        node_prefix: &Bound<'_, PyBytes>,
+    ) -> PyResult<Option<PyRevision>> {
+        let prefix = node_prefix_from_py_bytes(node_prefix)?;
+        let nt = self.nt.read().map_err(map_lock_error)?;
+        // Safety: we don't leak any reference derived from self.index
+        // as returned type is Copy
+        let idx = &*unsafe { self.index.try_borrow(py)? };
+        Ok(nt
+            .find_bin(idx, prefix)
+            .map_err(nodemap_error)?
+            .map(|r| r.into()))
+    }
+}
+
 pub fn init_module<'py>(
     py: Python<'py>,
     package: &str,
 ) -> PyResult<Bound<'py, PyModule>> {
     let m = new_submodule(py, package, "revlog")?;
     m.add_class::<InnerRevlog>()?;
+    m.add_class::<NodeTree>()?;
     Ok(m)
 }
--- a/tests/test-rust-revlog.py	Sun Dec 22 21:37:29 2024 +0100
+++ b/tests/test-rust-revlog.py	Sun Dec 22 17:02:09 2024 +0100
@@ -126,6 +126,36 @@
         del idx[0::17]
         self.assertEqual(len(idx), 0)
 
+    def test_standalone_nodetree(self):
+        idx = self.parserustindex()
+        nt = self.nodetree(idx)
+        for i in range(4):
+            nt.insert(i)
+
+        # invalidation is upon mutation *of the index*
+        self.assertFalse(nt.is_invalidated())
+
+        bin_nodes = [entry[7] for entry in idx]
+        hex_nodes = [hex(n) for n in bin_nodes]
+
+        for i, node in enumerate(hex_nodes):
+            self.assertEqual(nt.prefix_rev_lookup(node), i)
+            self.assertEqual(nt.prefix_rev_lookup(node[:5]), i)
+
+        # all 4 revisions in idx (standard data set) have different
+        # first nybbles in their Node IDs,
+        # hence `nt.shortest()` should return 1 for them, except when
+        # the leading nybble is 0 (ambiguity with NULL_NODE)
+        for i, (bin_node, hex_node) in enumerate(zip(bin_nodes, hex_nodes)):
+            shortest = nt.shortest(bin_node)
+            expected = 2 if hex_node[0] == ord('0') else 1
+            self.assertEqual(shortest, expected)
+            self.assertEqual(nt.prefix_rev_lookup(hex_node[:shortest]), i)
+
+        # test invalidation (generation poisoning) detection
+        del idx[3]
+        self.assertTrue(nt.is_invalidated())
+
 
 # Conditional skipping done by the base class
 class RustInnerRevlogTest(
@@ -152,33 +182,6 @@
         # let's check bool for an empty one
         self.assertFalse(LazyAncestors(rustidx, [0], 0, False))
 
-    def test_standalone_nodetree(self):
-        idx = self.parserustindex()
-        nt = self.nodetree(idx)
-        for i in range(4):
-            nt.insert(i)
-
-        bin_nodes = [entry[7] for entry in idx]
-        hex_nodes = [hex(n) for n in bin_nodes]
-
-        for i, node in enumerate(hex_nodes):
-            self.assertEqual(nt.prefix_rev_lookup(node), i)
-            self.assertEqual(nt.prefix_rev_lookup(node[:5]), i)
-
-        # all 4 revisions in idx (standard data set) have different
-        # first nybbles in their Node IDs,
-        # hence `nt.shortest()` should return 1 for them, except when
-        # the leading nybble is 0 (ambiguity with NULL_NODE)
-        for i, (bin_node, hex_node) in enumerate(zip(bin_nodes, hex_nodes)):
-            shortest = nt.shortest(bin_node)
-            expected = 2 if hex_node[0] == ord('0') else 1
-            self.assertEqual(shortest, expected)
-            self.assertEqual(nt.prefix_rev_lookup(hex_node[:shortest]), i)
-
-        # test invalidation (generation poisoning) detection
-        del idx[3]
-        self.assertTrue(nt.is_invalidated())
-
 
 # Conditional skipping done by the base class
 class PyO3InnerRevlogTest(