From 289678e0974f3b90091889fe74d3564d6cd46cc6 Mon Sep 17 00:00:00 2001 From: Max Hipperson <40496579+maxhipperson@users.noreply.github.com> Date: Wed, 28 Feb 2024 15:28:13 +0000 Subject: [PATCH] refactor: use the pandas testing functions for df and series assertions --- uhura/pandas_tools.py | 40 +++++----------------------------------- 1 file changed, 5 insertions(+), 35 deletions(-) diff --git a/uhura/pandas_tools.py b/uhura/pandas_tools.py index 224acc0..8711198 100644 --- a/uhura/pandas_tools.py +++ b/uhura/pandas_tools.py @@ -1,39 +1,9 @@ -from typing import Any -import pandas as pd from collections import defaultdict -from uhura.comparison import Comparer - - -def _safe_hash(df): - hash_sum = 0 - for column in df.columns: - try: - hash_sum += pd.util.hash_pandas_object(df[column]) - except TypeError: - continue - return hash_sum - - -def _compare_dataframes(df1: pd.DataFrame, df2: pd.DataFrame): - """Compare two dataframes - - Multiple asserts before the "overall" hash comparison to allow us to narrow down - what is wrong - """ - assert df1.shape == df2.shape, f"Dataframe shapes do not match ({df1.shape} vs {df2.shape})" - assert all(df1.columns == df2.columns), "Dataframe columns do not match" - assert all(df1.dtypes == df2.dtypes), "Dataframe dtypes do not match" - assert all(_safe_hash(df1) == _safe_hash(df2)), "Dataframe hashes do not match" +from typing import Any +import pandas as pd -def _compare_series(ser1: pd.Series, ser2: pd.Series): - assert len(ser1) == len(ser2), f"Series lengths do not match ({len(ser1)} vs {len(ser2)})" - assert all(ser1.index == ser2.index), "Series indexes do not match" - assert ser1.dtype == ser2.dtype, "Series dtypes do not match" - try: - assert all(ser1 == ser2), "Series hashes do not match" - except TypeError: - pass +from uhura.comparison import Comparer def compare_data(data1: Any, data2: Any): @@ -51,8 +21,8 @@ def base_compare(self, actual, expected): COMPARISON_LOOKUP = { - pd.DataFrame: _compare_dataframes, - pd.Series: _compare_series, + pd.DataFrame: pd.testing.assert_frame_equal, + pd.Series: pd.testing.assert_series_equal, } pandas_comparator = defaultdict(PandasComparer)