Skip to content

Commit

Permalink
Merge pull request #472 from Michael-J-Ward/pyo3-bounds
Browse files Browse the repository at this point in the history
Update Pyo3 bounds
  • Loading branch information
haixuanTao authored Jun 8, 2024
2 parents 685b01e + bf088fc commit d7be6a4
Show file tree
Hide file tree
Showing 12 changed files with 145 additions and 124 deletions.
140 changes: 82 additions & 58 deletions Cargo.lock

Large diffs are not rendered by default.

12 changes: 6 additions & 6 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,12 @@ dora-coordinator = { version = "0.3.4", path = "binaries/coordinator" }
dora-ros2-bridge = { path = "libraries/extensions/ros2-bridge" }
dora-ros2-bridge-msg-gen = { path = "libraries/extensions/ros2-bridge/msg-gen" }
dora-ros2-bridge-python = { path = "libraries/extensions/ros2-bridge/python" }
arrow = "48.0.0"
arrow-schema = "48.0.0"
arrow-data = "48.0.0"
arrow-array = "48.0.0"
pyo3 = "0.20.0"
pythonize = "0.20.0"
arrow = { version = "52" }
arrow-schema = { version = "52" }
arrow-data = { version = "52" }
arrow-array = { version = "52" }
pyo3 = "0.21"
pythonize = "0.21"

[package]
name = "dora-examples"
Expand Down
13 changes: 7 additions & 6 deletions apis/python/node/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,17 +122,17 @@ impl Node {
&mut self,
output_id: String,
data: PyObject,
metadata: Option<&PyDict>,
metadata: Option<Bound<'_, PyDict>>,
py: Python,
) -> eyre::Result<()> {
let parameters = pydict_to_metadata(metadata)?;

if let Ok(py_bytes) = data.downcast::<PyBytes>(py) {
if let Ok(py_bytes) = data.downcast_bound::<PyBytes>(py) {
let data = py_bytes.as_bytes();
self.node
.send_output_bytes(output_id.into(), parameters, data.len(), data)
.wrap_err("failed to send output")?;
} else if let Ok(arrow_array) = arrow::array::ArrayData::from_pyarrow(data.as_ref(py)) {
} else if let Ok(arrow_array) = arrow::array::ArrayData::from_pyarrow_bound(data.bind(py)) {
self.node.send_output(
output_id.into(),
parameters,
Expand Down Expand Up @@ -251,9 +251,10 @@ pub fn start_runtime() -> eyre::Result<()> {
}

#[pymodule]
fn dora(_py: Python, m: &PyModule) -> PyResult<()> {
dora_ros2_bridge_python::create_dora_ros2_bridge_module(m)?;
m.add_function(wrap_pyfunction!(start_runtime, m)?)?;
fn dora(_py: Python, m: Bound<'_, PyModule>) -> PyResult<()> {
dora_ros2_bridge_python::create_dora_ros2_bridge_module(&m)?;

m.add_function(wrap_pyfunction!(start_runtime, &m)?)?;
m.add_class::<Node>()?;
m.add_class::<PyEvent>()?;
m.setattr("__version__", env!("CARGO_PKG_VERSION"))?;
Expand Down
16 changes: 10 additions & 6 deletions apis/python/operator/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use arrow::{array::ArrayRef, pyarrow::ToPyArrow};
use dora_node_api::{merged::MergedEvent, Event, Metadata, MetadataParameters};
use eyre::{Context, Result};
use pyo3::{exceptions::PyLookupError, prelude::*, types::PyDict};
use pyo3::{exceptions::PyLookupError, prelude::*, pybacked::PyBackedStr, types::PyDict};

/// Dora Event
#[pyclass]
Expand Down Expand Up @@ -126,11 +126,15 @@ impl From<MergedEvent<PyObject>> for PyEvent {
}
}

pub fn pydict_to_metadata(dict: Option<&PyDict>) -> Result<MetadataParameters> {
pub fn pydict_to_metadata(dict: Option<Bound<'_, PyDict>>) -> Result<MetadataParameters> {
let mut default_metadata = MetadataParameters::default();
if let Some(metadata) = dict {
for (key, value) in metadata.iter() {
match key.extract::<&str>().context("Parsing metadata keys")? {
match key
.extract::<PyBackedStr>()
.context("Parsing metadata keys")?
.as_ref()
{
"watermark" => {
default_metadata.watermark =
value.extract().context("parsing watermark failed")?;
Expand All @@ -140,7 +144,7 @@ pub fn pydict_to_metadata(dict: Option<&PyDict>) -> Result<MetadataParameters> {
value.extract().context("parsing deadline failed")?;
}
"open_telemetry_context" => {
let otel_context: &str = value
let otel_context: PyBackedStr = value
.extract()
.context("parsing open telemetry context failed")?;
default_metadata.open_telemetry_context = otel_context.to_string();
Expand All @@ -152,8 +156,8 @@ pub fn pydict_to_metadata(dict: Option<&PyDict>) -> Result<MetadataParameters> {
Ok(default_metadata)
}

pub fn metadata_to_pydict<'a>(metadata: &'a Metadata, py: Python<'a>) -> &'a PyDict {
let dict = PyDict::new(py);
pub fn metadata_to_pydict<'a>(metadata: &'a Metadata, py: Python<'a>) -> pyo3::Bound<'a, PyDict> {
let dict = PyDict::new_bound(py);
dict.set_item(
"open_telemetry_context",
&metadata.parameters.open_telemetry_context,
Expand Down
2 changes: 1 addition & 1 deletion apis/rust/operator/src/raw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ pub unsafe fn dora_on_event<O: DoraOperator>(
status: DoraStatus::Continue,
};
};
let data = arrow::ffi::from_ffi(data_array, &input.schema);
let data = unsafe { arrow::ffi::from_ffi(data_array, &input.schema) };

match data {
Ok(data) => Event::Input {
Expand Down
2 changes: 1 addition & 1 deletion apis/rust/operator/types/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ pub fn dora_free_input_id(_input_id: char_p_boxed) {}
#[ffi_export]
pub fn dora_read_data(input: &mut Input) -> Option<safer_ffi::Vec<u8>> {
let data_array = input.data_array.take()?;
let data = arrow::ffi::from_ffi(data_array, &input.schema).ok()?;
let data = unsafe { arrow::ffi::from_ffi(data_array, &input.schema).ok()? };
let array = ArrowData(arrow::array::make_array(data));
let bytes: &[u8] = TryFrom::try_from(&array).ok()?;
Some(bytes.to_owned().into())
Expand Down
43 changes: 17 additions & 26 deletions binaries/runtime/src/operator/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use dora_operator_api_types::DoraStatus;
use eyre::{bail, eyre, Context, Result};
use pyo3::{
pyclass,
types::{IntoPyDict, PyDict},
types::{IntoPyDict, PyAnyMethods, PyDict, PyTracebackMethods},
Py, PyAny, Python,
};
use std::{
Expand All @@ -23,7 +23,7 @@ use tokio::sync::{mpsc::Sender, oneshot};
use tracing::{error, field, span, warn};

fn traceback(err: pyo3::PyErr) -> eyre::Report {
let traceback = Python::with_gil(|py| err.traceback(py).and_then(|t| t.format().ok()));
let traceback = Python::with_gil(|py| err.traceback_bound(py).and_then(|t| t.format().ok()));
if let Some(traceback) = traceback {
eyre::eyre!("{traceback}\n{err}")
} else {
Expand Down Expand Up @@ -78,7 +78,9 @@ pub fn run(
let parent_path = parent_path
.to_str()
.ok_or_else(|| eyre!("module path is not valid utf8"))?;
let sys = py.import("sys").wrap_err("failed to import `sys` module")?;
let sys = py
.import_bound("sys")
.wrap_err("failed to import `sys` module")?;
let sys_path = sys
.getattr("path")
.wrap_err("failed to import `sys.path` module")?;
Expand All @@ -90,14 +92,14 @@ pub fn run(
.wrap_err("failed to append module path to python search path")?;
}

let module = py.import(module_name).map_err(traceback)?;
let module = py.import_bound(module_name).map_err(traceback)?;
let operator_class = module
.getattr("Operator")
.wrap_err("no `Operator` class found in module")?;

let locals = [("Operator", operator_class)].into_py_dict(py);
let locals = [("Operator", operator_class)].into_py_dict_bound(py);
let operator = py
.eval("Operator()", None, Some(locals))
.eval_bound("Operator()", None, Some(&locals))
.map_err(traceback)?;
operator.setattr(
"dataflow_descriptor",
Expand Down Expand Up @@ -141,11 +143,11 @@ pub fn run(
.wrap_err("could not extract operator state as a PyDict")?;
// Reload module
let module = py
.import(module_name)
.import_bound(module_name)
.map_err(traceback)
.wrap_err(format!("Could not retrieve {module_name} while reloading"))?;
let importlib = py
.import("importlib")
.import_bound("importlib")
.wrap_err("failed to import `importlib` module")?;
let module = importlib
.call_method("reload", (module,), None)
Expand All @@ -155,9 +157,9 @@ pub fn run(
.wrap_err("no `Operator` class found in module")?;

// Create a new reloaded operator
let locals = [("Operator", reloaded_operator_class)].into_py_dict(py);
let locals = [("Operator", reloaded_operator_class)].into_py_dict_bound(py);
let operator: Py<pyo3::PyAny> = py
.eval("Operator()", None, Some(locals))
.eval_bound("Operator()", None, Some(&locals))
.map_err(traceback)
.wrap_err("Could not initialize reloaded operator")?
.into();
Expand Down Expand Up @@ -185,17 +187,6 @@ pub fn run(
let status = Python::with_gil(|py| -> Result<i32> {
let span = span!(tracing::Level::TRACE, "on_event", input_id = field::Empty);
let _ = span.enter();
// We need to create a new scoped `GILPool` because the dora-runtime
// is currently started through a `start_runtime` wrapper function,
// which is annotated with `#[pyfunction]`. This attribute creates an
// initial `GILPool` that lasts for the entire lifetime of the `dora-runtime`.
// However, we want the `PyBytes` created below to be freed earlier.
// creating a new scoped `GILPool` tied to this closure, will free `PyBytes`
// at the end of the closure.
// See https://github.com/PyO3/pyo3/pull/2864 and
// https://github.com/PyO3/pyo3/issues/2853 for more details.
let pool = unsafe { py.new_pool() };
let py = pool.python();

// Add metadata context if we have a tracer and
// incoming input has some metadata.
Expand Down Expand Up @@ -300,8 +291,8 @@ mod callback_impl {
use eyre::{eyre, Context, Result};
use pyo3::{
pymethods,
types::{PyBytes, PyDict},
PyObject, Python,
types::{PyBytes, PyBytesMethods, PyDict},
Bound, PyObject, Python,
};
use tokio::sync::oneshot;
use tracing::{field, span};
Expand All @@ -318,7 +309,7 @@ mod callback_impl {
&mut self,
output: &str,
data: PyObject,
metadata: Option<&PyDict>,
metadata: Option<Bound<'_, PyDict>>,
py: Python,
) -> Result<()> {
let parameters = pydict_to_metadata(metadata)
Expand Down Expand Up @@ -354,12 +345,12 @@ mod callback_impl {
}
};

let (sample, type_info) = if let Ok(py_bytes) = data.downcast::<PyBytes>(py) {
let (sample, type_info) = if let Ok(py_bytes) = data.downcast_bound::<PyBytes>(py) {
let data = py_bytes.as_bytes();
let mut sample = allocate_sample(data.len())?;
sample.copy_from_slice(data);
(sample, ArrowTypeInfo::byte_array(data.len()))
} else if let Ok(arrow_array) = ArrayData::from_pyarrow(data.as_ref(py)) {
} else if let Ok(arrow_array) = ArrayData::from_pyarrow_bound(data.bind(py)) {
let total_len = required_data_size(&arrow_array);
let mut sample = allocate_sample(total_len)?;

Expand Down
2 changes: 1 addition & 1 deletion binaries/runtime/src/operator/shared_lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ impl<'lib> SharedLibraryOperator<'lib> {
..Default::default()
};

let arrow_array = match arrow::ffi::from_ffi(data_array, &schema) {
let arrow_array = match unsafe { arrow::ffi::from_ffi(data_array, &schema) } {
Ok(a) => a,
Err(err) => return DoraResult::from_error(err.to_string()),
};
Expand Down
18 changes: 9 additions & 9 deletions libraries/extensions/ros2-bridge/python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ use eyre::{eyre, Context, ContextCompat, Result};
use futures::{Stream, StreamExt};
use pyo3::{
prelude::{pyclass, pymethods},
types::{PyDict, PyList, PyModule},
PyAny, PyObject, PyResult, Python,
types::{PyAnyMethods, PyDict, PyList, PyModule, PyModuleMethods},
Bound, PyAny, PyObject, PyResult, Python,
};
use typed::{deserialize::StructDeserializer, TypeInfo, TypedValue};

Expand Down Expand Up @@ -57,7 +57,7 @@ impl Ros2Context {
pub fn new(ros_paths: Option<Vec<PathBuf>>) -> eyre::Result<Self> {
Python::with_gil(|py| -> Result<()> {
let warnings = py
.import("warnings")
.import_bound("warnings")
.wrap_err("failed to import `warnings` module")?;
warnings
.call_method1("warn", ("dora-rs ROS2 Bridge is unstable and may change at any point without it being considered a breaking change",))
Expand Down Expand Up @@ -322,8 +322,8 @@ impl Ros2Publisher {
/// :type data: pyarrow.Array
/// :rtype: None
///
pub fn publish(&self, data: &PyAny) -> eyre::Result<()> {
let pyarrow = PyModule::import(data.py(), "pyarrow")?;
pub fn publish(&self, data: Bound<'_, PyAny>) -> eyre::Result<()> {
let pyarrow = PyModule::import_bound(data.py(), "pyarrow")?;

let data = if data.is_instance_of::<PyDict>() {
// convert to arrow struct scalar
Expand All @@ -332,15 +332,15 @@ impl Ros2Publisher {
data
};

let data = if data.is_instance(pyarrow.getattr("StructScalar")?)? {
let data = if data.is_instance(&pyarrow.getattr("StructScalar")?)? {
// convert to arrow array
let list = PyList::new(data.py(), [data]);
let list = PyList::new_bound(data.py(), [data]);
pyarrow.getattr("array")?.call1((list,))?
} else {
data
};

let value = arrow::array::ArrayData::from_pyarrow(data)?;
let value = arrow::array::ArrayData::from_pyarrow_bound(&data)?;
//// add type info to ensure correct serialization (e.g. struct types
//// and map types need to be serialized differently)
let typed_value = TypedValue {
Expand Down Expand Up @@ -431,7 +431,7 @@ impl Stream for Ros2SubscriptionStream {
}
}

pub fn create_dora_ros2_bridge_module(m: &PyModule) -> PyResult<()> {
pub fn create_dora_ros2_bridge_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<Ros2Context>()?;
m.add_class::<Ros2Node>()?;
m.add_class::<Ros2NodeOptions>()?;
Expand Down
18 changes: 10 additions & 8 deletions libraries/extensions/ros2-bridge/python/src/typed/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,12 @@ mod tests {
use arrow::pyarrow::ToPyArrow;

use pyo3::types::IntoPyDict;
use pyo3::types::PyAnyMethods;
use pyo3::types::PyDict;
use pyo3::types::PyList;
use pyo3::types::PyModule;
use pyo3::types::PyTuple;
use pyo3::PyNativeType;
use pyo3::Python;
use serde::de::DeserializeSeed;
use serde::Serialize;
Expand All @@ -61,13 +63,13 @@ mod tests {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")); //.join("test_utils.py"); // Adjust this path as needed

// Add the Python module's directory to sys.path
py.run(
py.run_bound(
"import sys; sys.path.append(str(path))",
Some([("path", path)].into_py_dict(py)),
Some(&[("path", path)].into_py_dict_bound(py)),
None,
)?;

let my_module = PyModule::import(py, "test_utils")?;
let my_module = PyModule::import_bound(py, "test_utils")?;

let arrays: &PyList = my_module.getattr("TEST_ARRAYS")?.extract()?;
for array_wrapper in arrays.iter() {
Expand All @@ -77,7 +79,7 @@ mod tests {
println!("Checking {}::{}", package_name, message_name);
let in_pyarrow = arrays.get_item(2)?;

let array = arrow::array::ArrayData::from_pyarrow(in_pyarrow)?;
let array = arrow::array::ArrayData::from_pyarrow_bound(&in_pyarrow.as_borrowed())?;
let type_info = TypeInfo {
package_name: package_name.into(),
message_name: message_name.clone().into(),
Expand All @@ -99,17 +101,17 @@ mod tests {

let out_pyarrow = out_value.to_pyarrow(py)?;

let test_utils = PyModule::import(py, "test_utils")?;
let context = PyDict::new(py);
let test_utils = PyModule::import_bound(py, "test_utils")?;
let context = PyDict::new_bound(py);

context.set_item("test_utils", test_utils)?;
context.set_item("in_pyarrow", in_pyarrow)?;
context.set_item("out_pyarrow", out_pyarrow)?;

let _ = py
.eval(
.eval_bound(
"test_utils.is_subset(in_pyarrow, out_pyarrow)",
Some(context),
Some(&context),
None,
)
.context("could not check if it is a subset")?;
Expand Down
2 changes: 1 addition & 1 deletion tool_nodes/dora-record/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@ dora-node-api = { workspace = true, features = ["tracing"] }
eyre = "0.6.8"
chrono = "0.4.31"
dora-tracing = { workspace = true }
parquet = { version = "48.0.0", features = ["async"] }
parquet = { version = "52", features = ["async"] }
1 change: 0 additions & 1 deletion tool_nodes/dora-record/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ async fn main() -> eyre::Result<()> {
let mut writer = AsyncArrowWriter::try_new(
file,
schema.clone(),
0,
Some(
WriterProperties::builder()
.set_compression(parquet::basic::Compression::BROTLI(
Expand Down

0 comments on commit d7be6a4

Please sign in to comment.