Skip to content

Commit

Permalink
fix: Column name mapping in missing left/right (#12)
Browse files Browse the repository at this point in the history
  • Loading branch information
MariusMerkleQC authored Nov 7, 2024
1 parent 1cb5220 commit e5f2484
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 8 deletions.
1 change: 1 addition & 0 deletions sqlcompyre/analysis/schema_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def table_names(self) -> Names:
return Names(
left=set(self.left_tables.keys()),
right=set(self.right_tables.keys()),
name_mapping=None,
ignore_casing=self.ignore_casing,
)

Expand Down
1 change: 1 addition & 0 deletions sqlcompyre/analysis/table_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def column_names(self) -> Names:
return Names(
left={col.name for col in self.left_table.columns},
right={col.name for col in self.right_table.columns},
name_mapping=self.column_name_mapping,
ignore_casing=self.ignore_casing,
)

Expand Down
27 changes: 24 additions & 3 deletions sqlcompyre/results/names.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,18 @@
class Names:
"""Investigate the names of database objects."""

def __init__(self, left: set[str], right: set[str], ignore_casing: bool):
def __init__(
self,
left: set[str],
right: set[str],
name_mapping: dict[str, str] | None,
ignore_casing: bool,
):
"""
Args:
left: Names from the "left" database object.
right: Names from the "right" database object.
name_mapping: Mapping from the "left" to the "right" database object.
ignore_casing: Whether to ignore casing for name equality.
"""
if ignore_casing:
Expand All @@ -20,6 +27,10 @@ def __init__(self, left: set[str], right: set[str], ignore_casing: bool):
else:
self._set_left = left
self._set_right = right
self._name_mapping = name_mapping
self._inverse_name_mapping = (
{v: k for k, v in name_mapping.items()} if name_mapping else {}
)

@cached_property
def left(self) -> list[str]:
Expand All @@ -39,12 +50,22 @@ def in_common(self) -> list[str]:
@cached_property
def missing_left(self) -> list[str]:
"""Ordered list of names provided only by the "right" database object."""
return sorted(self._set_right - self._set_left)
if self._name_mapping:
right_renamed = {
self._inverse_name_mapping.get(k, k) for k in self._set_right
}
return sorted(right_renamed - self._set_left)
else:
return sorted(self._set_right - self._set_left)

@cached_property
def missing_right(self) -> list[str]:
"""Ordered list of names provided only by the "left" database object."""
return sorted(self._set_left - self._set_right)
if self._name_mapping:
left_renamed = {self._name_mapping.get(k, k) for k in self._set_left}
return sorted(left_renamed - self._set_right)
else:
return sorted(self._set_left - self._set_right)

@cached_property
def equal(self) -> bool:
Expand Down
8 changes: 4 additions & 4 deletions tests/analysis/table_comparison/test_column_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ def test_partly_renaming(
column_name_mapping={"age": "age_v2", "gpa": "gpa_v2"},
)
assert len(comparison.column_names.in_common) == 2
assert len(comparison.column_names.missing_left) == 2
assert len(comparison.column_names.missing_right) == 2
assert len(comparison.column_names.missing_left) == 0
assert len(comparison.column_names.missing_right) == 0
assert comparison.row_counts.diff == 1
# Ensure that all columns are matched, one is primary key, 3 per table left
assert pd.read_sql(comparison.row_matches.joined_equal, con=engine).shape[1] == 7
Expand All @@ -102,8 +102,8 @@ def test_partly_renaming(
column_name_mapping={"age_v2": "age", "gpa_v2": "gpa"},
)
assert len(comparison.column_names.in_common) == 2
assert len(comparison.column_names.missing_left) == 2
assert len(comparison.column_names.missing_right) == 2
assert len(comparison.column_names.missing_left) == 0
assert len(comparison.column_names.missing_right) == 0
assert comparison.row_counts.diff == 1
# Ensure that all columns are matched, one is primary key, 3 per table left
assert pd.read_sql(comparison.row_matches.joined_equal, con=engine).shape[1] == 7
Expand Down
5 changes: 4 additions & 1 deletion tests/report/formatters/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@ def metadata_description() -> Metadata:
@pytest.fixture()
def names() -> Names:
return Names(
left={"hello", "world"}, right={"hello", "hi", "there"}, ignore_casing=False
left={"hello", "world"},
right={"hello", "hi", "there"},
name_mapping=None,
ignore_casing=False,
)


Expand Down

0 comments on commit e5f2484

Please sign in to comment.