Skip to content

Commit

Permalink
improve signature of ffi::PyIter_Send & add PyIterator::send
Browse files Browse the repository at this point in the history
  • Loading branch information
bschoenmaeckers committed Nov 29, 2024
1 parent 82ab509 commit 00708ec
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 3 deletions.
6 changes: 5 additions & 1 deletion pyo3-ffi/src/abstract_.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,11 @@ extern "C" {
pub fn PyIter_Next(arg1: *mut PyObject) -> *mut PyObject;
#[cfg(all(not(PyPy), Py_3_10))]
#[cfg_attr(PyPy, link_name = "PyPyIter_Send")]
pub fn PyIter_Send(iter: *mut PyObject, arg: *mut PyObject, presult: *mut *mut PyObject);
pub fn PyIter_Send(
iter: *mut PyObject,
arg: *mut PyObject,
presult: *mut *mut PyObject,
) -> PySendResult;

#[cfg_attr(PyPy, link_name = "PyPyNumber_Check")]
pub fn PyNumber_Check(o: *mut PyObject) -> c_int;
Expand Down
65 changes: 63 additions & 2 deletions src/types/iterator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,31 @@ impl PyIterator {
}
}

#[derive(Debug)]
pub enum PySendResult<'py> {
Next(Bound<'py, PyAny>),
Return(Bound<'py, PyAny>),
}

impl<'py> Bound<'py, PyIterator> {
/// Sends a value into the iterator.
#[inline]
#[cfg(all(not(PyPy), Py_3_10))]
pub fn send(&self, value: &Bound<'py, PyAny>) -> PyResult<PySendResult<'py>> {
let py = self.py();
let mut result = std::ptr::null_mut();
match unsafe { ffi::PyIter_Send(self.as_ptr(), value.as_ptr(), &mut result) } {
ffi::PySendResult::PYGEN_ERROR => Err(PyErr::fetch(py)),
ffi::PySendResult::PYGEN_RETURN => Ok(PySendResult::Return(unsafe {
result.assume_owned_unchecked(py)
})),
ffi::PySendResult::PYGEN_NEXT => Ok(PySendResult::Next(unsafe {
result.assume_owned_unchecked(py)
})),
}
}
}

impl<'py> Iterator for Bound<'py, PyIterator> {
type Item = PyResult<Bound<'py, PyAny>>;

Expand Down Expand Up @@ -105,9 +130,9 @@ impl PyTypeCheck for PyIterator {

#[cfg(test)]
mod tests {
use super::PyIterator;
use super::{PyIterator, PySendResult};
use crate::exceptions::PyTypeError;
use crate::types::{PyAnyMethods, PyDict, PyList, PyListMethods};
use crate::types::{PyAnyMethods, PyDict, PyList, PyListMethods, PyNone};
use crate::{ffi, IntoPyObject, Python};

#[test]
Expand Down Expand Up @@ -201,6 +226,42 @@ def fibonacci(target):
});
}

#[test]
#[cfg(all(not(PyPy), Py_3_10))]
fn send_generator() {
let generator = ffi::c_str!(
r#"
def gen():
value = None
while(True):
value = yield value
if value is None:
return
"#
);

Python::with_gil(|py| {
let context = PyDict::new(py);
py.run(generator, None, Some(&context)).unwrap();

let generator = py.eval(ffi::c_str!("gen()"), None, Some(&context)).unwrap();

let one = 1i32.into_pyobject(py).unwrap();
assert!(matches!(
generator.try_iter().unwrap().send(&PyNone::get(py)).unwrap(),
PySendResult::Next(value) if value.is_none()
));
assert!(matches!(
generator.try_iter().unwrap().send(&one).unwrap(),
PySendResult::Next(value) if value.is(&one)
));
assert!(matches!(
generator.try_iter().unwrap().send(&PyNone::get(py)).unwrap(),
PySendResult::Return(value) if value.is_none()
));
});
}

#[test]
fn fibonacci_generator_bound() {
use crate::types::any::PyAnyMethods;
Expand Down

0 comments on commit 00708ec

Please sign in to comment.