diff --git a/ibis/expr/types/arrays.py b/ibis/expr/types/arrays.py index 9eb6f941036e..6d526061bf5a 100644 --- a/ibis/expr/types/arrays.py +++ b/ibis/expr/types/arrays.py @@ -445,15 +445,19 @@ def map(self, func: Deferred | Callable[[ir.Value], ir.Value]) -> ir.ArrayValue: """ if isinstance(func, Deferred): name = "_" - else: + resolve = func.resolve + elif callable(func): name = next(iter(inspect.signature(func).parameters.keys())) + resolve = func + else: + raise TypeError( + f"`func` must be a Deferred or Callable, got `{type(func).__name__}`" + ) + parameter = ops.Argument( name=name, shape=self.op().shape, dtype=self.type().value_type ) - if isinstance(func, Deferred): - body = func.resolve(parameter.to_expr()) - else: - body = func(parameter.to_expr()) + body = resolve(parameter.to_expr()) return ops.ArrayMap(self, param=parameter.param, body=body).to_expr() def filter( @@ -545,17 +549,20 @@ def filter( """ if isinstance(predicate, Deferred): name = "_" - else: + resolve = predicate.resolve + elif callable(predicate): name = next(iter(inspect.signature(predicate).parameters.keys())) + resolve = predicate + else: + raise TypeError( + f"`predicate` must be a Deferred or Callable, got `{type(predicate).__name__}`" + ) parameter = ops.Argument( name=name, shape=self.op().shape, dtype=self.type().value_type, ) - if isinstance(predicate, Deferred): - body = predicate.resolve(parameter.to_expr()) - else: - body = predicate(parameter.to_expr()) + body = resolve(parameter.to_expr()) return ops.ArrayFilter(self, param=parameter.param, body=body).to_expr() def contains(self, other: ir.Value) -> ir.BooleanValue: diff --git a/ibis/tests/expr/test_value_exprs.py b/ibis/tests/expr/test_value_exprs.py index aec143083ff5..140cf4e031ba 100644 --- a/ibis/tests/expr/test_value_exprs.py +++ b/ibis/tests/expr/test_value_exprs.py @@ -1533,52 +1533,36 @@ def test_array_length_scalar(): assert isinstance(expr.op(), ops.ArrayLength) -def double_int(x): - return x * 2 - - -def double_float(x): - return x * 2.0 - - -def is_negative(x): - return x < 0 - - def test_array_map(): arr = ibis.array([1, 2, 3]) - result_int = arr.map(double_int) - result_float = arr.map(double_float) + r1 = arr.map(_ * 2) + r2 = arr.map(lambda x: x * 2.0) + r3 = arr.map(functools.partial(lambda a, b: a + b, b=2)) - assert result_int.type() == dt.Array(dt.int16) - assert result_float.type() == dt.Array(dt.float64) + assert r1.type() == dt.Array(dt.int16) + assert r2.type() == dt.Array(dt.float64) + assert r3.type() == dt.Array(dt.int16) - -def test_array_map_partial(): - arr = ibis.array([1, 2, 3]) - - def add(x, y): - return x + y - - result = arr.map(functools.partial(add, y=2)) - assert result.type() == dt.Array(dt.int16) + with pytest.raises(TypeError, match="must be a Deferred or Callable"): + # Non-deferred expressions aren't allowed + arr.map(arr[0]) def test_array_filter(): arr = ibis.array([1, 2, 3]) - result = arr.filter(is_negative) - assert result.type() == arr.type() - -def test_array_filter_partial(): - arr = ibis.array([1, 2, 3]) + r1 = arr.filter(lambda x: x < 0) + r2 = arr.filter(_ < 0) + r3 = arr.filter(functools.partial(lambda a, b: a == b, b=2)) - def equal(x, y): - return x == y + assert r1.type() == arr.type() + assert r2.type() == arr.type() + assert r3.type() == arr.type() - result = arr.filter(functools.partial(equal, y=2)) - assert result.type() == arr.type() + with pytest.raises(TypeError, match="must be a Deferred or Callable"): + # Non-deferred expressions aren't allowed + arr.filter(arr[0]) @pytest.mark.parametrize(