rust/hg-pyo3/src/revlog/mod.rs
changeset 52799 798355e46d03
parent 52798 88d62995a65b
child 52800 ebcbd2b7a3b6
--- a/rust/hg-pyo3/src/revlog/mod.rs	Mon Dec 23 20:06:56 2024 +0100
+++ b/rust/hg-pyo3/src/revlog/mod.rs	Mon Dec 23 20:44:26 2024 +0100
@@ -11,7 +11,7 @@
 use pyo3::conversion::IntoPyObject;
 use pyo3::exceptions::{PyIndexError, PyTypeError, PyValueError};
 use pyo3::prelude::*;
-use pyo3::types::{PyBytes, PyBytesMethods, PyList, PyTuple};
+use pyo3::types::{PyBytes, PyBytesMethods, PyDict, PyList, PySet, PyTuple};
 use pyo3_sharedref::{PyShareable, SharedByPyObject};
 
 use std::collections::HashSet;
@@ -21,12 +21,13 @@
 };
 
 use hg::{
+    errors::HgError,
     revlog::{
-        index::{Index, RevisionDataParams},
+        index::{Index, RevisionDataParams, SnapshotsCache, INDEX_ENTRY_SIZE},
         inner_revlog::InnerRevlog as CoreInnerRevlog,
         nodemap::{NodeMap, NodeMapError, NodeTree as CoreNodeTree},
         options::RevlogOpenOptions,
-        RevlogIndex, RevlogType,
+        RevlogError, RevlogIndex, RevlogType,
     },
     utils::files::get_path_from_bytes,
     vfs::FnCacheVfs,
@@ -87,6 +88,38 @@
     }
 }
 
+struct PySnapshotsCache<'a, 'py: 'a>(&'a Bound<'py, PyDict>);
+
+impl<'a, 'py> PySnapshotsCache<'a, 'py> {
+    fn insert_for_with_py_result(
+        &self,
+        rev: BaseRevision,
+        value: BaseRevision,
+    ) -> PyResult<()> {
+        match self.0.get_item(rev)? {
+            Some(obj) => obj.downcast::<PySet>()?.add(value),
+            None => {
+                let set = PySet::new(self.0.py(), vec![value])?;
+                self.0.set_item(rev, set)
+            }
+        }
+    }
+}
+
+impl<'a, 'py> SnapshotsCache for PySnapshotsCache<'a, 'py> {
+    fn insert_for(
+        &mut self,
+        rev: BaseRevision,
+        value: BaseRevision,
+    ) -> Result<(), RevlogError> {
+        self.insert_for_with_py_result(rev, value).map_err(|_| {
+            RevlogError::Other(HgError::unsupported(
+                "Error in Python caches handling",
+            ))
+        })
+    }
+}
+
 #[pyclass]
 #[allow(dead_code)]
 struct InnerRevlog {
@@ -436,12 +469,27 @@
         let rev = Self::with_index_read(slf, |idx| {
             idx.check_revision(rev).ok_or_else(|| rev_not_in_index(rev))
         })?;
-        Self::with_core_read(slf, |irl| {
+        Self::with_core_read(slf, |_self_ref, irl| {
             irl.is_snapshot(rev)
                 .map_err(|e| PyValueError::new_err(e.to_string()))
         })
     }
 
+    /// Gather snapshot data in a cache dict
+    fn _index_findsnapshots(
+        slf: &Bound<'_, Self>,
+        cache: &Bound<'_, PyDict>,
+        start_rev: PyRevision,
+        end_rev: PyRevision,
+    ) -> PyResult<()> {
+        let mut cache = PySnapshotsCache(cache);
+        Self::with_index_read(slf, |idx| {
+            idx.find_snapshots(start_rev.into(), end_rev.into(), &mut cache)
+                .map_err(|_| revlog_error_bare())
+        })?;
+        Ok(())
+    }
+
     fn _index___len__(slf: &Bound<'_, Self>) -> PyResult<usize> {
         Self::with_index_read(slf, |idx| Ok(idx.len()))
     }