changeset 52556:bd65cb043aa5

rust-pyo3: error propagation for UnsafePyLeaked wrapping a result This `py_leaked_or_map_err` is the PyO3 version of the function of the same name in `hg-cpython/src/ancestors.rs`.
author Georges Racinet <georges.racinet@cloudcrane.io>
date Sat, 07 Dec 2024 18:18:09 +0100
parents 1dd673c1ab3b
children 736551565871
files rust/hg-pyo3/src/convert_cpython.rs
diffstat 1 files changed, 38 insertions(+), 0 deletions(-) [+]
line wrap: on
line diff
--- a/rust/hg-pyo3/src/convert_cpython.rs	Mon Dec 09 09:38:57 2024 +0100
+++ b/rust/hg-pyo3/src/convert_cpython.rs	Sat Dec 07 18:18:09 2024 +0100
@@ -212,3 +212,41 @@
     };
     Ok(py_shared.inner)
 }
+
+/// Error propagation for an [`UnsafePyLeaked`] wrapping a [`Result`]
+///
+/// TODO (will consider when implementing UnsafePyLeaked in PyO3):
+/// It would be nice for UnsafePyLeaked to provide this directly as a variant
+/// of the `map` method with a signature such as:
+///
+/// ```
+///   unsafe fn map_or_err(&self,
+///                        py: Python,
+///                        f: impl FnOnce(T) -> Result(U, E),
+///                        convert_err: impl FnOnce(E) -> PyErr)
+/// ```
+///
+/// This would spare users of the `cpython` crate the additional `unsafe` deref
+/// to inspect the error and return it outside `UnsafePyLeaked`, and the
+/// subsequent unwrapping that this function performs.
+#[allow(dead_code)]
+pub(crate) fn py_leaked_or_map_err<T, E: std::fmt::Debug + Copy>(
+    py: cpython::Python,
+    leaked: cpython::UnsafePyLeaked<Result<T, E>>,
+    convert_err: impl FnOnce(E) -> PyErr,
+) -> PyResult<cpython::UnsafePyLeaked<T>> {
+    // Safety: we don't leak the "faked" reference out of `UnsafePyLeaked`
+    if let Err(e) = *unsafe {
+        leaked
+            .try_borrow(py)
+            .map_err(|e| from_cpython_pyerr(py, e))?
+    } {
+        return Err(convert_err(e));
+    }
+    // Safety: we don't leak the "faked" reference out of `UnsafePyLeaked`
+    Ok(unsafe {
+        leaked.map(py, |res| {
+            res.expect("Error case should have already be treated")
+        })
+    })
+}