Skip to content

Commit acb8118

Browse files
conradsoonColin Ho
and
Colin Ho
authored
[FEAT]: Support .clip function (#3136)
Closes #1907. --------- Co-authored-by: Colin Ho <[email protected]>
1 parent 81b4306 commit acb8118

File tree

18 files changed

+548
-0
lines changed

18 files changed

+548
-0
lines changed

daft/daft/__init__.pyi

+2
Original file line numberDiff line numberDiff line change
@@ -1211,6 +1211,7 @@ class ConnectionHandle:
12111211
def abs(expr: PyExpr) -> PyExpr: ...
12121212
def cbrt(expr: PyExpr) -> PyExpr: ...
12131213
def ceil(expr: PyExpr) -> PyExpr: ...
1214+
def clip(expr: PyExpr, min: PyExpr, max: PyExpr) -> PyExpr: ...
12141215
def exp(expr: PyExpr) -> PyExpr: ...
12151216
def floor(expr: PyExpr) -> PyExpr: ...
12161217
def log2(expr: PyExpr) -> PyExpr: ...
@@ -1377,6 +1378,7 @@ class PySeries:
13771378
def floor(self) -> PySeries: ...
13781379
def sign(self) -> PySeries: ...
13791380
def round(self, decimal: int) -> PySeries: ...
1381+
def clip(self, min: PySeries, max: PySeries) -> PySeries: ...
13801382
def sqrt(self) -> PySeries: ...
13811383
def cbrt(self) -> PySeries: ...
13821384
def sin(self) -> PySeries: ...

daft/expressions/expressions.py

+12
Original file line numberDiff line numberDiff line change
@@ -622,6 +622,18 @@ def floor(self) -> Expression:
622622
expr = native.floor(self._expr)
623623
return Expression._from_pyexpr(expr)
624624

625+
def clip(self, min: Expression | None = None, max: Expression | None = None) -> Expression:
626+
"""Clips an expression to the given minimum and maximum values (``expr.clip(min, max)``).
627+
628+
Args:
629+
min: Minimum value to clip to. If None (or column value is Null), no lower clipping is applied.
630+
max: Maximum value to clip to. If None (or column value is Null), no upper clipping is applied.
631+
632+
"""
633+
min_expr = Expression._to_expression(min)
634+
max_expr = Expression._to_expression(max)
635+
return Expression._from_pyexpr(native.clip(self._expr, min_expr._expr, max_expr._expr))
636+
625637
def sign(self) -> Expression:
626638
"""The sign of a numeric expression (``expr.sign()``)"""
627639
expr = native.sign(self._expr)

daft/series.py

+3
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,9 @@ def sign(self) -> Series:
318318
def round(self, decimal: int) -> Series:
319319
return Series._from_pyseries(self._series.round(decimal))
320320

321+
def clip(self, min: Series, max: Series) -> Series:
322+
return Series._from_pyseries(self._series.clip(min._series, max._series))
323+
321324
def sqrt(self) -> Series:
322325
return Series._from_pyseries(self._series.sqrt())
323326

docs/source/api_docs/expressions.rst

+1
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ Numeric
5555
Expression.floor
5656
Expression.sign
5757
Expression.round
58+
Expression.clip
5859
Expression.sqrt
5960
Expression.cbrt
6061
Expression.sin

src/daft-core/src/array/ops/clip.rs

+142
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
use arrow2::array::PrimitiveArray;
2+
use common_error::{DaftError, DaftResult};
3+
use num_traits::{clamp, clamp_max, clamp_min};
4+
5+
use crate::{array::DataArray, datatypes::DaftNumericType, prelude::AsArrow};
6+
7+
impl<T> DataArray<T>
8+
where
9+
T: DaftNumericType,
10+
T::Native: PartialOrd,
11+
{
12+
/// Clips the values in the array to the provided left and right bounds.
13+
///
14+
/// # Arguments
15+
///
16+
/// * `left_bound` - The lower bound for clipping.
17+
/// * `right_bound` - The upper bound for clipping.
18+
///
19+
/// # Returns
20+
///
21+
/// * `DaftResult<Self>` - The clipped DataArray.
22+
pub fn clip(&self, left_bound: &Self, right_bound: &Self) -> DaftResult<Self> {
23+
match (self.len(), left_bound.len(), right_bound.len()) {
24+
// Case where all arrays have the same length
25+
(array_size, lbound_size, rbound_size)
26+
if array_size == lbound_size && array_size == rbound_size =>
27+
{
28+
let result = self
29+
.as_arrow()
30+
.values_iter() // Fine to use values_iter since we will apply the validity later, saves us 1 branch.
31+
.zip(left_bound.as_arrow().iter())
32+
.zip(right_bound.as_arrow().iter())
33+
.map(|((value, left), right)| match (left, right) {
34+
(Some(l), Some(r)) => Some(clamp(*value, *l, *r)),
35+
(Some(l), None) => Some(clamp_min(*value, *l)),
36+
(None, Some(r)) => Some(clamp_max(*value, *r)),
37+
(None, None) => Some(*value),
38+
});
39+
let result = PrimitiveArray::<T::Native>::from_trusted_len_iter(result);
40+
let data_array = Self::from((self.name(), Box::new(result)))
41+
.with_validity(self.validity().cloned())?;
42+
Ok(data_array)
43+
}
44+
// Case where left_bound has the same length as self and right_bound has length 1
45+
(array_size, lbound_size, 1) if array_size == lbound_size => {
46+
let right = right_bound.get(0);
47+
// We check the validity of right_bound here, since it has length 1.
48+
// This avoids a validity check in the clamp function
49+
match right {
50+
Some(r) => {
51+
// Right is valid, so we just clamp/clamp_max the values depending on the left bound
52+
let result = self
53+
.as_arrow()
54+
.values_iter()
55+
.zip(left_bound.as_arrow().iter())
56+
.map(move |(value, left)| match left {
57+
Some(l) => Some(clamp(*value, *l, r)),
58+
None => Some(clamp_max(*value, r)), // If left is null, we can just clamp_max
59+
});
60+
let result = PrimitiveArray::<T::Native>::from_trusted_len_iter(result);
61+
let data_array = Self::from((self.name(), Box::new(result)))
62+
.with_validity(self.validity().cloned())?;
63+
Ok(data_array)
64+
}
65+
None => {
66+
// In this case, right_bound is null, so we can just do a simple clamp_min
67+
let result = self
68+
.as_arrow()
69+
.values_iter()
70+
.zip(left_bound.as_arrow().iter())
71+
.map(|(value, left)| match left {
72+
Some(l) => Some(clamp_min(*value, *l)),
73+
None => Some(*value), // Left null, and right null, so we just don't do anything
74+
});
75+
let result = PrimitiveArray::<T::Native>::from_trusted_len_iter(result);
76+
let data_array = Self::from((self.name(), Box::new(result)))
77+
.with_validity(self.validity().cloned())?;
78+
Ok(data_array)
79+
}
80+
}
81+
}
82+
// Case where right_bound has the same length as self and left_bound has length 1
83+
(array_size, 1, rbound_size) if array_size == rbound_size => {
84+
let left = left_bound.get(0);
85+
match left {
86+
Some(l) => {
87+
let result = self
88+
.as_arrow()
89+
.values_iter()
90+
.zip(right_bound.as_arrow().iter())
91+
.map(move |(value, right)| match right {
92+
Some(r) => Some(clamp(*value, l, *r)),
93+
None => Some(clamp_min(*value, l)), // Right null, so we can just clamp_min
94+
});
95+
let result = PrimitiveArray::<T::Native>::from_trusted_len_iter(result);
96+
let data_array = Self::from((self.name(), Box::new(result)))
97+
.with_validity(self.validity().cloned())?;
98+
Ok(data_array)
99+
}
100+
None => {
101+
let result = self
102+
.as_arrow()
103+
.values_iter()
104+
.zip(right_bound.as_arrow().iter())
105+
.map(|(value, right)| match right {
106+
Some(r) => Some(clamp_max(*value, *r)),
107+
None => Some(*value),
108+
});
109+
let result = PrimitiveArray::<T::Native>::from_trusted_len_iter(result);
110+
let data_array = Self::from((self.name(), Box::new(result)))
111+
.with_validity(self.validity().cloned())?;
112+
Ok(data_array)
113+
}
114+
}
115+
}
116+
// Case where both left_bound and right_bound have length 1
117+
(_, 1, 1) => {
118+
let left = left_bound.get(0);
119+
let right = right_bound.get(0);
120+
match (left, right) {
121+
(Some(l), Some(r)) => self.apply(|value| clamp(value, l, r)),
122+
(Some(l), None) => self.apply(|value| clamp_min(value, l)),
123+
(None, Some(r)) => self.apply(|value| clamp_max(value, r)),
124+
(None, None) => {
125+
// Not doing anything here, so we can just return self
126+
Ok(self.clone())
127+
}
128+
}
129+
}
130+
// Handle incompatible lengths
131+
_ => Err(DaftError::ValueError(format!(
132+
"Unable to clip incompatible length arrays: {}: {}, {}: {}, {}: {}",
133+
self.name(),
134+
self.len(),
135+
left_bound.name(),
136+
left_bound.len(),
137+
right_bound.name(),
138+
right_bound.len()
139+
))),
140+
}
141+
}
142+
}

src/daft-core/src/array/ops/mod.rs

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ pub(crate) mod broadcast;
1212
pub(crate) mod cast;
1313
mod cbrt;
1414
mod ceil;
15+
mod clip;
1516
mod compare_agg;
1617
mod comparison;
1718
mod concat;

src/daft-core/src/datatypes/infer_datatype.rs

+43
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,49 @@ impl<'a> InferDataType<'a> {
6161
}
6262
}
6363

64+
pub fn clip_op(&self, min_infer_type: &Self, max_infer_type: &Self) -> DaftResult<DataType> {
65+
match (&self.0, &min_infer_type.0, &max_infer_type.0) {
66+
// Error cases first
67+
(input_type, _, _) if !input_type.is_numeric() => Err(DaftError::TypeError(format!(
68+
"Expected input to be numeric, got {}",
69+
input_type
70+
))),
71+
(_, min_type, _) if !min_type.is_numeric() && !min_type.is_null() => {
72+
Err(DaftError::TypeError(format!(
73+
"Expected min input to be numeric or null, got {}",
74+
min_type
75+
)))
76+
}
77+
(_, _, max_type) if !max_type.is_numeric() && !max_type.is_null() => {
78+
Err(DaftError::TypeError(format!(
79+
"Expected max input to be numeric or null, got {}",
80+
max_type
81+
)))
82+
}
83+
// Main logic for valid inputs
84+
(input_type, min_type, max_type) => {
85+
// This path gets called when the Python bindings pass in a Series, but note that there can still be nulls within the series.
86+
let mut output_type = (*input_type).clone();
87+
88+
// Check compatibility with min_infer_type
89+
if !min_type.is_null() {
90+
let (_, _, new_output_type) =
91+
InferDataType::from(&output_type).comparison_op(min_infer_type)?;
92+
output_type = new_output_type;
93+
}
94+
95+
// Check compatibility with max_infer_type
96+
if !max_type.is_null() {
97+
let (_, _, new_output_type) =
98+
InferDataType::from(&output_type).comparison_op(max_infer_type)?;
99+
output_type = new_output_type;
100+
}
101+
102+
Ok(output_type)
103+
}
104+
}
105+
}
106+
64107
pub fn comparison_op(
65108
&self,
66109
other: &Self,

src/daft-core/src/python/series.rs

+3
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,9 @@ impl PySeries {
149149
}
150150
Ok(self.series.round(decimal)?.into())
151151
}
152+
pub fn clip(&self, min: &Self, max: &Self) -> PyResult<Self> {
153+
Ok(self.series.clip(&min.series, &max.series)?.into())
154+
}
152155

153156
pub fn sqrt(&self) -> PyResult<Self> {
154157
Ok(self.series.sqrt()?.into())

src/daft-core/src/series/ops/clip.rs

+56
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
use common_error::{DaftError, DaftResult};
2+
3+
use crate::{
4+
datatypes::InferDataType,
5+
series::{IntoSeries, Series},
6+
with_match_numeric_daft_types,
7+
};
8+
9+
impl Series {
10+
/// Alias for .clip()
11+
pub fn clamp(&self, min: &Self, max: &Self) -> DaftResult<Self> {
12+
self.clip(min, max)
13+
}
14+
15+
/// Clip function to clamp values to a range
16+
pub fn clip(&self, min: &Self, max: &Self) -> DaftResult<Self> {
17+
let output_type = InferDataType::clip_op(
18+
&InferDataType::from(self.data_type()),
19+
&InferDataType::from(min.data_type()),
20+
&InferDataType::from(max.data_type()),
21+
)?;
22+
23+
// It's possible that we pass in something like .clamp(None, 2) on the Python binding side,
24+
// in which case we need to cast the None to the output type.
25+
let create_null_series = |name: &str| Self::full_null(name, &output_type, 1);
26+
let min = min
27+
.data_type()
28+
.is_null()
29+
.then(|| create_null_series(min.name()))
30+
.unwrap_or_else(|| min.clone());
31+
let max = max
32+
.data_type()
33+
.is_null()
34+
.then(|| create_null_series(max.name()))
35+
.unwrap_or_else(|| max.clone());
36+
37+
match &output_type {
38+
output_type if output_type.is_numeric() => {
39+
with_match_numeric_daft_types!(output_type, |$T| {
40+
let self_casted = self.cast(output_type)?;
41+
let min_casted = min.cast(output_type)?;
42+
let max_casted = max.cast(output_type)?;
43+
44+
let self_downcasted = self_casted.downcast::<<$T as DaftDataType>::ArrayType>()?;
45+
let min_downcasted = min_casted.downcast::<<$T as DaftDataType>::ArrayType>()?;
46+
let max_downcasted = max_casted.downcast::<<$T as DaftDataType>::ArrayType>()?;
47+
Ok(self_downcasted.clip(min_downcasted, max_downcasted)?.into_series())
48+
})
49+
}
50+
dt => Err(DaftError::TypeError(format!(
51+
"clip not implemented for {}",
52+
dt
53+
))),
54+
}
55+
}
56+
}

src/daft-core/src/series/ops/mod.rs

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ pub mod broadcast;
1111
pub mod cast;
1212
pub mod cbrt;
1313
pub mod ceil;
14+
pub mod clip;
1415
pub mod comparison;
1516
pub mod concat;
1617
pub mod downcast;

0 commit comments

Comments
 (0)