Mercurial > public > mercurial-scm > hg
changeset 52535:507fec66014f
rust-pyo3: MissingAncestors
author | Georges Racinet <georges.racinet@cloudcrane.io> |
---|---|
date | Sat, 07 Dec 2024 16:43:30 +0100 |
parents | 9af0330788a5 |
children | 98dcbe752dfe |
files | rust/hg-pyo3/src/ancestors.rs tests/test-rust-ancestor.py |
diffstat | 2 files changed, 134 insertions(+), 6 deletions(-) [+] |
line wrap: on
line diff
--- a/rust/hg-pyo3/src/ancestors.rs Sat Dec 07 14:55:42 2024 +0100 +++ b/rust/hg-pyo3/src/ancestors.rs Sat Dec 07 16:43:30 2024 +0100 @@ -10,9 +10,12 @@ //! and can be used as replacement for the the pure `ancestor` Python module. use cpython::UnsafePyLeaked; use pyo3::prelude::*; +use pyo3::types::PyTuple; +use std::collections::HashSet; use std::sync::RwLock; +use hg::MissingAncestors as CoreMissing; use vcsgraph::lazy_ancestors::{ AncestorsIterator as VCGAncestorsIterator, LazyAncestors as VCGLazyAncestors, @@ -153,6 +156,130 @@ } } +#[pyclass] +struct MissingAncestors { + inner: RwLock<UnsafePyLeaked<CoreMissing<PySharedIndex>>>, + proxy_index: PyObject, +} + +#[pymethods] +impl MissingAncestors { + #[new] + fn new( + index_proxy: &Bound<'_, PyAny>, + bases: &Bound<'_, PyAny>, + ) -> PyResult<Self> { + let cloned_proxy = index_proxy.clone().unbind(); + let bases_vec: Vec<_> = + rev_pyiter_collect_with_py_index(bases, index_proxy)?; + let (py, leaked_idx) = proxy_index_py_leak(index_proxy)?; + + // Safety: we don't leak the "faked" reference out of + // `UnsafePyLeaked` + let inner = unsafe { + leaked_idx.map(py, |idx| CoreMissing::new(idx, bases_vec)) + }; + Ok(Self { + inner: inner.into(), + proxy_index: cloned_proxy, + }) + } + + fn hasbases(slf: PyRef<'_, Self>) -> PyResult<bool> { + let leaked = slf.inner.read().map_err(map_lock_error)?; + // Safety: we don't leak the "faked" reference out of `UnsafePyLeaked` + let inner = unsafe { py_leaked_borrow(&slf, &leaked) }?; + Ok(inner.has_bases()) + } + + fn addbases( + slf: PyRefMut<'_, Self>, + bases: &Bound<'_, PyAny>, + ) -> PyResult<()> { + let index_proxy = slf.proxy_index.bind(slf.py()); + let bases_vec: Vec<_> = + rev_pyiter_collect_with_py_index(bases, index_proxy)?; + + let mut leaked = slf.inner.write().map_err(map_lock_error)?; + // Safety: we don't leak the "faked" reference out of `UnsafePyLeaked` + let mut inner = unsafe { py_leaked_borrow_mut(&slf, &mut leaked) }?; + inner.add_bases(bases_vec); + Ok(()) + } + + fn bases(slf: PyRef<'_, Self>) -> PyResult<HashSet<PyRevision>> { + let leaked = slf.inner.read().map_err(map_lock_error)?; + // Safety: we don't leak the "faked" reference out of `UnsafePyLeaked` + let inner = unsafe { py_leaked_borrow(&slf, &leaked) }?; + Ok(inner.get_bases().iter().map(|r| PyRevision(r.0)).collect()) + } + + fn basesheads(slf: PyRef<'_, Self>) -> PyResult<HashSet<PyRevision>> { + let leaked = slf.inner.read().map_err(map_lock_error)?; + // Safety: we don't leak the "faked" reference out of `UnsafePyLeaked` + let inner = unsafe { py_leaked_borrow(&slf, &leaked) }?; + Ok(inner + .bases_heads() + .map_err(GraphError::from_hg)? + .iter() + .map(|r| PyRevision(r.0)) + .collect()) + } + + fn removeancestorsfrom( + slf: PyRef<'_, Self>, + revs: &Bound<'_, PyAny>, + ) -> PyResult<()> { + // Original comment from hg-cpython: + // this is very lame: we convert to a Rust set, update it in place + // and then convert back to Python, only to have Python remove the + // excess (thankfully, Python is happy with a list or even an + // iterator) + // Leads to improve this: + // - have the CoreMissing instead do something emit revisions to + // discard + // - define a trait for sets of revisions in the core and implement + // it for a Python set rewrapped with the GIL marker + // PyO3 additional comment: the trait approach would probably be + // simpler because we can implement it without a Py wrappper, just + // on &Bound<'py, PySet> + let index_proxy = slf.proxy_index.bind(slf.py()); + let mut revs_set: HashSet<_> = + rev_pyiter_collect_with_py_index(revs, index_proxy)?; + + let mut leaked = slf.inner.write().map_err(map_lock_error)?; + // Safety: we don't leak the "faked" reference out of `UnsafePyLeaked` + let mut inner = unsafe { py_leaked_borrow_mut(&slf, &mut leaked) }?; + + inner + .remove_ancestors_from(&mut revs_set) + .map_err(GraphError::from_hg)?; + // convert as Python tuple and discard from original `revs` + let remaining_tuple = + PyTuple::new(slf.py(), revs_set.iter().map(|r| PyRevision(r.0)))?; + revs.call_method("intersection_update", (remaining_tuple,), None)?; + Ok(()) + } + + fn missingancestors( + slf: PyRefMut<'_, Self>, + bases: &Bound<'_, PyAny>, + ) -> PyResult<Vec<PyRevision>> { + let index_proxy = slf.proxy_index.bind(slf.py()); + let revs_vec: Vec<_> = + rev_pyiter_collect_with_py_index(bases, index_proxy)?; + + let mut leaked = slf.inner.write().map_err(map_lock_error)?; + // Safety: we don't leak the "faked" reference out of `UnsafePyLeaked` + let mut inner = unsafe { py_leaked_borrow_mut(&slf, &mut leaked) }?; + + let missing_vec = inner + .missing_ancestors(revs_vec) + .map_err(GraphError::from_hg)?; + Ok(missing_vec.iter().map(|r| PyRevision(r.0)).collect()) + } +} + pub fn init_module<'py>( py: Python<'py>, package: &str, @@ -160,5 +287,6 @@ let m = new_submodule(py, package, "ancestor")?; m.add_class::<AncestorsIterator>()?; m.add_class::<LazyAncestors>()?; + m.add_class::<MissingAncestors>()?; Ok(m) }
--- a/tests/test-rust-ancestor.py Sat Dec 07 14:55:42 2024 +0100 +++ b/tests/test-rust-ancestor.py Sat Dec 07 16:43:30 2024 +0100 @@ -172,12 +172,6 @@ idx = self.parserustindex() self.assertEqual(dagop.headrevs(idx, [1, 2, 3]), {3}) - -class RustCPythonAncestorsTest( - revlogtesting.RustRevlogBasedTestBase, RustAncestorsTestMixin -): - rustext_pkg = rustext - def testmissingancestors(self): MissingAncestors = self.ancestors_mod().MissingAncestors @@ -200,6 +194,12 @@ self.assertEqual(revs, {2, 3}) +class RustCPythonAncestorsTest( + revlogtesting.RustRevlogBasedTestBase, RustAncestorsTestMixin +): + rustext_pkg = rustext + + class PyO3AncestorsTest( revlogtesting.RustRevlogBasedTestBase, RustAncestorsTestMixin ):