Skip to content

Commit

Permalink
Merge pull request #4 from VorTECHsa/pandas-df-assert
Browse files Browse the repository at this point in the history
refactor: use the pandas testing functions for df and series assertions
  • Loading branch information
maxhipperson authored Feb 28, 2024
2 parents 677e8e2 + 289678e commit e2ebdf0
Showing 1 changed file with 5 additions and 35 deletions.
40 changes: 5 additions & 35 deletions uhura/pandas_tools.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,9 @@
from typing import Any
import pandas as pd
from collections import defaultdict
from uhura.comparison import Comparer


def _safe_hash(df):
hash_sum = 0
for column in df.columns:
try:
hash_sum += pd.util.hash_pandas_object(df[column])
except TypeError:
continue
return hash_sum


def _compare_dataframes(df1: pd.DataFrame, df2: pd.DataFrame):
"""Compare two dataframes
Multiple asserts before the "overall" hash comparison to allow us to narrow down
what is wrong
"""
assert df1.shape == df2.shape, f"Dataframe shapes do not match ({df1.shape} vs {df2.shape})"
assert all(df1.columns == df2.columns), "Dataframe columns do not match"
assert all(df1.dtypes == df2.dtypes), "Dataframe dtypes do not match"
assert all(_safe_hash(df1) == _safe_hash(df2)), "Dataframe hashes do not match"
from typing import Any

import pandas as pd

def _compare_series(ser1: pd.Series, ser2: pd.Series):
assert len(ser1) == len(ser2), f"Series lengths do not match ({len(ser1)} vs {len(ser2)})"
assert all(ser1.index == ser2.index), "Series indexes do not match"
assert ser1.dtype == ser2.dtype, "Series dtypes do not match"
try:
assert all(ser1 == ser2), "Series hashes do not match"
except TypeError:
pass
from uhura.comparison import Comparer


def compare_data(data1: Any, data2: Any):
Expand All @@ -51,8 +21,8 @@ def base_compare(self, actual, expected):


COMPARISON_LOOKUP = {
pd.DataFrame: _compare_dataframes,
pd.Series: _compare_series,
pd.DataFrame: pd.testing.assert_frame_equal,
pd.Series: pd.testing.assert_series_equal,
}

pandas_comparator = defaultdict(PandasComparer)

0 comments on commit e2ebdf0

Please sign in to comment.