Skip to content

Commit

Permalink
Update the pyo3 dependencies with the new bound api.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Nov 6, 2024
1 parent 965c370 commit 603250d
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 37 deletions.
4 changes: 2 additions & 2 deletions yomikomi-pyo3/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@ name = "yomikomi"
crate-type = ["cdylib"]

[dependencies]
numpy = "0.20.0"
pyo3 = "0.20.0"
numpy = "0.22.0"
pyo3 = "0.22.0"
yomikomi = { path = "../yomikomi" }
76 changes: 41 additions & 35 deletions yomikomi-pyo3/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use numpy::prelude::*;
use pyo3::prelude::*;
use std::collections::HashMap;
use std::sync::Arc;
Expand Down Expand Up @@ -76,7 +77,7 @@ impl Iterable for Warc {

struct Filter {
inner: Arc<dyn Iterable + 'static + Send + Sync>,
filter_fn: PyObject,
filter_fn: Arc<PyObject>,
field: Option<String>,
}

Expand All @@ -102,7 +103,7 @@ impl Iterable for Filter {
},
};
let value = filter_fn.call1(py, (value,)).map_err(w_py)?;
let value = value.is_true(py).map_err(w_py)?;
let value = value.is_truthy(py).map_err(w_py)?;
Ok(value)
})?;
Ok(value)
Expand All @@ -113,7 +114,7 @@ impl Iterable for Filter {

struct AndThen {
inner: Arc<dyn Iterable + 'static + Send + Sync>,
and_then_fn: PyObject,
and_then_fn: Arc<PyObject>,
}

impl Iterable for AndThen {
Expand All @@ -134,12 +135,13 @@ impl Iterable for AndThen {
if value.is_none(py) {
Ok(None)
} else {
let value = match value.as_ref(py).downcast::<pyo3::types::PyDict>() {
let value = match value.downcast_bound::<pyo3::types::PyDict>(py) {
Ok(value) => value,
Err(_) => {
let value = value.downcast_bound(py).map_err(Error::msg)?;
bail!(
"map-fn returned an object that is not a dict, {:?}",
value.as_ref(py).get_type()
value.get_type()
)
}
};
Expand All @@ -150,7 +152,7 @@ impl Iterable for AndThen {
Ok(str) => str.to_string_lossy().to_string(),
Err(_) => bail!("key is not a string, got {:?}", key.get_type()),
};
let value = py_to_array(value)?;
let value = py_to_array(py, value.as_unbound())?;
Ok((key, value))
})
.collect::<Result<HashMap<String, Array>>>()?;
Expand All @@ -165,7 +167,7 @@ impl Iterable for AndThen {

struct KeyTransform {
inner: Arc<dyn Iterable + 'static + Send + Sync>,
map_fn: PyObject,
map_fn: Arc<PyObject>,
field: String,
}

Expand All @@ -182,7 +184,7 @@ impl Iterable for KeyTransform {
let value = Python::with_gil(|py| {
let value = array_to_py(&value, py)?;
let value = map_fn.call1(py, (value,)).map_err(w_py)?;
py_to_array(value.as_ref(py))
py_to_array(py, &value)
})?;
sample.insert(field.to_string(), value);
Ok(Some(sample))
Expand Down Expand Up @@ -317,15 +319,15 @@ fn array_to_py(v: &Array, py: Python<'_>) -> Result<PyObject> {
}
1 => {
let v = v.to_vec1::<T>()?;
numpy::PyArray1::from_vec(py, v).into_py(py)
numpy::PyArray1::from_vec_bound(py, v).into_py(py)
}
2 => {
let v = v.to_vec2::<T>()?;
numpy::PyArray2::from_vec2(py, &v).map_err(Error::wrap)?.into_py(py)
numpy::PyArray2::from_vec2_bound(py, &v).map_err(Error::wrap)?.into_py(py)
}
3 => {
let v = v.to_vec3::<T>()?;
numpy::PyArray3::from_vec3(py, &v).map_err(Error::wrap)?.into_py(py)
numpy::PyArray3::from_vec3_bound(py, &v).map_err(Error::wrap)?.into_py(py)
}
r => bail!("unsupported rank for numpy conversion {r}"),
};
Expand Down Expand Up @@ -383,6 +385,7 @@ impl YkIterable {

#[pyo3(signature = (f, *, field))]
fn key_transform(&self, f: PyObject, field: String) -> PyResult<Self> {
let f = Arc::new(f);
let inner = KeyTransform { inner: self.inner.clone(), map_fn: f, field };
Ok(Self { inner: Arc::new(inner) })
}
Expand All @@ -392,12 +395,14 @@ impl YkIterable {
/// passed the value associated to this field rather than a whole dictionary.
#[pyo3(signature = (f, *, field=None))]
fn filter(&self, f: PyObject, field: Option<String>) -> PyResult<Self> {
let f = Arc::new(f);
let inner = Filter { inner: self.inner.clone(), filter_fn: f, field };
Ok(Self { inner: Arc::new(inner) })
}

#[pyo3(signature = (f))]
fn map(&self, f: PyObject) -> PyResult<Self> {
let f = Arc::new(f);
let inner = AndThen { inner: self.inner.clone(), and_then_fn: f };
Ok(Self { inner: Arc::new(inner) })
}
Expand Down Expand Up @@ -620,67 +625,67 @@ struct YkPyIterator {
field: Option<String>,
}

fn py_to_array(value: &PyAny) -> Result<Array> {
let py = value.py();
fn py_to_array(py: Python<'_>, value: &PyObject) -> Result<Array> {
// Be cautious in these conversions. Trying to `downcast_exact` on a numpy array of float32
// would work fine with a dtype of u8 but hold incorrect results. So instead we first extract
// the dtype and based on do the appropriate downcasting.
if let Ok(value) = value.downcast_exact::<numpy::PyUntypedArray>() {
if let Ok(value) = value.downcast_bound::<numpy::PyUntypedArray>(py) {
let dtype = value.dtype();
let shape = value.shape();
if dtype.is_equiv_to(numpy::dtype::<u8>(py)) {
if dtype.is_equiv_to(&numpy::dtype_bound::<u8>(py)) {
if let Ok(value) = value.downcast_exact::<numpy::PyArrayDyn<u8>>() {
let value = value.to_vec().map_err(Error::msg)?;
return Array::from(value).reshape(shape);
}
}
if dtype.is_equiv_to(numpy::dtype::<i8>(py)) {
if dtype.is_equiv_to(&numpy::dtype_bound::<i8>(py)) {
if let Ok(value) = value.downcast_exact::<numpy::PyArrayDyn<i8>>() {
let value = value.to_vec().map_err(Error::msg)?;
return Array::from(value).reshape(shape);
}
}
if dtype.is_equiv_to(numpy::dtype::<u32>(py)) {
if dtype.is_equiv_to(&numpy::dtype_bound::<u32>(py)) {
if let Ok(value) = value.downcast_exact::<numpy::PyArrayDyn<u32>>() {
let value = value.to_vec().map_err(Error::msg)?;
return Array::from(value).reshape(shape);
}
}
if dtype.is_equiv_to(numpy::dtype::<i64>(py)) {
if dtype.is_equiv_to(&numpy::dtype_bound::<i64>(py)) {
if let Ok(value) = value.downcast_exact::<numpy::PyArrayDyn<i64>>() {
let value = value.to_vec().map_err(Error::msg)?;
return Array::from(value).reshape(shape);
}
}
if dtype.is_equiv_to(numpy::dtype::<f32>(py)) {
if dtype.is_equiv_to(&numpy::dtype_bound::<f32>(py)) {
if let Ok(value) = value.downcast_exact::<numpy::PyArrayDyn<f32>>() {
let value = value.to_vec().map_err(Error::msg)?;
return Array::from(value).reshape(shape);
}
}
if dtype.is_equiv_to(numpy::dtype::<f64>(py)) {
if dtype.is_equiv_to(&numpy::dtype_bound::<f64>(py)) {
if let Ok(value) = value.downcast_exact::<numpy::PyArrayDyn<f64>>() {
let value = value.to_vec().map_err(Error::msg)?;
return Array::from(value).reshape(shape);
}
}
bail!("unsupported dtype for np.array {}", dtype)
}
if let Ok(value) = value.extract::<String>() {
if let Ok(value) = value.extract::<String>(py) {
return Ok(Array::from(value.into_bytes()));
}
if let Ok(value) = value.extract::<i64>() {
if let Ok(value) = value.extract::<i64>(py) {
return Ok(Array::from(value));
}
if let Ok(value) = value.extract::<f64>() {
if let Ok(value) = value.extract::<f64>(py) {
return Ok(Array::from(value));
}
if let Ok(value) = value.extract::<Vec<i64>>() {
if let Ok(value) = value.extract::<Vec<i64>>(py) {
return Ok(Array::from(value));
}
if let Ok(value) = value.extract::<Vec<f64>>() {
if let Ok(value) = value.extract::<Vec<f64>>(py) {
return Ok(Array::from(value));
}
let value = value.downcast_bound(py).map_err(Error::msg)?;
bail!("unsupported types for conversion to array {:?}", value.get_type())
}

Expand All @@ -696,31 +701,30 @@ impl Stream for YkPyIterator {
};
match &self.field {
None => {
let next = match next.as_ref(py).downcast::<pyo3::types::PyDict>() {
let next = match next.downcast_bound::<pyo3::types::PyDict>(py) {
Ok(next) => next,
Err(_) => {
bail!(
"iterator returned an object that is not a dict, {:?}",
next.as_ref(py).get_type()
)
let ty = next.into_bound(py).get_type();
bail!("iterator returned an object that is not a dict, {ty:?}",)
}
};
let next = next
.iter()
.map(|(key, value)| {
let value = value.as_unbound();
let key = match key.downcast::<pyo3::types::PyString>() {
Ok(str) => str.to_string_lossy().to_string(),
Err(_) => bail!("key is not a string, got {:?}", key.get_type()),
};
let value = py_to_array(value)?;
let value = py_to_array(py, value)?;
Ok((key, value))
})
.collect::<Result<HashMap<String, Array>>>()?;
Ok(Some(next))
}
Some(field) => {
let mut sample = HashMap::new();
let next = py_to_array(next.as_ref(py))?;
let next = py_to_array(py, &next)?;
sample.insert(field.clone(), next);
Ok(Some(sample))
}
Expand All @@ -736,8 +740,10 @@ struct YkPyIterable {

impl Iterable for YkPyIterable {
fn iter(&self) -> PyResult<StreamIter> {
let iterator =
Python::with_gil(|py| self.iterable.as_ref(py).iter().map(|v| v.to_object(py)))?;
let iterator = Python::with_gil(|py| {
let iterable = self.iterable.downcast_bound(py)?;
pyo3::types::PyAnyMethods::iter(iterable).map(|v| v.to_object(py))
})?;
let stream = YkPyIterator { iterator, field: self.field.clone() };
Ok(StreamIter { stream: Box::new(stream) })
}
Expand All @@ -754,7 +760,7 @@ fn stream(iterable: PyObject, field: Option<String>) -> PyResult<YkIterable> {
}

#[pymodule]
fn yomikomi(_py: Python, m: &PyModule) -> PyResult<()> {
fn yomikomi(_py: Python, m: &Bound<PyModule>) -> PyResult<()> {
m.add_class::<StreamIter>()?;
m.add_class::<JsonFilter>()?;
m.add_class::<YkIterable>()?;
Expand Down

0 comments on commit 603250d

Please sign in to comment.