--- /dev/null Thu Jan 01 00:00:00 1970 +0000
+++ b/rust/pyo3-sharedref/src/lib.rs Sun Dec 15 16:32:24 2024 +0100
@@ -0,0 +1,607 @@
+// Copyright (c) 2019 Raphaël Gomès <rgomes@octobus.net>,
+// Yuya Nishihara <yuya@tcha.org>
+// 2024 Georges Racinet <georges.racinet@cloudcrane.io>
+//
+// Permission is hereby granted, free of charge, to any person obtaining a copy
+// of this software and associated documentation files (the "Software"), to
+// deal in the Software without restriction, including without limitation the
+// rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+// sell copies of the Software, and to permit persons to whom the Software is
+// furnished to do so, subject to the following conditions:
+//
+// The above copyright notice and this permission notice shall be included in
+// all copies or substantial portions of the Software.
+//
+// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
+// IN THE SOFTWARE.
+
+//! Utility to share Rust reference across Python objects.
+
+use pyo3::exceptions::PyRuntimeError;
+use pyo3::prelude::*;
+
+use std::ops::{Deref, DerefMut};
+use std::sync::atomic::{AtomicUsize, Ordering};
+use std::sync::{
+ RwLock, RwLockReadGuard, RwLockWriteGuard, TryLockError, TryLockResult,
+};
+
+/// A mutable memory location shareable immutably across Python objects.
+///
+/// This data structure is meant to be used as a field in a Python class
+/// definition.
+/// It provides interior mutability in a way that allows it to be immutably
+/// referenced by other Python objects defined in Rust than its owner, in
+/// a more general form than references to the whole data.
+/// These immutable references are stored in the referencing Python objects as
+/// [`UnsafePyLeaked`] fields.
+///
+/// The primary use case is to implement a Python iterator over a Rust
+/// iterator: since a Python object cannot hold a lifetime-bound object,
+/// `Iter<'a, T>` cannot be a data field of the Python iterator object.
+/// While `&'a T` can be replaced with [`std::sync::Arc`], this is typically
+/// not suited for more complex objects that are created from such references
+/// and re-expose the lifetime on their types, such as iterators.
+/// The [`PySharedRef::leak_immutable()`] and [`UnsafePyLeaked::map()`] methods
+/// provide a way around this issue.
+///
+/// [`PySharedRefCell`] is [`Sync`]. It works internally with locks and
+/// a "generation" counter that keeps track of mutations.
+///
+/// [`PySharedRefCell`] is merely a data struct to be stored in its
+/// owner Python object.
+/// Any further operation will be performed through [`PySharedRef`], which is
+/// a lifetime-bound reference to the [`PySharedRefCell`].
+///
+/// # Example
+///
+/// ```
+/// use pyo3::prelude::*;
+/// use pyo3_sharedref::*;
+///
+/// use pyo3::ffi::c_str;
+/// use pyo3::types::PyDictMethods;
+/// use pyo3::types::{PyDict, PyTuple};
+/// use std::collections::{hash_set::Iter as IterHashSet, HashSet};
+/// use pyo3::exceptions::PyRuntimeError;
+/// use std::ffi::CStr;
+/// use std::vec::Vec;
+///
+/// #[pyclass(sequence)]
+/// struct Set {
+/// rust_set: PySharedRefCell<HashSet<i32>>,
+/// }
+///
+/// #[pymethods]
+/// impl Set {
+/// #[new]
+/// fn new(values: &Bound<'_, PyTuple>) -> PyResult<Self> {
+/// let as_vec = values.extract::<Vec<i32>>()?;
+/// let s: HashSet<_> = as_vec.iter().copied().collect();
+/// Ok(Self {
+/// rust_set: PySharedRefCell::new(s),
+/// })
+/// }
+///
+/// fn __iter__(slf: &Bound<'_, Self>) -> SetIterator {
+/// SetIterator::new(slf)
+/// }
+///
+/// fn add(slf: &Bound<'_, Self>, i: i32) -> PyResult<()> {
+/// let rust_set = &slf.borrow().rust_set;
+/// let shared_ref = unsafe { rust_set.borrow(slf) };
+/// let mut set_ref = shared_ref.borrow_mut();
+/// set_ref.insert(i);
+/// Ok(())
+/// }
+/// }
+///
+/// #[pyclass]
+/// struct SetIterator {
+/// rust_iter: UnsafePyLeaked<IterHashSet<'static, i32>>,
+/// }
+///
+/// #[pymethods]
+/// impl SetIterator {
+/// #[new]
+/// fn new(s: &Bound<'_, Set>) -> Self {
+/// let py = s.py();
+/// let rust_set = &s.borrow().rust_set;
+/// let shared_ref = unsafe { rust_set.borrow(s) };
+/// let leaked_set = shared_ref.leak_immutable();
+/// let iter = unsafe { leaked_set.map(py, |o| o.iter()) };
+/// Self {
+/// rust_iter: iter.into(),
+/// }
+/// }
+///
+/// fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
+/// slf
+/// }
+///
+/// fn __next__(mut slf: PyRefMut<'_, Self>) -> PyResult<Option<i32>> {
+/// let py = slf.py();
+/// let leaked = &mut slf.rust_iter;
+/// let mut inner = unsafe { leaked.try_borrow_mut(py) }?;
+/// Ok(inner.next().copied())
+/// }
+/// }
+///
+/// /// a shortcut similar to `[pyo3::py_run!]`, allowing inspection of PyErr
+/// fn py_run(statement: &CStr, locals: &Bound<'_, PyDict>) -> PyResult<()> {
+/// locals.py().run(statement, None, Some(locals))
+/// }
+///
+/// # pyo3::prepare_freethreaded_python();
+/// Python::with_gil(|py| {
+/// let tuple = PyTuple::new(py, vec![2, 1, 2])?;
+/// let set = Bound::new(py, Set::new(&tuple)?)?;
+/// let iter = Bound::new(py, Set::__iter__(&set))?;
+/// let locals = PyDict::new(py);
+/// locals.set_item("rust_set", set).unwrap();
+/// locals.set_item("rust_iter", iter).unwrap();
+///
+/// /// iterating on our Rust set just works
+/// py_run(
+/// c_str!("assert sorted(i for i in rust_iter) == [1, 2]"),
+/// &locals,
+/// )?;
+///
+/// /// however, if any mutation occurs on the Rust set, the iterator
+/// /// becomes invalid. Attempts to use it raise `RuntimeError`.
+/// py_run(c_str!("rust_set.add(3)"), &locals)?;
+/// let err = py_run(c_str!("next(rust_iter)"), &locals).unwrap_err();
+///
+/// let exc_repr = format!("{:?}", err.value(py));
+/// assert_eq!(
+/// exc_repr,
+/// "RuntimeError('Cannot access to leaked reference after mutation')"
+/// );
+/// # Ok::<(), PyErr>(())
+/// })
+/// # .expect("This example should not return an error");
+/// ```
+///
+/// The borrow rules are enforced dynamically in a similar manner to the
+/// Python iterator.
+#[derive(Debug)]
+pub struct PySharedRefCell<T: ?Sized> {
+ state: PySharedState,
+ data: RwLock<T>,
+}
+
+impl<T> PySharedRefCell<T> {
+ /// Creates a new `PySharedRefCell` containing `value`.
+ pub fn new(value: T) -> PySharedRefCell<T> {
+ Self {
+ state: PySharedState::new(),
+ data: value.into(),
+ }
+ }
+
+ /// Borrows the shared data and its state, keeping a reference
+ /// on the owner Python object.
+ ///
+ /// # Safety
+ ///
+ /// The `data` must be owned by the `owner`. Otherwise, calling
+ /// `leak_immutable()` on the shared ref would create an invalid reference.
+ pub unsafe fn borrow<'py>(
+ &'py self,
+ owner: &'py Bound<'py, PyAny>,
+ ) -> PySharedRef<'py, T> {
+ PySharedRef {
+ owner,
+ state: &self.state,
+ data: &self.data,
+ }
+ }
+}
+
+/// Errors that can happen in `leak_immutable()`
+#[derive(Debug, PartialEq, Eq)]
+pub enum TryLeakError {
+ /// The inner lock is poisoned and we do not want to implement recovery
+ InnerLockPoisoned,
+ /// The inner lock would block and we are expecting to take it immediately
+ InnerLockWouldBlock,
+}
+
+impl<T> From<TryLockError<T>> for TryLeakError {
+ fn from(e: TryLockError<T>) -> Self {
+ match e {
+ TryLockError::Poisoned(_) => Self::InnerLockPoisoned,
+ TryLockError::WouldBlock => Self::InnerLockWouldBlock,
+ }
+ }
+}
+
+/// A reference to [`PySharedRefCell`] owned by a Python object.
+///
+/// This is a lifetime-bound reference to the [`PySharedRefCell`] data field.
+pub struct PySharedRef<'py, T: 'py + ?Sized> {
+ owner: &'py Bound<'py, PyAny>,
+ state: &'py PySharedState,
+ data: &'py RwLock<T>, // TODO perhaps this needs Pin
+}
+
+impl<'py, T: ?Sized> PySharedRef<'py, T> {
+ /// Creates a reference to the given `PySharedRefCell` owned by the
+ /// given `PyObject`.
+ ///
+ /// # Safety
+ ///
+ /// The `data` must be owned by the `owner`. Otherwise, `leak_immutable()`
+ /// would create an invalid reference.
+ #[doc(hidden)]
+ pub unsafe fn new(
+ owner: &'py Bound<'py, PyAny>,
+ data: &'py PySharedRefCell<T>,
+ ) -> Self {
+ Self {
+ owner,
+ state: &data.state,
+ data: &data.data,
+ }
+ }
+
+ /// Immutably borrows the wrapped value.
+ ///
+ /// # Panics
+ ///
+ /// Panics if the value is currently mutably borrowed.
+ pub fn borrow(&self) -> RwLockReadGuard<'py, T> {
+ self.try_borrow().expect("already mutably borrowed")
+ }
+
+ /// Immutably borrows the wrapped value, returning an error if the value
+ /// is currently mutably borrowed.
+ pub fn try_borrow(&self) -> TryLockResult<RwLockReadGuard<'py, T>> {
+ // state isn't involved since
+ // - data.try_borrow() would fail if self is mutably borrowed,
+ // - and data.try_borrow_mut() would fail while self is borrowed.
+ self.data.try_read()
+ }
+
+ /// Mutably borrows the wrapped value.
+ ///
+ /// Any existing leaked references will be invalidated.
+ ///
+ /// # Panics
+ ///
+ /// Panics if the value is currently borrowed.
+ pub fn borrow_mut(&self) -> RwLockWriteGuard<'py, T> {
+ self.try_borrow_mut().expect("already borrowed")
+ }
+
+ /// Mutably borrows the wrapped value, returning an error if the value
+ /// is currently borrowed.
+ pub fn try_borrow_mut(&self) -> TryLockResult<RwLockWriteGuard<'py, T>> {
+ // the value may be immutably borrowed through UnsafePyLeaked
+ if self.state.current_borrow_count(self.py()) > 0 {
+ // propagate borrow-by-leaked state to data to get BorrowMutError
+ let _dummy = self.data.read();
+ let _unused = self.data.try_write()?;
+ unreachable!("BorrowMutError should have been returned");
+ }
+
+ let data_ref = self.data.try_write()?;
+ self.state.increment_generation(self.py());
+ Ok(data_ref)
+ }
+
+ /// Creates an immutable reference which is not bound to lifetime.
+ ///
+ /// # Panics
+ ///
+ /// Panics if the value is currently mutably borrowed.
+ pub fn leak_immutable(&self) -> UnsafePyLeaked<&'static T> {
+ self.try_leak_immutable().expect("already mutably borrowed")
+ }
+
+ /// Creates an immutable reference which is not bound to lifetime,
+ /// returning an error if the value is currently mutably borrowed.
+ pub fn try_leak_immutable(
+ &self,
+ ) -> Result<UnsafePyLeaked<&'static T>, TryLeakError> {
+ // make sure self.data isn't mutably borrowed; otherwise the
+ // generation number wouldn't be trusted.
+ let data_ref = self.try_borrow()?;
+
+ // keep reference to the owner so the data and state are alive,
+ // but the data pointer can be invalidated by borrow_mut().
+ // the state wouldn't since it is immutable.
+ let state_ptr: *const PySharedState = self.state;
+ let data_ptr: *const T = &*data_ref;
+ Ok(UnsafePyLeaked::<&'static T> {
+ owner: self.owner.clone().unbind(),
+ state: unsafe { &*state_ptr },
+ generation: self.state.current_generation(self.py()),
+ data: unsafe { &*data_ptr },
+ })
+ }
+
+ /// Retrieve the GIL handle
+ fn py(&self) -> Python<'py> {
+ // Since this is a smart pointer implying the GIL lifetime,
+ // we might as well use `assume_gil_acquired`, but the method
+ // of `Bound` does it for us.
+ self.owner.py()
+ }
+}
+
+/// The shared state between Python and Rust
+///
+/// `PySharedState` is owned by `PySharedRefCell`, and is shared across its
+/// derived references. The consistency of these references are guaranteed
+/// as follows:
+///
+/// - The immutability of `PycCass` object fields. Any mutation of
+/// [`PySharedRefCell`] is allowed only through its `borrow_mut()`.
+/// - The `py: Python<'_>` token, which makes sure that any data access is
+/// synchronized by the GIL.
+/// - The underlying `RefCell`, which prevents `PySharedRefCell` value from
+/// being directly borrowed or leaked while it is mutably borrowed.
+/// - The `borrow_count`, which is the number of references borrowed from
+/// `UnsafePyLeaked`. Just like `RefCell`, mutation is prohibited while
+/// `UnsafePyLeaked` is borrowed.
+/// - The `generation` counter, which increments on `borrow_mut()`.
+/// `UnsafePyLeaked` reference is valid only if the `current_generation()`
+/// equals to the `generation` at the time of `leak_immutable()`.
+#[derive(Debug)]
+struct PySharedState {
+ // The counter variable could be Cell<usize> since any operation on
+ // PySharedState is synchronized by the GIL, but being "atomic" makes
+ // PySharedState inherently Sync. The ordering requirement doesn't
+ // matter thanks to the GIL. That's why Ordering::Relaxed is used
+ // everywhere.
+ /// The number of immutable references borrowed through leaked reference.
+ borrow_count: AtomicUsize,
+ /// The mutation counter of the underlying value.
+ generation: AtomicUsize,
+}
+
+impl PySharedState {
+ const fn new() -> PySharedState {
+ PySharedState {
+ borrow_count: AtomicUsize::new(0),
+ generation: AtomicUsize::new(0),
+ }
+ }
+
+ fn current_borrow_count(&self, _py: Python) -> usize {
+ self.borrow_count.load(Ordering::Relaxed)
+ }
+
+ fn increase_borrow_count(&self, _py: Python) {
+ // this wraps around if there are more than usize::MAX borrowed
+ // references, which shouldn't happen due to memory limit.
+ self.borrow_count.fetch_add(1, Ordering::Relaxed);
+ }
+
+ fn decrease_borrow_count(&self, _py: Python) {
+ let prev_count = self.borrow_count.fetch_sub(1, Ordering::Relaxed);
+ assert!(prev_count > 0);
+ }
+
+ fn current_generation(&self, _py: Python) -> usize {
+ self.generation.load(Ordering::Relaxed)
+ }
+
+ fn increment_generation(&self, py: Python) {
+ assert_eq!(self.current_borrow_count(py), 0);
+ // this wraps around to the same value if mutably borrowed
+ // usize::MAX times, which wouldn't happen in practice.
+ self.generation.fetch_add(1, Ordering::Relaxed);
+ }
+}
+
+/// Helper to keep the borrow count updated while the shared object is
+/// immutably borrowed without using the `RefCell` interface.
+struct BorrowPyShared<'a> {
+ py: Python<'a>,
+ state: &'a PySharedState,
+}
+
+impl<'a> BorrowPyShared<'a> {
+ fn new(py: Python<'a>, state: &'a PySharedState) -> BorrowPyShared<'a> {
+ state.increase_borrow_count(py);
+ BorrowPyShared { py, state }
+ }
+}
+
+impl<'a> Drop for BorrowPyShared<'a> {
+ fn drop(&mut self) {
+ self.state.decrease_borrow_count(self.py);
+ }
+}
+
+/// An immutable reference to [`PySharedRefCell`] value, not bound to lifetime.
+///
+/// The reference will be invalidated once the original value is mutably
+/// borrowed.
+///
+/// # Safety
+///
+/// Even though [`UnsafePyLeaked`] tries to enforce the real lifetime of the
+/// underlying object, the object having the artificial `'static` lifetime
+/// may be exposed to your Rust code. You must be careful to not make a bare
+/// reference outlive the actual object lifetime.
+///
+/// TODO these two examples would not compile if [`UnsafePyLeaked::map()`]
+/// would only accept [`Fn`] instead of [`FnOnce`].
+///
+/// ```ignore
+/// let outer;
+/// unsafe { leaked.map(py, |o| { outer = o }) }; // Bad
+/// ```
+///
+/// ```ignore
+/// let outer;
+/// let mut leaked_iter = leaked.map(py, |o| o.iter());
+/// {
+/// let mut iter = unsafe { leaked_iter.try_borrow_mut(py) };
+/// let inner = iter.next(); // Good, in borrow scope
+/// outer = inner; // Bad, &'static T may outlive
+/// }
+/// ```
+pub struct UnsafePyLeaked<T: ?Sized> {
+ owner: PyObject,
+ state: &'static PySharedState,
+ /// Generation counter of data `T` captured when UnsafePyLeaked is
+ /// created.
+ generation: usize,
+ /// Underlying data of artificial lifetime, which is valid only when
+ /// state.generation == self.generation.
+ data: T,
+}
+
+// DO NOT implement Deref for UnsafePyLeaked<T>! Dereferencing UnsafePyLeaked
+// without taking Python GIL wouldn't be safe. Also, the underling reference
+// is invalid if generation != state.generation.
+
+impl<T: ?Sized> UnsafePyLeaked<T> {
+ // No panicking version of borrow() and borrow_mut() are implemented
+ // because the underlying value is supposed to be mutated in Python
+ // world, and the Rust library designer can't prevent it.
+
+ // try_borrow() and try_borrow_mut() are unsafe because self.data may
+ // have a function returning the inner &'static reference.
+ // If T is &'static U, its lifetime can be easily coerced to &'a U, but
+ // how could we do that for Whatever<'static> in general?
+
+ /// Immutably borrows the wrapped value.
+ ///
+ /// Borrowing fails if the underlying reference has been invalidated.
+ ///
+ /// # Safety
+ ///
+ /// The lifetime of the innermost object is artificial. Do not obtain and
+ /// copy it out of the borrow scope.
+ pub unsafe fn try_borrow<'a>(
+ &'a self,
+ py: Python<'a>,
+ ) -> PyResult<PyLeakedRef<'a, T>> {
+ self.validate_generation(py)?;
+ Ok(PyLeakedRef {
+ _borrow: BorrowPyShared::new(py, self.state),
+ data: &self.data,
+ })
+ }
+
+ /// Mutably borrows the wrapped value.
+ ///
+ /// Borrowing fails if the underlying reference has been invalidated.
+ ///
+ /// Typically `T` is an iterator. If `T` is an immutable reference,
+ /// `get_mut()` is useless since the inner value can't be mutated.
+ ///
+ /// # Safety
+ ///
+ /// The lifetime of the innermost object is artificial. Do not obtain and
+ /// copy it out of the borrow scope.
+ pub unsafe fn try_borrow_mut<'a>(
+ &'a mut self,
+ py: Python<'a>,
+ ) -> PyResult<PyLeakedRefMut<'a, T>> {
+ self.validate_generation(py)?;
+ Ok(PyLeakedRefMut {
+ _borrow: BorrowPyShared::new(py, self.state),
+ data: &mut self.data,
+ })
+ }
+
+ fn validate_generation(&self, py: Python) -> PyResult<()> {
+ if self.state.current_generation(py) == self.generation {
+ Ok(())
+ } else {
+ Err(PyRuntimeError::new_err(
+ "Cannot access to leaked reference after mutation",
+ ))
+ }
+ }
+}
+
+impl<T> UnsafePyLeaked<T> {
+ /// Converts the inner value by the given function.
+ ///
+ /// Typically `T` is a static reference to a collection, and `U` is an
+ /// iterator of that collection.
+ ///
+ /// # Panics
+ ///
+ /// Panics if the underlying reference has been invalidated.
+ ///
+ /// This is typically called immediately after the `UnsafePyLeaked` is
+ /// obtained. At this time, the reference must be valid and no panic
+ /// would occur.
+ ///
+ /// # Safety
+ ///
+ /// The lifetime of the object passed in to the function `f` is artificial.
+ /// It's typically a static reference, but is valid only while the
+ /// corresponding `UnsafePyLeaked` is alive. Do not copy it out of the
+ /// function call.
+ /// TODO would it be safe with `f: impl Fn(T) -> U` then?
+ pub unsafe fn map<U>(
+ self,
+ py: Python,
+ f: impl FnOnce(T) -> U,
+ ) -> UnsafePyLeaked<U> {
+ // Needs to test the generation value to make sure self.data reference
+ // is still intact.
+ self.validate_generation(py)
+ .expect("map() over invalidated leaked reference");
+
+ // f() could make the self.data outlive. That's why map() is unsafe.
+ // In order to make this function safe, maybe we'll need a way to
+ // temporarily restrict the lifetime of self.data and translate the
+ // returned object back to Something<'static>.
+ let new_data = f(self.data);
+ UnsafePyLeaked {
+ owner: self.owner,
+ state: self.state,
+ generation: self.generation,
+ data: new_data,
+ }
+ }
+}
+
+/// An immutably borrowed reference to a leaked value.
+pub struct PyLeakedRef<'a, T: 'a + ?Sized> {
+ _borrow: BorrowPyShared<'a>,
+ data: &'a T,
+}
+
+impl<'a, T: ?Sized> Deref for PyLeakedRef<'a, T> {
+ type Target = T;
+
+ fn deref(&self) -> &T {
+ self.data
+ }
+}
+
+/// A mutably borrowed reference to a leaked value.
+pub struct PyLeakedRefMut<'a, T: 'a + ?Sized> {
+ _borrow: BorrowPyShared<'a>,
+ data: &'a mut T,
+}
+
+impl<'a, T: ?Sized> Deref for PyLeakedRefMut<'a, T> {
+ type Target = T;
+
+ fn deref(&self) -> &T {
+ self.data
+ }
+}
+
+impl<'a, T: ?Sized> DerefMut for PyLeakedRefMut<'a, T> {
+ fn deref_mut(&mut self) -> &mut T {
+ self.data
+ }
+}