Skip to content

Commit

Permalink
feat(python): Allow insert_column to take expressions
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-beedie committed Sep 30, 2024
1 parent ab5200d commit b72bd07
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 3 deletions.
18 changes: 15 additions & 3 deletions py-polars/polars/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -4525,7 +4525,7 @@ def rename(
"""
return self.lazy().rename(mapping, strict=strict).collect(_eager=True)

def insert_column(self, index: int, column: Series) -> DataFrame:
def insert_column(self, index: int, column: IntoExprColumn) -> DataFrame:
"""
Insert a Series at a certain column index.
Expand All @@ -4536,7 +4536,7 @@ def insert_column(self, index: int, column: Series) -> DataFrame:
index
Index at which to insert the new `Series` column.
column
`Series` to insert.
`Series` or expression to insert.
Examples
--------
Expand Down Expand Up @@ -4577,7 +4577,19 @@ def insert_column(self, index: int, column: Series) -> DataFrame:
"""
if index < 0:
index = len(self.columns) + index
self._df.insert_column(index, column._s)

if isinstance(column, pl.Series):
self._df.insert_column(index, column._s)
else:
if isinstance(column, str):
column = F.col(column)
if isinstance(column, pl.Expr):
cols = self.columns
cols.insert(index, column) # type: ignore[arg-type]
self._df = self.select(cols)._df
else:
msg = f"column must be a Series or Expr, got {column!r} (type={type(column)})"
raise TypeError(msg)
return self

def filter(
Expand Down
23 changes: 23 additions & 0 deletions py-polars/tests/unit/dataframe/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,7 @@ def test_assignment() -> None:


def test_insert_column() -> None:
# insert series
df = (
pl.DataFrame({"z": [3, 4, 5]})
.insert_column(0, pl.Series("x", [1, 2, 3]))
Expand All @@ -466,6 +467,28 @@ def test_insert_column() -> None:
expected_df = pl.DataFrame({"x": [1, 2, 3], "y": [2, 3, 4], "z": [3, 4, 5]})
assert_frame_equal(expected_df, df)

# insert expressions
df = pl.DataFrame(
{
"id": ["xx", "yy", "zz"],
"v1": [5, 4, 6],
"v2": [7, 3, 3],
}
)
df.insert_column(3, (pl.col("v1") * pl.col("v2")).alias("v3"))
df.insert_column(1, (pl.col("v2") - pl.col("v1")).alias("v0"))

expected = pl.DataFrame(
{
"id": ["xx", "yy", "zz"],
"v0": [2, -1, -3],
"v1": [5, 4, 6],
"v2": [7, 3, 3],
"v3": [35, 12, 18],
}
)
assert_frame_equal(df, expected)


def test_replace_column() -> None:
df = (
Expand Down

0 comments on commit b72bd07

Please sign in to comment.