rust/hg-cpython/src/ref_sharing.rs
changeset 43423 945d4dba5e78
parent 43422 b9f791090211
child 43424 0836efe4967b
--- a/rust/hg-cpython/src/ref_sharing.rs	Sat Oct 12 19:10:51 2019 +0900
+++ b/rust/hg-cpython/src/ref_sharing.rs	Sat Oct 12 19:26:23 2019 +0900
@@ -25,6 +25,7 @@
 use crate::exceptions::AlreadyBorrowed;
 use cpython::{PyClone, PyObject, PyResult, Python};
 use std::cell::{Cell, Ref, RefCell, RefMut};
+use std::ops::{Deref, DerefMut};
 
 /// Manages the shared state between Python and Rust
 #[derive(Debug, Default)]
@@ -333,17 +334,29 @@
         }
     }
 
-    /// Returns an immutable reference to the inner value.
-    pub fn get_ref<'a>(&'a self, _py: Python<'a>) -> &'a T {
-        self.data.as_ref().unwrap()
+    /// Immutably borrows the wrapped value.
+    pub fn try_borrow<'a>(
+        &'a self,
+        py: Python<'a>,
+    ) -> PyResult<PyLeakedRef<'a, T>> {
+        Ok(PyLeakedRef {
+            _py: py,
+            data: self.data.as_ref().unwrap(),
+        })
     }
 
-    /// Returns a mutable reference to the inner value.
+    /// Mutably borrows the wrapped value.
     ///
     /// Typically `T` is an iterator. If `T` is an immutable reference,
     /// `get_mut()` is useless since the inner value can't be mutated.
-    pub fn get_mut<'a>(&'a mut self, _py: Python<'a>) -> &'a mut T {
-        self.data.as_mut().unwrap()
+    pub fn try_borrow_mut<'a>(
+        &'a mut self,
+        py: Python<'a>,
+    ) -> PyResult<PyLeakedRefMut<'a, T>> {
+        Ok(PyLeakedRefMut {
+            _py: py,
+            data: self.data.as_mut().unwrap(),
+        })
     }
 
     /// Converts the inner value by the given function.
@@ -389,6 +402,40 @@
     }
 }
 
+/// Immutably borrowed reference to a leaked value.
+pub struct PyLeakedRef<'a, T> {
+    _py: Python<'a>,
+    data: &'a T,
+}
+
+impl<T> Deref for PyLeakedRef<'_, T> {
+    type Target = T;
+
+    fn deref(&self) -> &T {
+        self.data
+    }
+}
+
+/// Mutably borrowed reference to a leaked value.
+pub struct PyLeakedRefMut<'a, T> {
+    _py: Python<'a>,
+    data: &'a mut T,
+}
+
+impl<T> Deref for PyLeakedRefMut<'_, T> {
+    type Target = T;
+
+    fn deref(&self) -> &T {
+        self.data
+    }
+}
+
+impl<T> DerefMut for PyLeakedRefMut<'_, T> {
+    fn deref_mut(&mut self) -> &mut T {
+        self.data
+    }
+}
+
 /// Defines a `py_class!` that acts as a Python iterator over a Rust iterator.
 ///
 /// TODO: this is a bit awkward to use, and a better (more complicated)
@@ -457,7 +504,8 @@
             def __next__(&self) -> PyResult<$success_type> {
                 let mut inner_opt = self.inner(py).borrow_mut();
                 if let Some(leaked) = inner_opt.as_mut() {
-                    match leaked.get_mut(py).next() {
+                    let mut iter = leaked.try_borrow_mut(py)?;
+                    match iter.next() {
                         None => {
                             // replace Some(inner) by None, drop $leaked
                             inner_opt.take();
@@ -512,6 +560,28 @@
     }
 
     #[test]
+    fn test_leaked_borrow() {
+        let (gil, owner) = prepare_env();
+        let py = gil.python();
+        let leaked = owner.string_shared(py).leak_immutable().unwrap();
+        let leaked_ref = leaked.try_borrow(py).unwrap();
+        assert_eq!(*leaked_ref, "new");
+    }
+
+    #[test]
+    fn test_leaked_borrow_mut() {
+        let (gil, owner) = prepare_env();
+        let py = gil.python();
+        let leaked = owner.string_shared(py).leak_immutable().unwrap();
+        let mut leaked_iter = unsafe { leaked.map(py, |s| s.chars()) };
+        let mut leaked_ref = leaked_iter.try_borrow_mut(py).unwrap();
+        assert_eq!(leaked_ref.next(), Some('n'));
+        assert_eq!(leaked_ref.next(), Some('e'));
+        assert_eq!(leaked_ref.next(), Some('w'));
+        assert_eq!(leaked_ref.next(), None);
+    }
+
+    #[test]
     fn test_borrow_mut_while_leaked() {
         let (gil, owner) = prepare_env();
         let py = gil.python();