changeset 52835:4f41a8acf350

rust-pyo3: add a `with_pybytes_buffer` util This is very similar to the one we have in `hg-cpython`, and serves the same purpose. Explanations inline.
author Rapha?l Gom?s <rgomes@octobus.net>
date Fri, 03 Jan 2025 12:43:52 +0100
parents d90a78ca0bdd
children 9435a212a773
files rust/hg-pyo3/src/util.rs
diffstat 1 files changed, 100 insertions(+), 1 deletions(-) [+]
line wrap: on
line diff
--- a/rust/hg-pyo3/src/util.rs	Fri Jan 03 01:05:22 2025 +0100
+++ b/rust/hg-pyo3/src/util.rs	Fri Jan 03 12:43:52 2025 +0100
@@ -1,7 +1,10 @@
+use hg::errors::HgError;
+use hg::revlog::inner_revlog::RevisionBuffer;
 use pyo3::buffer::{Element, PyBuffer};
 use pyo3::exceptions::PyValueError;
 use pyo3::prelude::*;
-use pyo3::types::PyDict;
+use pyo3::types::{PyBytes, PyDict};
+
 /// Create the module, with `__package__` given from parent
 ///
 /// According to PyO3 documentation, which links to
@@ -77,3 +80,99 @@
 
     Ok((buf, Box::new(bytes)))
 }
+
+/// Takes an initialization function `init` which writes bytes to a
+/// Python-backed buffer, to save on a (potentially large) memory allocation
+/// and copy. If `init` fails to write the full expected length `len`, an error
+/// is raised.
+pub fn with_pybytes_buffer<F>(
+    py: Python,
+    len: usize,
+    init: F,
+) -> Result<Py<PyBytes>, hg::revlog::RevlogError>
+where
+    F: FnOnce(
+        &mut dyn RevisionBuffer<Target = Py<PyBytes>>,
+    ) -> Result<(), hg::revlog::RevlogError>,
+{
+    // Largely inspired by code in PyO3
+    // https://pyo3.rs/main/doc/pyo3/types/struct.pybytes#method.new_bound_with
+    unsafe {
+        let pyptr = pyo3::ffi::PyBytes_FromStringAndSize(
+            std::ptr::null(),
+            len as pyo3::ffi::Py_ssize_t,
+        );
+        let pybytes = Bound::from_owned_ptr_or_err(py, pyptr)
+            .map_err(|e| HgError::abort_simple(e.to_string()))?
+            .downcast_into_unchecked();
+        let buffer: *mut u8 = pyo3::ffi::PyBytes_AsString(pyptr).cast();
+        debug_assert!(!buffer.is_null());
+        let mut rev_buf = PyRevisionBuffer::new(pybytes.unbind(), buffer, len);
+        // Initialise the bytestring in init
+        // If init returns an Err, the buffer is deallocated by `pybytes`
+        init(&mut rev_buf).map(|_| rev_buf.finish())
+    }
+}
+
+/// Wrapper around a Python-provided buffer into which the revision contents
+/// will be written. Done for speed in order to save a large allocation + copy.
+struct PyRevisionBuffer {
+    py_bytes: Py<PyBytes>,
+    _buf: *mut u8,
+    len: usize,
+    current_buf: *mut u8,
+    current_len: usize,
+}
+
+impl PyRevisionBuffer {
+    /// # Safety
+    ///
+    /// `buf` should be the start of the allocated bytes of `bytes`, and `len`
+    /// exactly the length of said allocated bytes.
+    #[inline]
+    unsafe fn new(bytes: Py<PyBytes>, buf: *mut u8, len: usize) -> Self {
+        Self {
+            py_bytes: bytes,
+            _buf: buf,
+            len,
+            current_len: 0,
+            current_buf: buf,
+        }
+    }
+
+    /// Number of bytes that have been copied to. Will be different to the
+    /// total allocated length of the buffer unless the revision is done being
+    /// written.
+    #[inline]
+    fn current_len(&self) -> usize {
+        self.current_len
+    }
+}
+
+impl RevisionBuffer for PyRevisionBuffer {
+    type Target = Py<PyBytes>;
+
+    #[inline]
+    fn extend_from_slice(&mut self, slice: &[u8]) {
+        assert!(self.current_len + slice.len() <= self.len);
+        unsafe {
+            // We cannot use `copy_from_nonoverlapping` since it's *possible*
+            // to create a slice from the same Python memory region using
+            // [`PyBytesDeref`]. Probable that LLVM has an optimization anyway?
+            self.current_buf.copy_from(slice.as_ptr(), slice.len());
+            self.current_buf = self.current_buf.add(slice.len());
+        }
+        self.current_len += slice.len()
+    }
+
+    #[inline]
+    fn finish(self) -> Self::Target {
+        // catch unzeroed bytes before it becomes undefined behavior
+        assert_eq!(
+            self.current_len(),
+            self.len,
+            "not enough bytes read for revision"
+        );
+        self.py_bytes
+    }
+}