Skip to content

Commit

Permalink
[FEAT] connect: with_columns_renamed
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewgazelka committed Dec 11, 2024
1 parent 5238279 commit fa1b9d8
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 0 deletions.
5 changes: 5 additions & 0 deletions src/daft-connect/src/translation/logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use tracing::warn;
use crate::translation::logical_plan::{
aggregate::aggregate, filter::filter, local_relation::local_relation, project::project,
range::range, read::read, to_df::to_df, with_columns::with_columns,
with_columns_renamed::with_columns_renamed,
};

mod aggregate;
Expand All @@ -17,6 +18,7 @@ mod range;
mod read;
mod to_df;
mod with_columns;
mod with_columns_renamed;

pub struct Plan {
pub builder: LogicalPlanBuilder,
Expand Down Expand Up @@ -75,6 +77,9 @@ pub async fn to_logical_plan(relation: Relation) -> eyre::Result<Plan> {
RelType::LocalRelation(l) => {
local_relation(l).wrap_err("Failed to apply local_relation to logical plan")
}
RelType::WithColumnsRenamed(w) => with_columns_renamed(*w)
.await

Check warning on line 81 in src/daft-connect/src/translation/logical_plan.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/logical_plan.rs#L81

Added line #L81 was not covered by tests
.wrap_err("Failed to apply with_columns_renamed to logical plan"),
RelType::Read(r) => read(r)
.await
.wrap_err("Failed to apply read to logical plan"),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
use daft_dsl::col;
use eyre::{bail, Context};

use crate::translation::Plan;

pub async fn with_columns_renamed(
with_columns_renamed: spark_connect::WithColumnsRenamed,
) -> eyre::Result<Plan> {
let spark_connect::WithColumnsRenamed {
input,
rename_columns_map,
renames,
} = with_columns_renamed;

let Some(input) = input else {
bail!("Input is required");

Check warning on line 16 in src/daft-connect/src/translation/logical_plan/with_columns_renamed.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/logical_plan/with_columns_renamed.rs#L16

Added line #L16 was not covered by tests
};

let mut plan = Box::pin(crate::translation::to_logical_plan(*input)).await?;

// todo: do we want to implement this directly into daft?

// Convert the rename mappings into expressions
let rename_exprs = if !rename_columns_map.is_empty() {
// Use rename_columns_map if provided (legacy format)
rename_columns_map
.into_iter()
.map(|(old_name, new_name)| col(old_name.as_str()).alias(new_name.as_str()))
.collect()
} else {
// Use renames if provided (new format)
renames
.into_iter()
.map(|rename| col(rename.col_name.as_str()).alias(rename.new_col_name.as_str()))
.collect()

Check warning on line 35 in src/daft-connect/src/translation/logical_plan/with_columns_renamed.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/logical_plan/with_columns_renamed.rs#L32-L35

Added lines #L32 - L35 were not covered by tests
};

// Apply the rename expressions to the plan
plan.builder = plan
.builder
.select(rename_exprs)
.wrap_err("Failed to apply rename expressions to logical plan")?;

Ok(plan)
}
24 changes: 24 additions & 0 deletions tests/connect/test_with_columns_renamed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from __future__ import annotations


def test_with_columns_renamed(spark_session):
# Test withColumnRenamed
df = spark_session.range(5)
renamed_df = df.withColumnRenamed("id", "number")

collected = renamed_df.collect()
assert len(collected) == 5
assert "number" in renamed_df.columns
assert "id" not in renamed_df.columns
assert [row["number"] for row in collected] == list(range(5))

# todo: this fails but is this expected or no?
# # Test withColumnsRenamed
# df = spark_session.range(2)
# renamed_df = df.withColumnsRenamed({"id": "number", "id": "character"})

# collected = renamed_df.collect()
# assert len(collected) == 2
# assert set(renamed_df.columns) == {"number", "character"}
# assert "id" not in renamed_df.columns
# assert [(row["number"], row["character"]) for row in collected] == [(0, 0), (1, 1)]

0 comments on commit fa1b9d8

Please sign in to comment.