Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: python udf implementation #703

Merged
merged 18 commits into from
Aug 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions crates/sparrow-compiler/src/ast_to_dfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,8 @@ pub fn add_udf_to_dfg(
})
.try_collect()?;

let is_new = dfg.add_udf(udf.clone(), args.iter().map(|i| i.value()).collect())?;
let value = is_any_new(dfg, &args)?;
let value = dfg.add_udf(udf.clone(), args.iter().map(|i| i.value()).collect())?;
let is_new = is_any_new(dfg, &args)?;

let time_domain_check = TimeDomainCheck::Compatible;
let time_domain =
Expand Down
2 changes: 1 addition & 1 deletion crates/sparrow-compiler/src/plan/expression_to_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ pub(super) fn dfg_to_plan(
InstKind::FieldRef => "field_ref".to_owned(),
InstKind::Record => "record".to_owned(),
InstKind::Cast(_) => "cast".to_owned(),
InstKind::Udf(udf) => udf.signature().name().to_owned(),
InstKind::Udf(udf) => udf.uuid().to_string(),
};

let result_type =
Expand Down
4 changes: 2 additions & 2 deletions crates/sparrow-instructions/src/evaluators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ use time::*;
#[derive(Debug)]
pub struct StaticInfo<'a> {
inst_kind: &'a InstKind,
args: Vec<StaticArg>,
result_type: &'a DataType,
pub args: Vec<StaticArg>,
pub result_type: &'a DataType,
}

impl<'a> StaticInfo<'a> {
Expand Down
28 changes: 24 additions & 4 deletions crates/sparrow-runtime/src/execute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use chrono::NaiveDateTime;
use enum_map::EnumMap;
use error_stack::{IntoReport, IntoReportCompat, ResultExt};
use futures::Stream;
use hashbrown::HashMap;
use prost_wkt_types::Timestamp;
use sparrow_api::kaskada::v1alpha::execute_request::Limits;
use sparrow_api::kaskada::v1alpha::{
Expand All @@ -12,7 +13,9 @@ use sparrow_api::kaskada::v1alpha::{
};
use sparrow_arrow::scalar_value::ScalarValue;
use sparrow_compiler::{hash_compute_plan_proto, DataContext};
use sparrow_instructions::Udf;
use sparrow_qfr::kaskada::sparrow::v1alpha::FlightRecordHeader;
use uuid::Uuid;

use crate::execute::compute_store_guard::ComputeStoreGuard;
use crate::execute::error::Error;
Expand Down Expand Up @@ -66,9 +69,15 @@ pub async fn execute(
..ExecutionOptions::default()
};

// let output_at_time = request.final_result_time;

execute_new(plan, destination, data_context, options, None).await
execute_new(
plan,
destination,
data_context,
options,
None,
HashMap::new(),
)
.await
}

#[derive(Default, Debug)]
Expand Down Expand Up @@ -214,12 +223,18 @@ async fn load_key_hash_inverse(
/// ----------
/// - key_hash_inverse: If set, specifies the key hash inverses to use. If None, the
/// key hashes will be created.
/// - udfs: contains the mapping of uuid to udf implementation. This is currently used
/// so we can serialize the uuid to the ComputePlan, then look up what implementation to use
/// when creating the evaluator. This works because we are on a single machine, and don't need
/// to pass the plan around. However, we'll eventually need to look into serializing/pickling
/// the callable.
pub async fn execute_new(
plan: ComputePlan,
destination: Destination,
mut data_context: DataContext,
options: ExecutionOptions,
key_hash_inverse: Option<Arc<ThreadSafeKeyHashInverse>>,
udfs: HashMap<Uuid, Arc<dyn Udf>>,
) -> error_stack::Result<impl Stream<Item = error_stack::Result<ExecuteResponse, Error>>, Error> {
let object_stores = Arc::new(ObjectStoreRegistry::default());

Expand Down Expand Up @@ -265,6 +280,7 @@ pub async fn execute_new(
output_at_time,
bounded_lateness_ns: options.bounded_lateness_ns,
materialize: options.materialize,
udfs,
};

// Start executing the query. We pass the response channel to the
Expand Down Expand Up @@ -326,5 +342,9 @@ pub async fn materialize(
// TODO: the `execute_with_progress` method contains a lot of additional logic that is theoretically not needed,
// as the materialization does not exit, and should not need to handle cleanup tasks that regular
// queries do. We should likely refactor this to use a separate `materialize_with_progress` method.
execute_new(plan, destination, data_context, options, None).await

// TODO: Unimplemented feature - UDFs
let udfs = HashMap::new();
jordanrfrazier marked this conversation as resolved.
Show resolved Hide resolved

execute_new(plan, destination, data_context, options, None, udfs).await
}
1 change: 1 addition & 0 deletions crates/sparrow-runtime/src/execute/compute_executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ pub struct ComputeResult {

impl ComputeExecutor {
/// Spawns the compute tasks using the new operation based executor.
#[allow(clippy::too_many_arguments)]
pub async fn try_spawn(
mut context: OperationContext,
plan_hash: PlanHash,
Expand Down
18 changes: 13 additions & 5 deletions crates/sparrow-runtime/src/execute/operation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,16 @@ use chrono::NaiveDateTime;
use enum_map::EnumMap;
use error_stack::{IntoReport, IntoReportCompat, Report, Result, ResultExt};
use futures::Future;
use hashbrown::HashMap;
use prost_wkt_types::Timestamp;
use sparrow_api::kaskada::v1alpha::operation_plan::tick_operation::TickBehavior;
use sparrow_api::kaskada::v1alpha::{operation_plan, ComputePlan, LateBoundValue, OperationPlan};
use sparrow_arrow::scalar_value::ScalarValue;
use sparrow_compiler::DataContext;
use sparrow_instructions::ComputeStore;
use sparrow_instructions::{ComputeStore, Udf};
use tokio::task::JoinHandle;
use tracing::Instrument;
use uuid::Uuid;

use self::final_tick::FinalTickOperation;
use self::input_batch::InputBatch;
Expand Down Expand Up @@ -103,6 +105,8 @@ pub(crate) struct OperationContext {
///
/// Derived from the ExecutionOptions,
pub materialize: bool,
/// Mapping of uuid to user-defined functions.
pub udfs: HashMap<Uuid, Arc<dyn Udf>>,
}

impl OperationContext {
Expand Down Expand Up @@ -187,10 +191,14 @@ impl OperationExecutor {

let operation_label = operator.label();

let mut expression_executor =
ExpressionExecutor::try_new(operation_label, operation.expressions, late_bindings)
.into_report()
.change_context(Error::internal_msg("unable to create executor"))?;
let mut expression_executor = ExpressionExecutor::try_new(
operation_label,
operation.expressions,
late_bindings,
&context.udfs,
)
.into_report()
.change_context(Error::internal_msg("unable to create executor"))?;

debug_assert_eq!(operator.input_len(), input_channels.len());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,17 @@ use arrow::array::ArrayRef;
use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit};
use arrow::record_batch::RecordBatch;
use enum_map::EnumMap;
use hashbrown::HashMap;
use itertools::Itertools;
use sparrow_api::kaskada::v1alpha::expression_plan::Operator;
use sparrow_api::kaskada::v1alpha::{ExpressionPlan, LateBoundValue, OperationInputRef};
use sparrow_arrow::scalar_value::ScalarValue;
use sparrow_instructions::ValueRef;
use sparrow_instructions::{
create_evaluator, ColumnarValue, ComputeStore, Evaluator, GroupingIndices, InstKind, InstOp,
RuntimeInfo, StaticArg, StaticInfo, StoreKey,
};
use sparrow_instructions::{Udf, ValueRef};
use uuid::Uuid;

use crate::execute::operation::InputBatch;
use crate::Batch;
Expand Down Expand Up @@ -49,6 +51,7 @@ impl ExpressionExecutor {
operation_label: &'static str,
expressions: Vec<ExpressionPlan>,
late_bindings: &EnumMap<LateBoundValue, Option<ScalarValue>>,
udfs: &HashMap<Uuid, Arc<dyn Udf>>,
) -> anyhow::Result<Self> {
let mut input_columns = Vec::new();

Expand Down Expand Up @@ -102,8 +105,15 @@ impl ExpressionExecutor {
} else if inst == "cast" {
InstKind::Cast(data_type.clone())
} else {
let inst_op = InstOp::from_str(&inst)?;
InstKind::Simple(inst_op)
// This assumes we'll never have an InstOp function name that
// matches a uuid, which should be safe.
if let Ok(uuid) = Uuid::from_str(&inst) {
let udf = udfs.get(&uuid).ok_or(anyhow::anyhow!("expected udf"))?;
InstKind::Udf(udf.clone())
} else {
let inst_op = InstOp::from_str(&inst)?;
InstKind::Simple(inst_op)
}
};

let static_info = StaticInfo::new(&inst_kind, args, &data_type);
Expand Down
2 changes: 2 additions & 0 deletions crates/sparrow-runtime/src/execute/operation/scan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,7 @@ mod tests {
use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit};
use arrow::record_batch::RecordBatch;
use futures::StreamExt;
use hashbrown::HashMap;
use itertools::Itertools;
use sparrow_api::kaskada::v1alpha::compute_table::FileSet;
use sparrow_api::kaskada::v1alpha::operation_input_ref::{self, Column};
Expand Down Expand Up @@ -509,6 +510,7 @@ mod tests {
output_at_time: None,
bounded_lateness_ns: None,
materialize: false,
udfs: HashMap::new(),
};

executor
Expand Down
3 changes: 3 additions & 0 deletions crates/sparrow-runtime/src/execute/operation/testing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::sync::Arc;
use anyhow::Context;
use arrow::datatypes::{DataType, Field, Schema, TimeUnit};
use arrow::record_batch::RecordBatch;
use hashbrown::HashMap;
use itertools::Itertools;
use sparrow_api::kaskada::v1alpha::{ComputePlan, OperationPlan};
use sparrow_compiler::DataContext;
Expand Down Expand Up @@ -189,6 +190,7 @@ pub(super) async fn run_operation(
output_at_time: None,
bounded_lateness_ns: None,
materialize: false,
udfs: HashMap::new(),
};
executor
.execute(
Expand Down Expand Up @@ -246,6 +248,7 @@ pub(super) async fn run_operation_json(
output_at_time: None,
bounded_lateness_ns: None,
materialize: false,
udfs: HashMap::new(),
};
executor
.execute(
Expand Down
1 change: 1 addition & 0 deletions crates/sparrow-session/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ The Sparrow session builder.

[dependencies]
arrow-array.workspace = true
hashbrown.workspace = true
arrow-schema.workspace = true
arrow-select.workspace = true
derive_more.workspace = true
Expand Down
18 changes: 13 additions & 5 deletions crates/sparrow-session/src/session.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use hashbrown::HashMap;
use std::borrow::Cow;
use std::collections::HashMap;
use std::sync::Arc;

use arrow_schema::SchemaRef;
Expand All @@ -25,6 +25,13 @@ pub struct Session {
data_context: DataContext,
dfg: Dfg,
key_hash_inverse: HashMap<GroupId, Arc<ThreadSafeKeyHashInverse>>,
/// Keeps track of the uuid mapping.
///
/// We currently do not serialize the `dyn Udf` into the plan, and instead
/// directly use this local mapping to look up the udf from the serialized
/// uuid. Once we run on multiple machines, we'll have to serialize/pickle the
/// udf as well.
udfs: HashMap<Uuid, Arc<dyn Udf>>,
jordanrfrazier marked this conversation as resolved.
Show resolved Hide resolved
}

#[derive(Default)]
Expand Down Expand Up @@ -291,12 +298,11 @@ impl Session {
///
/// This bypasses much of the plumbing of the [ExprOp] required due to our construction
/// of the AST.
#[allow(unused)]
fn add_udf_to_dfg(
pub fn add_udf_to_dfg(
&mut self,
udf: Arc<dyn Udf>,
args: Vec<Expr>,
) -> error_stack::Result<AstDfgRef, Error> {
) -> error_stack::Result<Expr, Error> {
let signature = udf.signature();
let arg_names = signature.arg_names().to_owned();
signature.assert_valid_argument_count(args.len());
Expand Down Expand Up @@ -334,7 +340,8 @@ impl Session {
.collect();
Err(Error::Errors(errors))?
} else {
Ok(result)
self.udfs.insert(*udf.uuid(), udf.clone());
Ok(Expr(result))
}
}

Expand Down Expand Up @@ -416,6 +423,7 @@ impl Session {
data_context,
options,
Some(key_hash_inverse),
self.udfs.clone(),
))
.change_context(Error::Execute)?
.map_err(|e| e.change_context(Error::Execute))
Expand Down
3 changes: 3 additions & 0 deletions python/.vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"editor.formatOnSave": true
}
6 changes: 6 additions & 0 deletions python/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 5 additions & 0 deletions python/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Python library for building and executing temporal queries.

[dependencies]
arrow = { version = "43.0.0", features = ["pyarrow"] }
anyhow = { version = "1.0.70", features = ["backtrace"] }
derive_more = "0.99.17"
error-stack = { version = "0.3.1", features = ["anyhow", "spantrace"] }
futures = "0.3.27"
Expand All @@ -23,8 +24,12 @@ mimalloc = { version = "0.1.37", default-features = false, features = ["local_dy
pyo3 = {version = "0.19.1", features = ["abi3-py38", "extension-module", "generate-import-lib"]}
pyo3-asyncio = { version = "0.19.0", features = ["tokio-runtime"] }
sparrow-session = { path = "../crates/sparrow-session" }
sparrow-instructions = { path = "../crates/sparrow-instructions" }
sparrow-runtime = { path = "../crates/sparrow-runtime" }
sparrow-syntax = { path = "../crates/sparrow-syntax" }
tokio = { version = "1.27.0", features = ["sync"] }
tracing = "0.1.37"
uuid = { version = "1.3.0", features = ["v4"] }

[lib]
name = "kaskada"
Expand Down
4 changes: 3 additions & 1 deletion python/pysrc/kaskada/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Kaskada query builder and local executon engine."""
"""Kaskada query builder and local execution engine."""
from __future__ import annotations

from . import plot
Expand All @@ -10,6 +10,7 @@
from ._timestream import Literal
from ._timestream import Timestream
from ._timestream import record
from .udf import udf


__all__ = [
Expand All @@ -21,5 +22,6 @@
"Result",
"sources",
"Timestream",
"udf",
"windows",
]
Loading
Loading