Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Pyo3 bounds #472

Merged
merged 5 commits into from
Jun 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 82 additions & 58 deletions Cargo.lock

Large diffs are not rendered by default.

12 changes: 6 additions & 6 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,12 @@ dora-coordinator = { version = "0.3.4", path = "binaries/coordinator" }
dora-ros2-bridge = { path = "libraries/extensions/ros2-bridge" }
dora-ros2-bridge-msg-gen = { path = "libraries/extensions/ros2-bridge/msg-gen" }
dora-ros2-bridge-python = { path = "libraries/extensions/ros2-bridge/python" }
arrow = "48.0.0"
arrow-schema = "48.0.0"
arrow-data = "48.0.0"
arrow-array = "48.0.0"
pyo3 = "0.20.0"
pythonize = "0.20.0"
arrow = { version = "52" }
arrow-schema = { version = "52" }
arrow-data = { version = "52" }
arrow-array = { version = "52" }
pyo3 = "0.21"
pythonize = "0.21"

[package]
name = "dora-examples"
Expand Down
13 changes: 7 additions & 6 deletions apis/python/node/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,17 +122,17 @@ impl Node {
&mut self,
output_id: String,
data: PyObject,
metadata: Option<&PyDict>,
metadata: Option<Bound<'_, PyDict>>,
py: Python,
) -> eyre::Result<()> {
let parameters = pydict_to_metadata(metadata)?;

if let Ok(py_bytes) = data.downcast::<PyBytes>(py) {
if let Ok(py_bytes) = data.downcast_bound::<PyBytes>(py) {
let data = py_bytes.as_bytes();
self.node
.send_output_bytes(output_id.into(), parameters, data.len(), data)
.wrap_err("failed to send output")?;
} else if let Ok(arrow_array) = arrow::array::ArrayData::from_pyarrow(data.as_ref(py)) {
} else if let Ok(arrow_array) = arrow::array::ArrayData::from_pyarrow_bound(data.bind(py)) {
self.node.send_output(
output_id.into(),
parameters,
Expand Down Expand Up @@ -251,9 +251,10 @@ pub fn start_runtime() -> eyre::Result<()> {
}

#[pymodule]
fn dora(_py: Python, m: &PyModule) -> PyResult<()> {
dora_ros2_bridge_python::create_dora_ros2_bridge_module(m)?;
m.add_function(wrap_pyfunction!(start_runtime, m)?)?;
fn dora(_py: Python, m: Bound<'_, PyModule>) -> PyResult<()> {
dora_ros2_bridge_python::create_dora_ros2_bridge_module(&m)?;

m.add_function(wrap_pyfunction!(start_runtime, &m)?)?;
m.add_class::<Node>()?;
m.add_class::<PyEvent>()?;
m.setattr("__version__", env!("CARGO_PKG_VERSION"))?;
Expand Down
16 changes: 10 additions & 6 deletions apis/python/operator/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use arrow::{array::ArrayRef, pyarrow::ToPyArrow};
use dora_node_api::{merged::MergedEvent, Event, Metadata, MetadataParameters};
use eyre::{Context, Result};
use pyo3::{exceptions::PyLookupError, prelude::*, types::PyDict};
use pyo3::{exceptions::PyLookupError, prelude::*, pybacked::PyBackedStr, types::PyDict};

/// Dora Event
#[pyclass]
Expand Down Expand Up @@ -126,11 +126,15 @@ impl From<MergedEvent<PyObject>> for PyEvent {
}
}

pub fn pydict_to_metadata(dict: Option<&PyDict>) -> Result<MetadataParameters> {
pub fn pydict_to_metadata(dict: Option<Bound<'_, PyDict>>) -> Result<MetadataParameters> {
let mut default_metadata = MetadataParameters::default();
if let Some(metadata) = dict {
for (key, value) in metadata.iter() {
match key.extract::<&str>().context("Parsing metadata keys")? {
match key
.extract::<PyBackedStr>()
.context("Parsing metadata keys")?
.as_ref()
{
"watermark" => {
default_metadata.watermark =
value.extract().context("parsing watermark failed")?;
Expand All @@ -140,7 +144,7 @@ pub fn pydict_to_metadata(dict: Option<&PyDict>) -> Result<MetadataParameters> {
value.extract().context("parsing deadline failed")?;
}
"open_telemetry_context" => {
let otel_context: &str = value
let otel_context: PyBackedStr = value
.extract()
.context("parsing open telemetry context failed")?;
default_metadata.open_telemetry_context = otel_context.to_string();
Expand All @@ -152,8 +156,8 @@ pub fn pydict_to_metadata(dict: Option<&PyDict>) -> Result<MetadataParameters> {
Ok(default_metadata)
}

pub fn metadata_to_pydict<'a>(metadata: &'a Metadata, py: Python<'a>) -> &'a PyDict {
let dict = PyDict::new(py);
pub fn metadata_to_pydict<'a>(metadata: &'a Metadata, py: Python<'a>) -> pyo3::Bound<'a, PyDict> {
let dict = PyDict::new_bound(py);
dict.set_item(
"open_telemetry_context",
&metadata.parameters.open_telemetry_context,
Expand Down
2 changes: 1 addition & 1 deletion apis/rust/operator/src/raw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ pub unsafe fn dora_on_event<O: DoraOperator>(
status: DoraStatus::Continue,
};
};
let data = arrow::ffi::from_ffi(data_array, &input.schema);
let data = unsafe { arrow::ffi::from_ffi(data_array, &input.schema) };

match data {
Ok(data) => Event::Input {
Expand Down
2 changes: 1 addition & 1 deletion apis/rust/operator/types/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ pub fn dora_free_input_id(_input_id: char_p_boxed) {}
#[ffi_export]
pub fn dora_read_data(input: &mut Input) -> Option<safer_ffi::Vec<u8>> {
let data_array = input.data_array.take()?;
let data = arrow::ffi::from_ffi(data_array, &input.schema).ok()?;
let data = unsafe { arrow::ffi::from_ffi(data_array, &input.schema).ok()? };
let array = ArrowData(arrow::array::make_array(data));
let bytes: &[u8] = TryFrom::try_from(&array).ok()?;
Some(bytes.to_owned().into())
Expand Down
43 changes: 17 additions & 26 deletions binaries/runtime/src/operator/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use dora_operator_api_types::DoraStatus;
use eyre::{bail, eyre, Context, Result};
use pyo3::{
pyclass,
types::{IntoPyDict, PyDict},
types::{IntoPyDict, PyAnyMethods, PyDict, PyTracebackMethods},
Py, PyAny, Python,
};
use std::{
Expand All @@ -23,7 +23,7 @@ use tokio::sync::{mpsc::Sender, oneshot};
use tracing::{error, field, span, warn};

fn traceback(err: pyo3::PyErr) -> eyre::Report {
let traceback = Python::with_gil(|py| err.traceback(py).and_then(|t| t.format().ok()));
let traceback = Python::with_gil(|py| err.traceback_bound(py).and_then(|t| t.format().ok()));
if let Some(traceback) = traceback {
eyre::eyre!("{traceback}\n{err}")
} else {
Expand Down Expand Up @@ -78,7 +78,9 @@ pub fn run(
let parent_path = parent_path
.to_str()
.ok_or_else(|| eyre!("module path is not valid utf8"))?;
let sys = py.import("sys").wrap_err("failed to import `sys` module")?;
let sys = py
.import_bound("sys")
.wrap_err("failed to import `sys` module")?;
let sys_path = sys
.getattr("path")
.wrap_err("failed to import `sys.path` module")?;
Expand All @@ -90,14 +92,14 @@ pub fn run(
.wrap_err("failed to append module path to python search path")?;
}

let module = py.import(module_name).map_err(traceback)?;
let module = py.import_bound(module_name).map_err(traceback)?;
let operator_class = module
.getattr("Operator")
.wrap_err("no `Operator` class found in module")?;

let locals = [("Operator", operator_class)].into_py_dict(py);
let locals = [("Operator", operator_class)].into_py_dict_bound(py);
let operator = py
.eval("Operator()", None, Some(locals))
.eval_bound("Operator()", None, Some(&locals))
.map_err(traceback)?;
operator.setattr(
"dataflow_descriptor",
Expand Down Expand Up @@ -141,11 +143,11 @@ pub fn run(
.wrap_err("could not extract operator state as a PyDict")?;
// Reload module
let module = py
.import(module_name)
.import_bound(module_name)
.map_err(traceback)
.wrap_err(format!("Could not retrieve {module_name} while reloading"))?;
let importlib = py
.import("importlib")
.import_bound("importlib")
.wrap_err("failed to import `importlib` module")?;
let module = importlib
.call_method("reload", (module,), None)
Expand All @@ -155,9 +157,9 @@ pub fn run(
.wrap_err("no `Operator` class found in module")?;

// Create a new reloaded operator
let locals = [("Operator", reloaded_operator_class)].into_py_dict(py);
let locals = [("Operator", reloaded_operator_class)].into_py_dict_bound(py);
let operator: Py<pyo3::PyAny> = py
.eval("Operator()", None, Some(locals))
.eval_bound("Operator()", None, Some(&locals))
.map_err(traceback)
.wrap_err("Could not initialize reloaded operator")?
.into();
Expand Down Expand Up @@ -185,17 +187,6 @@ pub fn run(
let status = Python::with_gil(|py| -> Result<i32> {
let span = span!(tracing::Level::TRACE, "on_event", input_id = field::Empty);
let _ = span.enter();
// We need to create a new scoped `GILPool` because the dora-runtime
// is currently started through a `start_runtime` wrapper function,
// which is annotated with `#[pyfunction]`. This attribute creates an
// initial `GILPool` that lasts for the entire lifetime of the `dora-runtime`.
// However, we want the `PyBytes` created below to be freed earlier.
// creating a new scoped `GILPool` tied to this closure, will free `PyBytes`
// at the end of the closure.
// See https://github.com/PyO3/pyo3/pull/2864 and
// https://github.com/PyO3/pyo3/issues/2853 for more details.
let pool = unsafe { py.new_pool() };
let py = pool.python();

// Add metadata context if we have a tracer and
// incoming input has some metadata.
Expand Down Expand Up @@ -300,8 +291,8 @@ mod callback_impl {
use eyre::{eyre, Context, Result};
use pyo3::{
pymethods,
types::{PyBytes, PyDict},
PyObject, Python,
types::{PyBytes, PyBytesMethods, PyDict},
Bound, PyObject, Python,
};
use tokio::sync::oneshot;
use tracing::{field, span};
Expand All @@ -318,7 +309,7 @@ mod callback_impl {
&mut self,
output: &str,
data: PyObject,
metadata: Option<&PyDict>,
metadata: Option<Bound<'_, PyDict>>,
py: Python,
) -> Result<()> {
let parameters = pydict_to_metadata(metadata)
Expand Down Expand Up @@ -354,12 +345,12 @@ mod callback_impl {
}
};

let (sample, type_info) = if let Ok(py_bytes) = data.downcast::<PyBytes>(py) {
let (sample, type_info) = if let Ok(py_bytes) = data.downcast_bound::<PyBytes>(py) {
let data = py_bytes.as_bytes();
let mut sample = allocate_sample(data.len())?;
sample.copy_from_slice(data);
(sample, ArrowTypeInfo::byte_array(data.len()))
} else if let Ok(arrow_array) = ArrayData::from_pyarrow(data.as_ref(py)) {
} else if let Ok(arrow_array) = ArrayData::from_pyarrow_bound(data.bind(py)) {
let total_len = required_data_size(&arrow_array);
let mut sample = allocate_sample(total_len)?;

Expand Down
2 changes: 1 addition & 1 deletion binaries/runtime/src/operator/shared_lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ impl<'lib> SharedLibraryOperator<'lib> {
..Default::default()
};

let arrow_array = match arrow::ffi::from_ffi(data_array, &schema) {
let arrow_array = match unsafe { arrow::ffi::from_ffi(data_array, &schema) } {
Ok(a) => a,
Err(err) => return DoraResult::from_error(err.to_string()),
};
Expand Down
18 changes: 9 additions & 9 deletions libraries/extensions/ros2-bridge/python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ use eyre::{eyre, Context, ContextCompat, Result};
use futures::{Stream, StreamExt};
use pyo3::{
prelude::{pyclass, pymethods},
types::{PyDict, PyList, PyModule},
PyAny, PyObject, PyResult, Python,
types::{PyAnyMethods, PyDict, PyList, PyModule, PyModuleMethods},
Bound, PyAny, PyObject, PyResult, Python,
};
use typed::{deserialize::StructDeserializer, TypeInfo, TypedValue};

Expand Down Expand Up @@ -57,7 +57,7 @@ impl Ros2Context {
pub fn new(ros_paths: Option<Vec<PathBuf>>) -> eyre::Result<Self> {
Python::with_gil(|py| -> Result<()> {
let warnings = py
.import("warnings")
.import_bound("warnings")
.wrap_err("failed to import `warnings` module")?;
warnings
.call_method1("warn", ("dora-rs ROS2 Bridge is unstable and may change at any point without it being considered a breaking change",))
Expand Down Expand Up @@ -322,8 +322,8 @@ impl Ros2Publisher {
/// :type data: pyarrow.Array
/// :rtype: None
///
pub fn publish(&self, data: &PyAny) -> eyre::Result<()> {
let pyarrow = PyModule::import(data.py(), "pyarrow")?;
pub fn publish(&self, data: Bound<'_, PyAny>) -> eyre::Result<()> {
let pyarrow = PyModule::import_bound(data.py(), "pyarrow")?;

let data = if data.is_instance_of::<PyDict>() {
// convert to arrow struct scalar
Expand All @@ -332,15 +332,15 @@ impl Ros2Publisher {
data
};

let data = if data.is_instance(pyarrow.getattr("StructScalar")?)? {
let data = if data.is_instance(&pyarrow.getattr("StructScalar")?)? {
// convert to arrow array
let list = PyList::new(data.py(), [data]);
let list = PyList::new_bound(data.py(), [data]);
pyarrow.getattr("array")?.call1((list,))?
} else {
data
};

let value = arrow::array::ArrayData::from_pyarrow(data)?;
let value = arrow::array::ArrayData::from_pyarrow_bound(&data)?;
//// add type info to ensure correct serialization (e.g. struct types
//// and map types need to be serialized differently)
let typed_value = TypedValue {
Expand Down Expand Up @@ -431,7 +431,7 @@ impl Stream for Ros2SubscriptionStream {
}
}

pub fn create_dora_ros2_bridge_module(m: &PyModule) -> PyResult<()> {
pub fn create_dora_ros2_bridge_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<Ros2Context>()?;
m.add_class::<Ros2Node>()?;
m.add_class::<Ros2NodeOptions>()?;
Expand Down
18 changes: 10 additions & 8 deletions libraries/extensions/ros2-bridge/python/src/typed/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,12 @@ mod tests {
use arrow::pyarrow::ToPyArrow;

use pyo3::types::IntoPyDict;
use pyo3::types::PyAnyMethods;
use pyo3::types::PyDict;
use pyo3::types::PyList;
use pyo3::types::PyModule;
use pyo3::types::PyTuple;
use pyo3::PyNativeType;
use pyo3::Python;
use serde::de::DeserializeSeed;
use serde::Serialize;
Expand All @@ -61,13 +63,13 @@ mod tests {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")); //.join("test_utils.py"); // Adjust this path as needed

// Add the Python module's directory to sys.path
py.run(
py.run_bound(
"import sys; sys.path.append(str(path))",
Some([("path", path)].into_py_dict(py)),
Some(&[("path", path)].into_py_dict_bound(py)),
None,
)?;

let my_module = PyModule::import(py, "test_utils")?;
let my_module = PyModule::import_bound(py, "test_utils")?;

let arrays: &PyList = my_module.getattr("TEST_ARRAYS")?.extract()?;
for array_wrapper in arrays.iter() {
Expand All @@ -77,7 +79,7 @@ mod tests {
println!("Checking {}::{}", package_name, message_name);
let in_pyarrow = arrays.get_item(2)?;

let array = arrow::array::ArrayData::from_pyarrow(in_pyarrow)?;
let array = arrow::array::ArrayData::from_pyarrow_bound(&in_pyarrow.as_borrowed())?;
let type_info = TypeInfo {
package_name: package_name.into(),
message_name: message_name.clone().into(),
Expand All @@ -99,17 +101,17 @@ mod tests {

let out_pyarrow = out_value.to_pyarrow(py)?;

let test_utils = PyModule::import(py, "test_utils")?;
let context = PyDict::new(py);
let test_utils = PyModule::import_bound(py, "test_utils")?;
let context = PyDict::new_bound(py);

context.set_item("test_utils", test_utils)?;
context.set_item("in_pyarrow", in_pyarrow)?;
context.set_item("out_pyarrow", out_pyarrow)?;

let _ = py
.eval(
.eval_bound(
"test_utils.is_subset(in_pyarrow, out_pyarrow)",
Some(context),
Some(&context),
None,
)
.context("could not check if it is a subset")?;
Expand Down
2 changes: 1 addition & 1 deletion tool_nodes/dora-record/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@ dora-node-api = { workspace = true, features = ["tracing"] }
eyre = "0.6.8"
chrono = "0.4.31"
dora-tracing = { workspace = true }
parquet = { version = "48.0.0", features = ["async"] }
parquet = { version = "52", features = ["async"] }
1 change: 0 additions & 1 deletion tool_nodes/dora-record/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ async fn main() -> eyre::Result<()> {
let mut writer = AsyncArrowWriter::try_new(
file,
schema.clone(),
0,
Some(
WriterProperties::builder()
.set_compression(parquet::basic::Compression::BROTLI(
Expand Down
Loading