Skip to content

Commit

Permalink
feat: pyspark and duckdb Expr.name namespace (#1809)
Browse files Browse the repository at this point in the history
* feat: pyspark name namespace

* duckdb

* use anonymousexprerror

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: Marco Edward Gorelli <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Jan 17, 2025
1 parent fb1d34e commit 0bcb500
Show file tree
Hide file tree
Showing 11 changed files with 327 additions and 94 deletions.
5 changes: 5 additions & 0 deletions narwhals/_duckdb/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Sequence

from narwhals._duckdb.expr_dt import DuckDBExprDateTimeNamespace
from narwhals._duckdb.expr_name import DuckDBExprNameNamespace
from narwhals._duckdb.expr_str import DuckDBExprStringNamespace
from narwhals._duckdb.utils import binary_operation_returns_scalar
from narwhals._duckdb.utils import get_column_name
Expand Down Expand Up @@ -571,3 +572,7 @@ def str(self: Self) -> DuckDBExprStringNamespace:
@property
def dt(self: Self) -> DuckDBExprDateTimeNamespace:
return DuckDBExprDateTimeNamespace(self)

@property
def name(self: Self) -> DuckDBExprNameNamespace:
return DuckDBExprNameNamespace(self)
149 changes: 149 additions & 0 deletions narwhals/_duckdb/expr_name.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import Callable

from narwhals.exceptions import AnonymousExprError

if TYPE_CHECKING:
from typing_extensions import Self

from narwhals._duckdb.expr import DuckDBExpr


class DuckDBExprNameNamespace:
def __init__(self: Self, expr: DuckDBExpr) -> None:
self._compliant_expr = expr

def keep(self: Self) -> DuckDBExpr:
root_names = self._compliant_expr._root_names
if root_names is None:
msg = ".name.keep"
raise AnonymousExprError.from_expr_name(msg)
return self._compliant_expr.__class__(
lambda df: [
expr.alias(name)
for expr, name in zip(self._compliant_expr._call(df), root_names)
],
depth=self._compliant_expr._depth,
function_name=self._compliant_expr._function_name,
root_names=root_names,
output_names=root_names,
returns_scalar=self._compliant_expr._returns_scalar,
backend_version=self._compliant_expr._backend_version,
version=self._compliant_expr._version,
kwargs=self._compliant_expr._kwargs,
)

def map(self: Self, function: Callable[[str], str]) -> DuckDBExpr:
root_names = self._compliant_expr._root_names
if root_names is None:
msg = ".name.map"
raise AnonymousExprError.from_expr_name(msg)

output_names = [function(str(name)) for name in root_names]

return self._compliant_expr.__class__(
lambda df: [
expr.alias(name)
for expr, name in zip(self._compliant_expr._call(df), output_names)
],
depth=self._compliant_expr._depth,
function_name=self._compliant_expr._function_name,
root_names=root_names,
output_names=output_names,
returns_scalar=self._compliant_expr._returns_scalar,
backend_version=self._compliant_expr._backend_version,
version=self._compliant_expr._version,
kwargs={**self._compliant_expr._kwargs, "function": function},
)

def prefix(self: Self, prefix: str) -> DuckDBExpr:
root_names = self._compliant_expr._root_names
if root_names is None:
msg = ".name.prefix"
raise AnonymousExprError.from_expr_name(msg)

output_names = [prefix + str(name) for name in root_names]
return self._compliant_expr.__class__(
lambda df: [
expr.alias(name)
for expr, name in zip(self._compliant_expr._call(df), output_names)
],
depth=self._compliant_expr._depth,
function_name=self._compliant_expr._function_name,
root_names=root_names,
output_names=output_names,
returns_scalar=self._compliant_expr._returns_scalar,
backend_version=self._compliant_expr._backend_version,
version=self._compliant_expr._version,
kwargs={**self._compliant_expr._kwargs, "prefix": prefix},
)

def suffix(self: Self, suffix: str) -> DuckDBExpr:
root_names = self._compliant_expr._root_names
if root_names is None:
msg = ".name.suffix"
raise AnonymousExprError.from_expr_name(msg)

output_names = [str(name) + suffix for name in root_names]

return self._compliant_expr.__class__(
lambda df: [
expr.alias(name)
for expr, name in zip(self._compliant_expr._call(df), output_names)
],
depth=self._compliant_expr._depth,
function_name=self._compliant_expr._function_name,
root_names=root_names,
output_names=output_names,
returns_scalar=self._compliant_expr._returns_scalar,
backend_version=self._compliant_expr._backend_version,
version=self._compliant_expr._version,
kwargs={**self._compliant_expr._kwargs, "suffix": suffix},
)

def to_lowercase(self: Self) -> DuckDBExpr:
root_names = self._compliant_expr._root_names
if root_names is None:
msg = ".name.to_lowercase"
raise AnonymousExprError.from_expr_name(msg)

output_names = [str(name).lower() for name in root_names]

return self._compliant_expr.__class__(
lambda df: [
expr.alias(name)
for expr, name in zip(self._compliant_expr._call(df), output_names)
],
depth=self._compliant_expr._depth,
function_name=self._compliant_expr._function_name,
root_names=root_names,
output_names=output_names,
returns_scalar=self._compliant_expr._returns_scalar,
backend_version=self._compliant_expr._backend_version,
version=self._compliant_expr._version,
kwargs=self._compliant_expr._kwargs,
)

def to_uppercase(self: Self) -> DuckDBExpr:
root_names = self._compliant_expr._root_names
if root_names is None:
msg = ".name.to_uppercase"
raise AnonymousExprError.from_expr_name(msg)
output_names = [str(name).upper() for name in root_names]

return self._compliant_expr.__class__(
lambda df: [
expr.alias(name)
for expr, name in zip(self._compliant_expr._call(df), output_names)
],
depth=self._compliant_expr._depth,
function_name=self._compliant_expr._function_name,
root_names=root_names,
output_names=output_names,
returns_scalar=self._compliant_expr._returns_scalar,
backend_version=self._compliant_expr._backend_version,
version=self._compliant_expr._version,
kwargs=self._compliant_expr._kwargs,
)
5 changes: 5 additions & 0 deletions narwhals/_spark_like/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Sequence

from narwhals._expression_parsing import infer_new_root_output_names
from narwhals._spark_like.expr_name import SparkLikeExprNameNamespace
from narwhals._spark_like.expr_str import SparkLikeExprStringNamespace
from narwhals._spark_like.utils import get_column_name
from narwhals._spark_like.utils import maybe_evaluate
Expand Down Expand Up @@ -499,3 +500,7 @@ def is_null(self: Self) -> Self:
@property
def str(self: Self) -> SparkLikeExprStringNamespace:
return SparkLikeExprStringNamespace(self)

@property
def name(self: Self) -> SparkLikeExprNameNamespace:
return SparkLikeExprNameNamespace(self)
149 changes: 149 additions & 0 deletions narwhals/_spark_like/expr_name.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import Callable

from narwhals.exceptions import AnonymousExprError

if TYPE_CHECKING:
from typing_extensions import Self

from narwhals._spark_like.expr import SparkLikeExpr


class SparkLikeExprNameNamespace:
def __init__(self: Self, expr: SparkLikeExpr) -> None:
self._compliant_expr = expr

def keep(self: Self) -> SparkLikeExpr:
root_names = self._compliant_expr._root_names
if root_names is None:
msg = ".name.keep"
raise AnonymousExprError.from_expr_name(msg)

return self._compliant_expr.__class__(
lambda df: [
expr.alias(name)
for expr, name in zip(self._compliant_expr._call(df), root_names)
],
depth=self._compliant_expr._depth,
function_name=self._compliant_expr._function_name,
root_names=root_names,
output_names=root_names,
returns_scalar=self._compliant_expr._returns_scalar,
backend_version=self._compliant_expr._backend_version,
version=self._compliant_expr._version,
kwargs=self._compliant_expr._kwargs,
)

def map(self: Self, function: Callable[[str], str]) -> SparkLikeExpr:
root_names = self._compliant_expr._root_names
if root_names is None:
msg = ".name.map"
raise AnonymousExprError.from_expr_name(msg)

output_names = [function(str(name)) for name in root_names]

return self._compliant_expr.__class__(
lambda df: [
expr.alias(name)
for expr, name in zip(self._compliant_expr._call(df), output_names)
],
depth=self._compliant_expr._depth,
function_name=self._compliant_expr._function_name,
root_names=root_names,
output_names=output_names,
returns_scalar=self._compliant_expr._returns_scalar,
backend_version=self._compliant_expr._backend_version,
version=self._compliant_expr._version,
kwargs={**self._compliant_expr._kwargs, "function": function},
)

def prefix(self: Self, prefix: str) -> SparkLikeExpr:
root_names = self._compliant_expr._root_names
if root_names is None:
msg = ".name.prefix"
raise AnonymousExprError.from_expr_name(msg)

output_names = [prefix + str(name) for name in root_names]
return self._compliant_expr.__class__(
lambda df: [
expr.alias(name)
for expr, name in zip(self._compliant_expr._call(df), output_names)
],
depth=self._compliant_expr._depth,
function_name=self._compliant_expr._function_name,
root_names=root_names,
output_names=output_names,
returns_scalar=self._compliant_expr._returns_scalar,
backend_version=self._compliant_expr._backend_version,
version=self._compliant_expr._version,
kwargs={**self._compliant_expr._kwargs, "prefix": prefix},
)

def suffix(self: Self, suffix: str) -> SparkLikeExpr:
root_names = self._compliant_expr._root_names
if root_names is None:
msg = ".name.suffix"
raise AnonymousExprError.from_expr_name(msg)

output_names = [str(name) + suffix for name in root_names]

return self._compliant_expr.__class__(
lambda df: [
expr.alias(name)
for expr, name in zip(self._compliant_expr._call(df), output_names)
],
depth=self._compliant_expr._depth,
function_name=self._compliant_expr._function_name,
root_names=root_names,
output_names=output_names,
returns_scalar=self._compliant_expr._returns_scalar,
backend_version=self._compliant_expr._backend_version,
version=self._compliant_expr._version,
kwargs={**self._compliant_expr._kwargs, "suffix": suffix},
)

def to_lowercase(self: Self) -> SparkLikeExpr:
root_names = self._compliant_expr._root_names
if root_names is None:
msg = ".name.to_lowercase"
raise AnonymousExprError.from_expr_name(msg)
output_names = [str(name).lower() for name in root_names]

return self._compliant_expr.__class__(
lambda df: [
expr.alias(name)
for expr, name in zip(self._compliant_expr._call(df), output_names)
],
depth=self._compliant_expr._depth,
function_name=self._compliant_expr._function_name,
root_names=root_names,
output_names=output_names,
returns_scalar=self._compliant_expr._returns_scalar,
backend_version=self._compliant_expr._backend_version,
version=self._compliant_expr._version,
kwargs=self._compliant_expr._kwargs,
)

def to_uppercase(self: Self) -> SparkLikeExpr:
root_names = self._compliant_expr._root_names
if root_names is None:
msg = ".name.to_uppercase"
raise AnonymousExprError.from_expr_name(msg)
output_names = [str(name).upper() for name in root_names]

return self._compliant_expr.__class__(
lambda df: [
expr.alias(name)
for expr, name in zip(self._compliant_expr._call(df), output_names)
],
depth=self._compliant_expr._depth,
function_name=self._compliant_expr._function_name,
root_names=root_names,
output_names=output_names,
returns_scalar=self._compliant_expr._returns_scalar,
backend_version=self._compliant_expr._backend_version,
version=self._compliant_expr._version,
kwargs=self._compliant_expr._kwargs,
)
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def pytest_generate_tests(metafunc: pytest.Metafunc) -> None:
if (
any(
x in str(metafunc.module)
for x in ("list", "name", "unpivot", "from_dict", "from_numpy", "tail")
for x in ("list", "unpivot", "from_dict", "from_numpy", "tail")
)
and LAZY_CONSTRUCTORS["duckdb"] in constructors
):
Expand Down
19 changes: 3 additions & 16 deletions tests/expr_and_series/name/keep_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,34 +13,21 @@
data = {"foo": [1, 2, 3], "BAR": [4, 5, 6]}


def test_keep(request: pytest.FixtureRequest, constructor: Constructor) -> None:
if "pyspark" in str(constructor):
request.applymarker(pytest.mark.xfail)

def test_keep(constructor: Constructor) -> None:
df = nw.from_native(constructor(data))
result = df.select((nw.col("foo", "BAR") * 2).name.keep())
expected = {k: [e * 2 for e in v] for k, v in data.items()}
assert_equal_data(result, expected)


def test_keep_after_alias(
request: pytest.FixtureRequest, constructor: Constructor
) -> None:
if "pyspark" in str(constructor):
request.applymarker(pytest.mark.xfail)

def test_keep_after_alias(constructor: Constructor) -> None:
df = nw.from_native(constructor(data))
result = df.select((nw.col("foo")).alias("alias_for_foo").name.keep())
expected = {"foo": data["foo"]}
assert_equal_data(result, expected)


def test_keep_raise_anonymous(
request: pytest.FixtureRequest, constructor: Constructor
) -> None:
if "pyspark" in str(constructor):
request.applymarker(pytest.mark.xfail)

def test_keep_raise_anonymous(constructor: Constructor) -> None:
df_raw = constructor(data)
df = nw.from_native(df_raw)

Expand Down
Loading

0 comments on commit 0bcb500

Please sign in to comment.