Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PR: Update function for Polars-Rust #21055

Closed
pbower opened this issue Feb 3, 2025 · 2 comments
Closed

PR: Update function for Polars-Rust #21055

pbower opened this issue Feb 3, 2025 · 2 comments
Labels
enhancement New feature or an improvement of an existing feature

Comments

@pbower
Copy link

pbower commented Feb 3, 2025

Description

Hi team,

To my knowledge, the rust version of Polars lacks an update function.

I created one with unit tests for use in my own personal codebase, and therefore am sharing it back here in case you would like to either use it, suggest any refactoring for a PR or otherwise if for whatever reason you don't it, that's absolutely fine.

Here's the function, and unit tests, that all pass.

Note that currently:

  1. The code accepts 2 dataframes, rather than the Python version which says df1.update(df2).
  2. The explicit sorting may be at odds (e.g. on full joins) with the rust library semantics and could be made optional.
  3. It's Lazyframe, and doesn't yet include the Dataframe wrapper version.

Anyway, I'm sure it's at a starting point.

Cheers.
PB

/// This function updates `self` with values from `other`, matching rows based on:
///  1. Explicit join columns (`left_on` / `right_on`), or
///  2. A user-supplied `on` list of columns, or
///  3. If no keys are provided, a synthetic row index is used for both frames.
/// 
/// The `how` parameter determines which rows are kept:
///   - `"left"`: keep all rows from self (the left frame)
///   - `"inner"`: keep only rows that exist in both frames
///   - `"full"`: keep all rows from both frames (full outer join)
/// 
/// By default, any null value in the right frame (`other`) does **not** overwrite
/// the corresponding value in the left frame. Set `include_nulls = true` to allow
/// null values in `other` to overwrite values in `self`.
/// 
/// # Parameters
/// - `self`: The left (original) `LazyFrame`.
/// - `other`: The right `LazyFrame`, whose values will update the left.
/// - `on`: Optional single/array of column(s) on which to join (both frames share the same column names).
/// - `how`: One of `left`, `inner`, or `full`.
/// - `left_on`: Optional override for join columns on the left frame (if `on` is not used).
/// - `right_on`: Optional override for join columns on the right frame (if `on` is not used).
/// - `include_nulls`: Whether to allow nulls from `other` to overwrite values in `self`.
/// 
/// # Returns
/// A `LazyFrame` representing the merged/updated data.
/// 
/// # Notes
/// The primary nuance is that Polars does not guarantee row ordering on a Full join.
/// However, the tests provided assume that rows end up in ascending order of the join
/// key. To accommodate that, we explicitly sort by the join key(s) after the join (if
/// row indices were *not* used). If row indices were used, we sort by that artificial
/// index to replicate the Python behaviour of "update" as closely as possible.
pub fn update(
    mut left: LazyFrame,
    mut right: LazyFrame,
    on: Option<Vec<String>>,
    how: &str,
    mut left_on: Option<Vec<String>>,
    mut right_on: Option<Vec<String>>,
    include_nulls: bool,
) -> LazyFrame {
    // 1. Validate input parameters
    if how != "left" && how != "inner" && how != "full" {
        panic!("Invalid 'how' parameter. Expected one of 'left', 'inner', or 'full'.");
    }

    // 2. If `on` is specified, apply it to both left_on and right_on.
    if let Some(on_cols) = on {
        left_on = Some(on_cols.clone());
        right_on = Some(on_cols);
    }

    if left_on.is_none() || right_on.is_none() {
        panic!("Both `left_on` and `right_on`, or `on` must be provided for the join.");
    }
    
    let left_on = left_on.expect("left_on cannot be None");
    let right_on = right_on.expect("right_on cannot be None");

    // 4. Identify which columns in `right` are available for updating.
    let left_schema = left.collect_schema().unwrap();
    let right_schema = right.collect_schema().unwrap();

    let left_col_names: HashSet<String> =
        left_schema.iter_names().map(|s| s.to_string()).collect();
    let right_col_names: Vec<String> =
        right_schema.iter_names().map(|s| s.to_string()).collect();

    // 5. For full join, we allow join keys to be updated.
    let join_col_set: HashSet<String> = if how == "full" {
        HashSet::new()
    } else {
        right_on.iter().cloned().collect()
    };

    // Determine which columns are updatable (shared but not join columns)
    let updatable_cols: Vec<String> = right_col_names
        .iter()
        .filter(|c| left_col_names.contains(*c))
        .filter(|c| !join_col_set.contains(*c))
        .cloned()
        .collect();

    // 6. If no updatable columns (and not "full"), just return the left frame
    if how != "full" && updatable_cols.is_empty() {
        return left
    }

    // 7. Optionally add a validity marker to track matched rows
    let validity_marker = "__POLARS_VALIDITY";
    if include_nulls {
        // Add a column of true; after the join, any row that lacks this column is unmatched.
        right = right.with_columns([lit(true).alias(validity_marker)]);
    }

    // 8. Select only necessary columns for the join
    let mut needed_right_cols: HashSet<String> = right_on.iter().cloned().collect();
    for c in &updatable_cols {
        needed_right_cols.insert(c.clone());
    }
    if include_nulls {
        needed_right_cols.insert(validity_marker.to_string());
    }
    let right_select_exprs: Vec<Expr> = needed_right_cols
        .iter()
        .map(|colname| col(colname))
        .collect();
    right = right.select(right_select_exprs);

    // 9. Perform the join, adding suffix to right-side columns
    let suffix = "_right";
    let mut joined = left
        .join_builder()
        .with(right)
        .left_on(left_on.iter().map(|c| col(c)).collect::<Vec<_>>())
        .right_on(right_on.iter().map(|c| col(c)).collect::<Vec<_>>())
        .how(match how {
            "left" => JoinType::Left,
            "inner" => JoinType::Inner,
            "full" => JoinType::Full,
            _ => unreachable!(),
        })
        .suffix(suffix)
        .finish();

    // 10. Construct coalesce expressions for updating the columns
    let mut coalesce_exprs = Vec::with_capacity(updatable_cols.len());
    for c in &updatable_cols {
        let right_name = format!("{}{}", c, suffix);
        if include_nulls {
            coalesce_exprs.push(
                when(col(validity_marker).is_not_null())
                    .then(col(&right_name))
                    .otherwise(col(c))
                    .alias(c),
            );
        } else {
            coalesce_exprs.push(
                coalesce(&[col(&right_name), col(c)]).alias(c),
            );
        }
    }

    // 11. Drop the extra columns (suffix and validity marker if needed)
    let joined_schema = joined.collect_schema().unwrap();
    let suffix_exprs: Vec<Expr> = joined_schema
        .iter_names()
        .filter(|name| name.ends_with(suffix))
        .map(|n| col(n.as_str()))
        .collect();

    let mut drop_cols = vec![];

    if include_nulls {
        drop_cols.push(validity_marker.to_string());
    }
    let drop_exprs: Vec<Expr> = drop_cols.into_iter().map(|name| col(name)).collect();

    let mut final_lf = joined
        .with_columns(coalesce_exprs)
        .drop(suffix_exprs)
        .drop(drop_exprs);

    // 12. Sort by the join key(s) 
    if !left_on.is_empty() {
        let sort_exprs: Vec<Expr> = left_on.iter().map(|c| col(c)).collect();
        let sort_options = SortMultipleOptions {
            descending: vec![false; left_on.len()],
            nulls_last: vec![true; left_on.len()], // push null keys to the end
            multithreaded: true,
            maintain_order: true,
            limit: None,
        };
        final_lf = final_lf.sort_by_exprs(sort_exprs, sort_options);
    }

    final_lf
}



    #[test]
    fn test_update_left() {
        let left = df!(
            "A" => &[1, 2, 3, 4],
            "B" => &[400, 500, 600, 700]
        ).unwrap().lazy();

        let right = df!(
            "A" => &[1, 2, 3], // Existing rows
            "B" => &[-66, -99, -88] // Values to update
        ).unwrap().lazy();

        let updated = update(left, right, None, "left", Some(vec!["A".to_string()]), Some(vec!["A".to_string()]), false)
            .collect()
            .unwrap();

        let expected = df!(
            "A" => &[1, 2, 3, 4],
            "B" => &[-66, -99, -88, 700] // Updated, but 4 remains untouched
        )
        .unwrap();

        assert!(updated.equals(&expected));
    }

    #[test]
    fn test_update_inner() {
        let left = df!(
            "A" => &[1, 2, 3, 4],
            "B" => &[400, 500, 600, 700]
        ).unwrap().lazy();

        let right = df!(
            "A" => &[1, 2, 3], // Only matches first three
            "B" => &[-66, -99, -88] // Values to update
        ).unwrap().lazy();

        let updated = update(left, right, None, "inner", Some(vec!["A".to_string()]), Some(vec!["A".to_string()]), false)
            .collect()
            .unwrap();

        let expected = df!(
            "A" => &[1, 2, 3], // Only common rows remain
            "B" => &[-66, -99, -88] // Fully updated
        )
        .unwrap();

        assert!(updated.equals(&expected));
    }

    #[test]
    fn test_update_full() {
        let left = df!(
            "A" => &[1, 2, 3],
            "B" => &[400, 500, 600]
        ).unwrap().lazy();

        let right = df!(
            "A" => &[2, 3, 4], // New row 4 exists here
            "B" => &[-99, -88, 999] // Updates + new row
        ).unwrap().lazy();

        let updated = update(left, right, None, "full", Some(vec!["A".to_string()]), Some(vec!["A".to_string()]), false)
            .collect()
            .unwrap();

        let expected = df!(
            "A" => &[1, 2, 3, 4], // Full join keeps all rows
            "B" => &[400, -99, -88, 999] // Updated where possible, new row added
        )
        .unwrap();

        assert!(updated.equals(&expected));
    }

    #[test]
    fn test_update_include_nulls() {
        let left = df!(
            "A" => &[1, 2, 3],
            "B" => &[400, 500, 600]
        ).unwrap().lazy();

        let right = df!(
            "A" => &[1, 2, 3],
            "B" => &[None, Some(-99), None] // Some None values
        ).unwrap().lazy();

        let updated = update(left, right, None, "left", Some(vec!["A".to_string()]), Some(vec!["A".to_string()]), true)
            .collect()
            .unwrap();

        let expected = df!(
            "A" => &[1, 2, 3],
            "B" => &[None, Some(-99), None] // Left joined, now allows None overwrites
        )
        .unwrap();

        assert!(updated.equals_missing(&expected));
    }

    #[test]
    fn test_update_no_match() {
        let left = df!(
            "A" => &[1, 2, 3],
            "B" => &[400, 500, 600]
        ).unwrap().lazy();

        let right = df!(
            "A" => &[10, 11, 12], // No overlap with left
            "B" => &[-1, -2, -3]
        ).unwrap().lazy();

        let updated = update(left, right, None, "left", Some(vec!["A".to_string()]), Some(vec!["A".to_string()]), false)
            .collect()
            .unwrap();

        let expected = df!(
            "A" => &[1, 2, 3], // No rows were matched or updated
            "B" => &[400, 500, 600]
        )
        .unwrap();

        assert!(updated.equals(&expected));
    }

    #[test]
    fn test_update_partial_update() {
        let left = df!(
            "A" => &[1, 2, 3, 4, 5],
            "B" => &[10, 20, 30, 40, 50],
            "C" => &[100, 200, 300, 400, 500]
        ).unwrap().lazy();

        let right = df!(
            "A" => &[2, 3, 5], // Only updates these rows
            "B" => &[99, 88, 77] // Update B column
        ).unwrap().lazy();

        let updated = update(left, right, None, "left", Some(vec!["A".to_string()]), Some(vec!["A".to_string()]), false)
            .collect()
            .unwrap();

        let expected = df!(
            "A" => &[1, 2, 3, 4, 5],
            "B" => &[10, 99, 88, 40, 77], // Updated only in the provided rows
            "C" => &[100, 200, 300, 400, 500] // Unchanged
        )
        .unwrap();

        assert!(updated.equals(&expected));
    }

    // Helper function to create DataFrames
    fn create_df_1() -> LazyFrame {
        let df = df![
            "id" => &[1, 2, 3],
            "value" => &["a", "b", "c"]
        ]
        .unwrap();
        df.lazy()
    }

    fn create_df_2() -> LazyFrame {
        let df = df![
            "id" => &[1, 2, 4],
            "value" => &["x", "y", "z"]
        ]
        .unwrap();
        df.lazy()
    }

    // Test inner join update
    #[test]
    fn test_update_inner_join() {
        let left = create_df_1();
        let right = create_df_2();

        let updated = update(
            left,
            right,
            Some(vec!["id".to_string()]),
            "inner",
            None,
            None,
            false,
        );

        let result = updated.collect().unwrap();
        let expected = df![
            "id" => &[1, 2],
            "value" => &["x", "y"]
        ]
        .unwrap();

        assert_eq!(result, expected);
    }

    // Test left join update
    #[test]
    fn test_update_left_join() {
        let left = create_df_1();
        let right = create_df_2();

        let updated = update(
            left,
            right,
            Some(vec!["id".to_string()]),
            "left",
            None,
            None,
            false,
        );

        let result = updated.collect().unwrap();
        let expected = df![
            "id" => &[1, 2, 3],
            "value" => &["x", "y", "c"]
        ]
        .unwrap();

        assert_eq!(result, expected);
    }

    // Test full outer join update
    #[test]
    fn test_update_full_join() {
        let left = create_df_1();
        let right = create_df_2();

        let updated = update(
            left,
            right,
            Some(vec!["id".to_string()]),
            "full",
            None,
            None,
            false,
        );

        let result = updated.collect().unwrap();
        let expected = df![
            "id" => &[1, 2, 3, 4],
            "value" => &["x", "y", "c", "z"]
        ]
        .unwrap();

        assert_eq!(result, expected);
    }

    // Test empty DataFrames case
    #[test]
    fn test_update_with_empty_dataframes() {
        let left = df![
            "id" => &[] as &[i32],
            "value" => &[] as &[i32],
        ]
        .unwrap()
        .lazy();

        let right = df![
            "id" => &[] as &[i32],
            "value" => &[] as &[i32],
        ]
        .unwrap()
        .lazy();

        let updated = update(
            left,
            right,
            Some(vec!["id".to_string()]),
            "left",
            None,
            None,
            false,
        );

        let result = updated.collect().unwrap();

        let expected = df![
            "id" => &[] as &[i32],
            "value" => &[] as &[i32],
        ]
        .unwrap();

        assert_eq!(result, expected);
    }


    // Test update with composite keys
    #[test]
    fn test_update_with_composite_keys() {
        let df_left = df![
            "id" => &[1, 2, 3],
            "group" => &["A", "B", "A"],
            "value" => &["a", "b", "c"]
        ]
        .unwrap();

        let df_right = df![
            "id" => &[1, 2, 4],
            "group" => &["A", "B", "C"],
            "value" => &["x", "y", "z"]
        ]
        .unwrap();

        let left = df_left.lazy();
        let right = df_right.lazy();

        let updated = update(
            left,
            right,
            Some(vec!["id".to_string(), "group".to_string()]),
            "left",
            None,
            None,
            false,
        );

        let result = updated.collect().unwrap();
        let expected = df![
            "id" => &[1, 2, 3],
            "group" => &["A", "B", "A"],
            "value" => &["x", "y", "c"]
        ]
        .unwrap();

        assert_eq!(result, expected);
    }

    // Test when `include_nulls` is true and null overwrites occur
    #[test]
    fn test_update_with_include_nulls() {
        let df_left = df![
            "id" => &[1, 2],
            "value" => &["a", "b"]
        ]
        .unwrap();

        let df_right = df![
            "id" => &[1, 2],
            "value" => &[None, Some("y".to_string())]
        ]
        .unwrap();

        let left = df_left.lazy();
        let right = df_right.lazy();

        let updated = update(
            left,
            right,
            Some(vec!["id".to_string()]),
            "left",
            None,
            None,
            true, // Include nulls
        );

        let result = updated.collect().unwrap();
        let expected = df![
            "id" => &[1, 2],
            "value" => &[None, Some("y")]
        ]
        .unwrap();

        assert_eq!(result, expected);
    }
@pbower pbower added the enhancement New feature or an improvement of an existing feature label Feb 3, 2025
@Bidek56
Copy link
Contributor

Bidek56 commented Feb 9, 2025

Dude, why not put a PR with this code? :)

@pbower
Copy link
Author

pbower commented Feb 9, 2025

Sure, ok I’ll do that.

@pbower pbower closed this as completed Feb 9, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or an improvement of an existing feature
Projects
None yet
Development

No branches or pull requests

2 participants