Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(rust, python): Add cat.to_local() expression #15090

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions crates/polars-plan/src/dsl/cat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,8 @@ impl CategoricalNameSpace {
self.0
.apply_private(CategoricalFunction::GetCategories.into())
}

pub fn to_local(self) -> Expr {
self.0.map_private(CategoricalFunction::ToLocal.into())
}
}
8 changes: 8 additions & 0 deletions crates/polars-plan/src/dsl/function_expr/cat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@ use crate::map;
#[derive(Clone, PartialEq, Debug, Eq, Hash)]
pub enum CategoricalFunction {
GetCategories,
ToLocal,
}

impl CategoricalFunction {
pub(super) fn get_field(&self, mapper: FieldsMapper) -> PolarsResult<Field> {
use CategoricalFunction::*;
match self {
GetCategories => mapper.with_dtype(DataType::String),
ToLocal => mapper.map_global_cat_to_local(),
}
}
}
Expand All @@ -21,6 +23,7 @@ impl Display for CategoricalFunction {
use CategoricalFunction::*;
let s = match self {
GetCategories => "get_categories",
ToLocal => "to_local",
};
write!(f, "cat.{s}")
}
Expand All @@ -31,6 +34,7 @@ impl From<CategoricalFunction> for SpecialEq<Arc<dyn SeriesUdf>> {
use CategoricalFunction::*;
match func {
GetCategories => map!(get_categories),
ToLocal => map!(to_local),
}
}
}
Expand All @@ -48,3 +52,7 @@ fn get_categories(s: &Series) -> PolarsResult<Series> {
let arr = rev_map.get_categories().clone().boxed();
Series::try_from((ca.name(), arr))
}

fn to_local(s: &Series) -> PolarsResult<Series> {
Ok(s.categorical()?.to_local().into_series())
}
17 changes: 17 additions & 0 deletions crates/polars-plan/src/dsl/function_expr/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,23 @@ impl<'a> FieldsMapper<'a> {
self.map_dtype(|dtype| dtype.to_physical())
}

#[cfg(feature = "dtype-categorical")]
/// Map a global categorical to local
pub fn map_global_cat_to_local(&self) -> PolarsResult<Field> {
self.map_dtype(|dtype| match dtype {
DataType::Categorical(rev_map, ordering) => DataType::Categorical(
rev_map.clone().map(|rm| match &*rm {
RevMapping::Global(..) => {
Arc::new(RevMapping::build_local(rm.get_categories().clone()))
},
_ => rm,
}),
*ordering,
),
_ => dtype.clone(),
})
}

/// Map a single dtype with a potentially failing mapper function.
pub fn try_map_dtype(
&self,
Expand Down
41 changes: 41 additions & 0 deletions py-polars/polars/expr/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,44 @@ def get_categories(self) -> Expr:
└──────┘
"""
return wrap_expr(self._pyexpr.cat_get_categories())

def to_local(self) -> Expr:
"""
Convert a categorical column to its local representation.

This may change the underlying physical representation of the column.

See the documentation of :func:`StringCache` for more information on the
difference between local and global categoricals.

Examples
--------
Compare the global and local representations of a categorical.

>>> with pl.StringCache():
... _ = pl.Series("x", ["a", "b", "a"], dtype=pl.Categorical)
... df = pl.Series("y", ["c", "b", "d"], dtype=pl.Categorical).to_frame()
>>> df.select(pl.col("y").to_physical())
shape: (3, 1)
┌─────┐
│ y │
│ --- │
│ u32 │
╞═════╡
│ 2 │
│ 1 │
│ 3 │
└─────┘
>>> df.select(pl.col("y").cat.to_local().to_physical())
shape: (3, 1)
┌─────┐
│ y │
│ --- │
│ u32 │
╞═════╡
│ 0 │
│ 1 │
│ 2 │
└─────┘
"""
return wrap_expr(self._pyexpr.cat_to_local())
4 changes: 4 additions & 0 deletions py-polars/src/expr/categorical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,8 @@ impl PyExpr {
fn cat_get_categories(&self) -> Self {
self.inner.clone().cat().get_categories().into()
}

fn cat_to_local(&self) -> Self {
self.inner.clone().cat().to_local().into()
}
}
86 changes: 82 additions & 4 deletions py-polars/tests/unit/operations/namespaces/test_categorical.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from collections import OrderedDict

import polars as pl
from polars import StringCache
from polars.testing import assert_frame_equal


Expand Down Expand Up @@ -81,7 +84,7 @@ def test_categorical_get_categories() -> None:
).cat.get_categories().to_list() == ["foo", "bar", "ham"]


def test_cat_to_local() -> None:
def test_cat_series_to_local() -> None:
with pl.StringCache():
s1 = pl.Series(["a", "b", "a"], dtype=pl.Categorical)
s2 = pl.Series(["c", "b", "d"], dtype=pl.Categorical)
Expand All @@ -103,7 +106,36 @@ def test_cat_to_local() -> None:
assert s2.to_list() == ["c", "b", "d"]


def test_cat_to_local_missing_values() -> None:
def test_cat_expr_to_local() -> None:
with pl.StringCache():
df = pl.DataFrame(
{
"s1": pl.Series(["a", "b", "a"], dtype=pl.Categorical),
"s2": pl.Series(["c", "b", "d"], dtype=pl.Categorical),
}
)

# s2 physical starts after s1
assert df["s1"].to_physical().to_list() == [0, 1, 0]
assert df["s2"].to_physical().to_list() == [2, 1, 3]

out = df.select(pl.col("s1", "s2").cat.to_local())

# Physical has changed and now starts at 0, string values are the same
assert out["s1"].cat.is_local()
assert out["s2"].cat.is_local()
assert out["s1"].to_physical().to_list() == [0, 1, 0]
assert out["s2"].to_physical().to_list() == [0, 1, 2]
assert out["s1"].to_list() == ["a", "b", "a"]
assert out["s2"].to_list() == ["c", "b", "d"]

# s2 should be unchanged after the operation
assert not df["s2"].cat.is_local()
assert df["s2"].to_physical().to_list() == [2, 1, 3]
assert df["s2"].to_list() == ["c", "b", "d"]


def test_cat_series_to_local_missing_values() -> None:
with pl.StringCache():
_ = pl.Series(["a", "b"], dtype=pl.Categorical)
s = pl.Series(["c", "b", None, "d"], dtype=pl.Categorical)
Expand All @@ -112,7 +144,20 @@ def test_cat_to_local_missing_values() -> None:
assert out.to_physical().to_list() == [0, 1, None, 2]


def test_cat_to_local_already_local() -> None:
def test_cat_expr_to_local_missing_values() -> None:
with pl.StringCache():
_ = pl.Series(["a", "b"], dtype=pl.Categorical)
df = pl.DataFrame(
{
"s": pl.Series(["c", "b", None, "d"], dtype=pl.Categorical),
}
)

out = df.select(pl.col("s").cat.to_local())
assert out["s"].to_physical().to_list() == [0, 1, None, 2]


def test_cat_series_to_local_already_local() -> None:
s = pl.Series(["a", "c", "a", "b"], dtype=pl.Categorical)

assert s.cat.is_local()
Expand All @@ -122,7 +167,29 @@ def test_cat_to_local_already_local() -> None:
assert out.to_list() == ["a", "c", "a", "b"]


def test_cat_is_local() -> None:
def test_cat_expr_to_local_already_local() -> None:
df = pl.DataFrame({"s": pl.Series(["a", "c", "a", "b"], dtype=pl.Categorical)})

assert df["s"].cat.is_local()
out = df.select(pl.col("s").cat.to_local())

assert out["s"].to_physical().to_list() == [0, 1, 0, 2]
assert out["s"].to_list() == ["a", "c", "a", "b"]


@StringCache()
def test_cat_global_to_local_schema() -> None:
_ = pl.Series(["a", "b", "c"], dtype=pl.Categorical)
schema = (
pl.LazyFrame({"s": pl.Series(["c", "b", "d"], dtype=pl.Categorical)})
.select(pl.col("s").cat.to_local())
.collect_schema()
)

assert schema == OrderedDict([("s", pl.Categorical(ordering="physical"))])


def test_cat_series_is_local() -> None:
s = pl.Series(["a", "c", "a", "b"], dtype=pl.Categorical)
assert s.cat.is_local()

Expand All @@ -131,6 +198,17 @@ def test_cat_is_local() -> None:
assert not s2.cat.is_local()


def test_cat_expr_is_local() -> None:
df = pl.DataFrame({"s": pl.Series(["a", "c", "a", "b"], dtype=pl.Categorical)})
assert df["s"].cat.is_local()

with pl.StringCache():
df = df.with_columns(
pl.Series(["a", "b", "a", "c"], dtype=pl.Categorical).alias("s2")
)
assert not df["s2"].cat.is_local()


def test_cat_uses_lexical_ordering() -> None:
s = pl.Series(["a", "b", None, "b"]).cast(pl.Categorical)
assert s.cat.uses_lexical_ordering() is False
Expand Down
Loading