Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support intersect all and except distinct/all in DataFrame API #3537

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions daft/daft/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1634,6 +1634,7 @@ class LogicalPlanBuilder:
) -> LogicalPlanBuilder: ...
def concat(self, other: LogicalPlanBuilder) -> LogicalPlanBuilder: ...
def intersect(self, other: LogicalPlanBuilder, is_all: bool) -> LogicalPlanBuilder: ...
def except_(self, other: LogicalPlanBuilder, is_all: bool) -> LogicalPlanBuilder: ...
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if there's a better name for this, it would be really appreciated.

def add_monotonically_increasing_id(self, column_name: str | None) -> LogicalPlanBuilder: ...
def table_write(
self,
Expand Down
88 changes: 88 additions & 0 deletions daft/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2542,6 +2542,94 @@ def intersect(self, other: "DataFrame") -> "DataFrame":
builder = self._builder.intersect(other._builder)
return DataFrame(builder)

@DataframePublicAPI
def intersect_all(self, other: "DataFrame") -> "DataFrame":
"""Returns the intersection of two DataFrames, including duplicates.

Example:
>>> import daft
>>> df1 = daft.from_pydict({"a": [1, 2, 2], "b": [4, 6, 6]})
>>> df2 = daft.from_pydict({"a": [1, 1, 2, 2], "b": [4, 4, 6, 6]})
>>> df1.intersect_all(df2).collect()
╭───────┬───────╮
│ a ┆ b │
│ --- ┆ --- │
│ Int64 ┆ Int64 │
╞═══════╪═══════╡
│ 1 ┆ 4 │
├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ 2 ┆ 6 │
├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ 2 ┆ 6 │
╰───────┴───────╯
<BLANKLINE>
(Showing first 3 of 3 rows)

Args:
other (DataFrame): DataFrame to intersect with

Returns:
DataFrame: DataFrame with the intersection of the two DataFrames, including duplicates
"""
builder = self._builder.intersect_all(other._builder)
return DataFrame(builder)

@DataframePublicAPI
def except_distinct(self, other: "DataFrame") -> "DataFrame":
"""Returns the set difference of two DataFrames.

Example:
>>> import daft
>>> df1 = daft.from_pydict({"a": [1, 2, 3], "b": [4, 5, 6]})
>>> df2 = daft.from_pydict({"a": [1, 2, 3], "b": [4, 8, 6]})
>>> df1.except_distinct(df2).collect()
╭───────┬───────╮
│ a ┆ b │
│ --- ┆ --- │
│ Int64 ┆ Int64 │
╞═══════╪═══════╡
│ 2 ┆ 5 │
╰───────┴───────╯
<BLANKLINE>
(Showing first 1 of 1 rows)

Args:
other (DataFrame): DataFrame to except with

Returns:
DataFrame: DataFrame with the set difference of the two DataFrames
"""
builder = self._builder.except_distinct(other._builder)
return DataFrame(builder)

@DataframePublicAPI
def except_all(self, other: "DataFrame") -> "DataFrame":
"""Returns the set difference of two DataFrames, considering duplicates.

Example:
>>> import daft
>>> df1 = daft.from_pydict({"a": [1, 1, 2, 2], "b": [4, 4, 6, 6]})
>>> df2 = daft.from_pydict({"a": [1, 2, 2], "b": [4, 6, 6]})
>>> df1.except_all(df2).collect()
╭───────┬───────╮
│ a ┆ b │
│ --- ┆ --- │
│ Int64 ┆ Int64 │
╞═══════╪═══════╡
│ 1 ┆ 4 │
╰───────┴───────╯
<BLANKLINE>
(Showing first 1 of 1 rows)

Args:
other (DataFrame): DataFrame to except with

Returns:
DataFrame: DataFrame with the set difference of the two DataFrames, considering duplicates
"""
builder = self._builder.except_all(other._builder)
return DataFrame(builder)

def _materialize_results(self) -> None:
"""Materializes the results of for this DataFrame and hold a pointer to the results."""
context = get_context()
Expand Down
12 changes: 12 additions & 0 deletions daft/logical/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,18 @@ def intersect(self, other: LogicalPlanBuilder) -> LogicalPlanBuilder:
builder = self._builder.intersect(other._builder, False)
return LogicalPlanBuilder(builder)

def intersect_all(self, other: LogicalPlanBuilder) -> LogicalPlanBuilder:
builder = self._builder.intersect(other._builder, True)
return LogicalPlanBuilder(builder)

def except_distinct(self, other: LogicalPlanBuilder) -> LogicalPlanBuilder:
builder = self._builder.except_(other._builder, False)
return LogicalPlanBuilder(builder)

def except_all(self, other: LogicalPlanBuilder) -> LogicalPlanBuilder:
builder = self._builder.except_(other._builder, True)
return LogicalPlanBuilder(builder)

def add_monotonically_increasing_id(self, column_name: str | None) -> LogicalPlanBuilder:
builder = self._builder.add_monotonically_increasing_id(column_name)
return LogicalPlanBuilder(builder)
Expand Down
46 changes: 45 additions & 1 deletion src/daft-core/src/array/ops/list.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::{iter::repeat, sync::Arc};

use arrow2::offset::OffsetsBuffer;
use arrow2::offset::{Offsets, OffsetsBuffer};
use common_error::DaftResult;
use indexmap::{
map::{raw_entry_v1::RawEntryMut, RawEntryApiV1},
Expand Down Expand Up @@ -255,6 +255,31 @@ fn list_sort_helper_fixed_size(
.collect()
}

fn general_list_fill_helper(element: &Series, num_array: &Int64Array) -> DaftResult<Vec<Series>> {
let num_iter = create_iter(num_array, element.len());
let mut result = vec![];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we preallocate the capacity here?

let mut result = Vec::with_capacity(...)

let element_data = element.as_physical()?;
for (row_index, num) in num_iter.enumerate() {
let list_arr = if element.is_valid(row_index) {
let mut list_growable = make_growable(
element.name(),
element.data_type(),
vec![&element_data],
false,
num as usize,
);
for _ in 0..num {
list_growable.extend(0, row_index, 1);
}
list_growable.build()?
} else {
Series::full_null(element.name(), element.data_type(), num as usize)
};
result.push(list_arr);
}
Ok(result)
}

impl ListArray {
pub fn value_counts(&self) -> DaftResult<MapArray> {
struct IndexRef {
Expand Down Expand Up @@ -625,6 +650,25 @@ impl ListArray {
self.validity().cloned(),
))
}

pub fn list_fill(elem: &Series, num_array: &Int64Array) -> DaftResult<Self> {
let generated = general_list_fill_helper(elem, num_array)?;
let generated_refs: Vec<&Series> = generated.iter().collect();
let lengths = generated.iter().map(|arr| arr.len());
let offsets = Offsets::try_from_lengths(lengths)?;
let flat_child = if generated_refs.is_empty() {
// when there's no output, we should create an empty series
Series::empty(elem.name(), elem.data_type())
} else {
Series::concat(&generated_refs)?
};
Ok(Self::new(
elem.field().to_list_field()?,
flat_child,
offsets.into(),
None,
))
}
}

impl FixedSizeListArray {
Expand Down
13 changes: 12 additions & 1 deletion src/daft-core/src/series/ops/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@ use common_error::{DaftError, DaftResult};
use daft_schema::field::Field;

use crate::{
array::ListArray,
datatypes::{DataType, UInt64Array, Utf8Array},
prelude::CountMode,
prelude::{CountMode, Int64Array},
series::{IntoSeries, Series},
};

Expand Down Expand Up @@ -217,4 +218,14 @@ impl Series {
))),
}
}

/// Given a series of data T, repeat each data T with num times to create a list, returns
/// a series of repeated list.
/// # Example
/// ```txt
/// repeat([1, 2, 3], [2, 0, 1]) --> [[1, 1], [], [3]]
/// ```
pub fn list_fill(&self, num: &Int64Array) -> DaftResult<Self> {
ListArray::list_fill(self, num).map(|arr| arr.into_series())
}
}
63 changes: 63 additions & 0 deletions src/daft-functions/src/list/list_fill.rs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we add a couple tests for listfill in this file.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated, PTAL.

Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
use common_error::{DaftError, DaftResult};
use daft_core::{
datatypes::{DataType, Field},
prelude::{Schema, Series},
};
use daft_dsl::{
functions::{ScalarFunction, ScalarUDF},
ExprRef,
};
use serde::{Deserialize, Serialize};

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub struct ListFill {}

#[typetag::serde]
impl ScalarUDF for ListFill {
fn as_any(&self) -> &dyn std::any::Any {
self
}

Check warning on line 19 in src/daft-functions/src/list/list_fill.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-functions/src/list/list_fill.rs#L17-L19

Added lines #L17 - L19 were not covered by tests

fn name(&self) -> &'static str {
"fill"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
"fill"
"list_fill"

}

fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult<Field> {
match inputs {
[n, elem] => {
let num_field = n.to_field(schema)?;
let elem_field = elem.to_field(schema)?;
if !num_field.dtype.is_integer() {
return Err(DaftError::TypeError(format!(
"Expected num field to be of numeric type, received: {}",
num_field.dtype
)));

Check warning on line 34 in src/daft-functions/src/list/list_fill.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-functions/src/list/list_fill.rs#L31-L34

Added lines #L31 - L34 were not covered by tests
}
elem_field.to_list_field()
}
_ => Err(DaftError::SchemaMismatch(format!(
"Expected 2 input args, got {}",
inputs.len()
))),

Check warning on line 41 in src/daft-functions/src/list/list_fill.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-functions/src/list/list_fill.rs#L38-L41

Added lines #L38 - L41 were not covered by tests
}
}

fn evaluate(&self, inputs: &[Series]) -> DaftResult<Series> {
match inputs {
[num, elem] => {
let num = num.cast(&DataType::Int64)?;
let num_array = num.i64()?;
elem.list_fill(num_array)
}
_ => Err(DaftError::ValueError(format!(
"Expected 2 input args, got {}",
inputs.len()
))),

Check warning on line 55 in src/daft-functions/src/list/list_fill.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-functions/src/list/list_fill.rs#L52-L55

Added lines #L52 - L55 were not covered by tests
}
}
}

#[must_use]
pub fn list_fill(n: ExprRef, elem: ExprRef) -> ExprRef {
ScalarFunction::new(ListFill {}, vec![n, elem]).into()
}
2 changes: 2 additions & 0 deletions src/daft-functions/src/list/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ mod count;
mod explode;
mod get;
mod join;
mod list_fill;
mod max;
mod mean;
mod min;
Expand All @@ -17,6 +18,7 @@ pub use count::{list_count as count, ListCount};
pub use explode::{explode, Explode};
pub use get::{list_get as get, ListGet};
pub use join::{list_join as join, ListJoin};
pub use list_fill::list_fill;
pub use max::{list_max as max, ListMax};
pub use mean::{list_mean as mean, ListMean};
pub use min::{list_min as min, ListMin};
Expand Down
15 changes: 14 additions & 1 deletion src/daft-logical-plan/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -482,9 +482,17 @@ impl LogicalPlanBuilder {
pub fn intersect(&self, other: &Self, is_all: bool) -> DaftResult<Self> {
let logical_plan: LogicalPlan =
ops::Intersect::try_new(self.plan.clone(), other.plan.clone(), is_all)?
.to_optimized_join()?;
.to_logical_plan()?;
Ok(self.with_new_plan(logical_plan))
}

pub fn except(&self, other: &Self, is_all: bool) -> DaftResult<Self> {
let logical_plan: LogicalPlan =
ops::Except::try_new(self.plan.clone(), other.plan.clone(), is_all)?
.to_logical_plan()?;
Ok(self.with_new_plan(logical_plan))
}

pub fn union(&self, other: &Self, is_all: bool) -> DaftResult<Self> {
let logical_plan: LogicalPlan =
ops::Union::try_new(self.plan.clone(), other.plan.clone(), is_all)?
Expand Down Expand Up @@ -861,6 +869,11 @@ impl PyLogicalPlanBuilder {
Ok(self.builder.intersect(&other.builder, is_all)?.into())
}

#[pyo3(name = "except_")]
pub fn except(&self, other: &Self, is_all: bool) -> DaftResult<Self> {
Ok(self.builder.except(&other.builder, is_all)?.into())
}

pub fn add_monotonically_increasing_id(&self, column_name: Option<&str>) -> PyResult<Self> {
Ok(self
.builder
Expand Down
2 changes: 1 addition & 1 deletion src/daft-logical-plan/src/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ pub use pivot::Pivot;
pub use project::Project;
pub use repartition::Repartition;
pub use sample::Sample;
pub use set_operations::{Intersect, Union};
pub use set_operations::{Except, Intersect, Union};
pub use sink::Sink;
pub use sort::Sort;
pub use source::Source;
Expand Down
Loading
Loading