diff --git a/.github/workflows/test-package.yml b/.github/workflows/test-package.yml
index a8390f0e..c7899383 100644
--- a/.github/workflows/test-package.yml
+++ b/.github/workflows/test-package.yml
@@ -19,11 +19,10 @@ jobs:
strategy:
fail-fast: false
matrix:
- python-version: [3.8, 3.9, '3.10', '3.11']
- spark-version: [3.1.3, 3.2.4, 3.3.4, 3.4.2, 3.5.0]
+ python-version: [3.9, '3.10', '3.11']
+ spark-version: [3.2.4, 3.3.4, 3.4.2, 3.5.1]
+ pandas-version: [2.2.2, 1.5.3]
exclude:
- - python-version: '3.11'
- spark-version: 3.1.3
- python-version: '3.11'
spark-version: 3.2.4
- python-version: '3.11'
@@ -51,6 +50,7 @@ jobs:
python -m pip install --upgrade pip
python -m pip install pytest pytest-spark pypandoc
python -m pip install pyspark==${{ matrix.spark-version }}
+ python -m pip install pandas==${{ matrix.pandas-version }}
python -m pip install .[dev]
- name: Test with pytest
run: |
@@ -62,7 +62,8 @@ jobs:
strategy:
fail-fast: false
matrix:
- python-version: [3.8, 3.9, '3.10', '3.11']
+ python-version: [3.9, '3.10', '3.11']
+
env:
PYTHON_VERSION: ${{ matrix.python-version }}
@@ -88,7 +89,7 @@ jobs:
strategy:
fail-fast: false
matrix:
- python-version: [3.8, 3.9, '3.10', '3.11']
+ python-version: [3.9, '3.10', '3.11']
env:
PYTHON_VERSION: ${{ matrix.python-version }}
diff --git a/CONTRIBUTORS b/CONTRIBUTORS
index 185f3b4f..e59e6454 100644
--- a/CONTRIBUTORS
+++ b/CONTRIBUTORS
@@ -3,4 +3,5 @@
- Usman Azhar
- Mark Zhou
- Ian Whitestone
-- Faisal Dosani
\ No newline at end of file
+- Faisal Dosani
+- Lorenzo Mercado
\ No newline at end of file
diff --git a/README.md b/README.md
index dc518c94..b9abfee5 100644
--- a/README.md
+++ b/README.md
@@ -38,16 +38,44 @@ pip install datacompy[ray]
```
-### In-scope Spark versions
-Different versions of Spark play nicely with only certain versions of Python below is a matrix of what we test with
+### Legacy Spark Deprecation
+
+#### Starting with version 0.12.0
+
+The original ``SparkCompare`` implementation differs from all the other native implementations. To align the API better, and keep behaviour consistent we are deprecating ``SparkCompare`` into a new module ``LegacySparkCompare``
+
+If you wish to use the old SparkCompare moving forward you can
+
+```python
+import datacompy.legacy.LegacySparkCompare
+```
+
+#### Supported versions and dependncies
+
+Different versions of Spark, Pandas, and Python interact differently. Below is a matrix of what we test with.
+With the move to Pandas on Spark API and compatability issues with Pandas 2+ we will for the mean time note support Pandas 2
+with the Pandas on Spark implementation. Spark plans to support Pandas 2 in [Spark 4](https://issues.apache.org/jira/browse/SPARK-44101)
+
+With version ``0.12.0``:
+- Not support Pandas ``2.0.0`` For the native Spark implemention
+- Spark ``3.1`` support will be dropped
+- Python ``3.8`` support is dropped
+
+
+| | Spark 3.2.4 | Spark 3.3.4 | Spark 3.4.2 | Spark 3.5.1 |
+|-------------|-------------|-------------|-------------|-------------|
+| Python 3.9 | ✅ | ✅ | ✅ | ✅ |
+| Python 3.10 | ✅ | ✅ | ✅ | ✅ |
+| Python 3.11 | ❌ | ❌ | ✅ | ✅ |
+| Python 3.12 | ❌ | ❌ | ❌ | ❌ |
+
+
+| | Pandas < 1.5.3 | Pandas >=2.0.0 |
+|---------------|----------------|----------------|
+| Native Pandas | ✅ | ✅ |
+| Native Spark | ✅ | ❌ |
+| Fugue | ✅ | ✅ |
-| | Spark 3.1.3 | Spark 3.2.3 | Spark 3.3.4 | Spark 3.4.2 | Spark 3.5.0 |
-|-------------|--------------|-------------|-------------|-------------|-------------|
-| Python 3.8 | ✅ | ✅ | ✅ | ✅ | ✅ |
-| Python 3.9 | ✅ | ✅ | ✅ | ✅ | ✅ |
-| Python 3.10 | ✅ | ✅ | ✅ | ✅ | ✅ |
-| Python 3.11 | ❌ | ❌ | ❌ | ✅ | ✅ |
-| Python 3.12 | ❌ | ❌ | ❌ | ❌ | ❌ |
> [!NOTE]
@@ -56,7 +84,7 @@ Different versions of Spark play nicely with only certain versions of Python bel
## Supported backends
- Pandas: ([See documentation](https://capitalone.github.io/datacompy/pandas_usage.html))
-- Spark: ([See documentation](https://capitalone.github.io/datacompy/spark_usage.html))
+- Spark (Pandas on Spark API): ([See documentation](https://capitalone.github.io/datacompy/spark_usage.html))
- Polars (Experimental): ([See documentation](https://capitalone.github.io/datacompy/polars_usage.html))
- Fugue is a Python library that provides a unified interface for data processing on Pandas, DuckDB, Polars, Arrow,
Spark, Dask, Ray, and many other backends. DataComPy integrates with Fugue to provide a simple way to compare data
diff --git a/datacompy/__init__.py b/datacompy/__init__.py
index 6b1aab24..b43027ae 100644
--- a/datacompy/__init__.py
+++ b/datacompy/__init__.py
@@ -13,16 +13,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-__version__ = "0.11.3"
+__version__ = "0.12.0"
from datacompy.core import *
from datacompy.fugue import (
all_columns_match,
all_rows_overlap,
+ count_matching_rows,
intersect_columns,
is_match,
report,
unq_columns,
)
from datacompy.polars import PolarsCompare
-from datacompy.spark import NUMERIC_SPARK_TYPES, SparkCompare
+from datacompy.spark import SparkCompare
diff --git a/datacompy/core.py b/datacompy/core.py
index a1730768..042dffb4 100644
--- a/datacompy/core.py
+++ b/datacompy/core.py
@@ -20,6 +20,7 @@
PROC COMPARE in SAS - i.e. human-readable reporting on the difference between
two dataframes.
"""
+
import logging
import os
from typing import Any, Dict, List, Optional, Union, cast
@@ -283,7 +284,11 @@ def _dataframe_merge(self, ignore_spaces: bool) -> None:
self.df2[column] = self.df2[column].str.strip()
outer_join = self.df1.merge(
- self.df2, how="outer", suffixes=("_df1", "_df2"), indicator=True, **params
+ self.df2,
+ how="outer",
+ suffixes=("_" + self.df1_name, "_" + self.df2_name),
+ indicator=True,
+ **params,
)
# Clean up temp columns for duplicate row matching
if self._any_dupes:
@@ -295,8 +300,8 @@ def _dataframe_merge(self, ignore_spaces: bool) -> None:
self.df1.drop(order_column, axis=1, inplace=True)
self.df2.drop(order_column, axis=1, inplace=True)
- df1_cols = get_merged_columns(self.df1, outer_join, "_df1")
- df2_cols = get_merged_columns(self.df2, outer_join, "_df2")
+ df1_cols = get_merged_columns(self.df1, outer_join, self.df1_name)
+ df2_cols = get_merged_columns(self.df2, outer_join, self.df2_name)
LOG.debug("Selecting df1 unique rows")
self.df1_unq_rows = outer_join[outer_join["_merge"] == "left_only"][
@@ -334,8 +339,8 @@ def _intersect_compare(self, ignore_spaces: bool, ignore_case: bool) -> None:
max_diff = 0.0
null_diff = 0
else:
- col_1 = column + "_df1"
- col_2 = column + "_df2"
+ col_1 = column + "_" + self.df1_name
+ col_2 = column + "_" + self.df2_name
col_match = column + "_match"
self.intersect_rows[col_match] = columns_equal(
self.intersect_rows[col_1],
@@ -484,7 +489,10 @@ def sample_mismatch(
match_cnt = col_match.sum()
sample_count = min(sample_count, row_cnt - match_cnt)
sample = self.intersect_rows[~col_match].sample(sample_count)
- return_cols = self.join_columns + [column + "_df1", column + "_df2"]
+ return_cols = self.join_columns + [
+ column + "_" + self.df1_name,
+ column + "_" + self.df2_name,
+ ]
to_return = sample[return_cols]
if for_display:
to_return.columns = pd.Index(
@@ -517,8 +525,8 @@ def all_mismatch(self, ignore_matching_cols: bool = False) -> pd.DataFrame:
orig_col_name = col[:-6]
col_comparison = columns_equal(
- self.intersect_rows[orig_col_name + "_df1"],
- self.intersect_rows[orig_col_name + "_df2"],
+ self.intersect_rows[orig_col_name + "_" + self.df1_name],
+ self.intersect_rows[orig_col_name + "_" + self.df2_name],
self.rel_tol,
self.abs_tol,
self.ignore_spaces,
@@ -530,7 +538,12 @@ def all_mismatch(self, ignore_matching_cols: bool = False) -> pd.DataFrame:
):
LOG.debug(f"Adding column {orig_col_name} to the result.")
match_list.append(col)
- return_list.extend([orig_col_name + "_df1", orig_col_name + "_df2"])
+ return_list.extend(
+ [
+ orig_col_name + "_" + self.df1_name,
+ orig_col_name + "_" + self.df2_name,
+ ]
+ )
elif ignore_matching_cols:
LOG.debug(
f"Column {orig_col_name} is equal in df1 and df2. It will not be added to the result."
@@ -613,7 +626,6 @@ def df_to_str(pdf: pd.DataFrame) -> str:
)
# Column Matching
- cnt_intersect = self.intersect_rows.shape[0]
report += render(
"column_comparison.txt",
len([col for col in self.column_stats if col["unequal_cnt"] > 0]),
@@ -804,7 +816,7 @@ def columns_equal(
compare = pd.Series(
(col_1 == col_2) | (col_1.isnull() & col_2.isnull())
)
- except:
+ except Exception:
# Blanket exception should just return all False
compare = pd.Series(False, index=col_1.index)
compare.index = col_1.index
@@ -842,13 +854,13 @@ def compare_string_and_date_columns(
(pd.to_datetime(obj_column) == date_column)
| (obj_column.isnull() & date_column.isnull())
)
- except:
+ except Exception:
try:
return pd.Series(
(pd.to_datetime(obj_column, format="mixed") == date_column)
| (obj_column.isnull() & date_column.isnull())
)
- except:
+ except Exception:
return pd.Series(False, index=col_1.index)
@@ -871,8 +883,8 @@ def get_merged_columns(
for col in original_df.columns:
if col in merged_df.columns:
columns.append(col)
- elif col + suffix in merged_df.columns:
- columns.append(col + suffix)
+ elif col + "_" + suffix in merged_df.columns:
+ columns.append(col + "_" + suffix)
else:
raise ValueError("Column not found: %s", col)
return columns
@@ -920,7 +932,7 @@ def calculate_max_diff(col_1: "pd.Series[Any]", col_2: "pd.Series[Any]") -> floa
"""
try:
return cast(float, (col_1.astype(float) - col_2.astype(float)).abs().max())
- except:
+ except Exception:
return 0.0
diff --git a/datacompy/fugue.py b/datacompy/fugue.py
index 2ac4889a..8bc01d33 100644
--- a/datacompy/fugue.py
+++ b/datacompy/fugue.py
@@ -291,6 +291,101 @@ def all_rows_overlap(
return all(overlap)
+def count_matching_rows(
+ df1: AnyDataFrame,
+ df2: AnyDataFrame,
+ join_columns: Union[str, List[str]],
+ abs_tol: float = 0,
+ rel_tol: float = 0,
+ df1_name: str = "df1",
+ df2_name: str = "df2",
+ ignore_spaces: bool = False,
+ ignore_case: bool = False,
+ cast_column_names_lower: bool = True,
+ parallelism: Optional[int] = None,
+ strict_schema: bool = False,
+) -> int:
+ """Count the number of rows match (on overlapping fields)
+
+ Parameters
+ ----------
+ df1 : ``AnyDataFrame``
+ First dataframe to check
+ df2 : ``AnyDataFrame``
+ Second dataframe to check
+ join_columns : list or str, optional
+ Column(s) to join dataframes on. If a string is passed in, that one
+ column will be used.
+ abs_tol : float, optional
+ Absolute tolerance between two values.
+ rel_tol : float, optional
+ Relative tolerance between two values.
+ df1_name : str, optional
+ A string name for the first dataframe. This allows the reporting to
+ print out an actual name instead of "df1", and allows human users to
+ more easily track the dataframes.
+ df2_name : str, optional
+ A string name for the second dataframe
+ ignore_spaces : bool, optional
+ Flag to strip whitespace (including newlines) from string columns (including any join
+ columns)
+ ignore_case : bool, optional
+ Flag to ignore the case of string columns
+ cast_column_names_lower: bool, optional
+ Boolean indicator that controls of column names will be cast into lower case
+ parallelism: int, optional
+ An integer representing the amount of parallelism. Entering a value for this
+ will force to use of Fugue over just vanilla Pandas
+ strict_schema: bool, optional
+ The schema must match exactly if set to ``True``. This includes the names and types. Allows for a fast fail.
+
+ Returns
+ -------
+ int
+ Number of matching rows
+ """
+ if (
+ isinstance(df1, pd.DataFrame)
+ and isinstance(df2, pd.DataFrame)
+ and parallelism is None # user did not specify parallelism
+ and fa.get_current_parallelism() == 1 # currently on a local execution engine
+ ):
+ comp = Compare(
+ df1=df1,
+ df2=df2,
+ join_columns=join_columns,
+ abs_tol=abs_tol,
+ rel_tol=rel_tol,
+ df1_name=df1_name,
+ df2_name=df2_name,
+ ignore_spaces=ignore_spaces,
+ ignore_case=ignore_case,
+ cast_column_names_lower=cast_column_names_lower,
+ )
+ return comp.count_matching_rows()
+
+ try:
+ count_matching_rows = _distributed_compare(
+ df1=df1,
+ df2=df2,
+ join_columns=join_columns,
+ return_obj_func=lambda comp: comp.count_matching_rows(),
+ abs_tol=abs_tol,
+ rel_tol=rel_tol,
+ df1_name=df1_name,
+ df2_name=df2_name,
+ ignore_spaces=ignore_spaces,
+ ignore_case=ignore_case,
+ cast_column_names_lower=cast_column_names_lower,
+ parallelism=parallelism,
+ strict_schema=strict_schema,
+ )
+ except _StrictSchemaError:
+ return False
+
+ return sum(count_matching_rows)
+
+
def report(
df1: AnyDataFrame,
df2: AnyDataFrame,
@@ -460,7 +555,6 @@ def _any(col: str) -> int:
any_mismatch = len(match_sample) > 0
# Column Matching
- cnt_intersect = shape0("intersect_rows_shape")
rpt += render(
"column_comparison.txt",
len([col for col in column_stats if col["unequal_cnt"] > 0]),
diff --git a/datacompy/legacy.py b/datacompy/legacy.py
new file mode 100644
index 00000000..b23b9cb2
--- /dev/null
+++ b/datacompy/legacy.py
@@ -0,0 +1,928 @@
+#
+# Copyright 2024 Capital One Services, LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import sys
+from enum import Enum
+from itertools import chain
+from typing import Any, Dict, List, Optional, Set, TextIO, Tuple, Union
+from warnings import warn
+
+try:
+ import pyspark
+ from pyspark.sql import functions as F
+except ImportError:
+ pass # Let non-Spark people at least enjoy the loveliness of the pandas datacompy functionality
+
+
+warn(
+ f"The module {__name__} is deprecated. In future versions LegacySparkCompare will be completely removed.",
+ DeprecationWarning,
+ stacklevel=2,
+)
+
+
+class MatchType(Enum):
+ MISMATCH, MATCH, KNOWN_DIFFERENCE = range(3)
+
+
+# Used for checking equality with decimal(X, Y) types. Otherwise treated as the string "decimal".
+def decimal_comparator() -> str:
+ class DecimalComparator(str):
+ def __eq__(self, other: str) -> bool: # type: ignore[override]
+ return len(other) >= 7 and other[0:7] == "decimal"
+
+ return DecimalComparator("decimal")
+
+
+NUMERIC_SPARK_TYPES = [
+ "tinyint",
+ "smallint",
+ "int",
+ "bigint",
+ "float",
+ "double",
+ decimal_comparator(),
+]
+
+
+def _is_comparable(type1: str, type2: str) -> bool:
+ """Checks if two Spark data types can be safely compared.
+ Two data types are considered comparable if any of the following apply:
+ 1. Both data types are the same
+ 2. Both data types are numeric
+
+ Parameters
+ ----------
+ type1 : str
+ A string representation of a Spark data type
+ type2 : str
+ A string representation of a Spark data type
+
+ Returns
+ -------
+ bool
+ True if both data types are comparable
+ """
+
+ return type1 == type2 or (
+ type1 in NUMERIC_SPARK_TYPES and type2 in NUMERIC_SPARK_TYPES
+ )
+
+
+class LegacySparkCompare:
+ """Comparison class used to compare two Spark Dataframes.
+
+ Extends the ``Compare`` functionality to the wide world of Spark and
+ out-of-memory data.
+
+ Parameters
+ ----------
+ spark_session : ``pyspark.sql.SparkSession``
+ A ``SparkSession`` to be used to execute Spark commands in the
+ comparison.
+ base_df : ``pyspark.sql.DataFrame``
+ The dataframe to serve as a basis for comparison. While you will
+ ultimately get the same results comparing A to B as you will comparing
+ B to A, by convention ``base_df`` should be the canonical, gold
+ standard reference dataframe in the comparison.
+ compare_df : ``pyspark.sql.DataFrame``
+ The dataframe to be compared against ``base_df``.
+ join_columns : list
+ A list of columns comprising the join key(s) of the two dataframes.
+ If the column names are the same in the two dataframes, the names of
+ the columns can be given as strings. If the names differ, the
+ ``join_columns`` list should include tuples of the form
+ (base_column_name, compare_column_name).
+ column_mapping : list[tuple], optional
+ If columns to be compared have different names in the base and compare
+ dataframes, a list should be provided in ``columns_mapping`` consisting
+ of tuples of the form (base_column_name, compare_column_name) for each
+ set of differently-named columns to be compared against each other.
+ cache_intermediates : bool, optional
+ Whether or not ``SparkCompare`` will cache intermediate dataframes
+ (such as the deduplicated version of dataframes, or the joined
+ comparison). This will take a large amount of cache, proportional to
+ the size of your dataframes, but will significantly speed up
+ performance, as multiple steps will not have to recompute
+ transformations. False by default.
+ known_differences : list[dict], optional
+ A list of dictionaries that define transformations to apply to the
+ compare dataframe to match values when there are known differences
+ between base and compare. The dictionaries should contain:
+
+ * name: A name that describes the transformation
+ * types: The types that the transformation should be applied to.
+ This prevents certain transformations from being applied to
+ types that don't make sense and would cause exceptions.
+ * transformation: A Spark SQL statement to apply to the column
+ in the compare dataset. The string "{input}" will be replaced
+ by the variable in question.
+ abs_tol : float, optional
+ Absolute tolerance between two values.
+ rel_tol : float, optional
+ Relative tolerance between two values.
+ show_all_columns : bool, optional
+ If true, all columns will be shown in the report including columns
+ with a 100% match rate.
+ match_rates : bool, optional
+ If true, match rates by column will be shown in the column summary.
+
+ Returns
+ -------
+ SparkCompare
+ Instance of a ``SparkCompare`` object, ready to do some comparin'.
+ Note that if ``cache_intermediates=True``, this instance will already
+ have done some work deduping the input dataframes. If
+ ``cache_intermediates=False``, the instantiation of this object is lazy.
+ """
+
+ def __init__(
+ self,
+ spark_session: "pyspark.sql.SparkSession",
+ base_df: "pyspark.sql.DataFrame",
+ compare_df: "pyspark.sql.DataFrame",
+ join_columns: List[Union[str, Tuple[str, str]]],
+ column_mapping: Optional[List[Tuple[str, str]]] = None,
+ cache_intermediates: bool = False,
+ known_differences: Optional[List[Dict[str, Any]]] = None,
+ rel_tol: float = 0,
+ abs_tol: float = 0,
+ show_all_columns: bool = False,
+ match_rates: bool = False,
+ ):
+ self.rel_tol = rel_tol
+ self.abs_tol = abs_tol
+ if self.rel_tol < 0 or self.abs_tol < 0:
+ raise ValueError("Please enter positive valued tolerances")
+ self.show_all_columns = show_all_columns
+ self.match_rates = match_rates
+
+ self._original_base_df = base_df
+ self._original_compare_df = compare_df
+ self.cache_intermediates = cache_intermediates
+
+ self.join_columns = self._tuplizer(input_list=join_columns)
+ self._join_column_names = [name[0] for name in self.join_columns]
+
+ self._known_differences = known_differences
+
+ if column_mapping:
+ for mapping in column_mapping:
+ compare_df = compare_df.withColumnRenamed(mapping[1], mapping[0])
+ self.column_mapping = dict(column_mapping)
+ else:
+ self.column_mapping = {}
+
+ for mapping in self.join_columns:
+ if mapping[1] != mapping[0]:
+ compare_df = compare_df.withColumnRenamed(mapping[1], mapping[0])
+
+ self.spark = spark_session
+ self.base_unq_rows = self.compare_unq_rows = None
+ self._base_row_count: Optional[int] = None
+ self._compare_row_count: Optional[int] = None
+ self._common_row_count: Optional[int] = None
+ self._joined_dataframe: Optional["pyspark.sql.DataFrame"] = None
+ self._rows_only_base: Optional["pyspark.sql.DataFrame"] = None
+ self._rows_only_compare: Optional["pyspark.sql.DataFrame"] = None
+ self._all_matched_rows: Optional["pyspark.sql.DataFrame"] = None
+ self._all_rows_mismatched: Optional["pyspark.sql.DataFrame"] = None
+ self.columns_match_dict: Dict[str, Any] = {}
+
+ # drop the duplicates before actual comparison made.
+ self.base_df = base_df.dropDuplicates(self._join_column_names)
+ self.compare_df = compare_df.dropDuplicates(self._join_column_names)
+
+ if cache_intermediates:
+ self.base_df.cache()
+ self._base_row_count = self.base_df.count()
+ self.compare_df.cache()
+ self._compare_row_count = self.compare_df.count()
+
+ def _tuplizer(
+ self, input_list: List[Union[str, Tuple[str, str]]]
+ ) -> List[Tuple[str, str]]:
+ join_columns: List[Tuple[str, str]] = []
+ for val in input_list:
+ if isinstance(val, str):
+ join_columns.append((val, val))
+ else:
+ join_columns.append(val)
+
+ return join_columns
+
+ @property
+ def columns_in_both(self) -> Set[str]:
+ """set[str]: Get columns in both dataframes"""
+ return set(self.base_df.columns) & set(self.compare_df.columns)
+
+ @property
+ def columns_compared(self) -> List[str]:
+ """list[str]: Get columns to be compared in both dataframes (all
+ columns in both excluding the join key(s)"""
+ return [
+ column
+ for column in list(self.columns_in_both)
+ if column not in self._join_column_names
+ ]
+
+ @property
+ def columns_only_base(self) -> Set[str]:
+ """set[str]: Get columns that are unique to the base dataframe"""
+ return set(self.base_df.columns) - set(self.compare_df.columns)
+
+ @property
+ def columns_only_compare(self) -> Set[str]:
+ """set[str]: Get columns that are unique to the compare dataframe"""
+ return set(self.compare_df.columns) - set(self.base_df.columns)
+
+ @property
+ def base_row_count(self) -> int:
+ """int: Get the count of rows in the de-duped base dataframe"""
+ if self._base_row_count is None:
+ self._base_row_count = self.base_df.count()
+
+ return self._base_row_count
+
+ @property
+ def compare_row_count(self) -> int:
+ """int: Get the count of rows in the de-duped compare dataframe"""
+ if self._compare_row_count is None:
+ self._compare_row_count = self.compare_df.count()
+
+ return self._compare_row_count
+
+ @property
+ def common_row_count(self) -> int:
+ """int: Get the count of rows in common between base and compare dataframes"""
+ if self._common_row_count is None:
+ common_rows = self._get_or_create_joined_dataframe()
+ self._common_row_count = common_rows.count()
+
+ return self._common_row_count
+
+ def _get_unq_base_rows(self) -> "pyspark.sql.DataFrame":
+ """Get the rows only from base data frame"""
+ return self.base_df.select(self._join_column_names).subtract(
+ self.compare_df.select(self._join_column_names)
+ )
+
+ def _get_compare_rows(self) -> "pyspark.sql.DataFrame":
+ """Get the rows only from compare data frame"""
+ return self.compare_df.select(self._join_column_names).subtract(
+ self.base_df.select(self._join_column_names)
+ )
+
+ def _print_columns_summary(self, myfile: TextIO) -> None:
+ """Prints the column summary details"""
+ print("\n****** Column Summary ******", file=myfile)
+ print(
+ f"Number of columns in common with matching schemas: {len(self._columns_with_matching_schema())}",
+ file=myfile,
+ )
+ print(
+ f"Number of columns in common with schema differences: {len(self._columns_with_schemadiff())}",
+ file=myfile,
+ )
+ print(
+ f"Number of columns in base but not compare: {len(self.columns_only_base)}",
+ file=myfile,
+ )
+ print(
+ f"Number of columns in compare but not base: {len(self.columns_only_compare)}",
+ file=myfile,
+ )
+
+ def _print_only_columns(self, base_or_compare: str, myfile: TextIO) -> None:
+ """Prints the columns and data types only in either the base or compare datasets"""
+
+ if base_or_compare.upper() == "BASE":
+ columns = self.columns_only_base
+ df = self.base_df
+ elif base_or_compare.upper() == "COMPARE":
+ columns = self.columns_only_compare
+ df = self.compare_df
+ else:
+ raise ValueError(
+ f'base_or_compare must be BASE or COMPARE, but was "{base_or_compare}"'
+ )
+
+ # If there are no columns only in this dataframe, don't display this section
+ if not columns:
+ return
+
+ max_length = max([len(col) for col in columns] + [11])
+ format_pattern = f"{{:{max_length}s}}"
+
+ print(f"\n****** Columns In {base_or_compare.title()} Only ******", file=myfile)
+ print((format_pattern + " Dtype").format("Column Name"), file=myfile)
+ print("-" * max_length + " -------------", file=myfile)
+
+ for column in columns:
+ col_type = df.select(column).dtypes[0][1]
+ print((format_pattern + " {:13s}").format(column, col_type), file=myfile)
+
+ def _columns_with_matching_schema(self) -> Dict[str, str]:
+ """This function will identify the columns which has matching schema"""
+ col_schema_match = {}
+ base_columns_dict = dict(self.base_df.dtypes)
+ compare_columns_dict = dict(self.compare_df.dtypes)
+
+ for base_row, base_type in base_columns_dict.items():
+ if base_row in compare_columns_dict:
+ compare_column_type = compare_columns_dict.get(base_row)
+ if compare_column_type is not None and base_type in compare_column_type:
+ col_schema_match[base_row] = compare_column_type
+
+ return col_schema_match
+
+ def _columns_with_schemadiff(self) -> Dict[str, Dict[str, str]]:
+ """This function will identify the columns which has different schema"""
+ col_schema_diff = {}
+ base_columns_dict = dict(self.base_df.dtypes)
+ compare_columns_dict = dict(self.compare_df.dtypes)
+
+ for base_row, base_type in base_columns_dict.items():
+ if base_row in compare_columns_dict:
+ compare_column_type = compare_columns_dict.get(base_row)
+ if (
+ compare_column_type is not None
+ and base_type not in compare_column_type
+ ):
+ col_schema_diff[base_row] = dict(
+ base_type=base_type,
+ compare_type=compare_column_type,
+ )
+ return col_schema_diff
+
+ @property
+ def rows_both_mismatch(self) -> Optional["pyspark.sql.DataFrame"]:
+ """pyspark.sql.DataFrame: Returns all rows in both dataframes that have mismatches"""
+ if self._all_rows_mismatched is None:
+ self._merge_dataframes()
+
+ return self._all_rows_mismatched
+
+ @property
+ def rows_both_all(self) -> Optional["pyspark.sql.DataFrame"]:
+ """pyspark.sql.DataFrame: Returns all rows in both dataframes"""
+ if self._all_matched_rows is None:
+ self._merge_dataframes()
+
+ return self._all_matched_rows
+
+ @property
+ def rows_only_base(self) -> "pyspark.sql.DataFrame":
+ """pyspark.sql.DataFrame: Returns rows only in the base dataframe"""
+ if not self._rows_only_base:
+ base_rows = self._get_unq_base_rows()
+ base_rows.createOrReplaceTempView("baseRows")
+ self.base_df.createOrReplaceTempView("baseTable")
+ join_condition = " AND ".join(
+ [
+ "A.`" + name + "`<=>B.`" + name + "`"
+ for name in self._join_column_names
+ ]
+ )
+ sql_query = "select A.* from baseTable as A, baseRows as B where {}".format(
+ join_condition
+ )
+ self._rows_only_base = self.spark.sql(sql_query)
+
+ if self.cache_intermediates:
+ self._rows_only_base.cache().count()
+
+ return self._rows_only_base
+
+ @property
+ def rows_only_compare(self) -> Optional["pyspark.sql.DataFrame"]:
+ """pyspark.sql.DataFrame: Returns rows only in the compare dataframe"""
+ if not self._rows_only_compare:
+ compare_rows = self._get_compare_rows()
+ compare_rows.createOrReplaceTempView("compareRows")
+ self.compare_df.createOrReplaceTempView("compareTable")
+ where_condition = " AND ".join(
+ [
+ "A.`" + name + "`<=>B.`" + name + "`"
+ for name in self._join_column_names
+ ]
+ )
+ sql_query = (
+ "select A.* from compareTable as A, compareRows as B where {}".format(
+ where_condition
+ )
+ )
+ self._rows_only_compare = self.spark.sql(sql_query)
+
+ if self.cache_intermediates:
+ self._rows_only_compare.cache().count()
+
+ return self._rows_only_compare
+
+ def _generate_select_statement(self, match_data: bool = True) -> str:
+ """This function is to generate the select statement to be used later in the query."""
+ base_only = list(set(self.base_df.columns) - set(self.compare_df.columns))
+ compare_only = list(set(self.compare_df.columns) - set(self.base_df.columns))
+ sorted_list = sorted(list(chain(base_only, compare_only, self.columns_in_both)))
+ select_statement = ""
+
+ for column_name in sorted_list:
+ if column_name in self.columns_compared:
+ if match_data:
+ select_statement = select_statement + ",".join(
+ [self._create_case_statement(name=column_name)]
+ )
+ else:
+ select_statement = select_statement + ",".join(
+ [self._create_select_statement(name=column_name)]
+ )
+ elif column_name in base_only:
+ select_statement = select_statement + ",".join(
+ ["A.`" + column_name + "`"]
+ )
+
+ elif column_name in compare_only:
+ if match_data:
+ select_statement = select_statement + ",".join(
+ ["B.`" + column_name + "`"]
+ )
+ else:
+ select_statement = select_statement + ",".join(
+ ["A.`" + column_name + "`"]
+ )
+ elif column_name in self._join_column_names:
+ select_statement = select_statement + ",".join(
+ ["A.`" + column_name + "`"]
+ )
+
+ if column_name != sorted_list[-1]:
+ select_statement = select_statement + " , "
+
+ return select_statement
+
+ def _merge_dataframes(self) -> None:
+ """Merges the two dataframes and creates self._all_matched_rows and self._all_rows_mismatched."""
+ full_joined_dataframe = self._get_or_create_joined_dataframe()
+ full_joined_dataframe.createOrReplaceTempView("full_matched_table")
+
+ select_statement = self._generate_select_statement(False)
+ select_query = """SELECT {} FROM full_matched_table A""".format(
+ select_statement
+ )
+ self._all_matched_rows = self.spark.sql(select_query).orderBy(
+ self._join_column_names # type: ignore[arg-type]
+ )
+ self._all_matched_rows.createOrReplaceTempView("matched_table")
+
+ where_cond = " OR ".join(
+ ["A.`" + name + "_match`= False" for name in self.columns_compared]
+ )
+ mismatch_query = """SELECT * FROM matched_table A WHERE {}""".format(where_cond)
+ self._all_rows_mismatched = self.spark.sql(mismatch_query).orderBy(
+ self._join_column_names # type: ignore[arg-type]
+ )
+
+ def _get_or_create_joined_dataframe(self) -> "pyspark.sql.DataFrame":
+ if self._joined_dataframe is None:
+ join_condition = " AND ".join(
+ [
+ "A.`" + name + "`<=>B.`" + name + "`"
+ for name in self._join_column_names
+ ]
+ )
+ select_statement = self._generate_select_statement(match_data=True)
+
+ self.base_df.createOrReplaceTempView("base_table")
+ self.compare_df.createOrReplaceTempView("compare_table")
+
+ join_query = r"""
+ SELECT {}
+ FROM base_table A
+ JOIN compare_table B
+ ON {}""".format(
+ select_statement, join_condition
+ )
+
+ self._joined_dataframe = self.spark.sql(join_query)
+ if self.cache_intermediates:
+ self._joined_dataframe.cache()
+ self._common_row_count = self._joined_dataframe.count()
+
+ return self._joined_dataframe
+
+ def _print_num_of_rows_with_column_equality(self, myfile: TextIO) -> None:
+ # match_dataframe contains columns from both dataframes with flag to indicate if columns matched
+ match_dataframe = self._get_or_create_joined_dataframe().select(
+ *self.columns_compared
+ )
+ match_dataframe.createOrReplaceTempView("matched_df")
+
+ where_cond = " AND ".join(
+ [
+ "A.`" + name + "`=" + str(MatchType.MATCH.value)
+ for name in self.columns_compared
+ ]
+ )
+ match_query = (
+ r"""SELECT count(*) AS row_count FROM matched_df A WHERE {}""".format(
+ where_cond
+ )
+ )
+ all_rows_matched = self.spark.sql(match_query)
+ all_rows_matched_head = all_rows_matched.head()
+ matched_rows = (
+ all_rows_matched_head[0] if all_rows_matched_head is not None else 0
+ )
+
+ print("\n****** Row Comparison ******", file=myfile)
+ print(
+ f"Number of rows with some columns unequal: {self.common_row_count - matched_rows}",
+ file=myfile,
+ )
+ print(f"Number of rows with all columns equal: {matched_rows}", file=myfile)
+
+ def _populate_columns_match_dict(self) -> None:
+ """
+ side effects:
+ columns_match_dict assigned to { column -> match_type_counts }
+ where:
+ column (string): Name of a column that exists in both the base and comparison columns
+ match_type_counts (list of int with size = len(MatchType)): The number of each match type seen for this column (in order of the MatchType enum values)
+
+ returns: None
+ """
+
+ match_dataframe = self._get_or_create_joined_dataframe().select(
+ *self.columns_compared
+ )
+
+ def helper(c: str) -> "pyspark.sql.Column":
+ # Create a predicate for each match type, comparing column values to the match type value
+ predicates = [F.col(c) == k.value for k in MatchType]
+ # Create a tuple(number of match types found for each match type in this column)
+ return F.struct(
+ [F.lit(F.sum(pred.cast("integer"))) for pred in predicates]
+ ).alias(c)
+
+ # For each column, create a single tuple. This tuple's values correspond to the number of times
+ # each match type appears in that column
+ match_data_agg = match_dataframe.agg(
+ *[helper(col) for col in self.columns_compared]
+ ).collect()
+ match_data = match_data_agg[0]
+
+ for c in self.columns_compared:
+ self.columns_match_dict[c] = match_data[c]
+
+ def _create_select_statement(self, name: str) -> str:
+ if self._known_differences:
+ match_type_comparison = ""
+ for k in MatchType:
+ match_type_comparison += (
+ " WHEN (A.`{name}`={match_value}) THEN '{match_name}'".format(
+ name=name, match_value=str(k.value), match_name=k.name
+ )
+ )
+ return "A.`{name}_base`, A.`{name}_compare`, (CASE WHEN (A.`{name}`={match_failure}) THEN False ELSE True END) AS `{name}_match`, (CASE {match_type_comparison} ELSE 'UNDEFINED' END) AS `{name}_match_type` ".format(
+ name=name,
+ match_failure=MatchType.MISMATCH.value,
+ match_type_comparison=match_type_comparison,
+ )
+ else:
+ return "A.`{name}_base`, A.`{name}_compare`, CASE WHEN (A.`{name}`={match_failure}) THEN False ELSE True END AS `{name}_match` ".format(
+ name=name, match_failure=MatchType.MISMATCH.value
+ )
+
+ def _create_case_statement(self, name: str) -> str:
+ equal_comparisons = ["(A.`{name}` IS NULL AND B.`{name}` IS NULL)"]
+ known_diff_comparisons = ["(FALSE)"]
+
+ base_dtype = [d[1] for d in self.base_df.dtypes if d[0] == name][0]
+ compare_dtype = [d[1] for d in self.compare_df.dtypes if d[0] == name][0]
+
+ if _is_comparable(base_dtype, compare_dtype):
+ if (base_dtype in NUMERIC_SPARK_TYPES) and (
+ compare_dtype in NUMERIC_SPARK_TYPES
+ ): # numeric tolerance comparison
+ equal_comparisons.append(
+ "((A.`{name}`=B.`{name}`) OR ((abs(A.`{name}`-B.`{name}`))<=("
+ + str(self.abs_tol)
+ + "+("
+ + str(self.rel_tol)
+ + "*abs(A.`{name}`)))))"
+ )
+ else: # non-numeric comparison
+ equal_comparisons.append("((A.`{name}`=B.`{name}`))")
+
+ if self._known_differences:
+ new_input = "B.`{name}`"
+ for kd in self._known_differences:
+ if compare_dtype in kd["types"]:
+ if "flags" in kd and "nullcheck" in kd["flags"]:
+ known_diff_comparisons.append(
+ "(("
+ + kd["transformation"].format(new_input, input=new_input)
+ + ") is null AND A.`{name}` is null)"
+ )
+ else:
+ known_diff_comparisons.append(
+ "(("
+ + kd["transformation"].format(new_input, input=new_input)
+ + ") = A.`{name}`)"
+ )
+
+ case_string = (
+ "( CASE WHEN ("
+ + " OR ".join(equal_comparisons)
+ + ") THEN {match_success} WHEN ("
+ + " OR ".join(known_diff_comparisons)
+ + ") THEN {match_known_difference} ELSE {match_failure} END) "
+ + "AS `{name}`, A.`{name}` AS `{name}_base`, B.`{name}` AS `{name}_compare`"
+ )
+
+ return case_string.format(
+ name=name,
+ match_success=MatchType.MATCH.value,
+ match_known_difference=MatchType.KNOWN_DIFFERENCE.value,
+ match_failure=MatchType.MISMATCH.value,
+ )
+
+ def _print_row_summary(self, myfile: TextIO) -> None:
+ base_df_cnt = self.base_df.count()
+ compare_df_cnt = self.compare_df.count()
+ base_df_with_dup_cnt = self._original_base_df.count()
+ compare_df_with_dup_cnt = self._original_compare_df.count()
+
+ print("\n****** Row Summary ******", file=myfile)
+ print(f"Number of rows in common: {self.common_row_count}", file=myfile)
+ print(
+ f"Number of rows in base but not compare: {base_df_cnt - self.common_row_count}",
+ file=myfile,
+ )
+ print(
+ f"Number of rows in compare but not base: {compare_df_cnt - self.common_row_count}",
+ file=myfile,
+ )
+ print(
+ f"Number of duplicate rows found in base: {base_df_with_dup_cnt - base_df_cnt}",
+ file=myfile,
+ )
+ print(
+ f"Number of duplicate rows found in compare: {compare_df_with_dup_cnt - compare_df_cnt}",
+ file=myfile,
+ )
+
+ def _print_schema_diff_details(self, myfile: TextIO) -> None:
+ schema_diff_dict = self._columns_with_schemadiff()
+
+ if not schema_diff_dict: # If there are no differences, don't print the section
+ return
+
+ # For columns with mismatches, what are the longest base and compare column name lengths (with minimums)?
+ base_name_max = max([len(key) for key in schema_diff_dict] + [16])
+ compare_name_max = max(
+ [len(self._base_to_compare_name(key)) for key in schema_diff_dict] + [19]
+ )
+
+ format_pattern = "{{:{base}s}} {{:{compare}s}}".format(
+ base=base_name_max, compare=compare_name_max
+ )
+
+ print("\n****** Schema Differences ******", file=myfile)
+ print(
+ (format_pattern + " Base Dtype Compare Dtype").format(
+ "Base Column Name", "Compare Column Name"
+ ),
+ file=myfile,
+ )
+ print(
+ "-" * base_name_max
+ + " "
+ + "-" * compare_name_max
+ + " ------------- -------------",
+ file=myfile,
+ )
+
+ for base_column, types in schema_diff_dict.items():
+ compare_column = self._base_to_compare_name(base_column)
+
+ print(
+ (format_pattern + " {:13s} {:13s}").format(
+ base_column,
+ compare_column,
+ types["base_type"],
+ types["compare_type"],
+ ),
+ file=myfile,
+ )
+
+ def _base_to_compare_name(self, base_name: str) -> str:
+ """Translates a column name in the base dataframe to its counterpart in the
+ compare dataframe, if they are different."""
+
+ if base_name in self.column_mapping:
+ return self.column_mapping[base_name]
+ else:
+ for name in self.join_columns:
+ if base_name == name[0]:
+ return name[1]
+ return base_name
+
+ def _print_row_matches_by_column(self, myfile: TextIO) -> None:
+ self._populate_columns_match_dict()
+ columns_with_mismatches = {
+ key: self.columns_match_dict[key]
+ for key in self.columns_match_dict
+ if self.columns_match_dict[key][MatchType.MISMATCH.value]
+ }
+
+ # corner case: when all columns match but no rows match
+ # issue: #276
+ try:
+ columns_fully_matching = {
+ key: self.columns_match_dict[key]
+ for key in self.columns_match_dict
+ if sum(self.columns_match_dict[key])
+ == self.columns_match_dict[key][MatchType.MATCH.value]
+ }
+ except TypeError:
+ columns_fully_matching = {}
+
+ try:
+ columns_with_any_diffs = {
+ key: self.columns_match_dict[key]
+ for key in self.columns_match_dict
+ if sum(self.columns_match_dict[key])
+ != self.columns_match_dict[key][MatchType.MATCH.value]
+ }
+ except TypeError:
+ columns_with_any_diffs = {}
+ #
+
+ base_types = {x[0]: x[1] for x in self.base_df.dtypes}
+ compare_types = {x[0]: x[1] for x in self.compare_df.dtypes}
+
+ print("\n****** Column Comparison ******", file=myfile)
+
+ if self._known_differences:
+ print(
+ f"Number of columns compared with unexpected differences in some values: {len(columns_with_mismatches)}",
+ file=myfile,
+ )
+ print(
+ f"Number of columns compared with all values equal but known differences found: {len(self.columns_compared) - len(columns_with_mismatches) - len(columns_fully_matching)}",
+ file=myfile,
+ )
+ print(
+ f"Number of columns compared with all values completely equal: {len(columns_fully_matching)}",
+ file=myfile,
+ )
+ else:
+ print(
+ f"Number of columns compared with some values unequal: {len(columns_with_mismatches)}",
+ file=myfile,
+ )
+ print(
+ f"Number of columns compared with all values equal: {len(columns_fully_matching)}",
+ file=myfile,
+ )
+
+ # If all columns matched, don't print columns with unequal values
+ if (not self.show_all_columns) and (
+ len(columns_fully_matching) == len(self.columns_compared)
+ ):
+ return
+
+ # if show_all_columns is set, set column name length maximum to max of ALL columns(with minimum)
+ if self.show_all_columns:
+ base_name_max = max([len(key) for key in self.columns_match_dict] + [16])
+ compare_name_max = max(
+ [
+ len(self._base_to_compare_name(key))
+ for key in self.columns_match_dict
+ ]
+ + [19]
+ )
+
+ # For columns with any differences, what are the longest base and compare column name lengths (with minimums)?
+ else:
+ base_name_max = max([len(key) for key in columns_with_any_diffs] + [16])
+ compare_name_max = max(
+ [len(self._base_to_compare_name(key)) for key in columns_with_any_diffs]
+ + [19]
+ )
+
+ """ list of (header, condition, width, align)
+ where
+ header (String) : output header for a column
+ condition (Bool): true if this header should be displayed
+ width (Int) : width of the column
+ align (Bool) : true if right-aligned
+ """
+ headers_columns_unequal = [
+ ("Base Column Name", True, base_name_max, False),
+ ("Compare Column Name", True, compare_name_max, False),
+ ("Base Dtype ", True, 13, False),
+ ("Compare Dtype", True, 13, False),
+ ("# Matches", True, 9, True),
+ ("# Known Diffs", self._known_differences is not None, 13, True),
+ ("# Mismatches", True, 12, True),
+ ]
+ if self.match_rates:
+ headers_columns_unequal.append(("Match Rate %", True, 12, True))
+ headers_columns_unequal_valid = [h for h in headers_columns_unequal if h[1]]
+ padding = 2 # spaces to add to left and right of each column
+
+ if self.show_all_columns:
+ print("\n****** Columns with Equal/Unequal Values ******", file=myfile)
+ else:
+ print("\n****** Columns with Unequal Values ******", file=myfile)
+
+ format_pattern = (" " * padding).join(
+ [
+ ("{:" + (">" if h[3] else "") + str(h[2]) + "}")
+ for h in headers_columns_unequal_valid
+ ]
+ )
+ print(
+ format_pattern.format(*[h[0] for h in headers_columns_unequal_valid]),
+ file=myfile,
+ )
+ print(
+ format_pattern.format(
+ *["-" * len(h[0]) for h in headers_columns_unequal_valid]
+ ),
+ file=myfile,
+ )
+
+ for column_name, column_values in sorted(
+ self.columns_match_dict.items(), key=lambda i: i[0]
+ ):
+ num_matches = column_values[MatchType.MATCH.value]
+ num_known_diffs = (
+ None
+ if self._known_differences is None
+ else column_values[MatchType.KNOWN_DIFFERENCE.value]
+ )
+ num_mismatches = column_values[MatchType.MISMATCH.value]
+ compare_column = self._base_to_compare_name(column_name)
+
+ if num_mismatches or num_known_diffs or self.show_all_columns:
+ output_row = [
+ column_name,
+ compare_column,
+ base_types.get(column_name),
+ compare_types.get(column_name),
+ str(num_matches),
+ str(num_mismatches),
+ ]
+ if self.match_rates:
+ match_rate = 100 * (
+ 1
+ - (column_values[MatchType.MISMATCH.value] + 0.0)
+ / self.common_row_count
+ + 0.0
+ )
+ output_row.append("{:02.5f}".format(match_rate))
+ if num_known_diffs is not None:
+ output_row.insert(len(output_row) - 1, str(num_known_diffs))
+ print(format_pattern.format(*output_row), file=myfile)
+
+ # noinspection PyUnresolvedReferences
+ def report(self, file: TextIO = sys.stdout) -> None:
+ """Creates a comparison report and prints it to the file specified
+ (stdout by default).
+
+ Parameters
+ ----------
+ file : ``file``, optional
+ A filehandle to write the report to. By default, this is
+ sys.stdout, printing the report to stdout. You can also redirect
+ this to an output file, as in the example.
+
+ Examples
+ --------
+ >>> with open('my_report.txt', 'w') as report_file:
+ ... comparison.report(file=report_file)
+ """
+
+ self._print_columns_summary(file)
+ self._print_schema_diff_details(file)
+ self._print_only_columns("BASE", file)
+ self._print_only_columns("COMPARE", file)
+ self._print_row_summary(file)
+ self._merge_dataframes()
+ self._print_num_of_rows_with_column_equality(file)
+ self._print_row_matches_by_column(file)
diff --git a/datacompy/polars.py b/datacompy/polars.py
index 814a7cd6..aca96296 100644
--- a/datacompy/polars.py
+++ b/datacompy/polars.py
@@ -20,6 +20,7 @@
PROC COMPARE in SAS - i.e. human-readable reporting on the difference between
two dataframes.
"""
+
import logging
import os
from copy import deepcopy
@@ -265,9 +266,9 @@ def _dataframe_merge(self, ignore_spaces: bool) -> None:
df2_non_join_columns = OrderedSet(df2.columns) - OrderedSet(temp_join_columns)
for c in df1_non_join_columns:
- df1 = df1.rename({c: c + "_df1"})
+ df1 = df1.rename({c: c + "_" + self.df1_name})
for c in df2_non_join_columns:
- df2 = df2.rename({c: c + "_df2"})
+ df2 = df2.rename({c: c + "_" + self.df2_name})
# generate merge indicator
df1 = df1.with_columns(_merge_left=pl.lit(True))
@@ -290,8 +291,8 @@ def _dataframe_merge(self, ignore_spaces: bool) -> None:
if self._any_dupes:
outer_join = outer_join.drop(order_column)
- df1_cols = get_merged_columns(self.df1, outer_join, "_df1")
- df2_cols = get_merged_columns(self.df2, outer_join, "_df2")
+ df1_cols = get_merged_columns(self.df1, outer_join, self.df1_name)
+ df2_cols = get_merged_columns(self.df2, outer_join, self.df2_name)
LOG.debug("Selecting df1 unique rows")
self.df1_unq_rows = outer_join.filter(
@@ -333,8 +334,8 @@ def _intersect_compare(self, ignore_spaces: bool, ignore_case: bool) -> None:
max_diff = 0.0
null_diff = 0
else:
- col_1 = column + "_df1"
- col_2 = column + "_df2"
+ col_1 = column + "_" + self.df1_name
+ col_2 = column + "_" + self.df2_name
col_match = column + "_match"
self.intersect_rows = self.intersect_rows.with_columns(
columns_equal(
@@ -499,7 +500,10 @@ def sample_mismatch(
sample = self.intersect_rows.filter(pl.col(column + "_match") != True).sample(
sample_count
)
- return_cols = self.join_columns + [column + "_df1", column + "_df2"]
+ return_cols = self.join_columns + [
+ column + "_" + self.df1_name,
+ column + "_" + self.df2_name,
+ ]
to_return = sample[return_cols]
if for_display:
to_return.columns = self.join_columns + [
@@ -529,8 +533,8 @@ def all_mismatch(self, ignore_matching_cols: bool = False) -> "pl.DataFrame":
orig_col_name = col[:-6]
col_comparison = columns_equal(
- self.intersect_rows[orig_col_name + "_df1"],
- self.intersect_rows[orig_col_name + "_df2"],
+ self.intersect_rows[orig_col_name + "_" + self.df1_name],
+ self.intersect_rows[orig_col_name + "_" + self.df2_name],
self.rel_tol,
self.abs_tol,
self.ignore_spaces,
@@ -542,7 +546,12 @@ def all_mismatch(self, ignore_matching_cols: bool = False) -> "pl.DataFrame":
):
LOG.debug(f"Adding column {orig_col_name} to the result.")
match_list.append(col)
- return_list.extend([orig_col_name + "_df1", orig_col_name + "_df2"])
+ return_list.extend(
+ [
+ orig_col_name + "_" + self.df1_name,
+ orig_col_name + "_" + self.df2_name,
+ ]
+ )
elif ignore_matching_cols:
LOG.debug(
f"Column {orig_col_name} is equal in df1 and df2. It will not be added to the result."
@@ -622,7 +631,6 @@ def df_to_str(pdf: "pl.DataFrame") -> str:
)
# Column Matching
- cnt_intersect = self.intersect_rows.shape[0]
report += render(
"column_comparison.txt",
len([col for col in self.column_stats if col["unequal_cnt"] > 0]),
@@ -824,7 +832,7 @@ def columns_equal(
compare = pl.Series(
(col_1.eq_missing(col_2)) | (col_1.is_null() & col_2.is_null())
)
- except:
+ except Exception:
# Blanket exception should just return all False
compare = pl.Series(False * col_1.shape[0])
return compare
@@ -861,7 +869,7 @@ def compare_string_and_date_columns(
(str_column.str.to_datetime().eq_missing(date_column))
| (str_column.is_null() & date_column.is_null())
)
- except:
+ except Exception:
return pl.Series([False] * col_1.shape[0])
@@ -884,8 +892,8 @@ def get_merged_columns(
for col in original_df.columns:
if col in merged_df.columns:
columns.append(col)
- elif col + suffix in merged_df.columns:
- columns.append(col + suffix)
+ elif col + "_" + suffix in merged_df.columns:
+ columns.append(col + "_" + suffix)
else:
raise ValueError("Column not found: %s", col)
return columns
@@ -935,7 +943,7 @@ def calculate_max_diff(col_1: "pl.Series", col_2: "pl.Series") -> float:
return cast(
float, (col_1.cast(pl.Float64) - col_2.cast(pl.Float64)).abs().max()
)
- except:
+ except Exception:
return 0.0
diff --git a/datacompy/spark.py b/datacompy/spark.py
index 9fdc2093..070a58e5 100644
--- a/datacompy/spark.py
+++ b/datacompy/spark.py
@@ -13,916 +13,978 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import sys
-from enum import Enum
-from itertools import chain
-from typing import Any, Dict, List, Optional, Set, TextIO, Tuple, Union
-from warnings import warn
-
-try:
- import pyspark
- from pyspark.sql import functions as F
-except ImportError:
- pass # Let non-Spark people at least enjoy the loveliness of the pandas datacompy functionality
-
-
-warn(
- f"The module {__name__} is deprecated. In future versions (0.12.0 and above) SparkCompare will be refactored and the legacy logic will move to LegacySparkCompare ",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class MatchType(Enum):
- MISMATCH, MATCH, KNOWN_DIFFERENCE = range(3)
+"""
+Compare two Pandas on Spark DataFrames
+Originally this package was meant to provide similar functionality to
+PROC COMPARE in SAS - i.e. human-readable reporting on the difference between
+two dataframes.
+"""
-# Used for checking equality with decimal(X, Y) types. Otherwise treated as the string "decimal".
-def decimal_comparator() -> str:
- class DecimalComparator(str):
- def __eq__(self, other: str) -> bool: # type: ignore[override]
- return len(other) >= 7 and other[0:7] == "decimal"
+import logging
+import os
- return DecimalComparator("decimal")
+import pandas as pd
+from ordered_set import OrderedSet
+from datacompy.base import BaseCompare
-NUMERIC_SPARK_TYPES = [
- "tinyint",
- "smallint",
- "int",
- "bigint",
- "float",
- "double",
- decimal_comparator(),
-]
-
-
-def _is_comparable(type1: str, type2: str) -> bool:
- """Checks if two Spark data types can be safely compared.
- Two data types are considered comparable if any of the following apply:
- 1. Both data types are the same
- 2. Both data types are numeric
-
- Parameters
- ----------
- type1 : str
- A string representation of a Spark data type
- type2 : str
- A string representation of a Spark data type
+try:
+ import pyspark.pandas as ps
+ from pandas.api.types import is_numeric_dtype
+except ImportError:
+ pass # Let non-Spark people at least enjoy the loveliness of the pandas datacompy functionality
- Returns
- -------
- bool
- True if both data types are comparable
- """
- return type1 == type2 or (
- type1 in NUMERIC_SPARK_TYPES and type2 in NUMERIC_SPARK_TYPES
- )
+LOG = logging.getLogger(__name__)
-class SparkCompare:
- """Comparison class used to compare two Spark Dataframes.
+class SparkCompare(BaseCompare):
+ """Comparison class to be used to compare whether two Pandas on Spark dataframes are equal.
- Extends the ``Compare`` functionality to the wide world of Spark and
- out-of-memory data.
+ Both df1 and df2 should be dataframes containing all of the join_columns,
+ with unique column names. Differences between values are compared to
+ abs_tol + rel_tol * abs(df2['value']).
Parameters
----------
- spark_session : ``pyspark.sql.SparkSession``
- A ``SparkSession`` to be used to execute Spark commands in the
- comparison.
- base_df : ``pyspark.sql.DataFrame``
- The dataframe to serve as a basis for comparison. While you will
- ultimately get the same results comparing A to B as you will comparing
- B to A, by convention ``base_df`` should be the canonical, gold
- standard reference dataframe in the comparison.
- compare_df : ``pyspark.sql.DataFrame``
- The dataframe to be compared against ``base_df``.
- join_columns : list
- A list of columns comprising the join key(s) of the two dataframes.
- If the column names are the same in the two dataframes, the names of
- the columns can be given as strings. If the names differ, the
- ``join_columns`` list should include tuples of the form
- (base_column_name, compare_column_name).
- column_mapping : list[tuple], optional
- If columns to be compared have different names in the base and compare
- dataframes, a list should be provided in ``columns_mapping`` consisting
- of tuples of the form (base_column_name, compare_column_name) for each
- set of differently-named columns to be compared against each other.
- cache_intermediates : bool, optional
- Whether or not ``SparkCompare`` will cache intermediate dataframes
- (such as the deduplicated version of dataframes, or the joined
- comparison). This will take a large amount of cache, proportional to
- the size of your dataframes, but will significantly speed up
- performance, as multiple steps will not have to recompute
- transformations. False by default.
- known_differences : list[dict], optional
- A list of dictionaries that define transformations to apply to the
- compare dataframe to match values when there are known differences
- between base and compare. The dictionaries should contain:
-
- * name: A name that describes the transformation
- * types: The types that the transformation should be applied to.
- This prevents certain transformations from being applied to
- types that don't make sense and would cause exceptions.
- * transformation: A Spark SQL statement to apply to the column
- in the compare dataset. The string "{input}" will be replaced
- by the variable in question.
+ df1 : pyspark.pandas.frame.DataFrame
+ First dataframe to check
+ df2 : pyspark.pandas.frame.DataFrame
+ Second dataframe to check
+ join_columns : list or str, optional
+ Column(s) to join dataframes on. If a string is passed in, that one
+ column will be used.
abs_tol : float, optional
Absolute tolerance between two values.
rel_tol : float, optional
Relative tolerance between two values.
- show_all_columns : bool, optional
- If true, all columns will be shown in the report including columns
- with a 100% match rate.
- match_rates : bool, optional
- If true, match rates by column will be shown in the column summary.
-
- Returns
- -------
- SparkCompare
- Instance of a ``SparkCompare`` object, ready to do some comparin'.
- Note that if ``cache_intermediates=True``, this instance will already
- have done some work deduping the input dataframes. If
- ``cache_intermediates=False``, the instantiation of this object is lazy.
+ df1_name : str, optional
+ A string name for the first dataframe. This allows the reporting to
+ print out an actual name instead of "df1", and allows human users to
+ more easily track the dataframes.
+ df2_name : str, optional
+ A string name for the second dataframe
+ ignore_spaces : bool, optional
+ Flag to strip whitespace (including newlines) from string columns (including any join
+ columns)
+ ignore_case : bool, optional
+ Flag to ignore the case of string columns
+ cast_column_names_lower: bool, optional
+ Boolean indicator that controls of column names will be cast into lower case
+
+ Attributes
+ ----------
+ df1_unq_rows : pyspark.pandas.frame.DataFrame
+ All records that are only in df1 (based on a join on join_columns)
+ df2_unq_rows : pyspark.pandas.frame.DataFrame
+ All records that are only in df2 (based on a join on join_columns)
"""
def __init__(
self,
- spark_session: "pyspark.sql.SparkSession",
- base_df: "pyspark.sql.DataFrame",
- compare_df: "pyspark.sql.DataFrame",
- join_columns: List[Union[str, Tuple[str, str]]],
- column_mapping: Optional[List[Tuple[str, str]]] = None,
- cache_intermediates: bool = False,
- known_differences: Optional[List[Dict[str, Any]]] = None,
- rel_tol: float = 0,
- abs_tol: float = 0,
- show_all_columns: bool = False,
- match_rates: bool = False,
+ df1,
+ df2,
+ join_columns,
+ abs_tol=0,
+ rel_tol=0,
+ df1_name="df1",
+ df2_name="df2",
+ ignore_spaces=False,
+ ignore_case=False,
+ cast_column_names_lower=True,
):
- self.rel_tol = rel_tol
- self.abs_tol = abs_tol
- if self.rel_tol < 0 or self.abs_tol < 0:
- raise ValueError("Please enter positive valued tolerances")
- self.show_all_columns = show_all_columns
- self.match_rates = match_rates
-
- self._original_base_df = base_df
- self._original_compare_df = compare_df
- self.cache_intermediates = cache_intermediates
-
- self.join_columns = self._tuplizer(input_list=join_columns)
- self._join_column_names = [name[0] for name in self.join_columns]
-
- self._known_differences = known_differences
+ if pd.__version__ >= "2.0.0":
+ raise Exception(
+ "It seems like you are running Pandas 2+. Please note that Pandas 2+ will only be supported in Spark 4+. See: https://issues.apache.org/jira/browse/SPARK-44101. If you need to use Spark DataFrame with Pandas 2+ then consider using Fugue otherwise downgrade to Pandas 1.5.3"
+ )
- if column_mapping:
- for mapping in column_mapping:
- compare_df = compare_df.withColumnRenamed(mapping[1], mapping[0])
- self.column_mapping = dict(column_mapping)
+ ps.set_option("compute.ops_on_diff_frames", True)
+ self.cast_column_names_lower = cast_column_names_lower
+ if isinstance(join_columns, (str, int, float)):
+ self.join_columns = [
+ (
+ str(join_columns).lower()
+ if self.cast_column_names_lower
+ else str(join_columns)
+ )
+ ]
else:
- self.column_mapping = {}
-
- for mapping in self.join_columns:
- if mapping[1] != mapping[0]:
- compare_df = compare_df.withColumnRenamed(mapping[1], mapping[0])
-
- self.spark = spark_session
- self.base_unq_rows = self.compare_unq_rows = None
- self._base_row_count: Optional[int] = None
- self._compare_row_count: Optional[int] = None
- self._common_row_count: Optional[int] = None
- self._joined_dataframe: Optional["pyspark.sql.DataFrame"] = None
- self._rows_only_base: Optional["pyspark.sql.DataFrame"] = None
- self._rows_only_compare: Optional["pyspark.sql.DataFrame"] = None
- self._all_matched_rows: Optional["pyspark.sql.DataFrame"] = None
- self._all_rows_mismatched: Optional["pyspark.sql.DataFrame"] = None
- self.columns_match_dict: Dict[str, Any] = {}
-
- # drop the duplicates before actual comparison made.
- self.base_df = base_df.dropDuplicates(self._join_column_names)
- self.compare_df = compare_df.dropDuplicates(self._join_column_names)
-
- if cache_intermediates:
- self.base_df.cache()
- self._base_row_count = self.base_df.count()
- self.compare_df.cache()
- self._compare_row_count = self.compare_df.count()
-
- def _tuplizer(
- self, input_list: List[Union[str, Tuple[str, str]]]
- ) -> List[Tuple[str, str]]:
- join_columns: List[Tuple[str, str]] = []
- for val in input_list:
- if isinstance(val, str):
- join_columns.append((val, val))
- else:
- join_columns.append(val)
+ self.join_columns = [
+ str(col).lower() if self.cast_column_names_lower else str(col)
+ for col in join_columns
+ ]
- return join_columns
+ self._any_dupes = False
+ self.df1 = df1
+ self.df2 = df2
+ self.df1_name = df1_name
+ self.df2_name = df2_name
+ self.abs_tol = abs_tol
+ self.rel_tol = rel_tol
+ self.ignore_spaces = ignore_spaces
+ self.ignore_case = ignore_case
+ self.df1_unq_rows = self.df2_unq_rows = self.intersect_rows = None
+ self.column_stats = []
+ self._compare(ignore_spaces, ignore_case)
@property
- def columns_in_both(self) -> Set[str]:
- """set[str]: Get columns in both dataframes"""
- return set(self.base_df.columns) & set(self.compare_df.columns)
+ def df1(self):
+ return self._df1
+
+ @df1.setter
+ def df1(self, df1):
+ """Check that it is a dataframe and has the join columns"""
+ self._df1 = df1
+ self._validate_dataframe(
+ "df1", cast_column_names_lower=self.cast_column_names_lower
+ )
@property
- def columns_compared(self) -> List[str]:
- """list[str]: Get columns to be compared in both dataframes (all
- columns in both excluding the join key(s)"""
- return [
- column
- for column in list(self.columns_in_both)
- if column not in self._join_column_names
- ]
+ def df2(self):
+ return self._df2
+
+ @df2.setter
+ def df2(self, df2):
+ """Check that it is a dataframe and has the join columns"""
+ self._df2 = df2
+ self._validate_dataframe(
+ "df2", cast_column_names_lower=self.cast_column_names_lower
+ )
- @property
- def columns_only_base(self) -> Set[str]:
- """set[str]: Get columns that are unique to the base dataframe"""
- return set(self.base_df.columns) - set(self.compare_df.columns)
+ def _validate_dataframe(self, index, cast_column_names_lower=True):
+ """Check that it is a dataframe and has the join columns
- @property
- def columns_only_compare(self) -> Set[str]:
- """set[str]: Get columns that are unique to the compare dataframe"""
- return set(self.compare_df.columns) - set(self.base_df.columns)
+ Parameters
+ ----------
+ index : str
+ The "index" of the dataframe - df1 or df2.
+ cast_column_names_lower: bool, optional
+ Boolean indicator that controls of column names will be cast into lower case
+ """
+ dataframe = getattr(self, index)
+ if not isinstance(dataframe, (ps.DataFrame)):
+ raise TypeError(f"{index} must be a pyspark.pandas.frame.DataFrame")
- @property
- def base_row_count(self) -> int:
- """int: Get the count of rows in the de-duped base dataframe"""
- if self._base_row_count is None:
- self._base_row_count = self.base_df.count()
+ if cast_column_names_lower:
+ dataframe.columns = [str(col).lower() for col in dataframe.columns]
+ else:
+ dataframe.columns = [str(col) for col in dataframe.columns]
+ # Check if join_columns are present in the dataframe
+ if not set(self.join_columns).issubset(set(dataframe.columns)):
+ raise ValueError(f"{index} must have all columns from join_columns")
- return self._base_row_count
+ if len(set(dataframe.columns)) < len(dataframe.columns):
+ raise ValueError(f"{index} must have unique column names")
- @property
- def compare_row_count(self) -> int:
- """int: Get the count of rows in the de-duped compare dataframe"""
- if self._compare_row_count is None:
- self._compare_row_count = self.compare_df.count()
+ if len(dataframe.drop_duplicates(subset=self.join_columns)) < len(dataframe):
+ self._any_dupes = True
- return self._compare_row_count
+ def _compare(self, ignore_spaces, ignore_case):
+ """Actually run the comparison. This tries to run df1.equals(df2)
+ first so that if they're truly equal we can tell.
- @property
- def common_row_count(self) -> int:
- """int: Get the count of rows in common between base and compare dataframes"""
- if self._common_row_count is None:
- common_rows = self._get_or_create_joined_dataframe()
- self._common_row_count = common_rows.count()
-
- return self._common_row_count
-
- def _get_unq_base_rows(self) -> "pyspark.sql.DataFrame":
- """Get the rows only from base data frame"""
- return self.base_df.select(self._join_column_names).subtract(
- self.compare_df.select(self._join_column_names)
+ This method will log out information about what is different between
+ the two dataframes, and will also return a boolean.
+ """
+ LOG.debug("Checking equality")
+ if self.df1.equals(self.df2).all().all():
+ LOG.info("df1 pyspark.pandas.frame.DataFrame.equals df2")
+ else:
+ LOG.info("df1 does not pyspark.pandas.frame.DataFrame.equals df2")
+ LOG.info(f"Number of columns in common: {len(self.intersect_columns())}")
+ LOG.debug("Checking column overlap")
+ for col in self.df1_unq_columns():
+ LOG.info(f"Column in df1 and not in df2: {col}")
+ LOG.info(
+ f"Number of columns in df1 and not in df2: {len(self.df1_unq_columns())}"
)
-
- def _get_compare_rows(self) -> "pyspark.sql.DataFrame":
- """Get the rows only from compare data frame"""
- return self.compare_df.select(self._join_column_names).subtract(
- self.base_df.select(self._join_column_names)
+ for col in self.df2_unq_columns():
+ LOG.info(f"Column in df2 and not in df1: {col}")
+ LOG.info(
+ f"Number of columns in df2 and not in df1: {len(self.df2_unq_columns())}"
)
+ # cache
+ self.df1.spark.cache()
+ self.df2.spark.cache()
+
+ LOG.debug("Merging dataframes")
+ self._dataframe_merge(ignore_spaces)
+ self._intersect_compare(ignore_spaces, ignore_case)
+ if self.matches():
+ LOG.info("df1 matches df2")
+ else:
+ LOG.info("df1 does not match df2")
- def _print_columns_summary(self, myfile: TextIO) -> None:
- """Prints the column summary details"""
- print("\n****** Column Summary ******", file=myfile)
- print(
- f"Number of columns in common with matching schemas: {len(self._columns_with_matching_schema())}",
- file=myfile,
- )
- print(
- f"Number of columns in common with schema differences: {len(self._columns_with_schemadiff())}",
- file=myfile,
- )
- print(
- f"Number of columns in base but not compare: {len(self.columns_only_base)}",
- file=myfile,
- )
- print(
- f"Number of columns in compare but not base: {len(self.columns_only_compare)}",
- file=myfile,
- )
+ def df1_unq_columns(self):
+ """Get columns that are unique to df1"""
+ return OrderedSet(self.df1.columns) - OrderedSet(self.df2.columns)
+
+ def df2_unq_columns(self):
+ """Get columns that are unique to df2"""
+ return OrderedSet(self.df2.columns) - OrderedSet(self.df1.columns)
+
+ def intersect_columns(self):
+ """Get columns that are shared between the two dataframes"""
+ return OrderedSet(self.df1.columns) & OrderedSet(self.df2.columns)
+
+ def _dataframe_merge(self, ignore_spaces):
+ """Merge df1 to df2 on the join columns, to get df1 - df2, df2 - df1
+ and df1 & df2
+ """
+
+ LOG.debug("Outer joining")
+
+ df1 = self.df1.copy()
+ df2 = self.df2.copy()
+
+ if self._any_dupes:
+ LOG.debug("Duplicate rows found, deduping by order of remaining fields")
+ temp_join_columns = list(self.join_columns)
- def _print_only_columns(self, base_or_compare: str, myfile: TextIO) -> None:
- """Prints the columns and data types only in either the base or compare datasets"""
+ # Create order column for uniqueness of match
+ order_column = temp_column_name(df1, df2)
+ df1[order_column] = generate_id_within_group(df1, temp_join_columns)
+ df2[order_column] = generate_id_within_group(df2, temp_join_columns)
+ temp_join_columns.append(order_column)
- if base_or_compare.upper() == "BASE":
- columns = self.columns_only_base
- df = self.base_df
- elif base_or_compare.upper() == "COMPARE":
- columns = self.columns_only_compare
- df = self.compare_df
+ params = {"on": temp_join_columns}
else:
- raise ValueError(
- f'base_or_compare must be BASE or COMPARE, but was "{base_or_compare}"'
- )
+ params = {"on": self.join_columns}
- # If there are no columns only in this dataframe, don't display this section
- if not columns:
- return
-
- max_length = max([len(col) for col in columns] + [11])
- format_pattern = f"{{:{max_length}s}}"
-
- print(f"\n****** Columns In {base_or_compare.title()} Only ******", file=myfile)
- print((format_pattern + " Dtype").format("Column Name"), file=myfile)
- print("-" * max_length + " -------------", file=myfile)
-
- for column in columns:
- col_type = df.select(column).dtypes[0][1]
- print((format_pattern + " {:13s}").format(column, col_type), file=myfile)
-
- def _columns_with_matching_schema(self) -> Dict[str, str]:
- """This function will identify the columns which has matching schema"""
- col_schema_match = {}
- base_columns_dict = dict(self.base_df.dtypes)
- compare_columns_dict = dict(self.compare_df.dtypes)
-
- for base_row, base_type in base_columns_dict.items():
- if base_row in compare_columns_dict:
- compare_column_type = compare_columns_dict.get(base_row)
- if compare_column_type is not None and base_type in compare_column_type:
- col_schema_match[base_row] = compare_column_type
-
- return col_schema_match
-
- def _columns_with_schemadiff(self) -> Dict[str, Dict[str, str]]:
- """This function will identify the columns which has different schema"""
- col_schema_diff = {}
- base_columns_dict = dict(self.base_df.dtypes)
- compare_columns_dict = dict(self.compare_df.dtypes)
-
- for base_row, base_type in base_columns_dict.items():
- if base_row in compare_columns_dict:
- compare_column_type = compare_columns_dict.get(base_row)
- if (
- compare_column_type is not None
- and base_type not in compare_column_type
- ):
- col_schema_diff[base_row] = dict(
- base_type=base_type,
- compare_type=compare_column_type,
- )
- return col_schema_diff
+ if ignore_spaces:
+ for column in self.join_columns:
+ if df1[column].dtype.kind == "O":
+ df1[column] = df1[column].str.strip()
+ if df2[column].dtype.kind == "O":
+ df2[column] = df2[column].str.strip()
- @property
- def rows_both_mismatch(self) -> Optional["pyspark.sql.DataFrame"]:
- """pyspark.sql.DataFrame: Returns all rows in both dataframes that have mismatches"""
- if self._all_rows_mismatched is None:
- self._merge_dataframes()
+ non_join_columns = (
+ OrderedSet(df1.columns) | OrderedSet(df2.columns)
+ ) - OrderedSet(self.join_columns)
- return self._all_rows_mismatched
+ for c in non_join_columns:
+ df1.rename(columns={c: c + "_" + self.df1_name}, inplace=True)
+ df2.rename(columns={c: c + "_" + self.df2_name}, inplace=True)
- @property
- def rows_both_all(self) -> Optional["pyspark.sql.DataFrame"]:
- """pyspark.sql.DataFrame: Returns all rows in both dataframes"""
- if self._all_matched_rows is None:
- self._merge_dataframes()
+ # generate merge indicator
+ df1["_merge_left"] = True
+ df2["_merge_right"] = True
- return self._all_matched_rows
+ for c in self.join_columns:
+ df1.rename(columns={c: c + "_" + self.df1_name}, inplace=True)
+ df2.rename(columns={c: c + "_" + self.df2_name}, inplace=True)
- @property
- def rows_only_base(self) -> "pyspark.sql.DataFrame":
- """pyspark.sql.DataFrame: Returns rows only in the base dataframe"""
- if not self._rows_only_base:
- base_rows = self._get_unq_base_rows()
- base_rows.createOrReplaceTempView("baseRows")
- self.base_df.createOrReplaceTempView("baseTable")
- join_condition = " AND ".join(
- [
- "A.`" + name + "`<=>B.`" + name + "`"
- for name in self._join_column_names
- ]
- )
- sql_query = "select A.* from baseTable as A, baseRows as B where {}".format(
- join_condition
- )
- self._rows_only_base = self.spark.sql(sql_query)
+ # cache
+ df1.spark.cache()
+ df2.spark.cache()
- if self.cache_intermediates:
- self._rows_only_base.cache().count()
+ # NULL SAFE Outer join using ON
+ on = " and ".join(
+ [
+ f"df1.`{c}_{self.df1_name}` <=> df2.`{c}_{self.df2_name}`"
+ for c in params["on"]
+ ]
+ )
+ outer_join = ps.sql(
+ """
+ SELECT * FROM
+ {df1} df1 FULL OUTER JOIN {df2} df2
+ ON
+ """
+ + on,
+ df1=df1,
+ df2=df2,
+ )
- return self._rows_only_base
+ outer_join["_merge"] = None # initialize col
- @property
- def rows_only_compare(self) -> Optional["pyspark.sql.DataFrame"]:
- """pyspark.sql.DataFrame: Returns rows only in the compare dataframe"""
- if not self._rows_only_compare:
- compare_rows = self._get_compare_rows()
- compare_rows.createOrReplaceTempView("compareRows")
- self.compare_df.createOrReplaceTempView("compareTable")
- where_condition = " AND ".join(
+ # process merge indicator
+ outer_join["_merge"] = outer_join._merge.mask(
+ (outer_join["_merge_left"] == True) & (outer_join["_merge_right"] == True),
+ "both",
+ )
+ outer_join["_merge"] = outer_join._merge.mask(
+ (outer_join["_merge_left"] == True) & (outer_join["_merge_right"] != True),
+ "left_only",
+ )
+ outer_join["_merge"] = outer_join._merge.mask(
+ (outer_join["_merge_left"] != True) & (outer_join["_merge_right"] == True),
+ "right_only",
+ )
+
+ # Clean up temp columns for duplicate row matching
+ if self._any_dupes:
+ outer_join = outer_join.drop(
[
- "A.`" + name + "`<=>B.`" + name + "`"
- for name in self._join_column_names
- ]
+ order_column + "_" + self.df1_name,
+ order_column + "_" + self.df2_name,
+ ],
+ axis=1,
)
- sql_query = (
- "select A.* from compareTable as A, compareRows as B where {}".format(
- where_condition
- )
+ df1 = df1.drop(
+ [
+ order_column + "_" + self.df1_name,
+ order_column + "_" + self.df2_name,
+ ],
+ axis=1,
+ )
+ df2 = df2.drop(
+ [
+ order_column + "_" + self.df1_name,
+ order_column + "_" + self.df2_name,
+ ],
+ axis=1,
)
- self._rows_only_compare = self.spark.sql(sql_query)
- if self.cache_intermediates:
- self._rows_only_compare.cache().count()
+ df1_cols = get_merged_columns(df1, outer_join, self.df1_name)
+ df2_cols = get_merged_columns(df2, outer_join, self.df2_name)
- return self._rows_only_compare
+ LOG.debug("Selecting df1 unique rows")
+ self.df1_unq_rows = outer_join[outer_join["_merge"] == "left_only"][
+ df1_cols
+ ].copy()
- def _generate_select_statement(self, match_data: bool = True) -> str:
- """This function is to generate the select statement to be used later in the query."""
- base_only = list(set(self.base_df.columns) - set(self.compare_df.columns))
- compare_only = list(set(self.compare_df.columns) - set(self.base_df.columns))
- sorted_list = sorted(list(chain(base_only, compare_only, self.columns_in_both)))
- select_statement = ""
+ LOG.debug("Selecting df2 unique rows")
+ self.df2_unq_rows = outer_join[outer_join["_merge"] == "right_only"][
+ df2_cols
+ ].copy()
- for column_name in sorted_list:
- if column_name in self.columns_compared:
- if match_data:
- select_statement = select_statement + ",".join(
- [self._create_case_statement(name=column_name)]
- )
- else:
- select_statement = select_statement + ",".join(
- [self._create_select_statement(name=column_name)]
- )
- elif column_name in base_only:
- select_statement = select_statement + ",".join(
- ["A.`" + column_name + "`"]
- )
+ LOG.info(f"Number of rows in df1 and not in df2: {len(self.df1_unq_rows)}")
+ LOG.info(f"Number of rows in df2 and not in df1: {len(self.df2_unq_rows)}")
- elif column_name in compare_only:
- if match_data:
- select_statement = select_statement + ",".join(
- ["B.`" + column_name + "`"]
- )
- else:
- select_statement = select_statement + ",".join(
- ["A.`" + column_name + "`"]
- )
- elif column_name in self._join_column_names:
- select_statement = select_statement + ",".join(
- ["A.`" + column_name + "`"]
+ LOG.debug("Selecting intersecting rows")
+ self.intersect_rows = outer_join[outer_join["_merge"] == "both"].copy()
+ LOG.info(
+ "Number of rows in df1 and df2 (not necessarily equal): {len(self.intersect_rows)}"
+ )
+ # cache
+ self.intersect_rows.spark.cache()
+
+ def _intersect_compare(self, ignore_spaces, ignore_case):
+ """Run the comparison on the intersect dataframe
+
+ This loops through all columns that are shared between df1 and df2, and
+ creates a column column_match which is True for matches, False
+ otherwise.
+ """
+ LOG.debug("Comparing intersection")
+ row_cnt = len(self.intersect_rows)
+ for column in self.intersect_columns():
+ if column in self.join_columns:
+ match_cnt = row_cnt
+ col_match = ""
+ max_diff = 0
+ null_diff = 0
+ else:
+ col_1 = column + "_" + self.df1_name
+ col_2 = column + "_" + self.df2_name
+ col_match = column + "_match"
+ self.intersect_rows[col_match] = columns_equal(
+ self.intersect_rows[col_1],
+ self.intersect_rows[col_2],
+ self.rel_tol,
+ self.abs_tol,
+ ignore_spaces,
+ ignore_case,
+ )
+ match_cnt = self.intersect_rows[col_match].sum()
+ max_diff = calculate_max_diff(
+ self.intersect_rows[col_1], self.intersect_rows[col_2]
)
- if column_name != sorted_list[-1]:
- select_statement = select_statement + " , "
+ try:
+ null_diff = (
+ (self.intersect_rows[col_1].isnull())
+ ^ (self.intersect_rows[col_2].isnull())
+ ).sum()
+ except TypeError: # older pyspark compatibility
+ temp_null_diff = self.intersect_rows[[col_1, col_2]].isnull()
+ null_diff = (temp_null_diff[col_1] != temp_null_diff[col_2]).sum()
+
+ if row_cnt > 0:
+ match_rate = float(match_cnt) / row_cnt
+ else:
+ match_rate = 0
+ LOG.info(f"{column}: {match_cnt} / {row_cnt} ({match_rate:.2%}) match")
+
+ self.column_stats.append(
+ {
+ "column": column,
+ "match_column": col_match,
+ "match_cnt": match_cnt,
+ "unequal_cnt": row_cnt - match_cnt,
+ "dtype1": str(self.df1[column].dtype),
+ "dtype2": str(self.df2[column].dtype),
+ "all_match": all(
+ (
+ self.df1[column].dtype == self.df2[column].dtype,
+ row_cnt == match_cnt,
+ )
+ ),
+ "max_diff": max_diff,
+ "null_diff": null_diff,
+ }
+ )
- return select_statement
+ def all_columns_match(self):
+ """Whether the columns all match in the dataframes"""
+ return self.df1_unq_columns() == self.df2_unq_columns() == set()
- def _merge_dataframes(self) -> None:
- """Merges the two dataframes and creates self._all_matched_rows and self._all_rows_mismatched."""
- full_joined_dataframe = self._get_or_create_joined_dataframe()
- full_joined_dataframe.createOrReplaceTempView("full_matched_table")
+ def all_rows_overlap(self):
+ """Whether the rows are all present in both dataframes
- select_statement = self._generate_select_statement(False)
- select_query = """SELECT {} FROM full_matched_table A""".format(
- select_statement
- )
- self._all_matched_rows = self.spark.sql(select_query).orderBy(
- self._join_column_names # type: ignore[arg-type]
- )
- self._all_matched_rows.createOrReplaceTempView("matched_table")
+ Returns
+ -------
+ bool
+ True if all rows in df1 are in df2 and vice versa (based on
+ existence for join option)
+ """
+ return len(self.df1_unq_rows) == len(self.df2_unq_rows) == 0
- where_cond = " OR ".join(
- ["A.`" + name + "_match`= False" for name in self.columns_compared]
- )
- mismatch_query = """SELECT * FROM matched_table A WHERE {}""".format(where_cond)
- self._all_rows_mismatched = self.spark.sql(mismatch_query).orderBy(
- self._join_column_names # type: ignore[arg-type]
- )
+ def count_matching_rows(self):
+ """Count the number of rows match (on overlapping fields)
- def _get_or_create_joined_dataframe(self) -> "pyspark.sql.DataFrame":
- if self._joined_dataframe is None:
- join_condition = " AND ".join(
- [
- "A.`" + name + "`<=>B.`" + name + "`"
- for name in self._join_column_names
- ]
+ Returns
+ -------
+ int
+ Number of matching rows
+ """
+ conditions = []
+ match_columns = []
+ for column in self.intersect_columns():
+ if column not in self.join_columns:
+ match_columns.append(column + "_match")
+ conditions.append(f"`{column}_match` == True")
+ if len(conditions) > 0:
+ match_columns_count = (
+ self.intersect_rows[match_columns]
+ .query(" and ".join(conditions))
+ .shape[0]
)
- select_statement = self._generate_select_statement(match_data=True)
+ else:
+ match_columns_count = 0
+ return match_columns_count
- self.base_df.createOrReplaceTempView("base_table")
- self.compare_df.createOrReplaceTempView("compare_table")
+ def intersect_rows_match(self):
+ """Check whether the intersect rows all match"""
+ actual_length = self.intersect_rows.shape[0]
+ return self.count_matching_rows() == actual_length
- join_query = r"""
- SELECT {}
- FROM base_table A
- JOIN compare_table B
- ON {}""".format(
- select_statement, join_condition
- )
+ def matches(self, ignore_extra_columns=False):
+ """Return True or False if the dataframes match.
- self._joined_dataframe = self.spark.sql(join_query)
- if self.cache_intermediates:
- self._joined_dataframe.cache()
- self._common_row_count = self._joined_dataframe.count()
+ Parameters
+ ----------
+ ignore_extra_columns : bool
+ Ignores any columns in one dataframe and not in the other.
+ """
+ if not ignore_extra_columns and not self.all_columns_match():
+ return False
+ elif not self.all_rows_overlap():
+ return False
+ elif not self.intersect_rows_match():
+ return False
+ else:
+ return True
- return self._joined_dataframe
+ def subset(self):
+ """Return True if dataframe 2 is a subset of dataframe 1.
- def _print_num_of_rows_with_column_equality(self, myfile: TextIO) -> None:
- # match_dataframe contains columns from both dataframes with flag to indicate if columns matched
- match_dataframe = self._get_or_create_joined_dataframe().select(
- *self.columns_compared
- )
- match_dataframe.createOrReplaceTempView("matched_df")
+ Dataframe 2 is considered a subset if all of its columns are in
+ dataframe 1, and all of its rows match rows in dataframe 1 for the
+ shared columns.
+ """
+ if not self.df2_unq_columns() == set():
+ return False
+ elif not len(self.df2_unq_rows) == 0:
+ return False
+ elif not self.intersect_rows_match():
+ return False
+ else:
+ return True
- where_cond = " AND ".join(
- [
- "A.`" + name + "`=" + str(MatchType.MATCH.value)
- for name in self.columns_compared
+ def sample_mismatch(self, column, sample_count=10, for_display=False):
+ """Returns a sample sub-dataframe which contains the identifying
+ columns, and df1 and df2 versions of the column.
+
+ Parameters
+ ----------
+ column : str
+ The raw column name (i.e. without ``_df1`` appended)
+ sample_count : int, optional
+ The number of sample records to return. Defaults to 10.
+ for_display : bool, optional
+ Whether this is just going to be used for display (overwrite the
+ column names)
+
+ Returns
+ -------
+ pyspark.pandas.frame.DataFrame
+ A sample of the intersection dataframe, containing only the
+ "pertinent" columns, for rows that don't match on the provided
+ column.
+ """
+ row_cnt = self.intersect_rows.shape[0]
+ col_match = self.intersect_rows[column + "_match"]
+ match_cnt = col_match.sum()
+ sample_count = min(sample_count, row_cnt - match_cnt)
+ sample = self.intersect_rows[~col_match].head(sample_count)
+
+ for c in self.join_columns:
+ sample[c] = sample[c + "_" + self.df1_name]
+
+ return_cols = self.join_columns + [
+ column + "_" + self.df1_name,
+ column + "_" + self.df2_name,
+ ]
+ to_return = sample[return_cols]
+ if for_display:
+ to_return.columns = self.join_columns + [
+ column + " (" + self.df1_name + ")",
+ column + " (" + self.df2_name + ")",
]
- )
- match_query = (
- r"""SELECT count(*) AS row_count FROM matched_df A WHERE {}""".format(
- where_cond
- )
- )
- all_rows_matched = self.spark.sql(match_query)
- all_rows_matched_head = all_rows_matched.head()
- matched_rows = (
- all_rows_matched_head[0] if all_rows_matched_head is not None else 0
- )
+ return to_return
- print("\n****** Row Comparison ******", file=myfile)
- print(
- f"Number of rows with some columns unequal: {self.common_row_count - matched_rows}",
- file=myfile,
- )
- print(f"Number of rows with all columns equal: {matched_rows}", file=myfile)
+ def all_mismatch(self, ignore_matching_cols=False):
+ """All rows with any columns that have a mismatch. Returns all df1 and df2 versions of the columns and join
+ columns.
+
+ Parameters
+ ----------
+ ignore_matching_cols : bool, optional
+ Whether showing the matching columns in the output or not. The default is False.
- def _populate_columns_match_dict(self) -> None:
+ Returns
+ -------
+ pyspark.pandas.frame.DataFrame
+ All rows of the intersection dataframe, containing any columns, that don't match.
"""
- side effects:
- columns_match_dict assigned to { column -> match_type_counts }
- where:
- column (string): Name of a column that exists in both the base and comparison columns
- match_type_counts (list of int with size = len(MatchType)): The number of each match type seen for this column (in order of the MatchType enum values)
+ match_list = []
+ return_list = []
+ for col in self.intersect_rows.columns:
+ if col.endswith("_match"):
+ orig_col_name = col[:-6]
+
+ col_comparison = columns_equal(
+ self.intersect_rows[orig_col_name + "_" + self.df1_name],
+ self.intersect_rows[orig_col_name + "_" + self.df2_name],
+ self.rel_tol,
+ self.abs_tol,
+ self.ignore_spaces,
+ self.ignore_case,
+ )
+
+ if not ignore_matching_cols or (
+ ignore_matching_cols and not col_comparison.all()
+ ):
+ LOG.debug(f"Adding column {orig_col_name} to the result.")
+ match_list.append(col)
+ return_list.extend(
+ [
+ orig_col_name + "_" + self.df1_name,
+ orig_col_name + "_" + self.df2_name,
+ ]
+ )
+ elif ignore_matching_cols:
+ LOG.debug(
+ f"Column {orig_col_name} is equal in df1 and df2. It will not be added to the result."
+ )
+
+ mm_bool = self.intersect_rows[match_list].T.all()
- returns: None
+ updated_join_columns = []
+ for c in self.join_columns:
+ updated_join_columns.append(c + "_" + self.df1_name)
+ updated_join_columns.append(c + "_" + self.df2_name)
+
+ return self.intersect_rows[~mm_bool][updated_join_columns + return_list]
+
+ def report(self, sample_count=10, column_count=10, html_file=None):
+ """Returns a string representation of a report. The representation can
+ then be printed or saved to a file.
+
+ Parameters
+ ----------
+ sample_count : int, optional
+ The number of sample records to return. Defaults to 10.
+
+ column_count : int, optional
+ The number of columns to display in the sample records output. Defaults to 10.
+
+ html_file : str, optional
+ HTML file name to save report output to. If ``None`` the file creation will be skipped.
+
+ Returns
+ -------
+ str
+ The report, formatted kinda nicely.
"""
+ # Header
+ report = render("header.txt")
+ df_header = ps.DataFrame(
+ {
+ "DataFrame": [self.df1_name, self.df2_name],
+ "Columns": [self.df1.shape[1], self.df2.shape[1]],
+ "Rows": [self.df1.shape[0], self.df2.shape[0]],
+ }
+ )
+ report += df_header[["DataFrame", "Columns", "Rows"]].to_string()
+ report += "\n\n"
+
+ # Column Summary
+ report += render(
+ "column_summary.txt",
+ len(self.intersect_columns()),
+ len(self.df1_unq_columns()),
+ len(self.df2_unq_columns()),
+ self.df1_name,
+ self.df2_name,
+ )
- match_dataframe = self._get_or_create_joined_dataframe().select(
- *self.columns_compared
+ # Row Summary
+ match_on = ", ".join(self.join_columns)
+ report += render(
+ "row_summary.txt",
+ match_on,
+ self.abs_tol,
+ self.rel_tol,
+ self.intersect_rows.shape[0],
+ self.df1_unq_rows.shape[0],
+ self.df2_unq_rows.shape[0],
+ self.intersect_rows.shape[0] - self.count_matching_rows(),
+ self.count_matching_rows(),
+ self.df1_name,
+ self.df2_name,
+ "Yes" if self._any_dupes else "No",
)
- def helper(c: str) -> "pyspark.sql.Column":
- # Create a predicate for each match type, comparing column values to the match type value
- predicates = [F.col(c) == k.value for k in MatchType]
- # Create a tuple(number of match types found for each match type in this column)
- return F.struct(
- [F.lit(F.sum(pred.cast("integer"))) for pred in predicates]
- ).alias(c)
-
- # For each column, create a single tuple. This tuple's values correspond to the number of times
- # each match type appears in that column
- match_data_agg = match_dataframe.agg(
- *[helper(col) for col in self.columns_compared]
- ).collect()
- match_data = match_data_agg[0]
-
- for c in self.columns_compared:
- self.columns_match_dict[c] = match_data[c]
-
- def _create_select_statement(self, name: str) -> str:
- if self._known_differences:
- match_type_comparison = ""
- for k in MatchType:
- match_type_comparison += (
- " WHEN (A.`{name}`={match_value}) THEN '{match_name}'".format(
- name=name, match_value=str(k.value), match_name=k.name
- )
+ # Column Matching
+ report += render(
+ "column_comparison.txt",
+ len([col for col in self.column_stats if col["unequal_cnt"] > 0]),
+ len([col for col in self.column_stats if col["unequal_cnt"] == 0]),
+ sum([col["unequal_cnt"] for col in self.column_stats]),
+ )
+
+ match_stats = []
+ match_sample = []
+ any_mismatch = False
+ for column in self.column_stats:
+ if not column["all_match"]:
+ any_mismatch = True
+ match_stats.append(
+ {
+ "Column": column["column"],
+ f"{self.df1_name} dtype": column["dtype1"],
+ f"{self.df2_name} dtype": column["dtype2"],
+ "# Unequal": column["unequal_cnt"],
+ "Max Diff": column["max_diff"],
+ "# Null Diff": column["null_diff"],
+ }
)
- return "A.`{name}_base`, A.`{name}_compare`, (CASE WHEN (A.`{name}`={match_failure}) THEN False ELSE True END) AS `{name}_match`, (CASE {match_type_comparison} ELSE 'UNDEFINED' END) AS `{name}_match_type` ".format(
- name=name,
- match_failure=MatchType.MISMATCH.value,
- match_type_comparison=match_type_comparison,
+ if column["unequal_cnt"] > 0:
+ match_sample.append(
+ self.sample_mismatch(
+ column["column"], sample_count, for_display=True
+ )
+ )
+
+ if any_mismatch:
+ report += "Columns with Unequal Values or Types\n"
+ report += "------------------------------------\n"
+ report += "\n"
+ df_match_stats = ps.DataFrame(match_stats)
+ df_match_stats.sort_values("Column", inplace=True)
+ # Have to specify again for sorting
+ report += df_match_stats[
+ [
+ "Column",
+ f"{self.df1_name} dtype",
+ f"{self.df2_name} dtype",
+ "# Unequal",
+ "Max Diff",
+ "# Null Diff",
+ ]
+ ].to_string()
+ report += "\n\n"
+
+ if sample_count > 0:
+ report += "Sample Rows with Unequal Values\n"
+ report += "-------------------------------\n"
+ report += "\n"
+ for sample in match_sample:
+ report += sample.to_string()
+ report += "\n\n"
+
+ if min(sample_count, self.df1_unq_rows.shape[0]) > 0:
+ report += (
+ f"Sample Rows Only in {self.df1_name} (First {column_count} Columns)\n"
)
- else:
- return "A.`{name}_base`, A.`{name}_compare`, CASE WHEN (A.`{name}`={match_failure}) THEN False ELSE True END AS `{name}_match` ".format(
- name=name, match_failure=MatchType.MISMATCH.value
+ report += (
+ f"---------------------------------------{'-' * len(self.df1_name)}\n"
)
+ report += "\n"
+ columns = self.df1_unq_rows.columns[:column_count]
+ unq_count = min(sample_count, self.df1_unq_rows.shape[0])
+ report += self.df1_unq_rows.head(unq_count)[columns].to_string()
+ report += "\n\n"
+
+ if min(sample_count, self.df2_unq_rows.shape[0]) > 0:
+ report += (
+ f"Sample Rows Only in {self.df2_name} (First {column_count} Columns)\n"
+ )
+ report += (
+ f"---------------------------------------{'-' * len(self.df2_name)}\n"
+ )
+ report += "\n"
+ columns = self.df2_unq_rows.columns[:column_count]
+ unq_count = min(sample_count, self.df2_unq_rows.shape[0])
+ report += self.df2_unq_rows.head(unq_count)[columns].to_string()
+ report += "\n\n"
- def _create_case_statement(self, name: str) -> str:
- equal_comparisons = ["(A.`{name}` IS NULL AND B.`{name}` IS NULL)"]
- known_diff_comparisons = ["(FALSE)"]
-
- base_dtype = [d[1] for d in self.base_df.dtypes if d[0] == name][0]
- compare_dtype = [d[1] for d in self.compare_df.dtypes if d[0] == name][0]
-
- if _is_comparable(base_dtype, compare_dtype):
- if (base_dtype in NUMERIC_SPARK_TYPES) and (
- compare_dtype in NUMERIC_SPARK_TYPES
- ): # numeric tolerance comparison
- equal_comparisons.append(
- "((A.`{name}`=B.`{name}`) OR ((abs(A.`{name}`-B.`{name}`))<=("
- + str(self.abs_tol)
- + "+("
- + str(self.rel_tol)
- + "*abs(A.`{name}`)))))"
- )
- else: # non-numeric comparison
- equal_comparisons.append("((A.`{name}`=B.`{name}`))")
-
- if self._known_differences:
- new_input = "B.`{name}`"
- for kd in self._known_differences:
- if compare_dtype in kd["types"]:
- if "flags" in kd and "nullcheck" in kd["flags"]:
- known_diff_comparisons.append(
- "(("
- + kd["transformation"].format(new_input, input=new_input)
- + ") is null AND A.`{name}` is null)"
- )
- else:
- known_diff_comparisons.append(
- "(("
- + kd["transformation"].format(new_input, input=new_input)
- + ") = A.`{name}`)"
- )
+ if html_file:
+ html_report = report.replace("\n", "
").replace(" ", " ")
+ html_report = f"
{html_report}
"
+ with open(html_file, "w") as f:
+ f.write(html_report)
- case_string = (
- "( CASE WHEN ("
- + " OR ".join(equal_comparisons)
- + ") THEN {match_success} WHEN ("
- + " OR ".join(known_diff_comparisons)
- + ") THEN {match_known_difference} ELSE {match_failure} END) "
- + "AS `{name}`, A.`{name}` AS `{name}_base`, B.`{name}` AS `{name}_compare`"
- )
+ return report
- return case_string.format(
- name=name,
- match_success=MatchType.MATCH.value,
- match_known_difference=MatchType.KNOWN_DIFFERENCE.value,
- match_failure=MatchType.MISMATCH.value,
- )
- def _print_row_summary(self, myfile: TextIO) -> None:
- base_df_cnt = self.base_df.count()
- compare_df_cnt = self.compare_df.count()
- base_df_with_dup_cnt = self._original_base_df.count()
- compare_df_with_dup_cnt = self._original_compare_df.count()
-
- print("\n****** Row Summary ******", file=myfile)
- print(f"Number of rows in common: {self.common_row_count}", file=myfile)
- print(
- f"Number of rows in base but not compare: {base_df_cnt - self.common_row_count}",
- file=myfile,
- )
- print(
- f"Number of rows in compare but not base: {compare_df_cnt - self.common_row_count}",
- file=myfile,
- )
- print(
- f"Number of duplicate rows found in base: {base_df_with_dup_cnt - base_df_cnt}",
- file=myfile,
- )
- print(
- f"Number of duplicate rows found in compare: {compare_df_with_dup_cnt - compare_df_cnt}",
- file=myfile,
- )
+def render(filename, *fields):
+ """Renders out an individual template. This basically just reads in a
+ template file, and applies ``.format()`` on the fields.
- def _print_schema_diff_details(self, myfile: TextIO) -> None:
- schema_diff_dict = self._columns_with_schemadiff()
+ Parameters
+ ----------
+ filename : str
+ The file that contains the template. Will automagically prepend the
+ templates directory before opening
+ fields : list
+ Fields to be rendered out in the template
+
+ Returns
+ -------
+ str
+ The fully rendered out file.
+ """
+ this_dir = os.path.dirname(os.path.realpath(__file__))
+ with open(os.path.join(this_dir, "templates", filename)) as file_open:
+ return file_open.read().format(*fields)
- if not schema_diff_dict: # If there are no differences, don't print the section
- return
- # For columns with mismatches, what are the longest base and compare column name lengths (with minimums)?
- base_name_max = max([len(key) for key in schema_diff_dict] + [16])
- compare_name_max = max(
- [len(self._base_to_compare_name(key)) for key in schema_diff_dict] + [19]
- )
+def columns_equal(
+ col_1, col_2, rel_tol=0, abs_tol=0, ignore_spaces=False, ignore_case=False
+):
+ """Compares two columns from a dataframe, returning a True/False series,
+ with the same index as column 1.
- format_pattern = "{{:{base}s}} {{:{compare}s}}".format(
- base=base_name_max, compare=compare_name_max
- )
+ - Two nulls (np.nan) will evaluate to True.
+ - A null and a non-null value will evaluate to False.
+ - Numeric values will use the relative and absolute tolerances.
+ - Decimal values (decimal.Decimal) will attempt to be converted to floats
+ before comparing
+ - Non-numeric values (i.e. where np.isclose can't be used) will just
+ trigger True on two nulls or exact matches.
- print("\n****** Schema Differences ******", file=myfile)
- print(
- (format_pattern + " Base Dtype Compare Dtype").format(
- "Base Column Name", "Compare Column Name"
- ),
- file=myfile,
- )
- print(
- "-" * base_name_max
- + " "
- + "-" * compare_name_max
- + " ------------- -------------",
- file=myfile,
+ Parameters
+ ----------
+ col_1 : pyspark.pandas.series.Series
+ The first column to look at
+ col_2 : pyspark.pandas.series.Series
+ The second column
+ rel_tol : float, optional
+ Relative tolerance
+ abs_tol : float, optional
+ Absolute tolerance
+ ignore_spaces : bool, optional
+ Flag to strip whitespace (including newlines) from string columns
+ ignore_case : bool, optional
+ Flag to ignore the case of string columns
+
+ Returns
+ -------
+ pyspark.pandas.series.Series
+ A series of Boolean values. True == the values match, False == the
+ values don't match.
+ """
+ try:
+ compare = ((col_1 - col_2).abs() <= abs_tol + (rel_tol * col_2.abs())) | (
+ col_1.isnull() & col_2.isnull()
)
+ except TypeError:
+ if (
+ is_numeric_dtype(col_1.dtype.kind) and is_numeric_dtype(col_2.dtype.kind)
+ ) or (
+ col_1.spark.data_type.typeName() == "decimal"
+ and col_2.spark.data_type.typeName() == "decimal"
+ ):
+ compare = (
+ (col_1.astype(float) - col_2.astype(float)).abs()
+ <= abs_tol + (rel_tol * col_2.astype(float).abs())
+ ) | (col_1.astype(float).isnull() & col_2.astype(float).isnull())
+ else:
+ try:
+ col_1_temp = col_1.copy()
+ col_2_temp = col_2.copy()
+ if ignore_spaces:
+ if col_1.dtype.kind == "O":
+ col_1_temp = col_1_temp.str.strip()
+ if col_2.dtype.kind == "O":
+ col_2_temp = col_2_temp.str.strip()
+
+ if ignore_case:
+ if col_1.dtype.kind == "O":
+ col_1_temp = col_1_temp.str.upper()
+ if col_2.dtype.kind == "O":
+ col_2_temp = col_2_temp.str.upper()
+
+ if {col_1.dtype.kind, col_2.dtype.kind} == {"M", "O"}:
+ compare = compare_string_and_date_columns(col_1_temp, col_2_temp)
+ else:
+ compare = (col_1_temp == col_2_temp) | (
+ col_1_temp.isnull() & col_2_temp.isnull()
+ )
- for base_column, types in schema_diff_dict.items():
- compare_column = self._base_to_compare_name(base_column)
-
- print(
- (format_pattern + " {:13s} {:13s}").format(
- base_column,
- compare_column,
- types["base_type"],
- types["compare_type"],
- ),
- file=myfile,
- )
+ except Exception:
+ # Blanket exception should just return all False
+ compare = ps.Series(False, index=col_1.index.to_numpy())
+ return compare
- def _base_to_compare_name(self, base_name: str) -> str:
- """Translates a column name in the base dataframe to its counterpart in the
- compare dataframe, if they are different."""
- if base_name in self.column_mapping:
- return self.column_mapping[base_name]
- else:
- for name in self.join_columns:
- if base_name == name[0]:
- return name[1]
- return base_name
-
- def _print_row_matches_by_column(self, myfile: TextIO) -> None:
- self._populate_columns_match_dict()
- columns_with_mismatches = {
- key: self.columns_match_dict[key]
- for key in self.columns_match_dict
- if self.columns_match_dict[key][MatchType.MISMATCH.value]
- }
-
- # corner case: when all columns match but no rows match
- # issue: #276
- try:
- columns_fully_matching = {
- key: self.columns_match_dict[key]
- for key in self.columns_match_dict
- if sum(self.columns_match_dict[key])
- == self.columns_match_dict[key][MatchType.MATCH.value]
- }
- except TypeError:
- columns_fully_matching = {}
-
- try:
- columns_with_any_diffs = {
- key: self.columns_match_dict[key]
- for key in self.columns_match_dict
- if sum(self.columns_match_dict[key])
- != self.columns_match_dict[key][MatchType.MATCH.value]
- }
- except TypeError:
- columns_with_any_diffs = {}
- #
+def compare_string_and_date_columns(col_1, col_2):
+ """Compare a string column and date column, value-wise. This tries to
+ convert a string column to a date column and compare that way.
- base_types = {x[0]: x[1] for x in self.base_df.dtypes}
- compare_types = {x[0]: x[1] for x in self.compare_df.dtypes}
+ Parameters
+ ----------
+ col_1 : pyspark.pandas.series.Series
+ The first column to look at
+ col_2 : pyspark.pandas.series.Series
+ The second column
- print("\n****** Column Comparison ******", file=myfile)
+ Returns
+ -------
+ pyspark.pandas.series.Series
+ A series of Boolean values. True == the values match, False == the
+ values don't match.
+ """
+ if col_1.dtype.kind == "O":
+ obj_column = col_1
+ date_column = col_2
+ else:
+ obj_column = col_2
+ date_column = col_1
+
+ try:
+ compare = ps.Series(
+ (
+ (ps.to_datetime(obj_column) == date_column)
+ | (obj_column.isnull() & date_column.isnull())
+ ).to_numpy()
+ ) # force compute
+ except Exception:
+ compare = ps.Series(False, index=col_1.index.to_numpy())
+ return compare
+
+
+def get_merged_columns(original_df, merged_df, suffix):
+ """Gets the columns from an original dataframe, in the new merged dataframe
- if self._known_differences:
- print(
- f"Number of columns compared with unexpected differences in some values: {len(columns_with_mismatches)}",
- file=myfile,
- )
- print(
- f"Number of columns compared with all values equal but known differences found: {len(self.columns_compared) - len(columns_with_mismatches) - len(columns_fully_matching)}",
- file=myfile,
- )
- print(
- f"Number of columns compared with all values completely equal: {len(columns_fully_matching)}",
- file=myfile,
- )
+ Parameters
+ ----------
+ original_df : pyspark.pandas.frame.DataFrame
+ The original, pre-merge dataframe
+ merged_df : pyspark.pandas.frame.DataFrame
+ Post-merge with another dataframe, with suffixes added in.
+ suffix : str
+ What suffix was used to distinguish when the original dataframe was
+ overlapping with the other merged dataframe.
+ """
+ columns = []
+ for col in original_df.columns:
+ if col in merged_df.columns:
+ columns.append(col)
+ elif col + "_" + suffix in merged_df.columns:
+ columns.append(col + "_" + suffix)
else:
- print(
- f"Number of columns compared with some values unequal: {len(columns_with_mismatches)}",
- file=myfile,
- )
- print(
- f"Number of columns compared with all values equal: {len(columns_fully_matching)}",
- file=myfile,
- )
+ raise ValueError("Column not found: %s", col)
+ return columns
- # If all columns matched, don't print columns with unequal values
- if (not self.show_all_columns) and (
- len(columns_fully_matching) == len(self.columns_compared)
- ):
- return
- # if show_all_columns is set, set column name length maximum to max of ALL columns(with minimum)
- if self.show_all_columns:
- base_name_max = max([len(key) for key in self.columns_match_dict] + [16])
- compare_name_max = max(
- [
- len(self._base_to_compare_name(key))
- for key in self.columns_match_dict
- ]
- + [19]
- )
+def temp_column_name(*dataframes):
+ """Gets a temp column name that isn't included in columns of any dataframes
- # For columns with any differences, what are the longest base and compare column name lengths (with minimums)?
- else:
- base_name_max = max([len(key) for key in columns_with_any_diffs] + [16])
- compare_name_max = max(
- [len(self._base_to_compare_name(key)) for key in columns_with_any_diffs]
- + [19]
- )
+ Parameters
+ ----------
+ dataframes : list of pyspark.pandas.frame.DataFrame
+ The DataFrames to create a temporary column name for
- """ list of (header, condition, width, align)
- where
- header (String) : output header for a column
- condition (Bool): true if this header should be displayed
- width (Int) : width of the column
- align (Bool) : true if right-aligned
- """
- headers_columns_unequal = [
- ("Base Column Name", True, base_name_max, False),
- ("Compare Column Name", True, compare_name_max, False),
- ("Base Dtype ", True, 13, False),
- ("Compare Dtype", True, 13, False),
- ("# Matches", True, 9, True),
- ("# Known Diffs", self._known_differences is not None, 13, True),
- ("# Mismatches", True, 12, True),
- ]
- if self.match_rates:
- headers_columns_unequal.append(("Match Rate %", True, 12, True))
- headers_columns_unequal_valid = [h for h in headers_columns_unequal if h[1]]
- padding = 2 # spaces to add to left and right of each column
+ Returns
+ -------
+ str
+ String column name that looks like '_temp_x' for some integer x
+ """
+ i = 0
+ while True:
+ temp_column = f"_temp_{i}"
+ unique = True
+ for dataframe in dataframes:
+ if temp_column in dataframe.columns:
+ i += 1
+ unique = False
+ if unique:
+ return temp_column
- if self.show_all_columns:
- print("\n****** Columns with Equal/Unequal Values ******", file=myfile)
- else:
- print("\n****** Columns with Unequal Values ******", file=myfile)
- format_pattern = (" " * padding).join(
- [
- ("{:" + (">" if h[3] else "") + str(h[2]) + "}")
- for h in headers_columns_unequal_valid
- ]
- )
- print(
- format_pattern.format(*[h[0] for h in headers_columns_unequal_valid]),
- file=myfile,
- )
- print(
- format_pattern.format(
- *["-" * len(h[0]) for h in headers_columns_unequal_valid]
- ),
- file=myfile,
- )
+def calculate_max_diff(col_1, col_2):
+ """Get a maximum difference between two columns
- for column_name, column_values in sorted(
- self.columns_match_dict.items(), key=lambda i: i[0]
- ):
- num_matches = column_values[MatchType.MATCH.value]
- num_known_diffs = (
- None
- if self._known_differences is None
- else column_values[MatchType.KNOWN_DIFFERENCE.value]
- )
- num_mismatches = column_values[MatchType.MISMATCH.value]
- compare_column = self._base_to_compare_name(column_name)
-
- if num_mismatches or num_known_diffs or self.show_all_columns:
- output_row = [
- column_name,
- compare_column,
- base_types.get(column_name),
- compare_types.get(column_name),
- str(num_matches),
- str(num_mismatches),
- ]
- if self.match_rates:
- match_rate = 100 * (
- 1
- - (column_values[MatchType.MISMATCH.value] + 0.0)
- / self.common_row_count
- + 0.0
- )
- output_row.append("{:02.5f}".format(match_rate))
- if num_known_diffs is not None:
- output_row.insert(len(output_row) - 1, str(num_known_diffs))
- print(format_pattern.format(*output_row), file=myfile)
+ Parameters
+ ----------
+ col_1 : pyspark.pandas.series.Series
+ The first column
+ col_2 : pyspark.pandas.series.Series
+ The second column
- # noinspection PyUnresolvedReferences
- def report(self, file: TextIO = sys.stdout) -> None:
- """Creates a comparison report and prints it to the file specified
- (stdout by default).
+ Returns
+ -------
+ Numeric
+ Numeric field, or zero.
+ """
+ try:
+ return (col_1.astype(float) - col_2.astype(float)).abs().max()
+ except Exception:
+ return 0
- Parameters
- ----------
- file : ``file``, optional
- A filehandle to write the report to. By default, this is
- sys.stdout, printing the report to stdout. You can also redirect
- this to an output file, as in the example.
-
- Examples
- --------
- >>> with open('my_report.txt', 'w') as report_file:
- ... comparison.report(file=report_file)
- """
- self._print_columns_summary(file)
- self._print_schema_diff_details(file)
- self._print_only_columns("BASE", file)
- self._print_only_columns("COMPARE", file)
- self._print_row_summary(file)
- self._merge_dataframes()
- self._print_num_of_rows_with_column_equality(file)
- self._print_row_matches_by_column(file)
+def generate_id_within_group(dataframe, join_columns):
+ """Generate an ID column that can be used to deduplicate identical rows. The series generated
+ is the order within a unique group, and it handles nulls.
+
+ Parameters
+ ----------
+ dataframe : pyspark.pandas.frame.DataFrame
+ The dataframe to operate on
+ join_columns : list
+ List of strings which are the join columns
+
+ Returns
+ -------
+ pyspark.pandas.series.Series
+ The ID column that's unique in each group.
+ """
+ default_value = "DATACOMPY_NULL"
+ if dataframe[join_columns].isnull().any().any():
+ if (dataframe[join_columns] == default_value).any().any():
+ raise ValueError(f"{default_value} was found in your join columns")
+ return (
+ dataframe[join_columns]
+ .astype(str)
+ .fillna(default_value)
+ .groupby(join_columns)
+ .cumcount()
+ )
+ else:
+ return dataframe[join_columns].groupby(join_columns).cumcount()
diff --git a/docs/source/index.rst b/docs/source/index.rst
index 0ac03d6d..1d25d11c 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -10,7 +10,7 @@ Contents
Installation
Pandas Usage
- Spark Usage
+ Spark (Pandas on Spark) Usage
Polars Usage
Fugue Usage
Developer Instructions
diff --git a/docs/source/spark_usage.rst b/docs/source/spark_usage.rst
index 82c62722..a532316e 100644
--- a/docs/source/spark_usage.rst
+++ b/docs/source/spark_usage.rst
@@ -1,243 +1,252 @@
-Spark Usage
-===========
+Spark (Pandas on Spark) Usage
+=============================
.. important::
- With version ``v0.9.0`` SparkCompare now uses Null Safe (``<=>``) comparisons
+ With version ``v0.12.0`` the original ``SparkCompare`` is now replaced with a
+ Pandas on Spark implementation. The original ``SparkCompare`` implementation
+ differs from all the other native implementations. To align the API better,
+ and keep behaviour consistent we are deprecating the original ``SparkCompare``
+ into a new module ``LegacySparkCompare``
+ If you wish to use the old SparkCompare moving forward you can
-DataComPy's ``SparkCompare`` class will join two dataframes either on a list of join
-columns. It has the capability to map column names that may be different in each
-dataframe, including in the join columns. You are responsible for creating the
-dataframes from any source which Spark can handle and specifying a unique join
-key. If there are duplicates in either dataframe by join key, the match process
-will remove the duplicates before joining (and tell you how many duplicates were
-found).
+ .. code-block:: python
-As with the Pandas-based ``Compare`` class, comparisons will be attempted even
-if dtypes don't match. Any schema differences will be reported in the output
-as well as in any mismatch reports, so that you can assess whether or not a
-type mismatch is a problem or not.
+ import datacompy.legacy.LegacySparkCompare
+
-The main reasons why you would choose to use ``SparkCompare`` over ``Compare``
-are that your data is too large to fit into memory, or you're comparing data
-that works well in a Spark environment, like partitioned Parquet, CSV, or JSON
-files, or Cerebro tables.
+DataComPy's Pandas on Spark implementation ``SparkCompare`` (new in ``v0.12.0``)
+is a very similar port of the Pandas version
-Basic Usage
------------
+- ``on_index`` is NOT supported like in ``PandasCompare``
+- Joining is done using ``<=>`` which is the equality test that is safe for null values.
+- In the backend we are using the Pandas on Spark API. This might be less optimal than
+ native Spark code but allows for better maintainability and readability.
-.. code-block:: python
- import datetime
- import datacompy
- from pyspark.sql import Row
-
- # This example assumes you have a SparkSession named "spark" in your environment, as you
- # do when running `pyspark` from the terminal or in a Databricks notebook (Spark v2.0 and higher)
-
- data1 = [
- Row(acct_id=10000001234, dollar_amt=123.45, name='George Maharis', float_fld=14530.1555,
- date_fld=datetime.date(2017, 1, 1)),
- Row(acct_id=10000001235, dollar_amt=0.45, name='Michael Bluth', float_fld=1.0,
- date_fld=datetime.date(2017, 1, 1)),
- Row(acct_id=10000001236, dollar_amt=1345.0, name='George Bluth', float_fld=None,
- date_fld=datetime.date(2017, 1, 1)),
- Row(acct_id=10000001237, dollar_amt=123456.0, name='Bob Loblaw', float_fld=345.12,
- date_fld=datetime.date(2017, 1, 1)),
- Row(acct_id=10000001239, dollar_amt=1.05, name='Lucille Bluth', float_fld=None,
- date_fld=datetime.date(2017, 1, 1))
- ]
-
- data2 = [
- Row(acct_id=10000001234, dollar_amt=123.4, name='George Michael Bluth', float_fld=14530.155),
- Row(acct_id=10000001235, dollar_amt=0.45, name='Michael Bluth', float_fld=None),
- Row(acct_id=10000001236, dollar_amt=1345.0, name='George Bluth', float_fld=1.0),
- Row(acct_id=10000001237, dollar_amt=123456.0, name='Robert Loblaw', float_fld=345.12),
- Row(acct_id=10000001238, dollar_amt=1.05, name='Loose Seal Bluth', float_fld=111.0)
- ]
-
- base_df = spark.createDataFrame(data1)
- compare_df = spark.createDataFrame(data2)
-
- comparison = datacompy.SparkCompare(spark, base_df, compare_df, join_columns=['acct_id'])
-
- # This prints out a human-readable report summarizing differences
- comparison.report()
-
-
-Using SparkCompare on EMR or standalone Spark
----------------------------------------------
-
-1. Set proxy variables
-2. Create a virtual environment, if desired (``virtualenv venv; source venv/bin/activate``)
-3. Pip install datacompy and requirements
-4. Ensure your SPARK_HOME environment variable is set (this is probably ``/usr/lib/spark`` but may
- differ based on your installation)
-5. Augment your PYTHONPATH environment variable with
- ``export PYTHONPATH=$SPARK_HOME/python/lib/py4j-0.10.4-src.zip:$SPARK_HOME/python:$PYTHONPATH``
- (note that your version of py4j may differ depending on the version of Spark you're using)
-
-
-Using SparkCompare on Databricks
---------------------------------
-
-1. Clone this repository locally
-2. Create a datacompy egg by running ``python setup.py bdist_egg`` from the repo root directory.
-3. From the Databricks front page, click the "Library" link under the "New" section.
-4. On the New library page:
- a. Change source to "Upload Python Egg or PyPi"
- b. Under "Upload Egg", Library Name should be "datacompy"
- c. Drag the egg file in datacompy/dist/ to the "Drop library egg here to upload" box
- d. Click the "Create Library" button
-5. Once the library has been created, from the library page (which you can find in your /Users/{login} workspace),
- you can choose clusters to attach the library to.
-6. ``import datacompy`` in a notebook attached to the cluster that the library is attached to and enjoy!
-
-
-Performance Implications
-------------------------
-
-Spark scales incredibly well, so you can use ``SparkCompare`` to compare
-billions of rows of data, provided you spin up a big enough cluster. Still,
-joining billions of rows of data is an inherently large task, so there are a
-couple of things you may want to take into consideration when getting into the
-cliched realm of "big data":
-
-* ``SparkCompare`` will compare all columns in common in the dataframes and
- report on the rest. If there are columns in the data that you don't care to
- compare, use a ``select`` statement/method on the dataframe(s) to filter
- those out. Particularly when reading from wide Parquet files, this can make
- a huge difference when the columns you don't care about don't have to be
- read into memory and included in the joined dataframe.
-* For large datasets, adding ``cache_intermediates=True`` to the ``SparkCompare``
- call can help optimize performance by caching certain intermediate dataframes
- in memory, like the de-duped version of each input dataset, or the joined
- dataframe. Otherwise, Spark's lazy evaluation will recompute those each time
- it needs the data in a report or as you access instance attributes. This may
- be fine for smaller dataframes, but will be costly for larger ones. You do
- need to ensure that you have enough free cache memory before you do this, so
- this parameter is set to False by default.
-
-
-Known Differences
------------------
-
-For cases when two dataframes are expected to differ, it can be helpful to cluster detected
-differences into three categories: matches, known differences, and true mismatches. Known
-differences can be specified through an optional parameter:
+Supported Version
+------------------
+
+.. important::
+
+ Spark will not offically support Pandas 2 until Spark 4: https://issues.apache.org/jira/browse/SPARK-44101
+
+
+Until then we will not be supporting Pandas 2 for the Pandas on Spark API implementaion.
+For Fugue and the Native Pandas implementation Pandas 2 is supported. If you need to use Spark DataFrame with
+Pandas 2+ then consider using Fugue otherwise downgrade to Pandas 1.5.3
+
+
+SparkCompare Object Setup
+-------------------------
+
+There is currently only one supported method for joining your dataframes - by
+join column(s).
.. code-block:: python
- SparkCompare(spark, base_df, compare_df, join_columns=[...], column_mapping=[...],
- known_differences = [
- {
- 'name': "My Known Difference Name",
- 'types': ['int', 'bigint'],
- 'flags': ['nullcheck'],
- 'transformation': "case when {input}=0 then null else {input} end"
- },
- ...
- ]
+ from io import StringIO
+ import pandas as pd
+ import pyspark.pandas as ps
+ from datacompy import SparkCompare
+ from pyspark.sql import SparkSession
+
+ spark = SparkSession.builder.getOrCreate()
+
+ data1 = """acct_id,dollar_amt,name,float_fld,date_fld
+ 10000001234,123.45,George Maharis,14530.1555,2017-01-01
+ 10000001235,0.45,Michael Bluth,1,2017-01-01
+ 10000001236,1345,George Bluth,,2017-01-01
+ 10000001237,123456,Bob Loblaw,345.12,2017-01-01
+ 10000001239,1.05,Lucille Bluth,,2017-01-01
+ """
+
+ data2 = """acct_id,dollar_amt,name,float_fld
+ 10000001234,123.4,George Michael Bluth,14530.155
+ 10000001235,0.45,Michael Bluth,
+ 10000001236,1345,George Bluth,1
+ 10000001237,123456,Robert Loblaw,345.12
+ 10000001238,1.05,Loose Seal Bluth,111
+ """
+
+ df1 = ps.from_pandas(pd.read_csv(StringIO(data1)))
+ df2 = ps.from_pandas(pd.read_csv(StringIO(data2)))
+
+ compare = SparkCompare(
+ df1,
+ df2,
+ join_columns='acct_id', # You can also specify a list of columns
+ abs_tol=0, # Optional, defaults to 0
+ rel_tol=0, # Optional, defaults to 0
+ df1_name='Original', # Optional, defaults to 'df1'
+ df2_name='New' # Optional, defaults to 'df2'
)
+ compare.matches(ignore_extra_columns=False)
+ # False
+ # This method prints out a human-readable report summarizing and sampling differences
+ print(compare.report())
+
+
+Reports
+-------
+
+A report is generated by calling ``SparkCompare.report()``, which returns a string.
+Here is a sample report generated by ``datacompy`` for the two tables above,
+joined on ``acct_id`` (Note: if you don't specify ``df1_name`` and/or ``df2_name``,
+then any instance of "original" or "new" in the report is replaced with "df1"
+and/or "df2".)::
+
+ DataComPy Comparison
+ --------------------
+
+ DataFrame Summary
+ -----------------
+
+ DataFrame Columns Rows
+ 0 Original 5 5
+ 1 New 4 5
+
+ Column Summary
+ --------------
+
+ Number of columns in common: 4
+ Number of columns in Original but not in New: 1
+ Number of columns in New but not in Original: 0
+
+ Row Summary
+ -----------
-The 'known_differences' parameter is a list of Python dicts with the following fields:
+ Matched on: acct_id
+ Any duplicates on match values: No
+ Absolute Tolerance: 0
+ Relative Tolerance: 0
+ Number of rows in common: 4
+ Number of rows in Original but not in New: 1
+ Number of rows in New but not in Original: 1
-============== ========= ======================================================================
-Field Required? Description
-============== ========= ======================================================================
-name yes A user-readable title for this known difference
-types yes A list of Spark data types on which this transformation can be applied
-flags no Special flags used for computing known differences
-transformation yes Spark SQL function to apply, where {input} is a cell in the comparison
-============== ========= ======================================================================
+ Number of rows with some compared columns unequal: 4
+ Number of rows with all compared columns equal: 0
-Valid flags are:
+ Column Comparison
+ -----------------
-========= =============================================================
-Flag Description
-========= =============================================================
-nullcheck Must be set when the output of the transformation can be null
-========= =============================================================
+ Number of columns compared with some values unequal: 3
+ Number of columns compared with all values equal: 1
+ Total number of values which compare unequal: 6
-Transformations are applied to the compare side only. A known difference is found when transformation(compare.cell) equals base.cell. An example comparison is shown below.
+ Columns with Unequal Values or Types
+ ------------------------------------
+
+ Column Original dtype New dtype # Unequal Max Diff # Null Diff
+ 0 dollar_amt float64 float64 1 0.0500 0
+ 2 float_fld float64 float64 3 0.0005 2
+ 1 name object object 2 NaN 0
+
+ Sample Rows with Unequal Values
+ -------------------------------
+
+ acct_id dollar_amt (Original) dollar_amt (New)
+ 0 10000001234 123.45 123.4
+
+ acct_id name (Original) name (New)
+ 0 10000001234 George Maharis George Michael Bluth
+ 3 10000001237 Bob Loblaw Robert Loblaw
+
+ acct_id float_fld (Original) float_fld (New)
+ 0 10000001234 14530.1555 14530.155
+ 1 10000001235 1.0000 NaN
+ 2 10000001236 NaN 1.000
+
+ Sample Rows Only in Original (First 10 Columns)
+ -----------------------------------------------
+
+ acct_id_df1 dollar_amt_df1 name_df1 float_fld_df1 date_fld_df1 _merge_left
+ 5 10000001239 1.05 Lucille Bluth NaN 2017-01-01 True
+
+ Sample Rows Only in New (First 10 Columns)
+ ------------------------------------------
+
+ acct_id_df2 dollar_amt_df2 name_df2 float_fld_df2 _merge_right
+ 4 10000001238 1.05 Loose Seal Bluth 111.0 True
+
+
+Convenience Methods
+-------------------
+
+There are a few convenience methods available after the comparison has been run:
.. code-block:: python
- import datetime
- import datacompy
- from pyspark.sql import Row
-
- base_data = [
- Row(acct_id=10000001234, acct_sfx_num=0, clsd_reas_cd='*2', open_dt=datetime.date(2017, 5, 1), tbal_cd='0001'),
- Row(acct_id=10000001235, acct_sfx_num=0, clsd_reas_cd='V1', open_dt=datetime.date(2017, 5, 2), tbal_cd='0002'),
- Row(acct_id=10000001236, acct_sfx_num=0, clsd_reas_cd='V2', open_dt=datetime.date(2017, 5, 3), tbal_cd='0003'),
- Row(acct_id=10000001237, acct_sfx_num=0, clsd_reas_cd='*2', open_dt=datetime.date(2017, 5, 4), tbal_cd='0004'),
- Row(acct_id=10000001238, acct_sfx_num=0, clsd_reas_cd='*2', open_dt=datetime.date(2017, 5, 5), tbal_cd='0005')
- ]
- base_df = spark.createDataFrame(base_data)
-
- compare_data = [
- Row(ACCOUNT_IDENTIFIER=10000001234, SUFFIX_NUMBER=0, AM00_STATC_CLOSED=None, AM00_DATE_ACCOUNT_OPEN=2017121, AM0B_FC_TBAL=1.0),
- Row(ACCOUNT_IDENTIFIER=10000001235, SUFFIX_NUMBER=0, AM00_STATC_CLOSED='V1', AM00_DATE_ACCOUNT_OPEN=2017122, AM0B_FC_TBAL=2.0),
- Row(ACCOUNT_IDENTIFIER=10000001236, SUFFIX_NUMBER=0, AM00_STATC_CLOSED='V2', AM00_DATE_ACCOUNT_OPEN=2017123, AM0B_FC_TBAL=3.0),
- Row(ACCOUNT_IDENTIFIER=10000001237, SUFFIX_NUMBER=0, AM00_STATC_CLOSED='V3', AM00_DATE_ACCOUNT_OPEN=2017124, AM0B_FC_TBAL=4.0),
- Row(ACCOUNT_IDENTIFIER=10000001238, SUFFIX_NUMBER=0, AM00_STATC_CLOSED=None, AM00_DATE_ACCOUNT_OPEN=2017125, AM0B_FC_TBAL=5.0)
- ]
- compare_df = spark.createDataFrame(compare_data)
-
- comparison = datacompy.SparkCompare(spark, base_df, compare_df,
- join_columns = [('acct_id', 'ACCOUNT_IDENTIFIER'), ('acct_sfx_num', 'SUFFIX_NUMBER')],
- column_mapping = [('clsd_reas_cd', 'AM00_STATC_CLOSED'),
- ('open_dt', 'AM00_DATE_ACCOUNT_OPEN'),
- ('tbal_cd', 'AM0B_FC_TBAL')],
- known_differences= [
- {'name': 'Left-padded, four-digit numeric code',
- 'types': ['tinyint', 'smallint', 'int', 'bigint', 'float', 'double', 'decimal'],
- 'transformation': "lpad(cast({input} AS bigint), 4, '0')"},
- {'name': 'Null to *2',
- 'types': ['string'],
- 'transformation': "case when {input} is null then '*2' else {input} end"},
- {'name': 'Julian date -> date',
- 'types': ['bigint'],
- 'transformation': "to_date(cast(unix_timestamp(cast({input} AS string), 'yyyyDDD') AS timestamp))"}
- ])
- comparison.report()
-
-Corresponding output::
-
- ****** Column Summary ******
- Number of columns in common with matching schemas: 3
- Number of columns in common with schema differences: 2
- Number of columns in base but not compare: 0
- Number of columns in compare but not base: 0
-
- ****** Schema Differences ******
- Base Column Name Compare Column Name Base Dtype Compare Dtype
- ---------------- ---------------------- ------------- -------------
- open_dt AM00_DATE_ACCOUNT_OPEN date bigint
- tbal_cd AM0B_FC_TBAL string double
-
- ****** Row Summary ******
- Number of rows in common: 5
- Number of rows in base but not compare: 0
- Number of rows in compare but not base: 0
- Number of duplicate rows found in base: 0
- Number of duplicate rows found in compare: 0
-
- ****** Row Comparison ******
- Number of rows with some columns unequal: 5
- Number of rows with all columns equal: 0
-
- ****** Column Comparison ******
- Number of columns compared with unexpected differences in some values: 1
- Number of columns compared with all values equal but known differences found: 2
- Number of columns compared with all values completely equal: 0
-
- ****** Columns with Unequal Values ******
- Base Column Name Compare Column Name Base Dtype Compare Dtype # Matches # Known Diffs # Mismatches
- ---------------- ------------------- ------------- ------------- --------- ------------- ------------
- clsd_reas_cd AM00_STATC_CLOSED string string 2 2 1
- open_dt AM00_DATE_ACCOUNT_OPEN date bigint 0 5 0
- tbal_cd AM0B_FC_TBAL string double 0 5 0
\ No newline at end of file
+ print(compare.intersect_rows[['name_df1', 'name_df2', 'name_match']])
+ # name_df1 name_df2 name_match
+ # 0 George Maharis George Michael Bluth False
+ # 1 Michael Bluth Michael Bluth True
+ # 2 George Bluth George Bluth True
+ # 3 Bob Loblaw Robert Loblaw False
+
+ print(compare.df1_unq_rows)
+ # acct_id_df1 dollar_amt_df1 name_df1 float_fld_df1 date_fld_df1 _merge_left
+ # 5 10000001239 1.05 Lucille Bluth NaN 2017-01-01 True
+
+ print(compare.df2_unq_rows)
+ # acct_id_df2 dollar_amt_df2 name_df2 float_fld_df2 _merge_right
+ # 4 10000001238 1.05 Loose Seal Bluth 111.0 True
+
+ print(compare.intersect_columns())
+ # OrderedSet(['acct_id', 'dollar_amt', 'name', 'float_fld'])
+
+ print(compare.df1_unq_columns())
+ # OrderedSet(['date_fld'])
+
+ print(compare.df2_unq_columns())
+ # OrderedSet()
+
+Duplicate rows
+--------------
+
+Datacompy will try to handle rows that are duplicate in the join columns. It does this behind the
+scenes by generating a unique ID within each unique group of the join columns. For example, if you
+have two dataframes you're trying to join on acct_id:
+
+=========== ================
+acct_id name
+=========== ================
+1 George Maharis
+1 Michael Bluth
+2 George Bluth
+=========== ================
+
+=========== ================
+acct_id name
+=========== ================
+1 George Maharis
+1 Michael Bluth
+1 Tony Wonder
+2 George Bluth
+=========== ================
+
+Datacompy will generate a unique temporary ID for joining:
+
+=========== ================ ========
+acct_id name temp_id
+=========== ================ ========
+1 George Maharis 0
+1 Michael Bluth 1
+2 George Bluth 0
+=========== ================ ========
+
+=========== ================ ========
+acct_id name temp_id
+=========== ================ ========
+1 George Maharis 0
+1 Michael Bluth 1
+1 Tony Wonder 2
+2 George Bluth 0
+=========== ================ ========
+
+And then merge the two dataframes on a combination of the join_columns you specified and the temporary
+ID, before dropping the temp_id again. So the first two rows in the first dataframe will match the
+first two rows in the second dataframe, and the third row in the second dataframe will be recognized
+as uniquely in the second.
diff --git a/pyproject.toml b/pyproject.toml
index 8a6f73ea..3bde2d70 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -11,8 +11,8 @@ maintainers = [
{ name="Faisal Dosani", email="faisal.dosani@capitalone.com" }
]
license = {text = "Apache Software License"}
-dependencies = ["pandas<=2.2.1,>=0.25.0", "numpy<=1.26.4,>=1.22.0", "ordered-set<=4.1.0,>=4.0.2", "fugue<=0.8.7,>=0.8.7"]
-requires-python = ">=3.8.0"
+dependencies = ["pandas<=2.2.2,>=0.25.0", "numpy<=1.26.4,>=1.22.0", "ordered-set<=4.1.0,>=4.0.2", "fugue<=0.8.7,>=0.8.7"]
+requires-python = ">=3.9.0"
classifiers = [
"Intended Audience :: Developers",
"Natural Language :: English",
@@ -20,7 +20,6 @@ classifiers = [
"Programming Language :: Python",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3 :: Only",
- "Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
diff --git a/tests/test_fugue/conftest.py b/tests/test_fugue/conftest.py
index 6a5683d2..a2ca99b1 100644
--- a/tests/test_fugue/conftest.py
+++ b/tests/test_fugue/conftest.py
@@ -1,6 +1,6 @@
-import pytest
import numpy as np
import pandas as pd
+import pytest
@pytest.fixture
@@ -24,7 +24,8 @@ def ref_df():
c=np.random.choice(["aaa", "b_c", "csd"], 100),
)
)
- return [df1, df1_copy, df2, df3, df4]
+ df5 = df1.sample(frac=0.1)
+ return [df1, df1_copy, df2, df3, df4, df5]
@pytest.fixture
@@ -87,3 +88,16 @@ def large_diff_df2():
np.random.seed(0)
data = np.random.randint(6, 11, size=10000)
return pd.DataFrame({"x": data, "y": np.array([9] * 10000)}).convert_dtypes()
+
+
+@pytest.fixture
+def count_matching_rows_df():
+ np.random.seed(0)
+ df1 = pd.DataFrame(
+ dict(
+ a=np.arange(0, 100),
+ b=np.arange(0, 100),
+ )
+ )
+ df2 = df1.sample(frac=0.1)
+ return [df1, df2]
diff --git a/tests/test_fugue/test_duckdb.py b/tests/test_fugue/test_duckdb.py
index daed1edd..3643f22d 100644
--- a/tests/test_fugue/test_duckdb.py
+++ b/tests/test_fugue/test_duckdb.py
@@ -20,6 +20,7 @@
from datacompy import (
all_columns_match,
all_rows_overlap,
+ count_matching_rows,
intersect_columns,
is_match,
unq_columns,
@@ -138,3 +139,40 @@ def test_all_rows_overlap_duckdb(
duckdb.sql("SELECT 'a' AS a, 'b' AS b"),
join_columns="a",
)
+
+
+def test_count_matching_rows_duckdb(count_matching_rows_df):
+ with duckdb.connect():
+ df1 = duckdb.from_df(count_matching_rows_df[0])
+ df1_copy = duckdb.from_df(count_matching_rows_df[0])
+ df2 = duckdb.from_df(count_matching_rows_df[1])
+
+ assert (
+ count_matching_rows(
+ df1,
+ df1_copy,
+ join_columns="a",
+ )
+ == 100
+ )
+ assert count_matching_rows(df1, df2, join_columns="a") == 10
+ # Fugue
+
+ assert (
+ count_matching_rows(
+ df1,
+ df1_copy,
+ join_columns="a",
+ parallelism=2,
+ )
+ == 100
+ )
+ assert (
+ count_matching_rows(
+ df1,
+ df2,
+ join_columns="a",
+ parallelism=2,
+ )
+ == 10
+ )
diff --git a/tests/test_fugue/test_fugue_pandas.py b/tests/test_fugue/test_fugue_pandas.py
index 77884c2c..4fd74ce7 100644
--- a/tests/test_fugue/test_fugue_pandas.py
+++ b/tests/test_fugue/test_fugue_pandas.py
@@ -24,6 +24,7 @@
Compare,
all_columns_match,
all_rows_overlap,
+ count_matching_rows,
intersect_columns,
is_match,
report,
@@ -144,7 +145,6 @@ def test_report_pandas(
def test_unique_columns_native(ref_df):
df1 = ref_df[0]
- df1_copy = ref_df[1]
df2 = ref_df[2]
df3 = ref_df[3]
@@ -192,3 +192,41 @@ def test_all_rows_overlap_native(
# Fugue
assert all_rows_overlap(ref_df[0], shuffle_df, join_columns="a", parallelism=2)
assert not all_rows_overlap(ref_df[0], ref_df[4], join_columns="a", parallelism=2)
+
+
+def test_count_matching_rows_native(count_matching_rows_df):
+ # defaults to Compare class
+ assert (
+ count_matching_rows(
+ count_matching_rows_df[0],
+ count_matching_rows_df[0].copy(),
+ join_columns="a",
+ )
+ == 100
+ )
+ assert (
+ count_matching_rows(
+ count_matching_rows_df[0], count_matching_rows_df[1], join_columns="a"
+ )
+ == 10
+ )
+ # Fugue
+
+ assert (
+ count_matching_rows(
+ count_matching_rows_df[0],
+ count_matching_rows_df[0].copy(),
+ join_columns="a",
+ parallelism=2,
+ )
+ == 100
+ )
+ assert (
+ count_matching_rows(
+ count_matching_rows_df[0],
+ count_matching_rows_df[1],
+ join_columns="a",
+ parallelism=2,
+ )
+ == 10
+ )
diff --git a/tests/test_fugue/test_fugue_polars.py b/tests/test_fugue/test_fugue_polars.py
index fdb2212a..dcd19a94 100644
--- a/tests/test_fugue/test_fugue_polars.py
+++ b/tests/test_fugue/test_fugue_polars.py
@@ -20,6 +20,7 @@
from datacompy import (
all_columns_match,
all_rows_overlap,
+ count_matching_rows,
intersect_columns,
is_match,
unq_columns,
@@ -122,3 +123,37 @@ def test_all_rows_overlap_polars(
assert all_rows_overlap(rdf, rdf_copy, join_columns="a")
assert all_rows_overlap(rdf, sdf, join_columns="a")
assert not all_rows_overlap(rdf, rdf4, join_columns="a")
+
+
+def test_count_matching_rows_polars(count_matching_rows_df):
+ df1 = pl.from_pandas(count_matching_rows_df[0])
+ df2 = pl.from_pandas(count_matching_rows_df[1])
+ assert (
+ count_matching_rows(
+ df1,
+ df1.clone(),
+ join_columns="a",
+ )
+ == 100
+ )
+ assert count_matching_rows(df1, df2, join_columns="a") == 10
+ # Fugue
+
+ assert (
+ count_matching_rows(
+ df1,
+ df1.clone(),
+ join_columns="a",
+ parallelism=2,
+ )
+ == 100
+ )
+ assert (
+ count_matching_rows(
+ df1,
+ df2,
+ join_columns="a",
+ parallelism=2,
+ )
+ == 10
+ )
diff --git a/tests/test_fugue/test_fugue_spark.py b/tests/test_fugue/test_fugue_spark.py
index 99da708b..efc895ff 100644
--- a/tests/test_fugue/test_fugue_spark.py
+++ b/tests/test_fugue/test_fugue_spark.py
@@ -22,6 +22,7 @@
Compare,
all_columns_match,
all_rows_overlap,
+ count_matching_rows,
intersect_columns,
is_match,
report,
@@ -200,3 +201,44 @@ def test_all_rows_overlap_spark(
spark_session.sql("SELECT 'a' AS a, 'b' AS b"),
join_columns="a",
)
+
+
+def test_count_matching_rows_spark(spark_session, count_matching_rows_df):
+ count_matching_rows_df[0].iteritems = count_matching_rows_df[
+ 0
+ ].items # pandas 2 compatibility
+ count_matching_rows_df[1].iteritems = count_matching_rows_df[
+ 1
+ ].items # pandas 2 compatibility
+ df1 = spark_session.createDataFrame(count_matching_rows_df[0])
+ df1_copy = spark_session.createDataFrame(count_matching_rows_df[0])
+ df2 = spark_session.createDataFrame(count_matching_rows_df[1])
+ assert (
+ count_matching_rows(
+ df1,
+ df1_copy,
+ join_columns="a",
+ )
+ == 100
+ )
+ assert count_matching_rows(df1, df2, join_columns="a") == 10
+ # Fugue
+
+ assert (
+ count_matching_rows(
+ df1,
+ df1_copy,
+ join_columns="a",
+ parallelism=2,
+ )
+ == 100
+ )
+ assert (
+ count_matching_rows(
+ df1,
+ df2,
+ join_columns="a",
+ parallelism=2,
+ )
+ == 10
+ )
diff --git a/tests/test_legacy_spark.py b/tests/test_legacy_spark.py
new file mode 100644
index 00000000..30ec1500
--- /dev/null
+++ b/tests/test_legacy_spark.py
@@ -0,0 +1,2109 @@
+#
+# Copyright 2024 Capital One Services, LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import datetime
+import io
+import logging
+import re
+from decimal import Decimal
+
+import pytest
+
+pytest.importorskip("pyspark")
+
+from pyspark.sql import Row # noqa: E402
+from pyspark.sql.types import ( # noqa: E402
+ DateType,
+ DecimalType,
+ DoubleType,
+ LongType,
+ StringType,
+ StructField,
+ StructType,
+)
+
+from datacompy.legacy import ( # noqa: E402
+ NUMERIC_SPARK_TYPES,
+ LegacySparkCompare,
+ _is_comparable,
+)
+
+# Turn off py4j debug messages for all tests in this module
+logging.getLogger("py4j").setLevel(logging.INFO)
+
+CACHE_INTERMEDIATES = True
+
+
+@pytest.fixture(scope="module", name="base_df1")
+def base_df1_fixture(spark_session):
+ mock_data = [
+ Row(
+ acct=10000001234,
+ dollar_amt=123,
+ name="George Maharis",
+ float_fld=14530.1555,
+ date_fld=datetime.date(2017, 1, 1),
+ ),
+ Row(
+ acct=10000001235,
+ dollar_amt=0,
+ name="Michael Bluth",
+ float_fld=1.0,
+ date_fld=datetime.date(2017, 1, 1),
+ ),
+ Row(
+ acct=10000001236,
+ dollar_amt=1345,
+ name="George Bluth",
+ float_fld=None,
+ date_fld=datetime.date(2017, 1, 1),
+ ),
+ Row(
+ acct=10000001237,
+ dollar_amt=123456,
+ name="Bob Loblaw",
+ float_fld=345.12,
+ date_fld=datetime.date(2017, 1, 1),
+ ),
+ Row(
+ acct=10000001239,
+ dollar_amt=1,
+ name="Lucille Bluth",
+ float_fld=None,
+ date_fld=datetime.date(2017, 1, 1),
+ ),
+ ]
+
+ return spark_session.createDataFrame(mock_data)
+
+
+@pytest.fixture(scope="module", name="base_df2")
+def base_df2_fixture(spark_session):
+ mock_data = [
+ Row(
+ acct=10000001234,
+ dollar_amt=123,
+ super_duper_big_long_name="George Maharis",
+ float_fld=14530.1555,
+ date_fld=datetime.date(2017, 1, 1),
+ ),
+ Row(
+ acct=10000001235,
+ dollar_amt=0,
+ super_duper_big_long_name="Michael Bluth",
+ float_fld=1.0,
+ date_fld=datetime.date(2017, 1, 1),
+ ),
+ Row(
+ acct=10000001236,
+ dollar_amt=1345,
+ super_duper_big_long_name="George Bluth",
+ float_fld=None,
+ date_fld=datetime.date(2017, 1, 1),
+ ),
+ Row(
+ acct=10000001237,
+ dollar_amt=123456,
+ super_duper_big_long_name="Bob Loblaw",
+ float_fld=345.12,
+ date_fld=datetime.date(2017, 1, 1),
+ ),
+ Row(
+ acct=10000001239,
+ dollar_amt=1,
+ super_duper_big_long_name="Lucille Bluth",
+ float_fld=None,
+ date_fld=datetime.date(2017, 1, 1),
+ ),
+ ]
+
+ return spark_session.createDataFrame(mock_data)
+
+
+@pytest.fixture(scope="module", name="compare_df1")
+def compare_df1_fixture(spark_session):
+ mock_data2 = [
+ Row(
+ acct=10000001234,
+ dollar_amt=123.4,
+ name="George Michael Bluth",
+ float_fld=14530.155,
+ accnt_purge=False,
+ ),
+ Row(
+ acct=10000001235,
+ dollar_amt=0.45,
+ name="Michael Bluth",
+ float_fld=None,
+ accnt_purge=False,
+ ),
+ Row(
+ acct=10000001236,
+ dollar_amt=1345.0,
+ name="George Bluth",
+ float_fld=1.0,
+ accnt_purge=False,
+ ),
+ Row(
+ acct=10000001237,
+ dollar_amt=123456.0,
+ name="Bob Loblaw",
+ float_fld=345.12,
+ accnt_purge=False,
+ ),
+ Row(
+ acct=10000001238,
+ dollar_amt=1.05,
+ name="Loose Seal Bluth",
+ float_fld=111.0,
+ accnt_purge=True,
+ ),
+ Row(
+ acct=10000001238,
+ dollar_amt=1.05,
+ name="Loose Seal Bluth",
+ float_fld=111.0,
+ accnt_purge=True,
+ ),
+ ]
+
+ return spark_session.createDataFrame(mock_data2)
+
+
+@pytest.fixture(scope="module", name="compare_df2")
+def compare_df2_fixture(spark_session):
+ mock_data = [
+ Row(
+ acct=10000001234,
+ dollar_amt=123,
+ name="George Maharis",
+ float_fld=14530.1555,
+ date_fld=datetime.date(2017, 1, 1),
+ ),
+ Row(
+ acct=10000001235,
+ dollar_amt=0,
+ name="Michael Bluth",
+ float_fld=1.0,
+ date_fld=datetime.date(2017, 1, 1),
+ ),
+ Row(
+ acct=10000001236,
+ dollar_amt=1345,
+ name="George Bluth",
+ float_fld=None,
+ date_fld=datetime.date(2017, 1, 1),
+ ),
+ Row(
+ acct=10000001237,
+ dollar_amt=123456,
+ name="Bob Loblaw",
+ float_fld=345.12,
+ date_fld=datetime.date(2017, 1, 1),
+ ),
+ Row(
+ acct=10000001239,
+ dollar_amt=1,
+ name="Lucille Bluth",
+ float_fld=None,
+ date_fld=datetime.date(2017, 1, 1),
+ ),
+ ]
+
+ return spark_session.createDataFrame(mock_data)
+
+
+@pytest.fixture(scope="module", name="compare_df3")
+def compare_df3_fixture(spark_session):
+ mock_data2 = [
+ Row(
+ account_identifier=10000001234,
+ dollar_amount=123.4,
+ name="George Michael Bluth",
+ float_field=14530.155,
+ date_field=datetime.date(2017, 1, 1),
+ accnt_purge=False,
+ ),
+ Row(
+ account_identifier=10000001235,
+ dollar_amount=0.45,
+ name="Michael Bluth",
+ float_field=1.0,
+ date_field=datetime.date(2017, 1, 1),
+ accnt_purge=False,
+ ),
+ Row(
+ account_identifier=10000001236,
+ dollar_amount=1345.0,
+ name="George Bluth",
+ float_field=None,
+ date_field=datetime.date(2017, 1, 1),
+ accnt_purge=False,
+ ),
+ Row(
+ account_identifier=10000001237,
+ dollar_amount=123456.0,
+ name="Bob Loblaw",
+ float_field=345.12,
+ date_field=datetime.date(2017, 1, 1),
+ accnt_purge=False,
+ ),
+ Row(
+ account_identifier=10000001239,
+ dollar_amount=1.05,
+ name="Lucille Bluth",
+ float_field=None,
+ date_field=datetime.date(2017, 1, 1),
+ accnt_purge=True,
+ ),
+ ]
+
+ return spark_session.createDataFrame(mock_data2)
+
+
+@pytest.fixture(scope="module", name="base_tol")
+def base_tol_fixture(spark_session):
+ tol_data1 = [
+ Row(
+ account_identifier=10000001234,
+ dollar_amount=123.4,
+ name="Franklin Delano Bluth",
+ float_field=14530.155,
+ date_field=datetime.date(2017, 1, 1),
+ accnt_purge=False,
+ ),
+ Row(
+ account_identifier=10000001235,
+ dollar_amount=500.0,
+ name="Surely Funke",
+ float_field=None,
+ date_field=datetime.date(2017, 1, 1),
+ accnt_purge=True,
+ ),
+ Row(
+ account_identifier=10000001236,
+ dollar_amount=-1100.0,
+ name="Nichael Bluth",
+ float_field=None,
+ date_field=datetime.date(2017, 1, 1),
+ accnt_purge=True,
+ ),
+ Row(
+ account_identifier=10000001237,
+ dollar_amount=0.45,
+ name="Mr. F",
+ float_field=1.0,
+ date_field=datetime.date(2017, 1, 1),
+ accnt_purge=False,
+ ),
+ Row(
+ account_identifier=10000001238,
+ dollar_amount=1345.0,
+ name="Steve Holt!",
+ float_field=None,
+ date_field=datetime.date(2017, 1, 1),
+ accnt_purge=False,
+ ),
+ Row(
+ account_identifier=10000001239,
+ dollar_amount=123456.0,
+ name="Blue Man Group",
+ float_field=345.12,
+ date_field=datetime.date(2017, 1, 1),
+ accnt_purge=False,
+ ),
+ Row(
+ account_identifier=10000001240,
+ dollar_amount=1.1,
+ name="Her?",
+ float_field=None,
+ date_field=datetime.date(2017, 1, 1),
+ accnt_purge=True,
+ ),
+ Row(
+ account_identifier=10000001241,
+ dollar_amount=0.0,
+ name="Mrs. Featherbottom",
+ float_field=None,
+ date_field=datetime.date(2017, 1, 1),
+ accnt_purge=True,
+ ),
+ Row(
+ account_identifier=10000001242,
+ dollar_amount=0.0,
+ name="Ice",
+ float_field=345.12,
+ date_field=datetime.date(2017, 1, 1),
+ accnt_purge=False,
+ ),
+ Row(
+ account_identifier=10000001243,
+ dollar_amount=-10.0,
+ name="Frank Wrench",
+ float_field=None,
+ date_field=datetime.date(2017, 1, 1),
+ accnt_purge=True,
+ ),
+ Row(
+ account_identifier=10000001244,
+ dollar_amount=None,
+ name="Lucille 2",
+ float_field=None,
+ date_field=datetime.date(2017, 1, 1),
+ accnt_purge=True,
+ ),
+ Row(
+ account_identifier=10000001245,
+ dollar_amount=0.009999,
+ name="Gene Parmesan",
+ float_field=None,
+ date_field=datetime.date(2017, 1, 1),
+ accnt_purge=True,
+ ),
+ Row(
+ account_identifier=10000001246,
+ dollar_amount=None,
+ name="Motherboy",
+ float_field=None,
+ date_field=datetime.date(2017, 1, 1),
+ accnt_purge=True,
+ ),
+ ]
+
+ return spark_session.createDataFrame(tol_data1)
+
+
+@pytest.fixture(scope="module", name="compare_abs_tol")
+def compare_tol2_fixture(spark_session):
+ tol_data2 = [
+ Row(
+ account_identifier=10000001234,
+ dollar_amount=123.4,
+ name="Franklin Delano Bluth",
+ float_field=14530.155,
+ date_field=datetime.date(2017, 1, 1),
+ accnt_purge=False,
+ ), # full match
+ Row(
+ account_identifier=10000001235,
+ dollar_amount=500.01,
+ name="Surely Funke",
+ float_field=None,
+ date_field=datetime.date(2017, 1, 1),
+ accnt_purge=True,
+ ), # off by 0.01
+ Row(
+ account_identifier=10000001236,
+ dollar_amount=-1100.01,
+ name="Nichael Bluth",
+ float_field=None,
+ date_field=datetime.date(2017, 1, 1),
+ accnt_purge=True,
+ ), # off by -0.01
+ Row(
+ account_identifier=10000001237,
+ dollar_amount=0.46000000001,
+ name="Mr. F",
+ float_field=1.0,
+ date_field=datetime.date(2017, 1, 1),
+ accnt_purge=False,
+ ), # off by 0.01000000001
+ Row(
+ account_identifier=10000001238,
+ dollar_amount=1344.8999999999,
+ name="Steve Holt!",
+ float_field=None,
+ date_field=datetime.date(2017, 1, 1),
+ accnt_purge=False,
+ ), # off by -0.01000000001
+ Row(
+ account_identifier=10000001239,
+ dollar_amount=123456.0099999999,
+ name="Blue Man Group",
+ float_field=345.12,
+ date_field=datetime.date(2017, 1, 1),
+ accnt_purge=False,
+ ), # off by 0.00999999999
+ Row(
+ account_identifier=10000001240,
+ dollar_amount=1.090000001,
+ name="Her?",
+ float_field=None,
+ date_field=datetime.date(2017, 1, 1),
+ accnt_purge=True,
+ ), # off by -0.00999999999
+ Row(
+ account_identifier=10000001241,
+ dollar_amount=0.0,
+ name="Mrs. Featherbottom",
+ float_field=None,
+ date_field=datetime.date(2017, 1, 1),
+ accnt_purge=True,
+ ), # both zero
+ Row(
+ account_identifier=10000001242,
+ dollar_amount=1.0,
+ name="Ice",
+ float_field=345.12,
+ date_field=datetime.date(2017, 1, 1),
+ accnt_purge=False,
+ ), # base 0, compare 1
+ Row(
+ account_identifier=10000001243,
+ dollar_amount=0.0,
+ name="Frank Wrench",
+ float_field=None,
+ date_field=datetime.date(2017, 1, 1),
+ accnt_purge=True,
+ ), # base -10, compare 0
+ Row(
+ account_identifier=10000001244,
+ dollar_amount=-1.0,
+ name="Lucille 2",
+ float_field=None,
+ date_field=datetime.date(2017, 1, 1),
+ accnt_purge=True,
+ ), # base NULL, compare -1
+ Row(
+ account_identifier=10000001245,
+ dollar_amount=None,
+ name="Gene Parmesan",
+ float_field=None,
+ date_field=datetime.date(2017, 1, 1),
+ accnt_purge=True,
+ ), # base 0.009999, compare NULL
+ Row(
+ account_identifier=10000001246,
+ dollar_amount=None,
+ name="Motherboy",
+ float_field=None,
+ date_field=datetime.date(2017, 1, 1),
+ accnt_purge=True,
+ ), # both NULL
+ ]
+
+ return spark_session.createDataFrame(tol_data2)
+
+
+@pytest.fixture(scope="module", name="compare_rel_tol")
+def compare_tol3_fixture(spark_session):
+ tol_data3 = [
+ Row(
+ account_identifier=10000001234,
+ dollar_amount=123.4,
+ name="Franklin Delano Bluth",
+ float_field=14530.155,
+ date_field=datetime.date(2017, 1, 1),
+ accnt_purge=False,
+ ), # full match #MATCH
+ Row(
+ account_identifier=10000001235,
+ dollar_amount=550.0,
+ name="Surely Funke",
+ float_field=None,
+ date_field=datetime.date(2017, 1, 1),
+ accnt_purge=True,
+ ), # off by 10% #MATCH
+ Row(
+ account_identifier=10000001236,
+ dollar_amount=-1000.0,
+ name="Nichael Bluth",
+ float_field=None,
+ date_field=datetime.date(2017, 1, 1),
+ accnt_purge=True,
+ ), # off by -10% #MATCH
+ Row(
+ account_identifier=10000001237,
+ dollar_amount=0.49501,
+ name="Mr. F",
+ float_field=1.0,
+ date_field=datetime.date(2017, 1, 1),
+ accnt_purge=False,
+ ), # off by greater than 10%
+ Row(
+ account_identifier=10000001238,
+ dollar_amount=1210.001,
+ name="Steve Holt!",
+ float_field=None,
+ date_field=datetime.date(2017, 1, 1),
+ accnt_purge=False,
+ ), # off by greater than -10%
+ Row(
+ account_identifier=10000001239,
+ dollar_amount=135801.59999,
+ name="Blue Man Group",
+ float_field=345.12,
+ date_field=datetime.date(2017, 1, 1),
+ accnt_purge=False,
+ ), # off by just under 10% #MATCH
+ Row(
+ account_identifier=10000001240,
+ dollar_amount=1.000001,
+ name="Her?",
+ float_field=None,
+ date_field=datetime.date(2017, 1, 1),
+ accnt_purge=True,
+ ), # off by just under -10% #MATCH
+ Row(
+ account_identifier=10000001241,
+ dollar_amount=0.0,
+ name="Mrs. Featherbottom",
+ float_field=None,
+ date_field=datetime.date(2017, 1, 1),
+ accnt_purge=True,
+ ), # both zero #MATCH
+ Row(
+ account_identifier=10000001242,
+ dollar_amount=1.0,
+ name="Ice",
+ float_field=345.12,
+ date_field=datetime.date(2017, 1, 1),
+ accnt_purge=False,
+ ), # base 0, compare 1
+ Row(
+ account_identifier=10000001243,
+ dollar_amount=0.0,
+ name="Frank Wrench",
+ float_field=None,
+ date_field=datetime.date(2017, 1, 1),
+ accnt_purge=True,
+ ), # base -10, compare 0
+ Row(
+ account_identifier=10000001244,
+ dollar_amount=-1.0,
+ name="Lucille 2",
+ float_field=None,
+ date_field=datetime.date(2017, 1, 1),
+ accnt_purge=True,
+ ), # base NULL, compare -1
+ Row(
+ account_identifier=10000001245,
+ dollar_amount=None,
+ name="Gene Parmesan",
+ float_field=None,
+ date_field=datetime.date(2017, 1, 1),
+ accnt_purge=True,
+ ), # base 0.009999, compare NULL
+ Row(
+ account_identifier=10000001246,
+ dollar_amount=None,
+ name="Motherboy",
+ float_field=None,
+ date_field=datetime.date(2017, 1, 1),
+ accnt_purge=True,
+ ), # both NULL #MATCH
+ ]
+
+ return spark_session.createDataFrame(tol_data3)
+
+
+@pytest.fixture(scope="module", name="compare_both_tol")
+def compare_tol4_fixture(spark_session):
+ tol_data4 = [
+ Row(
+ account_identifier=10000001234,
+ dollar_amount=123.4,
+ name="Franklin Delano Bluth",
+ float_field=14530.155,
+ date_field=datetime.date(2017, 1, 1),
+ accnt_purge=False,
+ ), # full match
+ Row(
+ account_identifier=10000001235,
+ dollar_amount=550.01,
+ name="Surely Funke",
+ float_field=None,
+ date_field=datetime.date(2017, 1, 1),
+ accnt_purge=True,
+ ), # off by 10% and +0.01
+ Row(
+ account_identifier=10000001236,
+ dollar_amount=-1000.01,
+ name="Nichael Bluth",
+ float_field=None,
+ date_field=datetime.date(2017, 1, 1),
+ accnt_purge=True,
+ ), # off by -10% and -0.01
+ Row(
+ account_identifier=10000001237,
+ dollar_amount=0.505000000001,
+ name="Mr. F",
+ float_field=1.0,
+ date_field=datetime.date(2017, 1, 1),
+ accnt_purge=False,
+ ), # off by greater than 10% and +0.01
+ Row(
+ account_identifier=10000001238,
+ dollar_amount=1209.98999,
+ name="Steve Holt!",
+ float_field=None,
+ date_field=datetime.date(2017, 1, 1),
+ accnt_purge=False,
+ ), # off by greater than -10% and -0.01
+ Row(
+ account_identifier=10000001239,
+ dollar_amount=135801.609999,
+ name="Blue Man Group",
+ float_field=345.12,
+ date_field=datetime.date(2017, 1, 1),
+ accnt_purge=False,
+ ), # off by just under 10% and just under +0.01
+ Row(
+ account_identifier=10000001240,
+ dollar_amount=0.99000001,
+ name="Her?",
+ float_field=None,
+ date_field=datetime.date(2017, 1, 1),
+ accnt_purge=True,
+ ), # off by just under -10% and just under -0.01
+ Row(
+ account_identifier=10000001241,
+ dollar_amount=0.0,
+ name="Mrs. Featherbottom",
+ float_field=None,
+ date_field=datetime.date(2017, 1, 1),
+ accnt_purge=True,
+ ), # both zero
+ Row(
+ account_identifier=10000001242,
+ dollar_amount=1.0,
+ name="Ice",
+ float_field=345.12,
+ date_field=datetime.date(2017, 1, 1),
+ accnt_purge=False,
+ ), # base 0, compare 1
+ Row(
+ account_identifier=10000001243,
+ dollar_amount=0.0,
+ name="Frank Wrench",
+ float_field=None,
+ date_field=datetime.date(2017, 1, 1),
+ accnt_purge=True,
+ ), # base -10, compare 0
+ Row(
+ account_identifier=10000001244,
+ dollar_amount=-1.0,
+ name="Lucille 2",
+ float_field=None,
+ date_field=datetime.date(2017, 1, 1),
+ accnt_purge=True,
+ ), # base NULL, compare -1
+ Row(
+ account_identifier=10000001245,
+ dollar_amount=None,
+ name="Gene Parmesan",
+ float_field=None,
+ date_field=datetime.date(2017, 1, 1),
+ accnt_purge=True,
+ ), # base 0.009999, compare NULL
+ Row(
+ account_identifier=10000001246,
+ dollar_amount=None,
+ name="Motherboy",
+ float_field=None,
+ date_field=datetime.date(2017, 1, 1),
+ accnt_purge=True,
+ ), # both NULL
+ ]
+
+ return spark_session.createDataFrame(tol_data4)
+
+
+@pytest.fixture(scope="module", name="base_td")
+def base_td_fixture(spark_session):
+ mock_data = [
+ Row(
+ acct=10000001234,
+ acct_seq=0,
+ stat_cd="*2",
+ open_dt=datetime.date(2017, 5, 1),
+ cd="0001",
+ ),
+ Row(
+ acct=10000001235,
+ acct_seq=0,
+ stat_cd="V1",
+ open_dt=datetime.date(2017, 5, 2),
+ cd="0002",
+ ),
+ Row(
+ acct=10000001236,
+ acct_seq=0,
+ stat_cd="V2",
+ open_dt=datetime.date(2017, 5, 3),
+ cd="0003",
+ ),
+ Row(
+ acct=10000001237,
+ acct_seq=0,
+ stat_cd="*2",
+ open_dt=datetime.date(2017, 5, 4),
+ cd="0004",
+ ),
+ Row(
+ acct=10000001238,
+ acct_seq=0,
+ stat_cd="*2",
+ open_dt=datetime.date(2017, 5, 5),
+ cd="0005",
+ ),
+ ]
+
+ return spark_session.createDataFrame(mock_data)
+
+
+@pytest.fixture(scope="module", name="compare_source")
+def compare_source_fixture(spark_session):
+ mock_data = [
+ Row(
+ ACCOUNT_IDENTIFIER=10000001234,
+ SEQ_NUMBER=0,
+ STATC=None,
+ ACCOUNT_OPEN=2017121,
+ CODE=1.0,
+ ),
+ Row(
+ ACCOUNT_IDENTIFIER=10000001235,
+ SEQ_NUMBER=0,
+ STATC="V1",
+ ACCOUNT_OPEN=2017122,
+ CODE=2.0,
+ ),
+ Row(
+ ACCOUNT_IDENTIFIER=10000001236,
+ SEQ_NUMBER=0,
+ STATC="V2",
+ ACCOUNT_OPEN=2017123,
+ CODE=3.0,
+ ),
+ Row(
+ ACCOUNT_IDENTIFIER=10000001237,
+ SEQ_NUMBER=0,
+ STATC="V3",
+ ACCOUNT_OPEN=2017124,
+ CODE=4.0,
+ ),
+ Row(
+ ACCOUNT_IDENTIFIER=10000001238,
+ SEQ_NUMBER=0,
+ STATC=None,
+ ACCOUNT_OPEN=2017125,
+ CODE=5.0,
+ ),
+ ]
+
+ return spark_session.createDataFrame(mock_data)
+
+
+@pytest.fixture(scope="module", name="base_decimal")
+def base_decimal_fixture(spark_session):
+ mock_data = [
+ Row(acct=10000001234, dollar_amt=Decimal(123.4)),
+ Row(acct=10000001235, dollar_amt=Decimal(0.45)),
+ ]
+
+ return spark_session.createDataFrame(
+ mock_data,
+ schema=StructType(
+ [
+ StructField("acct", LongType(), True),
+ StructField("dollar_amt", DecimalType(8, 2), True),
+ ]
+ ),
+ )
+
+
+@pytest.fixture(scope="module", name="compare_decimal")
+def compare_decimal_fixture(spark_session):
+ mock_data = [
+ Row(acct=10000001234, dollar_amt=123.4),
+ Row(acct=10000001235, dollar_amt=0.456),
+ ]
+
+ return spark_session.createDataFrame(mock_data)
+
+
+@pytest.fixture(scope="module", name="comparison_abs_tol")
+def comparison_abs_tol_fixture(base_tol, compare_abs_tol, spark_session):
+ return LegacySparkCompare(
+ spark_session,
+ base_tol,
+ compare_abs_tol,
+ join_columns=["account_identifier"],
+ abs_tol=0.01,
+ )
+
+
+@pytest.fixture(scope="module", name="comparison_rel_tol")
+def comparison_rel_tol_fixture(base_tol, compare_rel_tol, spark_session):
+ return LegacySparkCompare(
+ spark_session,
+ base_tol,
+ compare_rel_tol,
+ join_columns=["account_identifier"],
+ rel_tol=0.1,
+ )
+
+
+@pytest.fixture(scope="module", name="comparison_both_tol")
+def comparison_both_tol_fixture(base_tol, compare_both_tol, spark_session):
+ return LegacySparkCompare(
+ spark_session,
+ base_tol,
+ compare_both_tol,
+ join_columns=["account_identifier"],
+ rel_tol=0.1,
+ abs_tol=0.01,
+ )
+
+
+@pytest.fixture(scope="module", name="comparison_neg_tol")
+def comparison_neg_tol_fixture(base_tol, compare_both_tol, spark_session):
+ return LegacySparkCompare(
+ spark_session,
+ base_tol,
+ compare_both_tol,
+ join_columns=["account_identifier"],
+ rel_tol=-0.2,
+ abs_tol=0.01,
+ )
+
+
+@pytest.fixture(scope="module", name="show_all_columns_and_match_rate")
+def show_all_columns_and_match_rate_fixture(base_tol, compare_both_tol, spark_session):
+ return LegacySparkCompare(
+ spark_session,
+ base_tol,
+ compare_both_tol,
+ join_columns=["account_identifier"],
+ show_all_columns=True,
+ match_rates=True,
+ )
+
+
+@pytest.fixture(scope="module", name="comparison_kd1")
+def comparison_known_diffs1(base_td, compare_source, spark_session):
+ return LegacySparkCompare(
+ spark_session,
+ base_td,
+ compare_source,
+ join_columns=[("acct", "ACCOUNT_IDENTIFIER"), ("acct_seq", "SEQ_NUMBER")],
+ column_mapping=[
+ ("stat_cd", "STATC"),
+ ("open_dt", "ACCOUNT_OPEN"),
+ ("cd", "CODE"),
+ ],
+ known_differences=[
+ {
+ "name": "Left-padded, four-digit numeric code",
+ "types": NUMERIC_SPARK_TYPES,
+ "transformation": "lpad(cast({input} AS bigint), 4, '0')",
+ },
+ {
+ "name": "Null to *2",
+ "types": ["string"],
+ "transformation": "case when {input} is null then '*2' else {input} end",
+ },
+ {
+ "name": "Julian date -> date",
+ "types": ["bigint"],
+ "transformation": "to_date(cast(unix_timestamp(cast({input} AS string), 'yyyyDDD') AS timestamp))",
+ },
+ ],
+ )
+
+
+@pytest.fixture(scope="module", name="comparison_kd2")
+def comparison_known_diffs2(base_td, compare_source, spark_session):
+ return LegacySparkCompare(
+ spark_session,
+ base_td,
+ compare_source,
+ join_columns=[("acct", "ACCOUNT_IDENTIFIER"), ("acct_seq", "SEQ_NUMBER")],
+ column_mapping=[
+ ("stat_cd", "STATC"),
+ ("open_dt", "ACCOUNT_OPEN"),
+ ("cd", "CODE"),
+ ],
+ known_differences=[
+ {
+ "name": "Left-padded, four-digit numeric code",
+ "types": NUMERIC_SPARK_TYPES,
+ "transformation": "lpad(cast({input} AS bigint), 4, '0')",
+ },
+ {
+ "name": "Null to *2",
+ "types": ["string"],
+ "transformation": "case when {input} is null then '*2' else {input} end",
+ },
+ ],
+ )
+
+
+@pytest.fixture(scope="module", name="comparison1")
+def comparison1_fixture(base_df1, compare_df1, spark_session):
+ return LegacySparkCompare(
+ spark_session,
+ base_df1,
+ compare_df1,
+ join_columns=["acct"],
+ cache_intermediates=CACHE_INTERMEDIATES,
+ )
+
+
+@pytest.fixture(scope="module", name="comparison2")
+def comparison2_fixture(base_df1, compare_df2, spark_session):
+ return LegacySparkCompare(
+ spark_session, base_df1, compare_df2, join_columns=["acct"]
+ )
+
+
+@pytest.fixture(scope="module", name="comparison3")
+def comparison3_fixture(base_df1, compare_df3, spark_session):
+ return LegacySparkCompare(
+ spark_session,
+ base_df1,
+ compare_df3,
+ join_columns=[("acct", "account_identifier")],
+ column_mapping=[
+ ("dollar_amt", "dollar_amount"),
+ ("float_fld", "float_field"),
+ ("date_fld", "date_field"),
+ ],
+ cache_intermediates=CACHE_INTERMEDIATES,
+ )
+
+
+@pytest.fixture(scope="module", name="comparison4")
+def comparison4_fixture(base_df2, compare_df1, spark_session):
+ return LegacySparkCompare(
+ spark_session,
+ base_df2,
+ compare_df1,
+ join_columns=["acct"],
+ column_mapping=[("super_duper_big_long_name", "name")],
+ )
+
+
+@pytest.fixture(scope="module", name="comparison_decimal")
+def comparison_decimal_fixture(base_decimal, compare_decimal, spark_session):
+ return LegacySparkCompare(
+ spark_session, base_decimal, compare_decimal, join_columns=["acct"]
+ )
+
+
+def test_absolute_tolerances(comparison_abs_tol):
+ stdout = io.StringIO()
+
+ comparison_abs_tol.report(file=stdout)
+ stdout.seek(0)
+ assert "****** Row Comparison ******" in stdout.getvalue()
+ assert "Number of rows with some columns unequal: 6" in stdout.getvalue()
+ assert "Number of rows with all columns equal: 7" in stdout.getvalue()
+ assert "Number of columns compared with some values unequal: 1" in stdout.getvalue()
+ assert "Number of columns compared with all values equal: 4" in stdout.getvalue()
+
+
+def test_relative_tolerances(comparison_rel_tol):
+ stdout = io.StringIO()
+
+ comparison_rel_tol.report(file=stdout)
+ stdout.seek(0)
+ assert "****** Row Comparison ******" in stdout.getvalue()
+ assert "Number of rows with some columns unequal: 6" in stdout.getvalue()
+ assert "Number of rows with all columns equal: 7" in stdout.getvalue()
+ assert "Number of columns compared with some values unequal: 1" in stdout.getvalue()
+ assert "Number of columns compared with all values equal: 4" in stdout.getvalue()
+
+
+def test_both_tolerances(comparison_both_tol):
+ stdout = io.StringIO()
+
+ comparison_both_tol.report(file=stdout)
+ stdout.seek(0)
+ assert "****** Row Comparison ******" in stdout.getvalue()
+ assert "Number of rows with some columns unequal: 6" in stdout.getvalue()
+ assert "Number of rows with all columns equal: 7" in stdout.getvalue()
+ assert "Number of columns compared with some values unequal: 1" in stdout.getvalue()
+ assert "Number of columns compared with all values equal: 4" in stdout.getvalue()
+
+
+def test_negative_tolerances(spark_session, base_tol, compare_both_tol):
+ with pytest.raises(ValueError, match="Please enter positive valued tolerances"):
+ comp = LegacySparkCompare(
+ spark_session,
+ base_tol,
+ compare_both_tol,
+ join_columns=["account_identifier"],
+ rel_tol=-0.2,
+ abs_tol=0.01,
+ )
+ comp.report()
+ pass
+
+
+def test_show_all_columns_and_match_rate(show_all_columns_and_match_rate):
+ stdout = io.StringIO()
+
+ show_all_columns_and_match_rate.report(file=stdout)
+
+ assert "****** Columns with Equal/Unequal Values ******" in stdout.getvalue()
+ assert (
+ "accnt_purge accnt_purge boolean boolean 13 0 100.00000"
+ in stdout.getvalue()
+ )
+ assert (
+ "date_field date_field date date 13 0 100.00000"
+ in stdout.getvalue()
+ )
+ assert (
+ "dollar_amount dollar_amount double double 3 10 23.07692"
+ in stdout.getvalue()
+ )
+ assert (
+ "float_field float_field double double 13 0 100.00000"
+ in stdout.getvalue()
+ )
+ assert (
+ "name name string string 13 0 100.00000"
+ in stdout.getvalue()
+ )
+
+
+def test_decimal_comparisons():
+ true_decimals = ["decimal", "decimal()", "decimal(20, 10)"]
+ assert all(v in NUMERIC_SPARK_TYPES for v in true_decimals)
+
+
+def test_decimal_comparator_acts_like_string():
+ acc = False
+ for t in NUMERIC_SPARK_TYPES:
+ acc = acc or (len(t) > 2 and t[0:3] == "dec")
+ assert acc
+
+
+def test_decimals_and_doubles_are_comparable():
+ assert _is_comparable("double", "decimal(10, 2)")
+
+
+def test_report_outputs_the_column_summary(comparison1):
+ stdout = io.StringIO()
+
+ comparison1.report(file=stdout)
+
+ assert "****** Column Summary ******" in stdout.getvalue()
+ assert "Number of columns in common with matching schemas: 3" in stdout.getvalue()
+ assert "Number of columns in common with schema differences: 1" in stdout.getvalue()
+ assert "Number of columns in base but not compare: 1" in stdout.getvalue()
+ assert "Number of columns in compare but not base: 1" in stdout.getvalue()
+
+
+def test_report_outputs_the_column_summary_for_identical_schemas(comparison2):
+ stdout = io.StringIO()
+
+ comparison2.report(file=stdout)
+
+ assert "****** Column Summary ******" in stdout.getvalue()
+ assert "Number of columns in common with matching schemas: 5" in stdout.getvalue()
+ assert "Number of columns in common with schema differences: 0" in stdout.getvalue()
+ assert "Number of columns in base but not compare: 0" in stdout.getvalue()
+ assert "Number of columns in compare but not base: 0" in stdout.getvalue()
+
+
+def test_report_outputs_the_column_summary_for_differently_named_columns(comparison3):
+ stdout = io.StringIO()
+
+ comparison3.report(file=stdout)
+
+ assert "****** Column Summary ******" in stdout.getvalue()
+ assert "Number of columns in common with matching schemas: 4" in stdout.getvalue()
+ assert "Number of columns in common with schema differences: 1" in stdout.getvalue()
+ assert "Number of columns in base but not compare: 0" in stdout.getvalue()
+ assert "Number of columns in compare but not base: 1" in stdout.getvalue()
+
+
+def test_report_outputs_the_row_summary(comparison1):
+ stdout = io.StringIO()
+
+ comparison1.report(file=stdout)
+
+ assert "****** Row Summary ******" in stdout.getvalue()
+ assert "Number of rows in common: 4" in stdout.getvalue()
+ assert "Number of rows in base but not compare: 1" in stdout.getvalue()
+ assert "Number of rows in compare but not base: 1" in stdout.getvalue()
+ assert "Number of duplicate rows found in base: 0" in stdout.getvalue()
+ assert "Number of duplicate rows found in compare: 1" in stdout.getvalue()
+
+
+def test_report_outputs_the_row_equality_comparison(comparison1):
+ stdout = io.StringIO()
+
+ comparison1.report(file=stdout)
+
+ assert "****** Row Comparison ******" in stdout.getvalue()
+ assert "Number of rows with some columns unequal: 3" in stdout.getvalue()
+ assert "Number of rows with all columns equal: 1" in stdout.getvalue()
+
+
+def test_report_outputs_the_row_summary_for_differently_named_columns(comparison3):
+ stdout = io.StringIO()
+
+ comparison3.report(file=stdout)
+
+ assert "****** Row Summary ******" in stdout.getvalue()
+ assert "Number of rows in common: 5" in stdout.getvalue()
+ assert "Number of rows in base but not compare: 0" in stdout.getvalue()
+ assert "Number of rows in compare but not base: 0" in stdout.getvalue()
+ assert "Number of duplicate rows found in base: 0" in stdout.getvalue()
+ assert "Number of duplicate rows found in compare: 0" in stdout.getvalue()
+
+
+def test_report_outputs_the_row_equality_comparison_for_differently_named_columns(
+ comparison3,
+):
+ stdout = io.StringIO()
+
+ comparison3.report(file=stdout)
+
+ assert "****** Row Comparison ******" in stdout.getvalue()
+ assert "Number of rows with some columns unequal: 3" in stdout.getvalue()
+ assert "Number of rows with all columns equal: 2" in stdout.getvalue()
+
+
+def test_report_outputs_column_detail_for_columns_in_only_one_dataframe(comparison1):
+ stdout = io.StringIO()
+
+ comparison1.report(file=stdout)
+ comparison1.report()
+ assert "****** Columns In Base Only ******" in stdout.getvalue()
+ r2 = r"""Column\s*Name \s* Dtype \n -+ \s+ -+ \ndate_fld \s+ date"""
+ assert re.search(r2, str(stdout.getvalue()), re.X) is not None
+
+
+def test_report_outputs_column_detail_for_columns_in_only_compare_dataframe(
+ comparison1,
+):
+ stdout = io.StringIO()
+
+ comparison1.report(file=stdout)
+ comparison1.report()
+ assert "****** Columns In Compare Only ******" in stdout.getvalue()
+ r2 = r"""Column\s*Name \s* Dtype \n -+ \s+ -+ \n accnt_purge \s+ boolean"""
+ assert re.search(r2, str(stdout.getvalue()), re.X) is not None
+
+
+def test_report_outputs_schema_difference_details(comparison1):
+ stdout = io.StringIO()
+
+ comparison1.report(file=stdout)
+
+ assert "****** Schema Differences ******" in stdout.getvalue()
+ assert re.search(
+ r"""Base\sColumn\sName \s+ Compare\sColumn\sName \s+ Base\sDtype \s+ Compare\sDtype \n
+ -+ \s+ -+ \s+ -+ \s+ -+ \n
+ dollar_amt \s+ dollar_amt \s+ bigint \s+ double""",
+ stdout.getvalue(),
+ re.X,
+ )
+
+
+def test_report_outputs_schema_difference_details_for_differently_named_columns(
+ comparison3,
+):
+ stdout = io.StringIO()
+
+ comparison3.report(file=stdout)
+
+ assert "****** Schema Differences ******" in stdout.getvalue()
+ assert re.search(
+ r"""Base\sColumn\sName \s+ Compare\sColumn\sName \s+ Base\sDtype \s+ Compare\sDtype \n
+ -+ \s+ -+ \s+ -+ \s+ -+ \n
+ dollar_amt \s+ dollar_amount \s+ bigint \s+ double""",
+ stdout.getvalue(),
+ re.X,
+ )
+
+
+def test_column_comparison_outputs_number_of_columns_with_differences(comparison1):
+ stdout = io.StringIO()
+
+ comparison1.report(file=stdout)
+
+ assert "****** Column Comparison ******" in stdout.getvalue()
+ assert "Number of columns compared with some values unequal: 3" in stdout.getvalue()
+ assert "Number of columns compared with all values equal: 0" in stdout.getvalue()
+
+
+def test_column_comparison_outputs_all_columns_equal_for_identical_dataframes(
+ comparison2,
+):
+ stdout = io.StringIO()
+
+ comparison2.report(file=stdout)
+
+ assert "****** Column Comparison ******" in stdout.getvalue()
+ assert "Number of columns compared with some values unequal: 0" in stdout.getvalue()
+ assert "Number of columns compared with all values equal: 4" in stdout.getvalue()
+
+
+def test_column_comparison_outputs_number_of_columns_with_differences_for_differently_named_columns(
+ comparison3,
+):
+ stdout = io.StringIO()
+
+ comparison3.report(file=stdout)
+
+ assert "****** Column Comparison ******" in stdout.getvalue()
+ assert "Number of columns compared with some values unequal: 3" in stdout.getvalue()
+ assert "Number of columns compared with all values equal: 1" in stdout.getvalue()
+
+
+def test_column_comparison_outputs_number_of_columns_with_differences_for_known_diffs(
+ comparison_kd1,
+):
+ stdout = io.StringIO()
+
+ comparison_kd1.report(file=stdout)
+
+ assert "****** Column Comparison ******" in stdout.getvalue()
+ assert (
+ "Number of columns compared with unexpected differences in some values: 1"
+ in stdout.getvalue()
+ )
+ assert (
+ "Number of columns compared with all values equal but known differences found: 2"
+ in stdout.getvalue()
+ )
+ assert (
+ "Number of columns compared with all values completely equal: 0"
+ in stdout.getvalue()
+ )
+
+
+def test_column_comparison_outputs_number_of_columns_with_differences_for_custom_known_diffs(
+ comparison_kd2,
+):
+ stdout = io.StringIO()
+
+ comparison_kd2.report(file=stdout)
+
+ assert "****** Column Comparison ******" in stdout.getvalue()
+ assert (
+ "Number of columns compared with unexpected differences in some values: 2"
+ in stdout.getvalue()
+ )
+ assert (
+ "Number of columns compared with all values equal but known differences found: 1"
+ in stdout.getvalue()
+ )
+ assert (
+ "Number of columns compared with all values completely equal: 0"
+ in stdout.getvalue()
+ )
+
+
+def test_columns_with_unequal_values_show_mismatch_counts(comparison1):
+ stdout = io.StringIO()
+
+ comparison1.report(file=stdout)
+
+ assert "****** Columns with Unequal Values ******" in stdout.getvalue()
+ assert re.search(
+ r"""Base\s*Column\s*Name \s+ Compare\s*Column\s*Name \s+ Base\s*Dtype \s+ Compare\sDtype \s*
+ \#\sMatches \s* \#\sMismatches \n
+ -+ \s+ -+ \s+ -+ \s+ -+ \s+ -+ \s+ -+""",
+ stdout.getvalue(),
+ re.X,
+ )
+ assert re.search(
+ r"""dollar_amt \s+ dollar_amt \s+ bigint \s+ double \s+ 2 \s+ 2""",
+ stdout.getvalue(),
+ re.X,
+ )
+ assert re.search(
+ r"""float_fld \s+ float_fld \s+ double \s+ double \s+ 1 \s+ 3""",
+ stdout.getvalue(),
+ re.X,
+ )
+ assert re.search(
+ r"""name \s+ name \s+ string \s+ string \s+ 3 \s+ 1""", stdout.getvalue(), re.X
+ )
+
+
+def test_columns_with_different_names_with_unequal_values_show_mismatch_counts(
+ comparison3,
+):
+ stdout = io.StringIO()
+
+ comparison3.report(file=stdout)
+
+ assert "****** Columns with Unequal Values ******" in stdout.getvalue()
+ assert re.search(
+ r"""Base\s*Column\s*Name \s+ Compare\s*Column\s*Name \s+ Base\s*Dtype \s+ Compare\sDtype \s*
+ \#\sMatches \s* \#\sMismatches \n
+ -+ \s+ -+ \s+ -+ \s+ -+ \s+ -+ \s+ -+""",
+ stdout.getvalue(),
+ re.X,
+ )
+ assert re.search(
+ r"""dollar_amt \s+ dollar_amount \s+ bigint \s+ double \s+ 2 \s+ 3""",
+ stdout.getvalue(),
+ re.X,
+ )
+ assert re.search(
+ r"""float_fld \s+ float_field \s+ double \s+ double \s+ 4 \s+ 1""",
+ stdout.getvalue(),
+ re.X,
+ )
+ assert re.search(
+ r"""name \s+ name \s+ string \s+ string \s+ 4 \s+ 1""", stdout.getvalue(), re.X
+ )
+
+
+def test_rows_only_base_returns_a_dataframe_with_rows_only_in_base(
+ spark_session, comparison1
+):
+ # require schema if contains only 1 row and contain field value as None
+ schema = StructType(
+ [
+ StructField("acct", LongType(), True),
+ StructField("date_fld", DateType(), True),
+ StructField("dollar_amt", LongType(), True),
+ StructField("float_fld", DoubleType(), True),
+ StructField("name", StringType(), True),
+ ]
+ )
+ expected_df = spark_session.createDataFrame(
+ [
+ Row(
+ acct=10000001239,
+ date_fld=datetime.date(2017, 1, 1),
+ dollar_amt=1,
+ float_fld=None,
+ name="Lucille Bluth",
+ )
+ ],
+ schema,
+ )
+ assert comparison1.rows_only_base.count() == 1
+ assert (
+ expected_df.union(
+ comparison1.rows_only_base.select(
+ "acct", "date_fld", "dollar_amt", "float_fld", "name"
+ )
+ )
+ .distinct()
+ .count()
+ == 1
+ )
+
+
+def test_rows_only_compare_returns_a_dataframe_with_rows_only_in_compare(
+ spark_session, comparison1
+):
+ expected_df = spark_session.createDataFrame(
+ [
+ Row(
+ acct=10000001238,
+ dollar_amt=1.05,
+ name="Loose Seal Bluth",
+ float_fld=111.0,
+ accnt_purge=True,
+ )
+ ]
+ )
+
+ assert comparison1.rows_only_compare.count() == 1
+ assert expected_df.union(comparison1.rows_only_compare).distinct().count() == 1
+
+
+def test_rows_both_mismatch_returns_a_dataframe_with_rows_where_variables_mismatched(
+ spark_session, comparison1
+):
+ expected_df = spark_session.createDataFrame(
+ [
+ Row(
+ accnt_purge=False,
+ acct=10000001234,
+ date_fld=datetime.date(2017, 1, 1),
+ dollar_amt_base=123,
+ dollar_amt_compare=123.4,
+ dollar_amt_match=False,
+ float_fld_base=14530.1555,
+ float_fld_compare=14530.155,
+ float_fld_match=False,
+ name_base="George Maharis",
+ name_compare="George Michael Bluth",
+ name_match=False,
+ ),
+ Row(
+ accnt_purge=False,
+ acct=10000001235,
+ date_fld=datetime.date(2017, 1, 1),
+ dollar_amt_base=0,
+ dollar_amt_compare=0.45,
+ dollar_amt_match=False,
+ float_fld_base=1.0,
+ float_fld_compare=None,
+ float_fld_match=False,
+ name_base="Michael Bluth",
+ name_compare="Michael Bluth",
+ name_match=True,
+ ),
+ Row(
+ accnt_purge=False,
+ acct=10000001236,
+ date_fld=datetime.date(2017, 1, 1),
+ dollar_amt_base=1345,
+ dollar_amt_compare=1345.0,
+ dollar_amt_match=True,
+ float_fld_base=None,
+ float_fld_compare=1.0,
+ float_fld_match=False,
+ name_base="George Bluth",
+ name_compare="George Bluth",
+ name_match=True,
+ ),
+ ]
+ )
+
+ assert comparison1.rows_both_mismatch.count() == 3
+ assert expected_df.union(comparison1.rows_both_mismatch).distinct().count() == 3
+
+
+def test_rows_both_mismatch_only_includes_rows_with_true_mismatches_when_known_diffs_are_present(
+ spark_session, comparison_kd1
+):
+ expected_df = spark_session.createDataFrame(
+ [
+ Row(
+ acct=10000001237,
+ acct_seq=0,
+ cd_base="0004",
+ cd_compare=4.0,
+ cd_match=True,
+ cd_match_type="KNOWN_DIFFERENCE",
+ open_dt_base=datetime.date(2017, 5, 4),
+ open_dt_compare=2017124,
+ open_dt_match=True,
+ open_dt_match_type="KNOWN_DIFFERENCE",
+ stat_cd_base="*2",
+ stat_cd_compare="V3",
+ stat_cd_match=False,
+ stat_cd_match_type="MISMATCH",
+ )
+ ]
+ )
+ assert comparison_kd1.rows_both_mismatch.count() == 1
+ assert expected_df.union(comparison_kd1.rows_both_mismatch).distinct().count() == 1
+
+
+def test_rows_both_all_returns_a_dataframe_with_all_rows_in_both_dataframes(
+ spark_session, comparison1
+):
+ expected_df = spark_session.createDataFrame(
+ [
+ Row(
+ accnt_purge=False,
+ acct=10000001234,
+ date_fld=datetime.date(2017, 1, 1),
+ dollar_amt_base=123,
+ dollar_amt_compare=123.4,
+ dollar_amt_match=False,
+ float_fld_base=14530.1555,
+ float_fld_compare=14530.155,
+ float_fld_match=False,
+ name_base="George Maharis",
+ name_compare="George Michael Bluth",
+ name_match=False,
+ ),
+ Row(
+ accnt_purge=False,
+ acct=10000001235,
+ date_fld=datetime.date(2017, 1, 1),
+ dollar_amt_base=0,
+ dollar_amt_compare=0.45,
+ dollar_amt_match=False,
+ float_fld_base=1.0,
+ float_fld_compare=None,
+ float_fld_match=False,
+ name_base="Michael Bluth",
+ name_compare="Michael Bluth",
+ name_match=True,
+ ),
+ Row(
+ accnt_purge=False,
+ acct=10000001236,
+ date_fld=datetime.date(2017, 1, 1),
+ dollar_amt_base=1345,
+ dollar_amt_compare=1345.0,
+ dollar_amt_match=True,
+ float_fld_base=None,
+ float_fld_compare=1.0,
+ float_fld_match=False,
+ name_base="George Bluth",
+ name_compare="George Bluth",
+ name_match=True,
+ ),
+ Row(
+ accnt_purge=False,
+ acct=10000001237,
+ date_fld=datetime.date(2017, 1, 1),
+ dollar_amt_base=123456,
+ dollar_amt_compare=123456.0,
+ dollar_amt_match=True,
+ float_fld_base=345.12,
+ float_fld_compare=345.12,
+ float_fld_match=True,
+ name_base="Bob Loblaw",
+ name_compare="Bob Loblaw",
+ name_match=True,
+ ),
+ ]
+ )
+
+ assert comparison1.rows_both_all.count() == 4
+ assert expected_df.union(comparison1.rows_both_all).distinct().count() == 4
+
+
+def test_rows_both_all_shows_known_diffs_flag_and_known_diffs_count_as_matches(
+ spark_session, comparison_kd1
+):
+ expected_df = spark_session.createDataFrame(
+ [
+ Row(
+ acct=10000001234,
+ acct_seq=0,
+ cd_base="0001",
+ cd_compare=1.0,
+ cd_match=True,
+ cd_match_type="KNOWN_DIFFERENCE",
+ open_dt_base=datetime.date(2017, 5, 1),
+ open_dt_compare=2017121,
+ open_dt_match=True,
+ open_dt_match_type="KNOWN_DIFFERENCE",
+ stat_cd_base="*2",
+ stat_cd_compare=None,
+ stat_cd_match=True,
+ stat_cd_match_type="KNOWN_DIFFERENCE",
+ ),
+ Row(
+ acct=10000001235,
+ acct_seq=0,
+ cd_base="0002",
+ cd_compare=2.0,
+ cd_match=True,
+ cd_match_type="KNOWN_DIFFERENCE",
+ open_dt_base=datetime.date(2017, 5, 2),
+ open_dt_compare=2017122,
+ open_dt_match=True,
+ open_dt_match_type="KNOWN_DIFFERENCE",
+ stat_cd_base="V1",
+ stat_cd_compare="V1",
+ stat_cd_match=True,
+ stat_cd_match_type="MATCH",
+ ),
+ Row(
+ acct=10000001236,
+ acct_seq=0,
+ cd_base="0003",
+ cd_compare=3.0,
+ cd_match=True,
+ cd_match_type="KNOWN_DIFFERENCE",
+ open_dt_base=datetime.date(2017, 5, 3),
+ open_dt_compare=2017123,
+ open_dt_match=True,
+ open_dt_match_type="KNOWN_DIFFERENCE",
+ stat_cd_base="V2",
+ stat_cd_compare="V2",
+ stat_cd_match=True,
+ stat_cd_match_type="MATCH",
+ ),
+ Row(
+ acct=10000001237,
+ acct_seq=0,
+ cd_base="0004",
+ cd_compare=4.0,
+ cd_match=True,
+ cd_match_type="KNOWN_DIFFERENCE",
+ open_dt_base=datetime.date(2017, 5, 4),
+ open_dt_compare=2017124,
+ open_dt_match=True,
+ open_dt_match_type="KNOWN_DIFFERENCE",
+ stat_cd_base="*2",
+ stat_cd_compare="V3",
+ stat_cd_match=False,
+ stat_cd_match_type="MISMATCH",
+ ),
+ Row(
+ acct=10000001238,
+ acct_seq=0,
+ cd_base="0005",
+ cd_compare=5.0,
+ cd_match=True,
+ cd_match_type="KNOWN_DIFFERENCE",
+ open_dt_base=datetime.date(2017, 5, 5),
+ open_dt_compare=2017125,
+ open_dt_match=True,
+ open_dt_match_type="KNOWN_DIFFERENCE",
+ stat_cd_base="*2",
+ stat_cd_compare=None,
+ stat_cd_match=True,
+ stat_cd_match_type="KNOWN_DIFFERENCE",
+ ),
+ ]
+ )
+
+ assert comparison_kd1.rows_both_all.count() == 5
+ assert expected_df.union(comparison_kd1.rows_both_all).distinct().count() == 5
+
+
+def test_rows_both_all_returns_a_dataframe_with_all_rows_in_identical_dataframes(
+ spark_session, comparison2
+):
+ expected_df = spark_session.createDataFrame(
+ [
+ Row(
+ acct=10000001234,
+ date_fld_base=datetime.date(2017, 1, 1),
+ date_fld_compare=datetime.date(2017, 1, 1),
+ date_fld_match=True,
+ dollar_amt_base=123,
+ dollar_amt_compare=123,
+ dollar_amt_match=True,
+ float_fld_base=14530.1555,
+ float_fld_compare=14530.1555,
+ float_fld_match=True,
+ name_base="George Maharis",
+ name_compare="George Maharis",
+ name_match=True,
+ ),
+ Row(
+ acct=10000001235,
+ date_fld_base=datetime.date(2017, 1, 1),
+ date_fld_compare=datetime.date(2017, 1, 1),
+ date_fld_match=True,
+ dollar_amt_base=0,
+ dollar_amt_compare=0,
+ dollar_amt_match=True,
+ float_fld_base=1.0,
+ float_fld_compare=1.0,
+ float_fld_match=True,
+ name_base="Michael Bluth",
+ name_compare="Michael Bluth",
+ name_match=True,
+ ),
+ Row(
+ acct=10000001236,
+ date_fld_base=datetime.date(2017, 1, 1),
+ date_fld_compare=datetime.date(2017, 1, 1),
+ date_fld_match=True,
+ dollar_amt_base=1345,
+ dollar_amt_compare=1345,
+ dollar_amt_match=True,
+ float_fld_base=None,
+ float_fld_compare=None,
+ float_fld_match=True,
+ name_base="George Bluth",
+ name_compare="George Bluth",
+ name_match=True,
+ ),
+ Row(
+ acct=10000001237,
+ date_fld_base=datetime.date(2017, 1, 1),
+ date_fld_compare=datetime.date(2017, 1, 1),
+ date_fld_match=True,
+ dollar_amt_base=123456,
+ dollar_amt_compare=123456,
+ dollar_amt_match=True,
+ float_fld_base=345.12,
+ float_fld_compare=345.12,
+ float_fld_match=True,
+ name_base="Bob Loblaw",
+ name_compare="Bob Loblaw",
+ name_match=True,
+ ),
+ Row(
+ acct=10000001239,
+ date_fld_base=datetime.date(2017, 1, 1),
+ date_fld_compare=datetime.date(2017, 1, 1),
+ date_fld_match=True,
+ dollar_amt_base=1,
+ dollar_amt_compare=1,
+ dollar_amt_match=True,
+ float_fld_base=None,
+ float_fld_compare=None,
+ float_fld_match=True,
+ name_base="Lucille Bluth",
+ name_compare="Lucille Bluth",
+ name_match=True,
+ ),
+ ]
+ )
+
+ assert comparison2.rows_both_all.count() == 5
+ assert expected_df.union(comparison2.rows_both_all).distinct().count() == 5
+
+
+def test_rows_both_all_returns_all_rows_in_both_dataframes_for_differently_named_columns(
+ spark_session, comparison3
+):
+ expected_df = spark_session.createDataFrame(
+ [
+ Row(
+ accnt_purge=False,
+ acct=10000001234,
+ date_fld_base=datetime.date(2017, 1, 1),
+ date_fld_compare=datetime.date(2017, 1, 1),
+ date_fld_match=True,
+ dollar_amt_base=123,
+ dollar_amt_compare=123.4,
+ dollar_amt_match=False,
+ float_fld_base=14530.1555,
+ float_fld_compare=14530.155,
+ float_fld_match=False,
+ name_base="George Maharis",
+ name_compare="George Michael Bluth",
+ name_match=False,
+ ),
+ Row(
+ accnt_purge=False,
+ acct=10000001235,
+ date_fld_base=datetime.date(2017, 1, 1),
+ date_fld_compare=datetime.date(2017, 1, 1),
+ date_fld_match=True,
+ dollar_amt_base=0,
+ dollar_amt_compare=0.45,
+ dollar_amt_match=False,
+ float_fld_base=1.0,
+ float_fld_compare=1.0,
+ float_fld_match=True,
+ name_base="Michael Bluth",
+ name_compare="Michael Bluth",
+ name_match=True,
+ ),
+ Row(
+ accnt_purge=False,
+ acct=10000001236,
+ date_fld_base=datetime.date(2017, 1, 1),
+ date_fld_compare=datetime.date(2017, 1, 1),
+ date_fld_match=True,
+ dollar_amt_base=1345,
+ dollar_amt_compare=1345.0,
+ dollar_amt_match=True,
+ float_fld_base=None,
+ float_fld_compare=None,
+ float_fld_match=True,
+ name_base="George Bluth",
+ name_compare="George Bluth",
+ name_match=True,
+ ),
+ Row(
+ accnt_purge=False,
+ acct=10000001237,
+ date_fld_base=datetime.date(2017, 1, 1),
+ date_fld_compare=datetime.date(2017, 1, 1),
+ date_fld_match=True,
+ dollar_amt_base=123456,
+ dollar_amt_compare=123456.0,
+ dollar_amt_match=True,
+ float_fld_base=345.12,
+ float_fld_compare=345.12,
+ float_fld_match=True,
+ name_base="Bob Loblaw",
+ name_compare="Bob Loblaw",
+ name_match=True,
+ ),
+ Row(
+ accnt_purge=True,
+ acct=10000001239,
+ date_fld_base=datetime.date(2017, 1, 1),
+ date_fld_compare=datetime.date(2017, 1, 1),
+ date_fld_match=True,
+ dollar_amt_base=1,
+ dollar_amt_compare=1.05,
+ dollar_amt_match=False,
+ float_fld_base=None,
+ float_fld_compare=None,
+ float_fld_match=True,
+ name_base="Lucille Bluth",
+ name_compare="Lucille Bluth",
+ name_match=True,
+ ),
+ ]
+ )
+
+ assert comparison3.rows_both_all.count() == 5
+ assert expected_df.union(comparison3.rows_both_all).distinct().count() == 5
+
+
+def test_columns_with_unequal_values_text_is_aligned(comparison4):
+ stdout = io.StringIO()
+
+ comparison4.report(file=stdout)
+ stdout.seek(0) # Back up to the beginning of the stream
+
+ text_alignment_validator(
+ report=stdout,
+ section_start="****** Columns with Unequal Values ******",
+ section_end="\n",
+ left_indices=(1, 2, 3, 4),
+ right_indices=(5, 6),
+ column_regexes=[
+ r"""(Base\sColumn\sName) \s+ (Compare\sColumn\sName) \s+ (Base\sDtype) \s+ (Compare\sDtype) \s+
+ (\#\sMatches) \s+ (\#\sMismatches)""",
+ r"""(-+) \s+ (-+) \s+ (-+) \s+ (-+) \s+ (-+) \s+ (-+)""",
+ r"""(dollar_amt) \s+ (dollar_amt) \s+ (bigint) \s+ (double) \s+ (2) \s+ (2)""",
+ r"""(float_fld) \s+ (float_fld) \s+ (double) \s+ (double) \s+ (1) \s+ (3)""",
+ r"""(super_duper_big_long_name) \s+ (name) \s+ (string) \s+ (string) \s+ (3) \s+ (1)\s*""",
+ ],
+ )
+
+
+def test_columns_with_unequal_values_text_is_aligned_with_known_differences(
+ comparison_kd1,
+):
+ stdout = io.StringIO()
+
+ comparison_kd1.report(file=stdout)
+ stdout.seek(0) # Back up to the beginning of the stream
+
+ text_alignment_validator(
+ report=stdout,
+ section_start="****** Columns with Unequal Values ******",
+ section_end="\n",
+ left_indices=(1, 2, 3, 4),
+ right_indices=(5, 6, 7),
+ column_regexes=[
+ r"""(Base\sColumn\sName) \s+ (Compare\sColumn\sName) \s+ (Base\sDtype) \s+ (Compare\sDtype) \s+
+ (\#\sMatches) \s+ (\#\sKnown\sDiffs) \s+ (\#\sMismatches)""",
+ r"""(-+) \s+ (-+) \s+ (-+) \s+ (-+) \s+ (-+) \s+ (-+) \s+ (-+)""",
+ r"""(stat_cd) \s+ (STATC) \s+ (string) \s+ (string) \s+ (2) \s+ (2) \s+ (1)""",
+ r"""(open_dt) \s+ (ACCOUNT_OPEN) \s+ (date) \s+ (bigint) \s+ (0) \s+ (5) \s+ (0)""",
+ r"""(cd) \s+ (CODE) \s+ (string) \s+ (double) \s+ (0) \s+ (5) \s+ (0)\s*""",
+ ],
+ )
+
+
+def test_columns_with_unequal_values_text_is_aligned_with_custom_known_differences(
+ comparison_kd2,
+):
+ stdout = io.StringIO()
+
+ comparison_kd2.report(file=stdout)
+ stdout.seek(0) # Back up to the beginning of the stream
+
+ text_alignment_validator(
+ report=stdout,
+ section_start="****** Columns with Unequal Values ******",
+ section_end="\n",
+ left_indices=(1, 2, 3, 4),
+ right_indices=(5, 6, 7),
+ column_regexes=[
+ r"""(Base\sColumn\sName) \s+ (Compare\sColumn\sName) \s+ (Base\sDtype) \s+ (Compare\sDtype) \s+
+ (\#\sMatches) \s+ (\#\sKnown\sDiffs) \s+ (\#\sMismatches)""",
+ r"""(-+) \s+ (-+) \s+ (-+) \s+ (-+) \s+ (-+) \s+ (-+) \s+ (-+)""",
+ r"""(stat_cd) \s+ (STATC) \s+ (string) \s+ (string) \s+ (2) \s+ (2) \s+ (1)""",
+ r"""(open_dt) \s+ (ACCOUNT_OPEN) \s+ (date) \s+ (bigint) \s+ (0) \s+ (0) \s+ (5)""",
+ r"""(cd) \s+ (CODE) \s+ (string) \s+ (double) \s+ (0) \s+ (5) \s+ (0)\s*""",
+ ],
+ )
+
+
+def test_columns_with_unequal_values_text_is_aligned_for_decimals(comparison_decimal):
+ stdout = io.StringIO()
+
+ comparison_decimal.report(file=stdout)
+ stdout.seek(0) # Back up to the beginning of the stream
+
+ text_alignment_validator(
+ report=stdout,
+ section_start="****** Columns with Unequal Values ******",
+ section_end="\n",
+ left_indices=(1, 2, 3, 4),
+ right_indices=(5, 6),
+ column_regexes=[
+ r"""(Base\sColumn\sName) \s+ (Compare\sColumn\sName) \s+ (Base\sDtype) \s+ (Compare\sDtype) \s+
+ (\#\sMatches) \s+ (\#\sMismatches)""",
+ r"""(-+) \s+ (-+) \s+ (-+) \s+ (-+) \s+ (-+) \s+ (-+)""",
+ r"""(dollar_amt) \s+ (dollar_amt) \s+ (decimal\(8,2\)) \s+ (double) \s+ (1) \s+ (1)""",
+ ],
+ )
+
+
+def test_schema_differences_text_is_aligned(comparison4):
+ stdout = io.StringIO()
+
+ comparison4.report(file=stdout)
+ comparison4.report()
+ stdout.seek(0) # Back up to the beginning of the stream
+
+ text_alignment_validator(
+ report=stdout,
+ section_start="****** Schema Differences ******",
+ section_end="\n",
+ left_indices=(1, 2, 3, 4),
+ right_indices=(),
+ column_regexes=[
+ r"""(Base\sColumn\sName) \s+ (Compare\sColumn\sName) \s+ (Base\sDtype) \s+ (Compare\sDtype)""",
+ r"""(-+) \s+ (-+) \s+ (-+) \s+ (-+)""",
+ r"""(dollar_amt) \s+ (dollar_amt) \s+ (bigint) \s+ (double)""",
+ ],
+ )
+
+
+def test_schema_differences_text_is_aligned_for_decimals(comparison_decimal):
+ stdout = io.StringIO()
+
+ comparison_decimal.report(file=stdout)
+ stdout.seek(0) # Back up to the beginning of the stream
+
+ text_alignment_validator(
+ report=stdout,
+ section_start="****** Schema Differences ******",
+ section_end="\n",
+ left_indices=(1, 2, 3, 4),
+ right_indices=(),
+ column_regexes=[
+ r"""(Base\sColumn\sName) \s+ (Compare\sColumn\sName) \s+ (Base\sDtype) \s+ (Compare\sDtype)""",
+ r"""(-+) \s+ (-+) \s+ (-+) \s+ (-+)""",
+ r"""(dollar_amt) \s+ (dollar_amt) \s+ (decimal\(8,2\)) \s+ (double)""",
+ ],
+ )
+
+
+def test_base_only_columns_text_is_aligned(comparison4):
+ stdout = io.StringIO()
+
+ comparison4.report(file=stdout)
+ stdout.seek(0) # Back up to the beginning of the stream
+
+ text_alignment_validator(
+ report=stdout,
+ section_start="****** Columns In Base Only ******",
+ section_end="\n",
+ left_indices=(1, 2),
+ right_indices=(),
+ column_regexes=[
+ r"""(Column\sName) \s+ (Dtype)""",
+ r"""(-+) \s+ (-+)""",
+ r"""(date_fld) \s+ (date)""",
+ ],
+ )
+
+
+def test_compare_only_columns_text_is_aligned(comparison4):
+ stdout = io.StringIO()
+
+ comparison4.report(file=stdout)
+ stdout.seek(0) # Back up to the beginning of the stream
+
+ text_alignment_validator(
+ report=stdout,
+ section_start="****** Columns In Compare Only ******",
+ section_end="\n",
+ left_indices=(1, 2),
+ right_indices=(),
+ column_regexes=[
+ r"""(Column\sName) \s+ (Dtype)""",
+ r"""(-+) \s+ (-+)""",
+ r"""(accnt_purge) \s+ (boolean)""",
+ ],
+ )
+
+
+def text_alignment_validator(
+ report, section_start, section_end, left_indices, right_indices, column_regexes
+):
+ r"""Check to make sure that report output columns are vertically aligned.
+
+ Parameters
+ ----------
+ report: An iterable returning lines of report output to be validated.
+ section_start: A string that represents the beginning of the section to be validated.
+ section_end: A string that represents the end of the section to be validated.
+ left_indices: The match group indexes (starting with 1) that should be left-aligned
+ in the output column.
+ right_indices: The match group indexes (starting with 1) that should be right-aligned
+ in the output column.
+ column_regexes: A list of regular expressions representing the expected output, with
+ each column enclosed with parentheses to return a match. The regular expression will
+ use the "X" flag, so it may contain whitespace, and any whitespace to be matched
+ should be explicitly given with \s. The first line will represent the alignments
+ that are expected in the following lines. The number of match groups should cover
+ all of the indices given in left/right_indices.
+
+ Runs assertions for every match group specified by left/right_indices to ensure that
+ all lines past the first are either left- or right-aligned with the same match group
+ on the first line.
+ """
+
+ at_column_section = False
+ processed_first_line = False
+ match_positions = [None] * (len(left_indices + right_indices) + 1)
+
+ for line in report:
+ if at_column_section:
+ if line == section_end: # Detect end of section and stop
+ break
+
+ if (
+ not processed_first_line
+ ): # First line in section - capture text start/end positions
+ matches = re.search(column_regexes[0], line, re.X)
+ assert matches is not None # Make sure we found at least this...
+
+ for n in left_indices:
+ match_positions[n] = matches.start(n)
+ for n in right_indices:
+ match_positions[n] = matches.end(n)
+ processed_first_line = True
+ else: # Match the stuff after the header text
+ match = None
+ for regex in column_regexes[1:]:
+ match = re.search(regex, line, re.X)
+ if match:
+ break
+
+ if not match:
+ raise AssertionError(f'Did not find a match for line: "{line}"')
+
+ for n in left_indices:
+ assert match_positions[n] == match.start(n)
+ for n in right_indices:
+ assert match_positions[n] == match.end(n)
+
+ if not at_column_section and section_start in line:
+ at_column_section = True
+
+
+def test_unicode_columns(spark_session):
+ df1 = spark_session.createDataFrame(
+ [
+ (1, "foo", "test"),
+ (2, "bar", "test"),
+ ],
+ ["id", "例", "予測対象日"],
+ )
+ df2 = spark_session.createDataFrame(
+ [
+ (1, "foo", "test"),
+ (2, "baz", "test"),
+ ],
+ ["id", "例", "予測対象日"],
+ )
+ compare = LegacySparkCompare(spark_session, df1, df2, join_columns=["例"])
+ # Just render the report to make sure it renders.
+ compare.report()
diff --git a/tests/test_polars.py b/tests/test_polars.py
index aabbcad1..679a9ab7 100644
--- a/tests/test_polars.py
+++ b/tests/test_polars.py
@@ -1231,7 +1231,7 @@ def test_dupes_with_nulls():
),
(
pl.DataFrame(
- {"a": [datetime(2018, 1, 1), np.nan, np.nan], "b": ["1", "2", "2"]}
+ {"a": [datetime(2018, 1, 1), None, None], "b": ["1", "2", "2"]}
),
pl.Series([1, 1, 2]),
),
diff --git a/tests/test_spark.py b/tests/test_spark.py
index af8aa8f3..88acc9a0 100644
--- a/tests/test_spark.py
+++ b/tests/test_spark.py
@@ -13,2103 +13,1311 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import datetime
-import io
+"""
+Testing out the datacompy functionality
+"""
+
import logging
import re
+import sys
+from datetime import datetime
from decimal import Decimal
+from io import StringIO
+from unittest import mock
+import numpy as np
+import pandas as pd
import pytest
+from pytest import raises
pytest.importorskip("pyspark")
-from pyspark.sql import Row, SparkSession
-from pyspark.sql.types import (
- DateType,
- DecimalType,
- DoubleType,
- LongType,
- StringType,
- StructField,
- StructType,
-)
+import pyspark.pandas as ps # noqa: E402
+from pandas.testing import assert_series_equal # noqa: E402
-import datacompy
-from datacompy import SparkCompare
-from datacompy.spark import _is_comparable
+from datacompy.spark import ( # noqa: E402
+ SparkCompare,
+ calculate_max_diff,
+ columns_equal,
+ generate_id_within_group,
+ temp_column_name,
+)
-# Turn off py4j debug messages for all tests in this module
-logging.getLogger("py4j").setLevel(logging.INFO)
+logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
-CACHE_INTERMEDIATES = True
+pandas_version = pytest.mark.skipif(
+ pd.__version__ >= "2.0.0", reason="Pandas 2 is currently not supported"
+)
-# Declare fixtures
-# (if we need to use these in other modules, move to conftest.py)
-@pytest.fixture(scope="module", name="spark")
-def spark_fixture():
- spark = (
- SparkSession.builder.master("local[2]")
- .config("spark.driver.bindAddress", "127.0.0.1")
- .appName("pytest")
- .getOrCreate()
+pd.DataFrame.iteritems = pd.DataFrame.items # Pandas 2+ compatability
+np.bool = np.bool_ # Numpy 1.24.3+ comptability
+
+
+@pandas_version
+def test_numeric_columns_equal_abs():
+ data = """a|b|expected
+1|1|True
+2|2.1|True
+3|4|False
+4|NULL|False
+NULL|4|False
+NULL|NULL|True"""
+
+ df = ps.from_pandas(pd.read_csv(StringIO(data), sep="|"))
+ actual_out = columns_equal(df.a, df.b, abs_tol=0.2)
+ expect_out = df["expected"]
+ assert_series_equal(
+ expect_out.to_pandas(), actual_out.to_pandas(), check_names=False
)
- yield spark
- spark.stop()
-
-
-@pytest.fixture(scope="module", name="base_df1")
-def base_df1_fixture(spark):
- mock_data = [
- Row(
- acct=10000001234,
- dollar_amt=123,
- name="George Maharis",
- float_fld=14530.1555,
- date_fld=datetime.date(2017, 1, 1),
- ),
- Row(
- acct=10000001235,
- dollar_amt=0,
- name="Michael Bluth",
- float_fld=1.0,
- date_fld=datetime.date(2017, 1, 1),
- ),
- Row(
- acct=10000001236,
- dollar_amt=1345,
- name="George Bluth",
- float_fld=None,
- date_fld=datetime.date(2017, 1, 1),
- ),
- Row(
- acct=10000001237,
- dollar_amt=123456,
- name="Bob Loblaw",
- float_fld=345.12,
- date_fld=datetime.date(2017, 1, 1),
- ),
- Row(
- acct=10000001239,
- dollar_amt=1,
- name="Lucille Bluth",
- float_fld=None,
- date_fld=datetime.date(2017, 1, 1),
- ),
- ]
-
- return spark.createDataFrame(mock_data)
-
-
-@pytest.fixture(scope="module", name="base_df2")
-def base_df2_fixture(spark):
- mock_data = [
- Row(
- acct=10000001234,
- dollar_amt=123,
- super_duper_big_long_name="George Maharis",
- float_fld=14530.1555,
- date_fld=datetime.date(2017, 1, 1),
- ),
- Row(
- acct=10000001235,
- dollar_amt=0,
- super_duper_big_long_name="Michael Bluth",
- float_fld=1.0,
- date_fld=datetime.date(2017, 1, 1),
- ),
- Row(
- acct=10000001236,
- dollar_amt=1345,
- super_duper_big_long_name="George Bluth",
- float_fld=None,
- date_fld=datetime.date(2017, 1, 1),
- ),
- Row(
- acct=10000001237,
- dollar_amt=123456,
- super_duper_big_long_name="Bob Loblaw",
- float_fld=345.12,
- date_fld=datetime.date(2017, 1, 1),
- ),
- Row(
- acct=10000001239,
- dollar_amt=1,
- super_duper_big_long_name="Lucille Bluth",
- float_fld=None,
- date_fld=datetime.date(2017, 1, 1),
- ),
- ]
-
- return spark.createDataFrame(mock_data)
-
-
-@pytest.fixture(scope="module", name="compare_df1")
-def compare_df1_fixture(spark):
- mock_data2 = [
- Row(
- acct=10000001234,
- dollar_amt=123.4,
- name="George Michael Bluth",
- float_fld=14530.155,
- accnt_purge=False,
- ),
- Row(
- acct=10000001235,
- dollar_amt=0.45,
- name="Michael Bluth",
- float_fld=None,
- accnt_purge=False,
- ),
- Row(
- acct=10000001236,
- dollar_amt=1345.0,
- name="George Bluth",
- float_fld=1.0,
- accnt_purge=False,
- ),
- Row(
- acct=10000001237,
- dollar_amt=123456.0,
- name="Bob Loblaw",
- float_fld=345.12,
- accnt_purge=False,
- ),
- Row(
- acct=10000001238,
- dollar_amt=1.05,
- name="Loose Seal Bluth",
- float_fld=111.0,
- accnt_purge=True,
- ),
- Row(
- acct=10000001238,
- dollar_amt=1.05,
- name="Loose Seal Bluth",
- float_fld=111.0,
- accnt_purge=True,
- ),
- ]
-
- return spark.createDataFrame(mock_data2)
-
-
-@pytest.fixture(scope="module", name="compare_df2")
-def compare_df2_fixture(spark):
- mock_data = [
- Row(
- acct=10000001234,
- dollar_amt=123,
- name="George Maharis",
- float_fld=14530.1555,
- date_fld=datetime.date(2017, 1, 1),
- ),
- Row(
- acct=10000001235,
- dollar_amt=0,
- name="Michael Bluth",
- float_fld=1.0,
- date_fld=datetime.date(2017, 1, 1),
- ),
- Row(
- acct=10000001236,
- dollar_amt=1345,
- name="George Bluth",
- float_fld=None,
- date_fld=datetime.date(2017, 1, 1),
- ),
- Row(
- acct=10000001237,
- dollar_amt=123456,
- name="Bob Loblaw",
- float_fld=345.12,
- date_fld=datetime.date(2017, 1, 1),
- ),
- Row(
- acct=10000001239,
- dollar_amt=1,
- name="Lucille Bluth",
- float_fld=None,
- date_fld=datetime.date(2017, 1, 1),
- ),
- ]
-
- return spark.createDataFrame(mock_data)
-
-
-@pytest.fixture(scope="module", name="compare_df3")
-def compare_df3_fixture(spark):
- mock_data2 = [
- Row(
- account_identifier=10000001234,
- dollar_amount=123.4,
- name="George Michael Bluth",
- float_field=14530.155,
- date_field=datetime.date(2017, 1, 1),
- accnt_purge=False,
- ),
- Row(
- account_identifier=10000001235,
- dollar_amount=0.45,
- name="Michael Bluth",
- float_field=1.0,
- date_field=datetime.date(2017, 1, 1),
- accnt_purge=False,
- ),
- Row(
- account_identifier=10000001236,
- dollar_amount=1345.0,
- name="George Bluth",
- float_field=None,
- date_field=datetime.date(2017, 1, 1),
- accnt_purge=False,
- ),
- Row(
- account_identifier=10000001237,
- dollar_amount=123456.0,
- name="Bob Loblaw",
- float_field=345.12,
- date_field=datetime.date(2017, 1, 1),
- accnt_purge=False,
- ),
- Row(
- account_identifier=10000001239,
- dollar_amount=1.05,
- name="Lucille Bluth",
- float_field=None,
- date_field=datetime.date(2017, 1, 1),
- accnt_purge=True,
- ),
- ]
- return spark.createDataFrame(mock_data2)
-
-@pytest.fixture(scope="module", name="base_tol")
-def base_tol_fixture(spark):
- tol_data1 = [
- Row(
- account_identifier=10000001234,
- dollar_amount=123.4,
- name="Franklin Delano Bluth",
- float_field=14530.155,
- date_field=datetime.date(2017, 1, 1),
- accnt_purge=False,
- ),
- Row(
- account_identifier=10000001235,
- dollar_amount=500.0,
- name="Surely Funke",
- float_field=None,
- date_field=datetime.date(2017, 1, 1),
- accnt_purge=True,
- ),
- Row(
- account_identifier=10000001236,
- dollar_amount=-1100.0,
- name="Nichael Bluth",
- float_field=None,
- date_field=datetime.date(2017, 1, 1),
- accnt_purge=True,
- ),
- Row(
- account_identifier=10000001237,
- dollar_amount=0.45,
- name="Mr. F",
- float_field=1.0,
- date_field=datetime.date(2017, 1, 1),
- accnt_purge=False,
- ),
- Row(
- account_identifier=10000001238,
- dollar_amount=1345.0,
- name="Steve Holt!",
- float_field=None,
- date_field=datetime.date(2017, 1, 1),
- accnt_purge=False,
- ),
- Row(
- account_identifier=10000001239,
- dollar_amount=123456.0,
- name="Blue Man Group",
- float_field=345.12,
- date_field=datetime.date(2017, 1, 1),
- accnt_purge=False,
- ),
- Row(
- account_identifier=10000001240,
- dollar_amount=1.1,
- name="Her?",
- float_field=None,
- date_field=datetime.date(2017, 1, 1),
- accnt_purge=True,
- ),
- Row(
- account_identifier=10000001241,
- dollar_amount=0.0,
- name="Mrs. Featherbottom",
- float_field=None,
- date_field=datetime.date(2017, 1, 1),
- accnt_purge=True,
- ),
- Row(
- account_identifier=10000001242,
- dollar_amount=0.0,
- name="Ice",
- float_field=345.12,
- date_field=datetime.date(2017, 1, 1),
- accnt_purge=False,
- ),
- Row(
- account_identifier=10000001243,
- dollar_amount=-10.0,
- name="Frank Wrench",
- float_field=None,
- date_field=datetime.date(2017, 1, 1),
- accnt_purge=True,
- ),
- Row(
- account_identifier=10000001244,
- dollar_amount=None,
- name="Lucille 2",
- float_field=None,
- date_field=datetime.date(2017, 1, 1),
- accnt_purge=True,
- ),
- Row(
- account_identifier=10000001245,
- dollar_amount=0.009999,
- name="Gene Parmesan",
- float_field=None,
- date_field=datetime.date(2017, 1, 1),
- accnt_purge=True,
- ),
- Row(
- account_identifier=10000001246,
- dollar_amount=None,
- name="Motherboy",
- float_field=None,
- date_field=datetime.date(2017, 1, 1),
- accnt_purge=True,
- ),
- ]
-
- return spark.createDataFrame(tol_data1)
-
-
-@pytest.fixture(scope="module", name="compare_abs_tol")
-def compare_tol2_fixture(spark):
- tol_data2 = [
- Row(
- account_identifier=10000001234,
- dollar_amount=123.4,
- name="Franklin Delano Bluth",
- float_field=14530.155,
- date_field=datetime.date(2017, 1, 1),
- accnt_purge=False,
- ), # full match
- Row(
- account_identifier=10000001235,
- dollar_amount=500.01,
- name="Surely Funke",
- float_field=None,
- date_field=datetime.date(2017, 1, 1),
- accnt_purge=True,
- ), # off by 0.01
- Row(
- account_identifier=10000001236,
- dollar_amount=-1100.01,
- name="Nichael Bluth",
- float_field=None,
- date_field=datetime.date(2017, 1, 1),
- accnt_purge=True,
- ), # off by -0.01
- Row(
- account_identifier=10000001237,
- dollar_amount=0.46000000001,
- name="Mr. F",
- float_field=1.0,
- date_field=datetime.date(2017, 1, 1),
- accnt_purge=False,
- ), # off by 0.01000000001
- Row(
- account_identifier=10000001238,
- dollar_amount=1344.8999999999,
- name="Steve Holt!",
- float_field=None,
- date_field=datetime.date(2017, 1, 1),
- accnt_purge=False,
- ), # off by -0.01000000001
- Row(
- account_identifier=10000001239,
- dollar_amount=123456.0099999999,
- name="Blue Man Group",
- float_field=345.12,
- date_field=datetime.date(2017, 1, 1),
- accnt_purge=False,
- ), # off by 0.00999999999
- Row(
- account_identifier=10000001240,
- dollar_amount=1.090000001,
- name="Her?",
- float_field=None,
- date_field=datetime.date(2017, 1, 1),
- accnt_purge=True,
- ), # off by -0.00999999999
- Row(
- account_identifier=10000001241,
- dollar_amount=0.0,
- name="Mrs. Featherbottom",
- float_field=None,
- date_field=datetime.date(2017, 1, 1),
- accnt_purge=True,
- ), # both zero
- Row(
- account_identifier=10000001242,
- dollar_amount=1.0,
- name="Ice",
- float_field=345.12,
- date_field=datetime.date(2017, 1, 1),
- accnt_purge=False,
- ), # base 0, compare 1
- Row(
- account_identifier=10000001243,
- dollar_amount=0.0,
- name="Frank Wrench",
- float_field=None,
- date_field=datetime.date(2017, 1, 1),
- accnt_purge=True,
- ), # base -10, compare 0
- Row(
- account_identifier=10000001244,
- dollar_amount=-1.0,
- name="Lucille 2",
- float_field=None,
- date_field=datetime.date(2017, 1, 1),
- accnt_purge=True,
- ), # base NULL, compare -1
- Row(
- account_identifier=10000001245,
- dollar_amount=None,
- name="Gene Parmesan",
- float_field=None,
- date_field=datetime.date(2017, 1, 1),
- accnt_purge=True,
- ), # base 0.009999, compare NULL
- Row(
- account_identifier=10000001246,
- dollar_amount=None,
- name="Motherboy",
- float_field=None,
- date_field=datetime.date(2017, 1, 1),
- accnt_purge=True,
- ), # both NULL
- ]
-
- return spark.createDataFrame(tol_data2)
-
-
-@pytest.fixture(scope="module", name="compare_rel_tol")
-def compare_tol3_fixture(spark):
- tol_data3 = [
- Row(
- account_identifier=10000001234,
- dollar_amount=123.4,
- name="Franklin Delano Bluth",
- float_field=14530.155,
- date_field=datetime.date(2017, 1, 1),
- accnt_purge=False,
- ), # full match #MATCH
- Row(
- account_identifier=10000001235,
- dollar_amount=550.0,
- name="Surely Funke",
- float_field=None,
- date_field=datetime.date(2017, 1, 1),
- accnt_purge=True,
- ), # off by 10% #MATCH
- Row(
- account_identifier=10000001236,
- dollar_amount=-1000.0,
- name="Nichael Bluth",
- float_field=None,
- date_field=datetime.date(2017, 1, 1),
- accnt_purge=True,
- ), # off by -10% #MATCH
- Row(
- account_identifier=10000001237,
- dollar_amount=0.49501,
- name="Mr. F",
- float_field=1.0,
- date_field=datetime.date(2017, 1, 1),
- accnt_purge=False,
- ), # off by greater than 10%
- Row(
- account_identifier=10000001238,
- dollar_amount=1210.001,
- name="Steve Holt!",
- float_field=None,
- date_field=datetime.date(2017, 1, 1),
- accnt_purge=False,
- ), # off by greater than -10%
- Row(
- account_identifier=10000001239,
- dollar_amount=135801.59999,
- name="Blue Man Group",
- float_field=345.12,
- date_field=datetime.date(2017, 1, 1),
- accnt_purge=False,
- ), # off by just under 10% #MATCH
- Row(
- account_identifier=10000001240,
- dollar_amount=1.000001,
- name="Her?",
- float_field=None,
- date_field=datetime.date(2017, 1, 1),
- accnt_purge=True,
- ), # off by just under -10% #MATCH
- Row(
- account_identifier=10000001241,
- dollar_amount=0.0,
- name="Mrs. Featherbottom",
- float_field=None,
- date_field=datetime.date(2017, 1, 1),
- accnt_purge=True,
- ), # both zero #MATCH
- Row(
- account_identifier=10000001242,
- dollar_amount=1.0,
- name="Ice",
- float_field=345.12,
- date_field=datetime.date(2017, 1, 1),
- accnt_purge=False,
- ), # base 0, compare 1
- Row(
- account_identifier=10000001243,
- dollar_amount=0.0,
- name="Frank Wrench",
- float_field=None,
- date_field=datetime.date(2017, 1, 1),
- accnt_purge=True,
- ), # base -10, compare 0
- Row(
- account_identifier=10000001244,
- dollar_amount=-1.0,
- name="Lucille 2",
- float_field=None,
- date_field=datetime.date(2017, 1, 1),
- accnt_purge=True,
- ), # base NULL, compare -1
- Row(
- account_identifier=10000001245,
- dollar_amount=None,
- name="Gene Parmesan",
- float_field=None,
- date_field=datetime.date(2017, 1, 1),
- accnt_purge=True,
- ), # base 0.009999, compare NULL
- Row(
- account_identifier=10000001246,
- dollar_amount=None,
- name="Motherboy",
- float_field=None,
- date_field=datetime.date(2017, 1, 1),
- accnt_purge=True,
- ), # both NULL #MATCH
- ]
-
- return spark.createDataFrame(tol_data3)
-
-
-@pytest.fixture(scope="module", name="compare_both_tol")
-def compare_tol4_fixture(spark):
- tol_data4 = [
- Row(
- account_identifier=10000001234,
- dollar_amount=123.4,
- name="Franklin Delano Bluth",
- float_field=14530.155,
- date_field=datetime.date(2017, 1, 1),
- accnt_purge=False,
- ), # full match
- Row(
- account_identifier=10000001235,
- dollar_amount=550.01,
- name="Surely Funke",
- float_field=None,
- date_field=datetime.date(2017, 1, 1),
- accnt_purge=True,
- ), # off by 10% and +0.01
- Row(
- account_identifier=10000001236,
- dollar_amount=-1000.01,
- name="Nichael Bluth",
- float_field=None,
- date_field=datetime.date(2017, 1, 1),
- accnt_purge=True,
- ), # off by -10% and -0.01
- Row(
- account_identifier=10000001237,
- dollar_amount=0.505000000001,
- name="Mr. F",
- float_field=1.0,
- date_field=datetime.date(2017, 1, 1),
- accnt_purge=False,
- ), # off by greater than 10% and +0.01
- Row(
- account_identifier=10000001238,
- dollar_amount=1209.98999,
- name="Steve Holt!",
- float_field=None,
- date_field=datetime.date(2017, 1, 1),
- accnt_purge=False,
- ), # off by greater than -10% and -0.01
- Row(
- account_identifier=10000001239,
- dollar_amount=135801.609999,
- name="Blue Man Group",
- float_field=345.12,
- date_field=datetime.date(2017, 1, 1),
- accnt_purge=False,
- ), # off by just under 10% and just under +0.01
- Row(
- account_identifier=10000001240,
- dollar_amount=0.99000001,
- name="Her?",
- float_field=None,
- date_field=datetime.date(2017, 1, 1),
- accnt_purge=True,
- ), # off by just under -10% and just under -0.01
- Row(
- account_identifier=10000001241,
- dollar_amount=0.0,
- name="Mrs. Featherbottom",
- float_field=None,
- date_field=datetime.date(2017, 1, 1),
- accnt_purge=True,
- ), # both zero
- Row(
- account_identifier=10000001242,
- dollar_amount=1.0,
- name="Ice",
- float_field=345.12,
- date_field=datetime.date(2017, 1, 1),
- accnt_purge=False,
- ), # base 0, compare 1
- Row(
- account_identifier=10000001243,
- dollar_amount=0.0,
- name="Frank Wrench",
- float_field=None,
- date_field=datetime.date(2017, 1, 1),
- accnt_purge=True,
- ), # base -10, compare 0
- Row(
- account_identifier=10000001244,
- dollar_amount=-1.0,
- name="Lucille 2",
- float_field=None,
- date_field=datetime.date(2017, 1, 1),
- accnt_purge=True,
- ), # base NULL, compare -1
- Row(
- account_identifier=10000001245,
- dollar_amount=None,
- name="Gene Parmesan",
- float_field=None,
- date_field=datetime.date(2017, 1, 1),
- accnt_purge=True,
- ), # base 0.009999, compare NULL
- Row(
- account_identifier=10000001246,
- dollar_amount=None,
- name="Motherboy",
- float_field=None,
- date_field=datetime.date(2017, 1, 1),
- accnt_purge=True,
- ), # both NULL
- ]
-
- return spark.createDataFrame(tol_data4)
-
-
-@pytest.fixture(scope="module", name="base_td")
-def base_td_fixture(spark):
- mock_data = [
- Row(
- acct=10000001234,
- acct_seq=0,
- stat_cd="*2",
- open_dt=datetime.date(2017, 5, 1),
- cd="0001",
- ),
- Row(
- acct=10000001235,
- acct_seq=0,
- stat_cd="V1",
- open_dt=datetime.date(2017, 5, 2),
- cd="0002",
- ),
- Row(
- acct=10000001236,
- acct_seq=0,
- stat_cd="V2",
- open_dt=datetime.date(2017, 5, 3),
- cd="0003",
- ),
- Row(
- acct=10000001237,
- acct_seq=0,
- stat_cd="*2",
- open_dt=datetime.date(2017, 5, 4),
- cd="0004",
- ),
- Row(
- acct=10000001238,
- acct_seq=0,
- stat_cd="*2",
- open_dt=datetime.date(2017, 5, 5),
- cd="0005",
- ),
- ]
-
- return spark.createDataFrame(mock_data)
-
-
-@pytest.fixture(scope="module", name="compare_source")
-def compare_source_fixture(spark):
- mock_data = [
- Row(
- ACCOUNT_IDENTIFIER=10000001234,
- SEQ_NUMBER=0,
- STATC=None,
- ACCOUNT_OPEN=2017121,
- CODE=1.0,
- ),
- Row(
- ACCOUNT_IDENTIFIER=10000001235,
- SEQ_NUMBER=0,
- STATC="V1",
- ACCOUNT_OPEN=2017122,
- CODE=2.0,
- ),
- Row(
- ACCOUNT_IDENTIFIER=10000001236,
- SEQ_NUMBER=0,
- STATC="V2",
- ACCOUNT_OPEN=2017123,
- CODE=3.0,
- ),
- Row(
- ACCOUNT_IDENTIFIER=10000001237,
- SEQ_NUMBER=0,
- STATC="V3",
- ACCOUNT_OPEN=2017124,
- CODE=4.0,
- ),
- Row(
- ACCOUNT_IDENTIFIER=10000001238,
- SEQ_NUMBER=0,
- STATC=None,
- ACCOUNT_OPEN=2017125,
- CODE=5.0,
- ),
- ]
-
- return spark.createDataFrame(mock_data)
-
-
-@pytest.fixture(scope="module", name="base_decimal")
-def base_decimal_fixture(spark):
- mock_data = [
- Row(acct=10000001234, dollar_amt=Decimal(123.4)),
- Row(acct=10000001235, dollar_amt=Decimal(0.45)),
- ]
-
- return spark.createDataFrame(
- mock_data,
- schema=StructType(
- [
- StructField("acct", LongType(), True),
- StructField("dollar_amt", DecimalType(8, 2), True),
- ]
- ),
+@pandas_version
+def test_numeric_columns_equal_rel():
+ data = """a|b|expected
+1|1|True
+2|2.1|True
+3|4|False
+4|NULL|False
+NULL|4|False
+NULL|NULL|True"""
+ df = ps.from_pandas(pd.read_csv(StringIO(data), sep="|"))
+ actual_out = columns_equal(df.a, df.b, rel_tol=0.2)
+ expect_out = df["expected"]
+ assert_series_equal(
+ expect_out.to_pandas(), actual_out.to_pandas(), check_names=False
)
-@pytest.fixture(scope="module", name="compare_decimal")
-def compare_decimal_fixture(spark):
- mock_data = [
- Row(acct=10000001234, dollar_amt=123.4),
- Row(acct=10000001235, dollar_amt=0.456),
- ]
-
- return spark.createDataFrame(mock_data)
-
-
-@pytest.fixture(scope="module", name="comparison_abs_tol")
-def comparison_abs_tol_fixture(base_tol, compare_abs_tol, spark):
- return SparkCompare(
- spark,
- base_tol,
- compare_abs_tol,
- join_columns=["account_identifier"],
- abs_tol=0.01,
+@pandas_version
+def test_string_columns_equal():
+ data = """a|b|expected
+Hi|Hi|True
+Yo|Yo|True
+Hey|Hey |False
+résumé|resume|False
+résumé|résumé|True
+💩|💩|True
+💩|🤔|False
+ | |True
+ | |False
+datacompy|DataComPy|False
+something||False
+|something|False
+||True"""
+ df = ps.from_pandas(pd.read_csv(StringIO(data), sep="|"))
+ actual_out = columns_equal(df.a, df.b, rel_tol=0.2)
+ expect_out = df["expected"]
+ assert_series_equal(
+ expect_out.to_pandas(), actual_out.to_pandas(), check_names=False
)
-@pytest.fixture(scope="module", name="comparison_rel_tol")
-def comparison_rel_tol_fixture(base_tol, compare_rel_tol, spark):
- return SparkCompare(
- spark,
- base_tol,
- compare_rel_tol,
- join_columns=["account_identifier"],
- rel_tol=0.1,
+@pandas_version
+def test_string_columns_equal_with_ignore_spaces():
+ data = """a|b|expected
+Hi|Hi|True
+Yo|Yo|True
+Hey|Hey |True
+résumé|resume|False
+résumé|résumé|True
+💩|💩|True
+💩|🤔|False
+ | |True
+ | |True
+datacompy|DataComPy|False
+something||False
+|something|False
+||True"""
+ df = ps.from_pandas(pd.read_csv(StringIO(data), sep="|"))
+ actual_out = columns_equal(df.a, df.b, rel_tol=0.2, ignore_spaces=True)
+ expect_out = df["expected"]
+ assert_series_equal(
+ expect_out.to_pandas(), actual_out.to_pandas(), check_names=False
)
-@pytest.fixture(scope="module", name="comparison_both_tol")
-def comparison_both_tol_fixture(base_tol, compare_both_tol, spark):
- return SparkCompare(
- spark,
- base_tol,
- compare_both_tol,
- join_columns=["account_identifier"],
- rel_tol=0.1,
- abs_tol=0.01,
+@pandas_version
+def test_string_columns_equal_with_ignore_spaces_and_case():
+ data = """a|b|expected
+Hi|Hi|True
+Yo|Yo|True
+Hey|Hey |True
+résumé|resume|False
+résumé|résumé|True
+💩|💩|True
+💩|🤔|False
+ | |True
+ | |True
+datacompy|DataComPy|True
+something||False
+|something|False
+||True"""
+ df = ps.from_pandas(pd.read_csv(StringIO(data), sep="|"))
+ actual_out = columns_equal(
+ df.a, df.b, rel_tol=0.2, ignore_spaces=True, ignore_case=True
)
-
-
-@pytest.fixture(scope="module", name="comparison_neg_tol")
-def comparison_neg_tol_fixture(base_tol, compare_both_tol, spark):
- return SparkCompare(
- spark,
- base_tol,
- compare_both_tol,
- join_columns=["account_identifier"],
- rel_tol=-0.2,
- abs_tol=0.01,
+ expect_out = df["expected"]
+ assert_series_equal(
+ expect_out.to_pandas(), actual_out.to_pandas(), check_names=False
)
-@pytest.fixture(scope="module", name="show_all_columns_and_match_rate")
-def show_all_columns_and_match_rate_fixture(base_tol, compare_both_tol, spark):
- return SparkCompare(
- spark,
- base_tol,
- compare_both_tol,
- join_columns=["account_identifier"],
- show_all_columns=True,
- match_rates=True,
+@pandas_version
+def test_date_columns_equal(tmp_path):
+ data = """a|b|expected
+2017-01-01|2017-01-01|True
+2017-01-02|2017-01-02|True
+2017-10-01|2017-10-10|False
+2017-01-01||False
+|2017-01-01|False
+||True"""
+ df = ps.from_pandas(pd.read_csv(StringIO(data), sep="|"))
+ # First compare just the strings
+ actual_out = columns_equal(df.a, df.b, rel_tol=0.2)
+ expect_out = df["expected"]
+ assert_series_equal(
+ expect_out.to_pandas(), actual_out.to_pandas(), check_names=False
)
-
-@pytest.fixture(scope="module", name="comparison_kd1")
-def comparison_known_diffs1(base_td, compare_source, spark):
- return SparkCompare(
- spark,
- base_td,
- compare_source,
- join_columns=[("acct", "ACCOUNT_IDENTIFIER"), ("acct_seq", "SEQ_NUMBER")],
- column_mapping=[
- ("stat_cd", "STATC"),
- ("open_dt", "ACCOUNT_OPEN"),
- ("cd", "CODE"),
- ],
- known_differences=[
- {
- "name": "Left-padded, four-digit numeric code",
- "types": datacompy.NUMERIC_SPARK_TYPES,
- "transformation": "lpad(cast({input} AS bigint), 4, '0')",
- },
- {
- "name": "Null to *2",
- "types": ["string"],
- "transformation": "case when {input} is null then '*2' else {input} end",
- },
- {
- "name": "Julian date -> date",
- "types": ["bigint"],
- "transformation": "to_date(cast(unix_timestamp(cast({input} AS string), 'yyyyDDD') AS timestamp))",
- },
- ],
+ # Then compare converted to datetime objects
+ df["a"] = ps.to_datetime(df["a"])
+ df["b"] = ps.to_datetime(df["b"])
+ actual_out = columns_equal(df.a, df.b, rel_tol=0.2)
+ expect_out = df["expected"]
+ assert_series_equal(
+ expect_out.to_pandas(), actual_out.to_pandas(), check_names=False
)
-
-
-@pytest.fixture(scope="module", name="comparison_kd2")
-def comparison_known_diffs2(base_td, compare_source, spark):
- return SparkCompare(
- spark,
- base_td,
- compare_source,
- join_columns=[("acct", "ACCOUNT_IDENTIFIER"), ("acct_seq", "SEQ_NUMBER")],
- column_mapping=[
- ("stat_cd", "STATC"),
- ("open_dt", "ACCOUNT_OPEN"),
- ("cd", "CODE"),
- ],
- known_differences=[
- {
- "name": "Left-padded, four-digit numeric code",
- "types": datacompy.NUMERIC_SPARK_TYPES,
- "transformation": "lpad(cast({input} AS bigint), 4, '0')",
- },
- {
- "name": "Null to *2",
- "types": ["string"],
- "transformation": "case when {input} is null then '*2' else {input} end",
- },
- ],
+ # and reverse
+ actual_out_rev = columns_equal(df.b, df.a, rel_tol=0.2)
+ assert_series_equal(
+ expect_out.to_pandas(), actual_out_rev.to_pandas(), check_names=False
)
-@pytest.fixture(scope="module", name="comparison1")
-def comparison1_fixture(base_df1, compare_df1, spark):
- return SparkCompare(
- spark,
- base_df1,
- compare_df1,
- join_columns=["acct"],
- cache_intermediates=CACHE_INTERMEDIATES,
+@pandas_version
+def test_date_columns_equal_with_ignore_spaces(tmp_path):
+ data = """a|b|expected
+2017-01-01|2017-01-01 |True
+2017-01-02 |2017-01-02|True
+2017-10-01 |2017-10-10 |False
+2017-01-01||False
+|2017-01-01|False
+||True"""
+ df = ps.from_pandas(pd.read_csv(StringIO(data), sep="|"))
+ # First compare just the strings
+ actual_out = columns_equal(df.a, df.b, rel_tol=0.2, ignore_spaces=True)
+ expect_out = df["expected"]
+ assert_series_equal(
+ expect_out.to_pandas(), actual_out.to_pandas(), check_names=False
)
-
-@pytest.fixture(scope="module", name="comparison2")
-def comparison2_fixture(base_df1, compare_df2, spark):
- return SparkCompare(spark, base_df1, compare_df2, join_columns=["acct"])
-
-
-@pytest.fixture(scope="module", name="comparison3")
-def comparison3_fixture(base_df1, compare_df3, spark):
- return SparkCompare(
- spark,
- base_df1,
- compare_df3,
- join_columns=[("acct", "account_identifier")],
- column_mapping=[
- ("dollar_amt", "dollar_amount"),
- ("float_fld", "float_field"),
- ("date_fld", "date_field"),
- ],
- cache_intermediates=CACHE_INTERMEDIATES,
+ # Then compare converted to datetime objects
+ df["a"] = ps.to_datetime(df["a"], errors="coerce")
+ df["b"] = ps.to_datetime(df["b"], errors="coerce")
+ actual_out = columns_equal(df.a, df.b, rel_tol=0.2, ignore_spaces=True)
+ expect_out = df["expected"]
+ assert_series_equal(
+ expect_out.to_pandas(), actual_out.to_pandas(), check_names=False
)
-
-
-@pytest.fixture(scope="module", name="comparison4")
-def comparison4_fixture(base_df2, compare_df1, spark):
- return SparkCompare(
- spark,
- base_df2,
- compare_df1,
- join_columns=["acct"],
- column_mapping=[("super_duper_big_long_name", "name")],
+ # and reverse
+ actual_out_rev = columns_equal(df.b, df.a, rel_tol=0.2, ignore_spaces=True)
+ assert_series_equal(
+ expect_out.to_pandas(), actual_out_rev.to_pandas(), check_names=False
)
-@pytest.fixture(scope="module", name="comparison_decimal")
-def comparison_decimal_fixture(base_decimal, compare_decimal, spark):
- return SparkCompare(spark, base_decimal, compare_decimal, join_columns=["acct"])
-
-
-def test_absolute_tolerances(comparison_abs_tol):
- stdout = io.StringIO()
-
- comparison_abs_tol.report(file=stdout)
- stdout.seek(0)
- assert "****** Row Comparison ******" in stdout.getvalue()
- assert "Number of rows with some columns unequal: 6" in stdout.getvalue()
- assert "Number of rows with all columns equal: 7" in stdout.getvalue()
- assert "Number of columns compared with some values unequal: 1" in stdout.getvalue()
- assert "Number of columns compared with all values equal: 4" in stdout.getvalue()
-
-
-def test_relative_tolerances(comparison_rel_tol):
- stdout = io.StringIO()
-
- comparison_rel_tol.report(file=stdout)
- stdout.seek(0)
- assert "****** Row Comparison ******" in stdout.getvalue()
- assert "Number of rows with some columns unequal: 6" in stdout.getvalue()
- assert "Number of rows with all columns equal: 7" in stdout.getvalue()
- assert "Number of columns compared with some values unequal: 1" in stdout.getvalue()
- assert "Number of columns compared with all values equal: 4" in stdout.getvalue()
-
-
-def test_both_tolerances(comparison_both_tol):
- stdout = io.StringIO()
-
- comparison_both_tol.report(file=stdout)
- stdout.seek(0)
- assert "****** Row Comparison ******" in stdout.getvalue()
- assert "Number of rows with some columns unequal: 6" in stdout.getvalue()
- assert "Number of rows with all columns equal: 7" in stdout.getvalue()
- assert "Number of columns compared with some values unequal: 1" in stdout.getvalue()
- assert "Number of columns compared with all values equal: 4" in stdout.getvalue()
-
-
-def test_negative_tolerances(spark, base_tol, compare_both_tol):
- with pytest.raises(ValueError, match="Please enter positive valued tolerances"):
- comp = SparkCompare(
- spark,
- base_tol,
- compare_both_tol,
- join_columns=["account_identifier"],
- rel_tol=-0.2,
- abs_tol=0.01,
- )
- comp.report()
- pass
-
-
-def test_show_all_columns_and_match_rate(show_all_columns_and_match_rate):
- stdout = io.StringIO()
-
- show_all_columns_and_match_rate.report(file=stdout)
-
- assert "****** Columns with Equal/Unequal Values ******" in stdout.getvalue()
- assert (
- "accnt_purge accnt_purge boolean boolean 13 0 100.00000"
- in stdout.getvalue()
+@pandas_version
+def test_date_columns_equal_with_ignore_spaces_and_case(tmp_path):
+ data = """a|b|expected
+2017-01-01|2017-01-01 |True
+2017-01-02 |2017-01-02|True
+2017-10-01 |2017-10-10 |False
+2017-01-01||False
+|2017-01-01|False
+||True"""
+ df = ps.from_pandas(pd.read_csv(StringIO(data), sep="|"))
+ # First compare just the strings
+ actual_out = columns_equal(
+ df.a, df.b, rel_tol=0.2, ignore_spaces=True, ignore_case=True
)
- assert (
- "date_field date_field date date 13 0 100.00000"
- in stdout.getvalue()
- )
- assert (
- "dollar_amount dollar_amount double double 3 10 23.07692"
- in stdout.getvalue()
+ expect_out = df["expected"]
+ assert_series_equal(
+ expect_out.to_pandas(), actual_out.to_pandas(), check_names=False
)
- assert (
- "float_field float_field double double 13 0 100.00000"
- in stdout.getvalue()
+
+ # Then compare converted to datetime objects
+ df["a"] = ps.to_datetime(df["a"], errors="coerce")
+ df["b"] = ps.to_datetime(df["b"], errors="coerce")
+ actual_out = columns_equal(df.a, df.b, rel_tol=0.2, ignore_spaces=True)
+ expect_out = df["expected"]
+ assert_series_equal(
+ expect_out.to_pandas(), actual_out.to_pandas(), check_names=False
)
- assert (
- "name name string string 13 0 100.00000"
- in stdout.getvalue()
+ # and reverse
+ actual_out_rev = columns_equal(df.b, df.a, rel_tol=0.2, ignore_spaces=True)
+ assert_series_equal(
+ expect_out.to_pandas(), actual_out_rev.to_pandas(), check_names=False
)
-def test_decimal_comparisons():
- true_decimals = ["decimal", "decimal()", "decimal(20, 10)"]
- assert all(v in datacompy.NUMERIC_SPARK_TYPES for v in true_decimals)
-
-
-def test_decimal_comparator_acts_like_string():
- acc = False
- for t in datacompy.NUMERIC_SPARK_TYPES:
- acc = acc or (len(t) > 2 and t[0:3] == "dec")
- assert acc
-
-
-def test_decimals_and_doubles_are_comparable():
- assert _is_comparable("double", "decimal(10, 2)")
-
-
-def test_report_outputs_the_column_summary(comparison1):
- stdout = io.StringIO()
-
- comparison1.report(file=stdout)
-
- assert "****** Column Summary ******" in stdout.getvalue()
- assert "Number of columns in common with matching schemas: 3" in stdout.getvalue()
- assert "Number of columns in common with schema differences: 1" in stdout.getvalue()
- assert "Number of columns in base but not compare: 1" in stdout.getvalue()
- assert "Number of columns in compare but not base: 1" in stdout.getvalue()
-
-
-def test_report_outputs_the_column_summary_for_identical_schemas(comparison2):
- stdout = io.StringIO()
-
- comparison2.report(file=stdout)
-
- assert "****** Column Summary ******" in stdout.getvalue()
- assert "Number of columns in common with matching schemas: 5" in stdout.getvalue()
- assert "Number of columns in common with schema differences: 0" in stdout.getvalue()
- assert "Number of columns in base but not compare: 0" in stdout.getvalue()
- assert "Number of columns in compare but not base: 0" in stdout.getvalue()
-
-
-def test_report_outputs_the_column_summary_for_differently_named_columns(comparison3):
- stdout = io.StringIO()
-
- comparison3.report(file=stdout)
-
- assert "****** Column Summary ******" in stdout.getvalue()
- assert "Number of columns in common with matching schemas: 4" in stdout.getvalue()
- assert "Number of columns in common with schema differences: 1" in stdout.getvalue()
- assert "Number of columns in base but not compare: 0" in stdout.getvalue()
- assert "Number of columns in compare but not base: 1" in stdout.getvalue()
-
-
-def test_report_outputs_the_row_summary(comparison1):
- stdout = io.StringIO()
-
- comparison1.report(file=stdout)
-
- assert "****** Row Summary ******" in stdout.getvalue()
- assert "Number of rows in common: 4" in stdout.getvalue()
- assert "Number of rows in base but not compare: 1" in stdout.getvalue()
- assert "Number of rows in compare but not base: 1" in stdout.getvalue()
- assert "Number of duplicate rows found in base: 0" in stdout.getvalue()
- assert "Number of duplicate rows found in compare: 1" in stdout.getvalue()
-
-
-def test_report_outputs_the_row_equality_comparison(comparison1):
- stdout = io.StringIO()
-
- comparison1.report(file=stdout)
-
- assert "****** Row Comparison ******" in stdout.getvalue()
- assert "Number of rows with some columns unequal: 3" in stdout.getvalue()
- assert "Number of rows with all columns equal: 1" in stdout.getvalue()
-
-
-def test_report_outputs_the_row_summary_for_differently_named_columns(comparison3):
- stdout = io.StringIO()
-
- comparison3.report(file=stdout)
-
- assert "****** Row Summary ******" in stdout.getvalue()
- assert "Number of rows in common: 5" in stdout.getvalue()
- assert "Number of rows in base but not compare: 0" in stdout.getvalue()
- assert "Number of rows in compare but not base: 0" in stdout.getvalue()
- assert "Number of duplicate rows found in base: 0" in stdout.getvalue()
- assert "Number of duplicate rows found in compare: 0" in stdout.getvalue()
-
-
-def test_report_outputs_the_row_equality_comparison_for_differently_named_columns(
- comparison3,
-):
- stdout = io.StringIO()
-
- comparison3.report(file=stdout)
-
- assert "****** Row Comparison ******" in stdout.getvalue()
- assert "Number of rows with some columns unequal: 3" in stdout.getvalue()
- assert "Number of rows with all columns equal: 2" in stdout.getvalue()
-
-
-def test_report_outputs_column_detail_for_columns_in_only_one_dataframe(comparison1):
- stdout = io.StringIO()
-
- comparison1.report(file=stdout)
- comparison1.report()
- assert "****** Columns In Base Only ******" in stdout.getvalue()
- r2 = r"""Column\s*Name \s* Dtype \n -+ \s+ -+ \ndate_fld \s+ date"""
- assert re.search(r2, str(stdout.getvalue()), re.X) is not None
-
-
-def test_report_outputs_column_detail_for_columns_in_only_compare_dataframe(
- comparison1,
-):
- stdout = io.StringIO()
-
- comparison1.report(file=stdout)
- comparison1.report()
- assert "****** Columns In Compare Only ******" in stdout.getvalue()
- r2 = r"""Column\s*Name \s* Dtype \n -+ \s+ -+ \n accnt_purge \s+ boolean"""
- assert re.search(r2, str(stdout.getvalue()), re.X) is not None
-
-
-def test_report_outputs_schema_difference_details(comparison1):
- stdout = io.StringIO()
-
- comparison1.report(file=stdout)
-
- assert "****** Schema Differences ******" in stdout.getvalue()
- assert re.search(
- r"""Base\sColumn\sName \s+ Compare\sColumn\sName \s+ Base\sDtype \s+ Compare\sDtype \n
- -+ \s+ -+ \s+ -+ \s+ -+ \n
- dollar_amt \s+ dollar_amt \s+ bigint \s+ double""",
- stdout.getvalue(),
- re.X,
+@pandas_version
+def test_date_columns_unequal():
+ """I want datetime fields to match with dates stored as strings"""
+ df = ps.DataFrame([{"a": "2017-01-01", "b": "2017-01-02"}, {"a": "2017-01-01"}])
+ df["a_dt"] = ps.to_datetime(df["a"])
+ df["b_dt"] = ps.to_datetime(df["b"])
+ assert columns_equal(df.a, df.a_dt).all()
+ assert columns_equal(df.b, df.b_dt).all()
+ assert columns_equal(df.a_dt, df.a).all()
+ assert columns_equal(df.b_dt, df.b).all()
+ assert not columns_equal(df.b_dt, df.a).any()
+ assert not columns_equal(df.a_dt, df.b).any()
+ assert not columns_equal(df.a, df.b_dt).any()
+ assert not columns_equal(df.b, df.a_dt).any()
+
+
+@pandas_version
+def test_bad_date_columns():
+ """If strings can't be coerced into dates then it should be false for the
+ whole column.
+ """
+ df = ps.DataFrame(
+ [{"a": "2017-01-01", "b": "2017-01-01"}, {"a": "2017-01-01", "b": "217-01-01"}]
)
+ df["a_dt"] = ps.to_datetime(df["a"])
+ assert not columns_equal(df.a_dt, df.b).any()
-def test_report_outputs_schema_difference_details_for_differently_named_columns(
- comparison3,
-):
- stdout = io.StringIO()
-
- comparison3.report(file=stdout)
-
- assert "****** Schema Differences ******" in stdout.getvalue()
- assert re.search(
- r"""Base\sColumn\sName \s+ Compare\sColumn\sName \s+ Base\sDtype \s+ Compare\sDtype \n
- -+ \s+ -+ \s+ -+ \s+ -+ \n
- dollar_amt \s+ dollar_amount \s+ bigint \s+ double""",
- stdout.getvalue(),
- re.X,
+@pandas_version
+def test_rounded_date_columns():
+ """If strings can't be coerced into dates then it should be false for the
+ whole column.
+ """
+ df = ps.DataFrame(
+ [
+ {"a": "2017-01-01", "b": "2017-01-01 00:00:00.000000", "exp": True},
+ {"a": "2017-01-01", "b": "2017-01-01 00:00:00.123456", "exp": False},
+ {"a": "2017-01-01", "b": "2017-01-01 00:00:01.000000", "exp": False},
+ {"a": "2017-01-01", "b": "2017-01-01 00:00:00", "exp": True},
+ ]
)
+ df["a_dt"] = ps.to_datetime(df["a"])
+ actual = columns_equal(df.a_dt, df.b)
+ expected = df["exp"]
+ assert_series_equal(actual.to_pandas(), expected.to_pandas(), check_names=False)
-def test_column_comparison_outputs_number_of_columns_with_differences(comparison1):
- stdout = io.StringIO()
-
- comparison1.report(file=stdout)
-
- assert "****** Column Comparison ******" in stdout.getvalue()
- assert "Number of columns compared with some values unequal: 3" in stdout.getvalue()
- assert "Number of columns compared with all values equal: 0" in stdout.getvalue()
-
-
-def test_column_comparison_outputs_all_columns_equal_for_identical_dataframes(
- comparison2,
-):
- stdout = io.StringIO()
-
- comparison2.report(file=stdout)
-
- assert "****** Column Comparison ******" in stdout.getvalue()
- assert "Number of columns compared with some values unequal: 0" in stdout.getvalue()
- assert "Number of columns compared with all values equal: 4" in stdout.getvalue()
-
-
-def test_column_comparison_outputs_number_of_columns_with_differences_for_differently_named_columns(
- comparison3,
-):
- stdout = io.StringIO()
-
- comparison3.report(file=stdout)
-
- assert "****** Column Comparison ******" in stdout.getvalue()
- assert "Number of columns compared with some values unequal: 3" in stdout.getvalue()
- assert "Number of columns compared with all values equal: 1" in stdout.getvalue()
-
-
-def test_column_comparison_outputs_number_of_columns_with_differences_for_known_diffs(
- comparison_kd1,
-):
- stdout = io.StringIO()
-
- comparison_kd1.report(file=stdout)
-
- assert "****** Column Comparison ******" in stdout.getvalue()
- assert (
- "Number of columns compared with unexpected differences in some values: 1"
- in stdout.getvalue()
- )
- assert (
- "Number of columns compared with all values equal but known differences found: 2"
- in stdout.getvalue()
+@pandas_version
+def test_decimal_float_columns_equal():
+ df = ps.DataFrame(
+ [
+ {"a": Decimal("1"), "b": 1, "expected": True},
+ {"a": Decimal("1.3"), "b": 1.3, "expected": True},
+ {"a": Decimal("1.000003"), "b": 1.000003, "expected": True},
+ {"a": Decimal("1.000000004"), "b": 1.000000003, "expected": False},
+ {"a": Decimal("1.3"), "b": 1.2, "expected": False},
+ {"a": np.nan, "b": np.nan, "expected": True},
+ {"a": np.nan, "b": 1, "expected": False},
+ {"a": Decimal("1"), "b": np.nan, "expected": False},
+ ]
)
- assert (
- "Number of columns compared with all values completely equal: 0"
- in stdout.getvalue()
+ actual_out = columns_equal(df.a, df.b)
+ expect_out = df["expected"]
+ assert_series_equal(
+ expect_out.to_pandas(), actual_out.to_pandas(), check_names=False
)
-def test_column_comparison_outputs_number_of_columns_with_differences_for_custom_known_diffs(
- comparison_kd2,
-):
- stdout = io.StringIO()
-
- comparison_kd2.report(file=stdout)
-
- assert "****** Column Comparison ******" in stdout.getvalue()
- assert (
- "Number of columns compared with unexpected differences in some values: 2"
- in stdout.getvalue()
- )
- assert (
- "Number of columns compared with all values equal but known differences found: 1"
- in stdout.getvalue()
+@pandas_version
+def test_decimal_float_columns_equal_rel():
+ df = ps.DataFrame(
+ [
+ {"a": Decimal("1"), "b": 1, "expected": True},
+ {"a": Decimal("1.3"), "b": 1.3, "expected": True},
+ {"a": Decimal("1.000003"), "b": 1.000003, "expected": True},
+ {"a": Decimal("1.000000004"), "b": 1.000000003, "expected": True},
+ {"a": Decimal("1.3"), "b": 1.2, "expected": False},
+ {"a": np.nan, "b": np.nan, "expected": True},
+ {"a": np.nan, "b": 1, "expected": False},
+ {"a": Decimal("1"), "b": np.nan, "expected": False},
+ ]
)
- assert (
- "Number of columns compared with all values completely equal: 0"
- in stdout.getvalue()
+ actual_out = columns_equal(df.a, df.b, abs_tol=0.001)
+ expect_out = df["expected"]
+ assert_series_equal(
+ expect_out.to_pandas(), actual_out.to_pandas(), check_names=False
)
-def test_columns_with_unequal_values_show_mismatch_counts(comparison1):
- stdout = io.StringIO()
-
- comparison1.report(file=stdout)
-
- assert "****** Columns with Unequal Values ******" in stdout.getvalue()
- assert re.search(
- r"""Base\s*Column\s*Name \s+ Compare\s*Column\s*Name \s+ Base\s*Dtype \s+ Compare\sDtype \s*
- \#\sMatches \s* \#\sMismatches \n
- -+ \s+ -+ \s+ -+ \s+ -+ \s+ -+ \s+ -+""",
- stdout.getvalue(),
- re.X,
- )
- assert re.search(
- r"""dollar_amt \s+ dollar_amt \s+ bigint \s+ double \s+ 2 \s+ 2""",
- stdout.getvalue(),
- re.X,
- )
- assert re.search(
- r"""float_fld \s+ float_fld \s+ double \s+ double \s+ 1 \s+ 3""",
- stdout.getvalue(),
- re.X,
+@pandas_version
+def test_decimal_columns_equal():
+ df = ps.DataFrame(
+ [
+ {"a": Decimal("1"), "b": Decimal("1"), "expected": True},
+ {"a": Decimal("1.3"), "b": Decimal("1.3"), "expected": True},
+ {"a": Decimal("1.000003"), "b": Decimal("1.000003"), "expected": True},
+ {
+ "a": Decimal("1.000000004"),
+ "b": Decimal("1.000000003"),
+ "expected": False,
+ },
+ {"a": Decimal("1.3"), "b": Decimal("1.2"), "expected": False},
+ {"a": np.nan, "b": np.nan, "expected": True},
+ {"a": np.nan, "b": Decimal("1"), "expected": False},
+ {"a": Decimal("1"), "b": np.nan, "expected": False},
+ ]
)
- assert re.search(
- r"""name \s+ name \s+ string \s+ string \s+ 3 \s+ 1""", stdout.getvalue(), re.X
+ actual_out = columns_equal(df.a, df.b)
+ expect_out = df["expected"]
+ assert_series_equal(
+ expect_out.to_pandas(), actual_out.to_pandas(), check_names=False
)
-def test_columns_with_different_names_with_unequal_values_show_mismatch_counts(
- comparison3,
-):
- stdout = io.StringIO()
+@pandas_version
+def test_decimal_columns_equal_rel():
+ df = ps.DataFrame(
+ [
+ {"a": Decimal("1"), "b": Decimal("1"), "expected": True},
+ {"a": Decimal("1.3"), "b": Decimal("1.3"), "expected": True},
+ {"a": Decimal("1.000003"), "b": Decimal("1.000003"), "expected": True},
+ {
+ "a": Decimal("1.000000004"),
+ "b": Decimal("1.000000003"),
+ "expected": True,
+ },
+ {"a": Decimal("1.3"), "b": Decimal("1.2"), "expected": False},
+ {"a": np.nan, "b": np.nan, "expected": True},
+ {"a": np.nan, "b": Decimal("1"), "expected": False},
+ {"a": Decimal("1"), "b": np.nan, "expected": False},
+ ]
+ )
+ actual_out = columns_equal(df.a, df.b, abs_tol=0.001)
+ expect_out = df["expected"]
+ assert_series_equal(
+ expect_out.to_pandas(), actual_out.to_pandas(), check_names=False
+ )
- comparison3.report(file=stdout)
- assert "****** Columns with Unequal Values ******" in stdout.getvalue()
- assert re.search(
- r"""Base\s*Column\s*Name \s+ Compare\s*Column\s*Name \s+ Base\s*Dtype \s+ Compare\sDtype \s*
- \#\sMatches \s* \#\sMismatches \n
- -+ \s+ -+ \s+ -+ \s+ -+ \s+ -+ \s+ -+""",
- stdout.getvalue(),
- re.X,
- )
- assert re.search(
- r"""dollar_amt \s+ dollar_amount \s+ bigint \s+ double \s+ 2 \s+ 3""",
- stdout.getvalue(),
- re.X,
+@pandas_version
+def test_infinity_and_beyond():
+ # https://spark.apache.org/docs/latest/sql-ref-datatypes.html#positivenegative-infinity-semantics
+ # Positive/negative infinity multiplied by 0 returns NaN.
+ # Positive infinity sorts lower than NaN and higher than any other values.
+ # Negative infinity sorts lower than any other values.
+ df = ps.DataFrame(
+ [
+ {"a": np.inf, "b": np.inf, "expected": True},
+ {"a": -np.inf, "b": -np.inf, "expected": True},
+ {"a": -np.inf, "b": np.inf, "expected": True},
+ {"a": np.inf, "b": -np.inf, "expected": True},
+ {"a": 1, "b": 1, "expected": True},
+ {"a": 1, "b": 0, "expected": False},
+ ]
)
- assert re.search(
- r"""float_fld \s+ float_field \s+ double \s+ double \s+ 4 \s+ 1""",
- stdout.getvalue(),
- re.X,
+ actual_out = columns_equal(df.a, df.b)
+ expect_out = df["expected"]
+ assert_series_equal(
+ expect_out.to_pandas(), actual_out.to_pandas(), check_names=False
)
- assert re.search(
- r"""name \s+ name \s+ string \s+ string \s+ 4 \s+ 1""", stdout.getvalue(), re.X
+
+
+@pandas_version
+def test_compare_df_setter_bad():
+ df = ps.DataFrame([{"a": 1, "c": 2}, {"a": 2, "c": 2}])
+ with raises(TypeError, match="df1 must be a pyspark.pandas.frame.DataFrame"):
+ compare = SparkCompare("a", "a", ["a"])
+ with raises(ValueError, match="df1 must have all columns from join_columns"):
+ compare = SparkCompare(df, df.copy(), ["b"])
+ df_dupe = ps.DataFrame([{"a": 1, "b": 2}, {"a": 1, "b": 3}])
+ assert (
+ SparkCompare(df_dupe, df_dupe.copy(), ["a", "b"])
+ .df1.equals(df_dupe)
+ .all()
+ .all()
)
-def test_rows_only_base_returns_a_dataframe_with_rows_only_in_base(spark, comparison1):
- # require schema if contains only 1 row and contain field value as None
- schema = StructType(
+@pandas_version
+def test_compare_df_setter_good():
+ df1 = ps.DataFrame([{"a": 1, "b": 2}, {"a": 2, "b": 2}])
+ df2 = ps.DataFrame([{"A": 1, "B": 2}, {"A": 2, "B": 3}])
+ compare = SparkCompare(df1, df2, ["a"])
+ assert compare.df1.equals(df1).all().all()
+ assert compare.df2.equals(df2).all().all()
+ assert compare.join_columns == ["a"]
+ compare = SparkCompare(df1, df2, ["A", "b"])
+ assert compare.df1.equals(df1).all().all()
+ assert compare.df2.equals(df2).all().all()
+ assert compare.join_columns == ["a", "b"]
+
+
+@pandas_version
+def test_compare_df_setter_different_cases():
+ df1 = ps.DataFrame([{"a": 1, "b": 2}, {"a": 2, "b": 2}])
+ df2 = ps.DataFrame([{"A": 1, "b": 2}, {"A": 2, "b": 3}])
+ compare = SparkCompare(df1, df2, ["a"])
+ assert compare.df1.equals(df1).all().all()
+ assert compare.df2.equals(df2).all().all()
+
+
+@pandas_version
+def test_columns_overlap():
+ df1 = ps.DataFrame([{"a": 1, "b": 2}, {"a": 2, "b": 2}])
+ df2 = ps.DataFrame([{"a": 1, "b": 2}, {"a": 2, "b": 3}])
+ compare = SparkCompare(df1, df2, ["a"])
+ assert compare.df1_unq_columns() == set()
+ assert compare.df2_unq_columns() == set()
+ assert compare.intersect_columns() == {"a", "b"}
+
+
+@pandas_version
+def test_columns_no_overlap():
+ df1 = ps.DataFrame([{"a": 1, "b": 2, "c": "hi"}, {"a": 2, "b": 2, "c": "yo"}])
+ df2 = ps.DataFrame([{"a": 1, "b": 2, "d": "oh"}, {"a": 2, "b": 3, "d": "ya"}])
+ compare = SparkCompare(df1, df2, ["a"])
+ assert compare.df1_unq_columns() == {"c"}
+ assert compare.df2_unq_columns() == {"d"}
+ assert compare.intersect_columns() == {"a", "b"}
+
+
+@pandas_version
+def test_columns_maintain_order_through_set_operations():
+ df1 = ps.DataFrame(
[
- StructField("acct", LongType(), True),
- StructField("date_fld", DateType(), True),
- StructField("dollar_amt", LongType(), True),
- StructField("float_fld", DoubleType(), True),
- StructField("name", StringType(), True),
- ]
+ (("A"), (0), (1), (2), (3), (4), (-2)),
+ (("B"), (0), (2), (2), (3), (4), (-3)),
+ ],
+ columns=["join", "f", "g", "b", "h", "a", "c"],
)
- expected_df = spark.createDataFrame(
+ df2 = ps.DataFrame(
[
- Row(
- acct=10000001239,
- date_fld=datetime.date(2017, 1, 1),
- dollar_amt=1,
- float_fld=None,
- name="Lucille Bluth",
- )
+ (("A"), (0), (1), (2), (-1), (4), (-3)),
+ (("B"), (1), (2), (3), (-1), (4), (-2)),
],
- schema,
+ columns=["join", "e", "h", "b", "a", "g", "d"],
)
- assert comparison1.rows_only_base.count() == 1
- assert (
- expected_df.union(
- comparison1.rows_only_base.select(
- "acct", "date_fld", "dollar_amt", "float_fld", "name"
- )
- )
- .distinct()
- .count()
- == 1
+ compare = SparkCompare(df1, df2, ["join"])
+ assert list(compare.df1_unq_columns()) == ["f", "c"]
+ assert list(compare.df2_unq_columns()) == ["e", "d"]
+ assert list(compare.intersect_columns()) == ["join", "g", "b", "h", "a"]
+
+
+@pandas_version
+def test_10k_rows():
+ df1 = ps.DataFrame(np.random.randint(0, 100, size=(10000, 2)), columns=["b", "c"])
+ df1.reset_index(inplace=True)
+ df1.columns = ["a", "b", "c"]
+ df2 = df1.copy()
+ df2["b"] = df2["b"] + 0.1
+ compare_tol = SparkCompare(df1, df2, ["a"], abs_tol=0.2)
+ assert compare_tol.matches()
+ assert len(compare_tol.df1_unq_rows) == 0
+ assert len(compare_tol.df2_unq_rows) == 0
+ assert compare_tol.intersect_columns() == {"a", "b", "c"}
+ assert compare_tol.all_columns_match()
+ assert compare_tol.all_rows_overlap()
+ assert compare_tol.intersect_rows_match()
+
+ compare_no_tol = SparkCompare(df1, df2, ["a"])
+ assert not compare_no_tol.matches()
+ assert len(compare_no_tol.df1_unq_rows) == 0
+ assert len(compare_no_tol.df2_unq_rows) == 0
+ assert compare_no_tol.intersect_columns() == {"a", "b", "c"}
+ assert compare_no_tol.all_columns_match()
+ assert compare_no_tol.all_rows_overlap()
+ assert not compare_no_tol.intersect_rows_match()
+
+
+@pandas_version
+def test_subset(caplog):
+ caplog.set_level(logging.DEBUG)
+ df1 = ps.DataFrame([{"a": 1, "b": 2, "c": "hi"}, {"a": 2, "b": 2, "c": "yo"}])
+ df2 = ps.DataFrame([{"a": 1, "c": "hi"}])
+ comp = SparkCompare(df1, df2, ["a"])
+ assert comp.subset()
+ assert "Checking equality" in caplog.text
+
+
+@pandas_version
+def test_not_subset(caplog):
+ caplog.set_level(logging.INFO)
+ df1 = ps.DataFrame([{"a": 1, "b": 2, "c": "hi"}, {"a": 2, "b": 2, "c": "yo"}])
+ df2 = ps.DataFrame([{"a": 1, "b": 2, "c": "hi"}, {"a": 2, "b": 2, "c": "great"}])
+ comp = SparkCompare(df1, df2, ["a"])
+ assert not comp.subset()
+ assert "c: 1 / 2 (50.00%) match" in caplog.text
+
+
+@pandas_version
+def test_large_subset():
+ df1 = ps.DataFrame(np.random.randint(0, 100, size=(10000, 2)), columns=["b", "c"])
+ df1.reset_index(inplace=True)
+ df1.columns = ["a", "b", "c"]
+ df2 = df1[["a", "b"]].head(50).copy()
+ comp = SparkCompare(df1, df2, ["a"])
+ assert not comp.matches()
+ assert comp.subset()
+
+
+@pandas_version
+def test_string_joiner():
+ df1 = ps.DataFrame([{"ab": 1, "bc": 2}, {"ab": 2, "bc": 2}])
+ df2 = ps.DataFrame([{"ab": 1, "bc": 2}, {"ab": 2, "bc": 2}])
+ compare = SparkCompare(df1, df2, "ab")
+ assert compare.matches()
+
+
+@pandas_version
+def test_decimal_with_joins():
+ df1 = ps.DataFrame([{"a": Decimal("1"), "b": 2}, {"a": Decimal("2"), "b": 2}])
+ df2 = ps.DataFrame([{"a": 1, "b": 2}, {"a": 2, "b": 2}])
+ compare = SparkCompare(df1, df2, "a")
+ assert compare.matches()
+ assert compare.all_columns_match()
+ assert compare.all_rows_overlap()
+ assert compare.intersect_rows_match()
+
+
+@pandas_version
+def test_decimal_with_nulls():
+ df1 = ps.DataFrame([{"a": 1, "b": Decimal("2")}, {"a": 2, "b": Decimal("2")}])
+ df2 = ps.DataFrame([{"a": 1, "b": 2}, {"a": 2, "b": 2}, {"a": 3, "b": 2}])
+ compare = SparkCompare(df1, df2, "a")
+ assert not compare.matches()
+ assert compare.all_columns_match()
+ assert not compare.all_rows_overlap()
+ assert compare.intersect_rows_match()
+
+
+@pandas_version
+def test_strings_with_joins():
+ df1 = ps.DataFrame([{"a": "hi", "b": 2}, {"a": "bye", "b": 2}])
+ df2 = ps.DataFrame([{"a": "hi", "b": 2}, {"a": "bye", "b": 2}])
+ compare = SparkCompare(df1, df2, "a")
+ assert compare.matches()
+ assert compare.all_columns_match()
+ assert compare.all_rows_overlap()
+ assert compare.intersect_rows_match()
+
+
+@pandas_version
+def test_temp_column_name():
+ df1 = ps.DataFrame([{"a": "hi", "b": 2}, {"a": "bye", "b": 2}])
+ df2 = ps.DataFrame(
+ [{"a": "hi", "b": 2}, {"a": "bye", "b": 2}, {"a": "back fo mo", "b": 3}]
)
+ actual = temp_column_name(df1, df2)
+ assert actual == "_temp_0"
-def test_rows_only_compare_returns_a_dataframe_with_rows_only_in_compare(
- spark, comparison1
-):
- expected_df = spark.createDataFrame(
- [
- Row(
- acct=10000001238,
- dollar_amt=1.05,
- name="Loose Seal Bluth",
- float_fld=111.0,
- accnt_purge=True,
- )
- ]
+@pandas_version
+def test_temp_column_name_one_has():
+ df1 = ps.DataFrame([{"_temp_0": "hi", "b": 2}, {"_temp_0": "bye", "b": 2}])
+ df2 = ps.DataFrame(
+ [{"a": "hi", "b": 2}, {"a": "bye", "b": 2}, {"a": "back fo mo", "b": 3}]
)
-
- assert comparison1.rows_only_compare.count() == 1
- assert expected_df.union(comparison1.rows_only_compare).distinct().count() == 1
+ actual = temp_column_name(df1, df2)
+ assert actual == "_temp_1"
-def test_rows_both_mismatch_returns_a_dataframe_with_rows_where_variables_mismatched(
- spark, comparison1
-):
- expected_df = spark.createDataFrame(
+@pandas_version
+def test_temp_column_name_both_have():
+ df1 = ps.DataFrame([{"_temp_0": "hi", "b": 2}, {"_temp_0": "bye", "b": 2}])
+ df2 = ps.DataFrame(
[
- Row(
- accnt_purge=False,
- acct=10000001234,
- date_fld=datetime.date(2017, 1, 1),
- dollar_amt_base=123,
- dollar_amt_compare=123.4,
- dollar_amt_match=False,
- float_fld_base=14530.1555,
- float_fld_compare=14530.155,
- float_fld_match=False,
- name_base="George Maharis",
- name_compare="George Michael Bluth",
- name_match=False,
- ),
- Row(
- accnt_purge=False,
- acct=10000001235,
- date_fld=datetime.date(2017, 1, 1),
- dollar_amt_base=0,
- dollar_amt_compare=0.45,
- dollar_amt_match=False,
- float_fld_base=1.0,
- float_fld_compare=None,
- float_fld_match=False,
- name_base="Michael Bluth",
- name_compare="Michael Bluth",
- name_match=True,
- ),
- Row(
- accnt_purge=False,
- acct=10000001236,
- date_fld=datetime.date(2017, 1, 1),
- dollar_amt_base=1345,
- dollar_amt_compare=1345.0,
- dollar_amt_match=True,
- float_fld_base=None,
- float_fld_compare=1.0,
- float_fld_match=False,
- name_base="George Bluth",
- name_compare="George Bluth",
- name_match=True,
- ),
+ {"_temp_0": "hi", "b": 2},
+ {"_temp_0": "bye", "b": 2},
+ {"a": "back fo mo", "b": 3},
]
)
-
- assert comparison1.rows_both_mismatch.count() == 3
- assert expected_df.union(comparison1.rows_both_mismatch).distinct().count() == 3
+ actual = temp_column_name(df1, df2)
+ assert actual == "_temp_1"
-def test_rows_both_mismatch_only_includes_rows_with_true_mismatches_when_known_diffs_are_present(
- spark, comparison_kd1
-):
- expected_df = spark.createDataFrame(
+@pandas_version
+def test_temp_column_name_both_have():
+ df1 = ps.DataFrame([{"_temp_0": "hi", "b": 2}, {"_temp_0": "bye", "b": 2}])
+ df2 = ps.DataFrame(
[
- Row(
- acct=10000001237,
- acct_seq=0,
- cd_base="0004",
- cd_compare=4.0,
- cd_match=True,
- cd_match_type="KNOWN_DIFFERENCE",
- open_dt_base=datetime.date(2017, 5, 4),
- open_dt_compare=2017124,
- open_dt_match=True,
- open_dt_match_type="KNOWN_DIFFERENCE",
- stat_cd_base="*2",
- stat_cd_compare="V3",
- stat_cd_match=False,
- stat_cd_match_type="MISMATCH",
- )
+ {"_temp_0": "hi", "b": 2},
+ {"_temp_1": "bye", "b": 2},
+ {"a": "back fo mo", "b": 3},
]
)
- assert comparison_kd1.rows_both_mismatch.count() == 1
- assert expected_df.union(comparison_kd1.rows_both_mismatch).distinct().count() == 1
+ actual = temp_column_name(df1, df2)
+ assert actual == "_temp_2"
-def test_rows_both_all_returns_a_dataframe_with_all_rows_in_both_dataframes(
- spark, comparison1
-):
- expected_df = spark.createDataFrame(
+@pandas_version
+def test_temp_column_name_one_already():
+ df1 = ps.DataFrame([{"_temp_1": "hi", "b": 2}, {"_temp_1": "bye", "b": 2}])
+ df2 = ps.DataFrame(
[
- Row(
- accnt_purge=False,
- acct=10000001234,
- date_fld=datetime.date(2017, 1, 1),
- dollar_amt_base=123,
- dollar_amt_compare=123.4,
- dollar_amt_match=False,
- float_fld_base=14530.1555,
- float_fld_compare=14530.155,
- float_fld_match=False,
- name_base="George Maharis",
- name_compare="George Michael Bluth",
- name_match=False,
- ),
- Row(
- accnt_purge=False,
- acct=10000001235,
- date_fld=datetime.date(2017, 1, 1),
- dollar_amt_base=0,
- dollar_amt_compare=0.45,
- dollar_amt_match=False,
- float_fld_base=1.0,
- float_fld_compare=None,
- float_fld_match=False,
- name_base="Michael Bluth",
- name_compare="Michael Bluth",
- name_match=True,
- ),
- Row(
- accnt_purge=False,
- acct=10000001236,
- date_fld=datetime.date(2017, 1, 1),
- dollar_amt_base=1345,
- dollar_amt_compare=1345.0,
- dollar_amt_match=True,
- float_fld_base=None,
- float_fld_compare=1.0,
- float_fld_match=False,
- name_base="George Bluth",
- name_compare="George Bluth",
- name_match=True,
- ),
- Row(
- accnt_purge=False,
- acct=10000001237,
- date_fld=datetime.date(2017, 1, 1),
- dollar_amt_base=123456,
- dollar_amt_compare=123456.0,
- dollar_amt_match=True,
- float_fld_base=345.12,
- float_fld_compare=345.12,
- float_fld_match=True,
- name_base="Bob Loblaw",
- name_compare="Bob Loblaw",
- name_match=True,
- ),
+ {"_temp_1": "hi", "b": 2},
+ {"_temp_1": "bye", "b": 2},
+ {"a": "back fo mo", "b": 3},
]
)
+ actual = temp_column_name(df1, df2)
+ assert actual == "_temp_0"
- assert comparison1.rows_both_all.count() == 4
- assert expected_df.union(comparison1.rows_both_all).distinct().count() == 4
+### Duplicate testing!
+@pandas_version
+def test_simple_dupes_one_field():
+ df1 = ps.DataFrame([{"a": 1, "b": 2}, {"a": 1, "b": 2}])
+ df2 = ps.DataFrame([{"a": 1, "b": 2}, {"a": 1, "b": 2}])
+ compare = SparkCompare(df1, df2, join_columns=["a"])
+ assert compare.matches()
+ # Just render the report to make sure it renders.
+ t = compare.report()
-def test_rows_both_all_shows_known_diffs_flag_and_known_diffs_count_as_matches(
- spark, comparison_kd1
-):
- expected_df = spark.createDataFrame(
- [
- Row(
- acct=10000001234,
- acct_seq=0,
- cd_base="0001",
- cd_compare=1.0,
- cd_match=True,
- cd_match_type="KNOWN_DIFFERENCE",
- open_dt_base=datetime.date(2017, 5, 1),
- open_dt_compare=2017121,
- open_dt_match=True,
- open_dt_match_type="KNOWN_DIFFERENCE",
- stat_cd_base="*2",
- stat_cd_compare=None,
- stat_cd_match=True,
- stat_cd_match_type="KNOWN_DIFFERENCE",
- ),
- Row(
- acct=10000001235,
- acct_seq=0,
- cd_base="0002",
- cd_compare=2.0,
- cd_match=True,
- cd_match_type="KNOWN_DIFFERENCE",
- open_dt_base=datetime.date(2017, 5, 2),
- open_dt_compare=2017122,
- open_dt_match=True,
- open_dt_match_type="KNOWN_DIFFERENCE",
- stat_cd_base="V1",
- stat_cd_compare="V1",
- stat_cd_match=True,
- stat_cd_match_type="MATCH",
- ),
- Row(
- acct=10000001236,
- acct_seq=0,
- cd_base="0003",
- cd_compare=3.0,
- cd_match=True,
- cd_match_type="KNOWN_DIFFERENCE",
- open_dt_base=datetime.date(2017, 5, 3),
- open_dt_compare=2017123,
- open_dt_match=True,
- open_dt_match_type="KNOWN_DIFFERENCE",
- stat_cd_base="V2",
- stat_cd_compare="V2",
- stat_cd_match=True,
- stat_cd_match_type="MATCH",
- ),
- Row(
- acct=10000001237,
- acct_seq=0,
- cd_base="0004",
- cd_compare=4.0,
- cd_match=True,
- cd_match_type="KNOWN_DIFFERENCE",
- open_dt_base=datetime.date(2017, 5, 4),
- open_dt_compare=2017124,
- open_dt_match=True,
- open_dt_match_type="KNOWN_DIFFERENCE",
- stat_cd_base="*2",
- stat_cd_compare="V3",
- stat_cd_match=False,
- stat_cd_match_type="MISMATCH",
- ),
- Row(
- acct=10000001238,
- acct_seq=0,
- cd_base="0005",
- cd_compare=5.0,
- cd_match=True,
- cd_match_type="KNOWN_DIFFERENCE",
- open_dt_base=datetime.date(2017, 5, 5),
- open_dt_compare=2017125,
- open_dt_match=True,
- open_dt_match_type="KNOWN_DIFFERENCE",
- stat_cd_base="*2",
- stat_cd_compare=None,
- stat_cd_match=True,
- stat_cd_match_type="KNOWN_DIFFERENCE",
- ),
- ]
- )
- assert comparison_kd1.rows_both_all.count() == 5
- assert expected_df.union(comparison_kd1.rows_both_all).distinct().count() == 5
+@pandas_version
+def test_simple_dupes_two_fields():
+ df1 = ps.DataFrame([{"a": 1, "b": 2}, {"a": 1, "b": 2, "c": 2}])
+ df2 = ps.DataFrame([{"a": 1, "b": 2}, {"a": 1, "b": 2, "c": 2}])
+ compare = SparkCompare(df1, df2, join_columns=["a", "b"])
+ assert compare.matches()
+ # Just render the report to make sure it renders.
+ t = compare.report()
-def test_rows_both_all_returns_a_dataframe_with_all_rows_in_identical_dataframes(
- spark, comparison2
-):
- expected_df = spark.createDataFrame(
- [
- Row(
- acct=10000001234,
- date_fld_base=datetime.date(2017, 1, 1),
- date_fld_compare=datetime.date(2017, 1, 1),
- date_fld_match=True,
- dollar_amt_base=123,
- dollar_amt_compare=123,
- dollar_amt_match=True,
- float_fld_base=14530.1555,
- float_fld_compare=14530.1555,
- float_fld_match=True,
- name_base="George Maharis",
- name_compare="George Maharis",
- name_match=True,
- ),
- Row(
- acct=10000001235,
- date_fld_base=datetime.date(2017, 1, 1),
- date_fld_compare=datetime.date(2017, 1, 1),
- date_fld_match=True,
- dollar_amt_base=0,
- dollar_amt_compare=0,
- dollar_amt_match=True,
- float_fld_base=1.0,
- float_fld_compare=1.0,
- float_fld_match=True,
- name_base="Michael Bluth",
- name_compare="Michael Bluth",
- name_match=True,
- ),
- Row(
- acct=10000001236,
- date_fld_base=datetime.date(2017, 1, 1),
- date_fld_compare=datetime.date(2017, 1, 1),
- date_fld_match=True,
- dollar_amt_base=1345,
- dollar_amt_compare=1345,
- dollar_amt_match=True,
- float_fld_base=None,
- float_fld_compare=None,
- float_fld_match=True,
- name_base="George Bluth",
- name_compare="George Bluth",
- name_match=True,
- ),
- Row(
- acct=10000001237,
- date_fld_base=datetime.date(2017, 1, 1),
- date_fld_compare=datetime.date(2017, 1, 1),
- date_fld_match=True,
- dollar_amt_base=123456,
- dollar_amt_compare=123456,
- dollar_amt_match=True,
- float_fld_base=345.12,
- float_fld_compare=345.12,
- float_fld_match=True,
- name_base="Bob Loblaw",
- name_compare="Bob Loblaw",
- name_match=True,
- ),
- Row(
- acct=10000001239,
- date_fld_base=datetime.date(2017, 1, 1),
- date_fld_compare=datetime.date(2017, 1, 1),
- date_fld_match=True,
- dollar_amt_base=1,
- dollar_amt_compare=1,
- dollar_amt_match=True,
- float_fld_base=None,
- float_fld_compare=None,
- float_fld_match=True,
- name_base="Lucille Bluth",
- name_compare="Lucille Bluth",
- name_match=True,
- ),
- ]
+@pandas_version
+def test_simple_dupes_one_field_two_vals_1():
+ df1 = ps.DataFrame([{"a": 1, "b": 2}, {"a": 1, "b": 0}])
+ df2 = ps.DataFrame([{"a": 1, "b": 2}, {"a": 1, "b": 0}])
+ compare = SparkCompare(df1, df2, join_columns=["a"])
+ assert compare.matches()
+ # Just render the report to make sure it renders.
+ t = compare.report()
+
+
+@pandas_version
+def test_simple_dupes_one_field_two_vals_2():
+ df1 = ps.DataFrame([{"a": 1, "b": 2}, {"a": 1, "b": 0}])
+ df2 = ps.DataFrame([{"a": 1, "b": 2}, {"a": 2, "b": 0}])
+ compare = SparkCompare(df1, df2, join_columns=["a"])
+ assert not compare.matches()
+ assert len(compare.df1_unq_rows) == 1
+ assert len(compare.df2_unq_rows) == 1
+ assert len(compare.intersect_rows) == 1
+ # Just render the report to make sure it renders.
+ t = compare.report()
+
+
+@pandas_version
+def test_simple_dupes_one_field_three_to_two_vals():
+ df1 = ps.DataFrame([{"a": 1, "b": 2}, {"a": 1, "b": 0}, {"a": 1, "b": 0}])
+ df2 = ps.DataFrame([{"a": 1, "b": 2}, {"a": 1, "b": 0}])
+ compare = SparkCompare(df1, df2, join_columns=["a"])
+ assert not compare.matches()
+ assert len(compare.df1_unq_rows) == 1
+ assert len(compare.df2_unq_rows) == 0
+ assert len(compare.intersect_rows) == 2
+ # Just render the report to make sure it renders.
+ t = compare.report()
+
+ assert "(First 1 Columns)" in compare.report(column_count=1)
+ assert "(First 2 Columns)" in compare.report(column_count=2)
+
+
+@pandas_version
+def test_dupes_from_real_data():
+ data = """acct_id,acct_sfx_num,trxn_post_dt,trxn_post_seq_num,trxn_amt,trxn_dt,debit_cr_cd,cash_adv_trxn_comn_cntry_cd,mrch_catg_cd,mrch_pstl_cd,visa_mail_phn_cd,visa_rqstd_pmt_svc_cd,mc_pmt_facilitator_idn_num
+100,0,2017-06-17,1537019,30.64,2017-06-15,D,CAN,5812,M2N5P5,,,0.0
+200,0,2017-06-24,1022477,485.32,2017-06-22,D,USA,4511,7114,7.0,1,
+100,0,2017-06-17,1537039,2.73,2017-06-16,D,CAN,5812,M4J 1M9,,,0.0
+200,0,2017-06-29,1049223,22.41,2017-06-28,D,USA,4789,21211,,A,
+100,0,2017-06-17,1537029,34.05,2017-06-16,D,CAN,5812,M4E 2C7,,,0.0
+200,0,2017-06-29,1049213,9.12,2017-06-28,D,CAN,5814,0,,,
+100,0,2017-06-19,1646426,165.21,2017-06-17,D,CAN,5411,M4M 3H9,,,0.0
+200,0,2017-06-30,1233082,28.54,2017-06-29,D,USA,4121,94105,7.0,G,
+100,0,2017-06-19,1646436,17.87,2017-06-18,D,CAN,5812,M4J 1M9,,,0.0
+200,0,2017-06-30,1233092,24.39,2017-06-29,D,USA,4121,94105,7.0,G,
+100,0,2017-06-19,1646446,5.27,2017-06-17,D,CAN,5200,M4M 3G6,,,0.0
+200,0,2017-06-30,1233102,61.8,2017-06-30,D,CAN,4121,0,,,
+100,0,2017-06-20,1607573,41.99,2017-06-19,D,CAN,5661,M4C1M9,,,0.0
+200,0,2017-07-01,1009403,2.31,2017-06-29,D,USA,5814,22102,,F,
+100,0,2017-06-20,1607553,86.88,2017-06-19,D,CAN,4812,H2R3A8,,,0.0
+200,0,2017-07-01,1009423,5.5,2017-06-29,D,USA,5812,2903,,F,
+100,0,2017-06-20,1607563,25.17,2017-06-19,D,CAN,5641,M4C 1M9,,,0.0
+200,0,2017-07-01,1009433,214.12,2017-06-29,D,USA,3640,20170,,A,
+100,0,2017-06-20,1607593,1.67,2017-06-19,D,CAN,5814,M2N 6L7,,,0.0
+200,0,2017-07-01,1009393,2.01,2017-06-29,D,USA,5814,22102,,F,"""
+ df1 = ps.from_pandas(pd.read_csv(StringIO(data), sep=","))
+ df2 = df1.copy()
+ compare_acct = SparkCompare(df1, df2, join_columns=["acct_id"])
+ assert compare_acct.matches()
+ compare_unq = SparkCompare(
+ df1,
+ df2,
+ join_columns=["acct_id", "acct_sfx_num", "trxn_post_dt", "trxn_post_seq_num"],
)
+ assert compare_unq.matches()
+ # Just render the report to make sure it renders.
+ t = compare_acct.report()
+ r = compare_unq.report()
+
+
+@pandas_version
+def test_strings_with_joins_with_ignore_spaces():
+ df1 = ps.DataFrame([{"a": "hi", "b": " A"}, {"a": "bye", "b": "A"}])
+ df2 = ps.DataFrame([{"a": "hi", "b": "A"}, {"a": "bye", "b": "A "}])
+ compare = SparkCompare(df1, df2, "a", ignore_spaces=False)
+ assert not compare.matches()
+ assert compare.all_columns_match()
+ assert compare.all_rows_overlap()
+ assert not compare.intersect_rows_match()
+
+ compare = SparkCompare(df1, df2, "a", ignore_spaces=True)
+ assert compare.matches()
+ assert compare.all_columns_match()
+ assert compare.all_rows_overlap()
+ assert compare.intersect_rows_match()
+
+
+@pandas_version
+def test_strings_with_joins_with_ignore_case():
+ df1 = ps.DataFrame([{"a": "hi", "b": "a"}, {"a": "bye", "b": "A"}])
+ df2 = ps.DataFrame([{"a": "hi", "b": "A"}, {"a": "bye", "b": "a"}])
+ compare = SparkCompare(df1, df2, "a", ignore_case=False)
+ assert not compare.matches()
+ assert compare.all_columns_match()
+ assert compare.all_rows_overlap()
+ assert not compare.intersect_rows_match()
+
+ compare = SparkCompare(df1, df2, "a", ignore_case=True)
+ assert compare.matches()
+ assert compare.all_columns_match()
+ assert compare.all_rows_overlap()
+ assert compare.intersect_rows_match()
+
+
+@pandas_version
+def test_decimal_with_joins_with_ignore_spaces():
+ df1 = ps.DataFrame([{"a": 1, "b": " A"}, {"a": 2, "b": "A"}])
+ df2 = ps.DataFrame([{"a": 1, "b": "A"}, {"a": 2, "b": "A "}])
+ compare = SparkCompare(df1, df2, "a", ignore_spaces=False)
+ assert not compare.matches()
+ assert compare.all_columns_match()
+ assert compare.all_rows_overlap()
+ assert not compare.intersect_rows_match()
+
+ compare = SparkCompare(df1, df2, "a", ignore_spaces=True)
+ assert compare.matches()
+ assert compare.all_columns_match()
+ assert compare.all_rows_overlap()
+ assert compare.intersect_rows_match()
+
+
+@pandas_version
+def test_decimal_with_joins_with_ignore_case():
+ df1 = ps.DataFrame([{"a": 1, "b": "a"}, {"a": 2, "b": "A"}])
+ df2 = ps.DataFrame([{"a": 1, "b": "A"}, {"a": 2, "b": "a"}])
+ compare = SparkCompare(df1, df2, "a", ignore_case=False)
+ assert not compare.matches()
+ assert compare.all_columns_match()
+ assert compare.all_rows_overlap()
+ assert not compare.intersect_rows_match()
+
+ compare = SparkCompare(df1, df2, "a", ignore_case=True)
+ assert compare.matches()
+ assert compare.all_columns_match()
+ assert compare.all_rows_overlap()
+ assert compare.intersect_rows_match()
+
+
+@pandas_version
+def test_joins_with_ignore_spaces():
+ df1 = ps.DataFrame([{"a": 1, "b": " A"}, {"a": 2, "b": "A"}])
+ df2 = ps.DataFrame([{"a": 1, "b": "A"}, {"a": 2, "b": "A "}])
+
+ compare = SparkCompare(df1, df2, "a", ignore_spaces=True)
+ assert compare.matches()
+ assert compare.all_columns_match()
+ assert compare.all_rows_overlap()
+ assert compare.intersect_rows_match()
+
+
+@pandas_version
+def test_joins_with_ignore_case():
+ df1 = ps.DataFrame([{"a": 1, "b": "a"}, {"a": 2, "b": "A"}])
+ df2 = ps.DataFrame([{"a": 1, "b": "A"}, {"a": 2, "b": "a"}])
+
+ compare = SparkCompare(df1, df2, "a", ignore_case=True)
+ assert compare.matches()
+ assert compare.all_columns_match()
+ assert compare.all_rows_overlap()
+ assert compare.intersect_rows_match()
+
+
+@pandas_version
+def test_strings_with_ignore_spaces_and_join_columns():
+ df1 = ps.DataFrame([{"a": "hi", "b": "A"}, {"a": "bye", "b": "A"}])
+ df2 = ps.DataFrame([{"a": " hi ", "b": "A"}, {"a": " bye ", "b": "A"}])
+ compare = SparkCompare(df1, df2, "a", ignore_spaces=False)
+ assert not compare.matches()
+ assert compare.all_columns_match()
+ assert not compare.all_rows_overlap()
+ assert compare.count_matching_rows() == 0
+
+ compare = SparkCompare(df1, df2, "a", ignore_spaces=True)
+ assert compare.matches()
+ assert compare.all_columns_match()
+ assert compare.all_rows_overlap()
+ assert compare.intersect_rows_match()
+ assert compare.count_matching_rows() == 2
+
+
+@pandas_version
+def test_integers_with_ignore_spaces_and_join_columns():
+ df1 = ps.DataFrame([{"a": 1, "b": "A"}, {"a": 2, "b": "A"}])
+ df2 = ps.DataFrame([{"a": 1, "b": "A"}, {"a": 2, "b": "A"}])
+ compare = SparkCompare(df1, df2, "a", ignore_spaces=False)
+ assert compare.matches()
+ assert compare.all_columns_match()
+ assert compare.all_rows_overlap()
+ assert compare.intersect_rows_match()
+ assert compare.count_matching_rows() == 2
+
+ compare = SparkCompare(df1, df2, "a", ignore_spaces=True)
+ assert compare.matches()
+ assert compare.all_columns_match()
+ assert compare.all_rows_overlap()
+ assert compare.intersect_rows_match()
+ assert compare.count_matching_rows() == 2
+
+
+@pandas_version
+def test_sample_mismatch():
+ data1 = """acct_id,dollar_amt,name,float_fld,date_fld
+ 10000001234,123.45,George Maharis,14530.1555,2017-01-01
+ 10000001235,0.45,Michael Bluth,1,2017-01-01
+ 10000001236,1345,George Bluth,,2017-01-01
+ 10000001237,123456,Bob Loblaw,345.12,2017-01-01
+ 10000001239,1.05,Lucille Bluth,,2017-01-01
+ 10000001240,123.45,George Maharis,14530.1555,2017-01-02
+ """
- assert comparison2.rows_both_all.count() == 5
- assert expected_df.union(comparison2.rows_both_all).distinct().count() == 5
+ data2 = """acct_id,dollar_amt,name,float_fld,date_fld
+ 10000001234,123.4,George Michael Bluth,14530.155,
+ 10000001235,0.45,Michael Bluth,,
+ 10000001236,1345,George Bluth,1,
+ 10000001237,123456,Robert Loblaw,345.12,
+ 10000001238,1.05,Loose Seal Bluth,111,
+ 10000001240,123.45,George Maharis,14530.1555,2017-01-02
+ """
+ df1 = ps.from_pandas(pd.read_csv(StringIO(data1), sep=","))
+ df2 = ps.from_pandas(pd.read_csv(StringIO(data2), sep=","))
-def test_rows_both_all_returns_all_rows_in_both_dataframes_for_differently_named_columns(
- spark, comparison3
-):
- expected_df = spark.createDataFrame(
- [
- Row(
- accnt_purge=False,
- acct=10000001234,
- date_fld_base=datetime.date(2017, 1, 1),
- date_fld_compare=datetime.date(2017, 1, 1),
- date_fld_match=True,
- dollar_amt_base=123,
- dollar_amt_compare=123.4,
- dollar_amt_match=False,
- float_fld_base=14530.1555,
- float_fld_compare=14530.155,
- float_fld_match=False,
- name_base="George Maharis",
- name_compare="George Michael Bluth",
- name_match=False,
- ),
- Row(
- accnt_purge=False,
- acct=10000001235,
- date_fld_base=datetime.date(2017, 1, 1),
- date_fld_compare=datetime.date(2017, 1, 1),
- date_fld_match=True,
- dollar_amt_base=0,
- dollar_amt_compare=0.45,
- dollar_amt_match=False,
- float_fld_base=1.0,
- float_fld_compare=1.0,
- float_fld_match=True,
- name_base="Michael Bluth",
- name_compare="Michael Bluth",
- name_match=True,
- ),
- Row(
- accnt_purge=False,
- acct=10000001236,
- date_fld_base=datetime.date(2017, 1, 1),
- date_fld_compare=datetime.date(2017, 1, 1),
- date_fld_match=True,
- dollar_amt_base=1345,
- dollar_amt_compare=1345.0,
- dollar_amt_match=True,
- float_fld_base=None,
- float_fld_compare=None,
- float_fld_match=True,
- name_base="George Bluth",
- name_compare="George Bluth",
- name_match=True,
- ),
- Row(
- accnt_purge=False,
- acct=10000001237,
- date_fld_base=datetime.date(2017, 1, 1),
- date_fld_compare=datetime.date(2017, 1, 1),
- date_fld_match=True,
- dollar_amt_base=123456,
- dollar_amt_compare=123456.0,
- dollar_amt_match=True,
- float_fld_base=345.12,
- float_fld_compare=345.12,
- float_fld_match=True,
- name_base="Bob Loblaw",
- name_compare="Bob Loblaw",
- name_match=True,
- ),
- Row(
- accnt_purge=True,
- acct=10000001239,
- date_fld_base=datetime.date(2017, 1, 1),
- date_fld_compare=datetime.date(2017, 1, 1),
- date_fld_match=True,
- dollar_amt_base=1,
- dollar_amt_compare=1.05,
- dollar_amt_match=False,
- float_fld_base=None,
- float_fld_compare=None,
- float_fld_match=True,
- name_base="Lucille Bluth",
- name_compare="Lucille Bluth",
- name_match=True,
- ),
- ]
- )
+ compare = SparkCompare(df1, df2, "acct_id")
- assert comparison3.rows_both_all.count() == 5
- assert expected_df.union(comparison3.rows_both_all).distinct().count() == 5
-
-
-def test_columns_with_unequal_values_text_is_aligned(comparison4):
- stdout = io.StringIO()
-
- comparison4.report(file=stdout)
- stdout.seek(0) # Back up to the beginning of the stream
-
- text_alignment_validator(
- report=stdout,
- section_start="****** Columns with Unequal Values ******",
- section_end="\n",
- left_indices=(1, 2, 3, 4),
- right_indices=(5, 6),
- column_regexes=[
- r"""(Base\sColumn\sName) \s+ (Compare\sColumn\sName) \s+ (Base\sDtype) \s+ (Compare\sDtype) \s+
- (\#\sMatches) \s+ (\#\sMismatches)""",
- r"""(-+) \s+ (-+) \s+ (-+) \s+ (-+) \s+ (-+) \s+ (-+)""",
- r"""(dollar_amt) \s+ (dollar_amt) \s+ (bigint) \s+ (double) \s+ (2) \s+ (2)""",
- r"""(float_fld) \s+ (float_fld) \s+ (double) \s+ (double) \s+ (1) \s+ (3)""",
- r"""(super_duper_big_long_name) \s+ (name) \s+ (string) \s+ (string) \s+ (3) \s+ (1)\s*""",
- ],
- )
+ output = compare.sample_mismatch(column="name", sample_count=1)
+ assert output.shape[0] == 1
+ assert (output.name_df1 != output.name_df2).all()
+ output = compare.sample_mismatch(column="name", sample_count=2)
+ assert output.shape[0] == 2
+ assert (output.name_df1 != output.name_df2).all()
-def test_columns_with_unequal_values_text_is_aligned_with_known_differences(
- comparison_kd1,
-):
- stdout = io.StringIO()
-
- comparison_kd1.report(file=stdout)
- stdout.seek(0) # Back up to the beginning of the stream
-
- text_alignment_validator(
- report=stdout,
- section_start="****** Columns with Unequal Values ******",
- section_end="\n",
- left_indices=(1, 2, 3, 4),
- right_indices=(5, 6, 7),
- column_regexes=[
- r"""(Base\sColumn\sName) \s+ (Compare\sColumn\sName) \s+ (Base\sDtype) \s+ (Compare\sDtype) \s+
- (\#\sMatches) \s+ (\#\sKnown\sDiffs) \s+ (\#\sMismatches)""",
- r"""(-+) \s+ (-+) \s+ (-+) \s+ (-+) \s+ (-+) \s+ (-+) \s+ (-+)""",
- r"""(stat_cd) \s+ (STATC) \s+ (string) \s+ (string) \s+ (2) \s+ (2) \s+ (1)""",
- r"""(open_dt) \s+ (ACCOUNT_OPEN) \s+ (date) \s+ (bigint) \s+ (0) \s+ (5) \s+ (0)""",
- r"""(cd) \s+ (CODE) \s+ (string) \s+ (double) \s+ (0) \s+ (5) \s+ (0)\s*""",
- ],
- )
+ output = compare.sample_mismatch(column="name", sample_count=3)
+ assert output.shape[0] == 2
+ assert (output.name_df1 != output.name_df2).all()
-def test_columns_with_unequal_values_text_is_aligned_with_custom_known_differences(
- comparison_kd2,
-):
- stdout = io.StringIO()
-
- comparison_kd2.report(file=stdout)
- stdout.seek(0) # Back up to the beginning of the stream
-
- text_alignment_validator(
- report=stdout,
- section_start="****** Columns with Unequal Values ******",
- section_end="\n",
- left_indices=(1, 2, 3, 4),
- right_indices=(5, 6, 7),
- column_regexes=[
- r"""(Base\sColumn\sName) \s+ (Compare\sColumn\sName) \s+ (Base\sDtype) \s+ (Compare\sDtype) \s+
- (\#\sMatches) \s+ (\#\sKnown\sDiffs) \s+ (\#\sMismatches)""",
- r"""(-+) \s+ (-+) \s+ (-+) \s+ (-+) \s+ (-+) \s+ (-+) \s+ (-+)""",
- r"""(stat_cd) \s+ (STATC) \s+ (string) \s+ (string) \s+ (2) \s+ (2) \s+ (1)""",
- r"""(open_dt) \s+ (ACCOUNT_OPEN) \s+ (date) \s+ (bigint) \s+ (0) \s+ (0) \s+ (5)""",
- r"""(cd) \s+ (CODE) \s+ (string) \s+ (double) \s+ (0) \s+ (5) \s+ (0)\s*""",
- ],
- )
+@pandas_version
+def test_all_mismatch_not_ignore_matching_cols_no_cols_matching():
+ data1 = """acct_id,dollar_amt,name,float_fld,date_fld
+ 10000001234,123.45,George Maharis,14530.1555,2017-01-01
+ 10000001235,0.45,Michael Bluth,1,2017-01-01
+ 10000001236,1345,George Bluth,,2017-01-01
+ 10000001237,123456,Bob Loblaw,345.12,2017-01-01
+ 10000001239,1.05,Lucille Bluth,,2017-01-01
+ 10000001240,123.45,George Maharis,14530.1555,2017-01-02
+ """
+ data2 = """acct_id,dollar_amt,name,float_fld,date_fld
+ 10000001234,123.4,George Michael Bluth,14530.155,
+ 10000001235,0.45,Michael Bluth,,
+ 10000001236,1345,George Bluth,1,
+ 10000001237,123456,Robert Loblaw,345.12,
+ 10000001238,1.05,Loose Seal Bluth,111,
+ 10000001240,123.45,George Maharis,14530.1555,2017-01-02
+ """
+ df1 = ps.from_pandas(pd.read_csv(StringIO(data1), sep=","))
+ df2 = ps.from_pandas(pd.read_csv(StringIO(data2), sep=","))
+ compare = SparkCompare(df1, df2, "acct_id")
+
+ output = compare.all_mismatch()
+ assert output.shape[0] == 4
+ assert output.shape[1] == 10
+
+ assert (output.name_df1 != output.name_df2).values.sum() == 2
+ assert (~(output.name_df1 != output.name_df2)).values.sum() == 2
+
+ assert (output.dollar_amt_df1 != output.dollar_amt_df2).values.sum() == 1
+ assert (~(output.dollar_amt_df1 != output.dollar_amt_df2)).values.sum() == 3
+
+ assert (output.float_fld_df1 != output.float_fld_df2).values.sum() == 3
+ assert (~(output.float_fld_df1 != output.float_fld_df2)).values.sum() == 1
+
+ assert (output.date_fld_df1 != output.date_fld_df2).values.sum() == 4
+ assert (~(output.date_fld_df1 != output.date_fld_df2)).values.sum() == 0
+
+
+@pandas_version
+def test_all_mismatch_not_ignore_matching_cols_some_cols_matching():
+ # Columns dollar_amt and name are matching
+ data1 = """acct_id,dollar_amt,name,float_fld,date_fld
+ 10000001234,123.45,George Maharis,14530.1555,2017-01-01
+ 10000001235,0.45,Michael Bluth,1,2017-01-01
+ 10000001236,1345,George Bluth,,2017-01-01
+ 10000001237,123456,Bob Loblaw,345.12,2017-01-01
+ 10000001239,1.05,Lucille Bluth,,2017-01-01
+ 10000001240,123.45,George Maharis,14530.1555,2017-01-02
+ """
+
+ data2 = """acct_id,dollar_amt,name,float_fld,date_fld
+ 10000001234,123.45,George Maharis,14530.155,
+ 10000001235,0.45,Michael Bluth,,
+ 10000001236,1345,George Bluth,1,
+ 10000001237,123456,Bob Loblaw,345.12,
+ 10000001238,1.05,Lucille Bluth,111,
+ 10000001240,123.45,George Maharis,14530.1555,2017-01-02
+ """
+ df1 = ps.from_pandas(pd.read_csv(StringIO(data1), sep=","))
+ df2 = ps.from_pandas(pd.read_csv(StringIO(data2), sep=","))
+ compare = SparkCompare(df1, df2, "acct_id")
+
+ output = compare.all_mismatch()
+ assert output.shape[0] == 4
+ assert output.shape[1] == 10
+
+ assert (output.name_df1 != output.name_df2).values.sum() == 0
+ assert (~(output.name_df1 != output.name_df2)).values.sum() == 4
+
+ assert (output.dollar_amt_df1 != output.dollar_amt_df2).values.sum() == 0
+ assert (~(output.dollar_amt_df1 != output.dollar_amt_df2)).values.sum() == 4
+
+ assert (output.float_fld_df1 != output.float_fld_df2).values.sum() == 3
+ assert (~(output.float_fld_df1 != output.float_fld_df2)).values.sum() == 1
+
+ assert (output.date_fld_df1 != output.date_fld_df2).values.sum() == 4
+ assert (~(output.date_fld_df1 != output.date_fld_df2)).values.sum() == 0
+
+
+@pandas_version
+def test_all_mismatch_ignore_matching_cols_some_cols_matching_diff_rows():
+ # Case where there are rows on either dataset which don't match up.
+ # Columns dollar_amt and name are matching
+ data1 = """acct_id,dollar_amt,name,float_fld,date_fld
+ 10000001234,123.45,George Maharis,14530.1555,2017-01-01
+ 10000001235,0.45,Michael Bluth,1,2017-01-01
+ 10000001236,1345,George Bluth,,2017-01-01
+ 10000001237,123456,Bob Loblaw,345.12,2017-01-01
+ 10000001239,1.05,Lucille Bluth,,2017-01-01
+ 10000001240,123.45,George Maharis,14530.1555,2017-01-02
+ 10000001241,1111.05,Lucille Bluth,
+ """
-def test_columns_with_unequal_values_text_is_aligned_for_decimals(comparison_decimal):
- stdout = io.StringIO()
+ data2 = """acct_id,dollar_amt,name,float_fld,date_fld
+ 10000001234,123.45,George Maharis,14530.155,
+ 10000001235,0.45,Michael Bluth,,
+ 10000001236,1345,George Bluth,1,
+ 10000001237,123456,Bob Loblaw,345.12,
+ 10000001238,1.05,Lucille Bluth,111,
+ """
+ df1 = ps.from_pandas(pd.read_csv(StringIO(data1), sep=","))
+ df2 = ps.from_pandas(pd.read_csv(StringIO(data2), sep=","))
+ compare = SparkCompare(df1, df2, "acct_id")
- comparison_decimal.report(file=stdout)
- stdout.seek(0) # Back up to the beginning of the stream
+ output = compare.all_mismatch(ignore_matching_cols=True)
- text_alignment_validator(
- report=stdout,
- section_start="****** Columns with Unequal Values ******",
- section_end="\n",
- left_indices=(1, 2, 3, 4),
- right_indices=(5, 6),
- column_regexes=[
- r"""(Base\sColumn\sName) \s+ (Compare\sColumn\sName) \s+ (Base\sDtype) \s+ (Compare\sDtype) \s+
- (\#\sMatches) \s+ (\#\sMismatches)""",
- r"""(-+) \s+ (-+) \s+ (-+) \s+ (-+) \s+ (-+) \s+ (-+)""",
- r"""(dollar_amt) \s+ (dollar_amt) \s+ (decimal\(8,2\)) \s+ (double) \s+ (1) \s+ (1)""",
- ],
- )
+ assert output.shape[0] == 4
+ assert output.shape[1] == 6
+ assert (output.float_fld_df1 != output.float_fld_df2).values.sum() == 3
+ assert (~(output.float_fld_df1 != output.float_fld_df2)).values.sum() == 1
-def test_schema_differences_text_is_aligned(comparison4):
- stdout = io.StringIO()
+ assert (output.date_fld_df1 != output.date_fld_df2).values.sum() == 4
+ assert (~(output.date_fld_df1 != output.date_fld_df2)).values.sum() == 0
- comparison4.report(file=stdout)
- comparison4.report()
- stdout.seek(0) # Back up to the beginning of the stream
+ assert not ("name_df1" in output and "name_df2" in output)
+ assert not ("dollar_amt_df1" in output and "dollar_amt_df1" in output)
- text_alignment_validator(
- report=stdout,
- section_start="****** Schema Differences ******",
- section_end="\n",
- left_indices=(1, 2, 3, 4),
- right_indices=(),
- column_regexes=[
- r"""(Base\sColumn\sName) \s+ (Compare\sColumn\sName) \s+ (Base\sDtype) \s+ (Compare\sDtype)""",
- r"""(-+) \s+ (-+) \s+ (-+) \s+ (-+)""",
- r"""(dollar_amt) \s+ (dollar_amt) \s+ (bigint) \s+ (double)""",
- ],
- )
+@pandas_version
+def test_all_mismatch_ignore_matching_cols_some_calls_matching():
+ # Columns dollar_amt and name are matching
+ data1 = """acct_id,dollar_amt,name,float_fld,date_fld
+ 10000001234,123.45,George Maharis,14530.1555,2017-01-01
+ 10000001235,0.45,Michael Bluth,1,2017-01-01
+ 10000001236,1345,George Bluth,,2017-01-01
+ 10000001237,123456,Bob Loblaw,345.12,2017-01-01
+ 10000001239,1.05,Lucille Bluth,,2017-01-01
+ 10000001240,123.45,George Maharis,14530.1555,2017-01-02
+ """
-def test_schema_differences_text_is_aligned_for_decimals(comparison_decimal):
- stdout = io.StringIO()
+ data2 = """acct_id,dollar_amt,name,float_fld,date_fld
+ 10000001234,123.45,George Maharis,14530.155,
+ 10000001235,0.45,Michael Bluth,,
+ 10000001236,1345,George Bluth,1,
+ 10000001237,123456,Bob Loblaw,345.12,
+ 10000001238,1.05,Lucille Bluth,111,
+ 10000001240,123.45,George Maharis,14530.1555,2017-01-02
+ """
+ df1 = ps.from_pandas(pd.read_csv(StringIO(data1), sep=","))
+ df2 = ps.from_pandas(pd.read_csv(StringIO(data2), sep=","))
+ compare = SparkCompare(df1, df2, "acct_id")
- comparison_decimal.report(file=stdout)
- stdout.seek(0) # Back up to the beginning of the stream
+ output = compare.all_mismatch(ignore_matching_cols=True)
- text_alignment_validator(
- report=stdout,
- section_start="****** Schema Differences ******",
- section_end="\n",
- left_indices=(1, 2, 3, 4),
- right_indices=(),
- column_regexes=[
- r"""(Base\sColumn\sName) \s+ (Compare\sColumn\sName) \s+ (Base\sDtype) \s+ (Compare\sDtype)""",
- r"""(-+) \s+ (-+) \s+ (-+) \s+ (-+)""",
- r"""(dollar_amt) \s+ (dollar_amt) \s+ (decimal\(8,2\)) \s+ (double)""",
- ],
- )
+ assert output.shape[0] == 4
+ assert output.shape[1] == 6
+ assert (output.float_fld_df1 != output.float_fld_df2).values.sum() == 3
+ assert (~(output.float_fld_df1 != output.float_fld_df2)).values.sum() == 1
-def test_base_only_columns_text_is_aligned(comparison4):
- stdout = io.StringIO()
+ assert (output.date_fld_df1 != output.date_fld_df2).values.sum() == 4
+ assert (~(output.date_fld_df1 != output.date_fld_df2)).values.sum() == 0
- comparison4.report(file=stdout)
- stdout.seek(0) # Back up to the beginning of the stream
+ assert not ("name_df1" in output and "name_df2" in output)
+ assert not ("dollar_amt_df1" in output and "dollar_amt_df1" in output)
- text_alignment_validator(
- report=stdout,
- section_start="****** Columns In Base Only ******",
- section_end="\n",
- left_indices=(1, 2),
- right_indices=(),
- column_regexes=[
- r"""(Column\sName) \s+ (Dtype)""",
- r"""(-+) \s+ (-+)""",
- r"""(date_fld) \s+ (date)""",
- ],
- )
+@pandas_version
+def test_all_mismatch_ignore_matching_cols_no_cols_matching():
+ data1 = """acct_id,dollar_amt,name,float_fld,date_fld
+ 10000001234,123.45,George Maharis,14530.1555,2017-01-01
+ 10000001235,0.45,Michael Bluth,1,2017-01-01
+ 10000001236,1345,George Bluth,,2017-01-01
+ 10000001237,123456,Bob Loblaw,345.12,2017-01-01
+ 10000001239,1.05,Lucille Bluth,,2017-01-01
+ 10000001240,123.45,George Maharis,14530.1555,2017-01-02
+ """
-def test_compare_only_columns_text_is_aligned(comparison4):
- stdout = io.StringIO()
+ data2 = """acct_id,dollar_amt,name,float_fld,date_fld
+ 10000001234,123.4,George Michael Bluth,14530.155,
+ 10000001235,0.45,Michael Bluth,,
+ 10000001236,1345,George Bluth,1,
+ 10000001237,123456,Robert Loblaw,345.12,
+ 10000001238,1.05,Loose Seal Bluth,111,
+ 10000001240,123.45,George Maharis,14530.1555,2017-01-02
+ """
+ df1 = ps.from_pandas(pd.read_csv(StringIO(data1), sep=","))
+ df2 = ps.from_pandas(pd.read_csv(StringIO(data2), sep=","))
+ compare = SparkCompare(df1, df2, "acct_id")
+
+ output = compare.all_mismatch()
+ assert output.shape[0] == 4
+ assert output.shape[1] == 10
+
+ assert (output.name_df1 != output.name_df2).values.sum() == 2
+ assert (~(output.name_df1 != output.name_df2)).values.sum() == 2
+
+ assert (output.dollar_amt_df1 != output.dollar_amt_df2).values.sum() == 1
+ assert (~(output.dollar_amt_df1 != output.dollar_amt_df2)).values.sum() == 3
+
+ assert (output.float_fld_df1 != output.float_fld_df2).values.sum() == 3
+ assert (~(output.float_fld_df1 != output.float_fld_df2)).values.sum() == 1
+
+ assert (output.date_fld_df1 != output.date_fld_df2).values.sum() == 4
+ assert (~(output.date_fld_df1 != output.date_fld_df2)).values.sum() == 0
+
+
+@pandas_version
+@pytest.mark.parametrize(
+ "column,expected",
+ [
+ ("base", 0),
+ ("floats", 0.2),
+ ("decimals", 0.1),
+ ("null_floats", 0.1),
+ ("strings", 0.1),
+ ("mixed_strings", 1),
+ ("infinity", np.inf),
+ ],
+)
+def test_calculate_max_diff(column, expected):
+ MAX_DIFF_DF = ps.DataFrame(
+ {
+ "base": [1, 1, 1, 1, 1],
+ "floats": [1.1, 1.1, 1.1, 1.2, 0.9],
+ "decimals": [
+ Decimal("1.1"),
+ Decimal("1.1"),
+ Decimal("1.1"),
+ Decimal("1.1"),
+ Decimal("1.1"),
+ ],
+ "null_floats": [np.nan, 1.1, 1, 1, 1],
+ "strings": ["1", "1", "1", "1.1", "1"],
+ "mixed_strings": ["1", "1", "1", "2", "some string"],
+ "infinity": [1, 1, 1, 1, np.inf],
+ }
+ )
+ assert np.isclose(
+ calculate_max_diff(MAX_DIFF_DF["base"], MAX_DIFF_DF[column]), expected
+ )
- comparison4.report(file=stdout)
- stdout.seek(0) # Back up to the beginning of the stream
- text_alignment_validator(
- report=stdout,
- section_start="****** Columns In Compare Only ******",
- section_end="\n",
- left_indices=(1, 2),
- right_indices=(),
- column_regexes=[
- r"""(Column\sName) \s+ (Dtype)""",
- r"""(-+) \s+ (-+)""",
- r"""(accnt_purge) \s+ (boolean)""",
- ],
+@pandas_version
+def test_dupes_with_nulls_strings():
+ df1 = ps.DataFrame(
+ {
+ "fld_1": [1, 2, 2, 3, 3, 4, 5, 5],
+ "fld_2": ["A", np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan],
+ "fld_3": [1, 2, 2, 3, 3, 4, 5, 5],
+ }
)
+ df2 = ps.DataFrame(
+ {
+ "fld_1": [1, 2, 3, 4, 5],
+ "fld_2": ["A", np.nan, np.nan, np.nan, np.nan],
+ "fld_3": [1, 2, 3, 4, 5],
+ }
+ )
+ comp = SparkCompare(df1, df2, join_columns=["fld_1", "fld_2"])
+ assert comp.subset()
+
+
+@pandas_version
+def test_dupes_with_nulls_ints():
+ df1 = ps.DataFrame(
+ {
+ "fld_1": [1, 2, 2, 3, 3, 4, 5, 5],
+ "fld_2": [1, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan],
+ "fld_3": [1, 2, 2, 3, 3, 4, 5, 5],
+ }
+ )
+ df2 = ps.DataFrame(
+ {
+ "fld_1": [1, 2, 3, 4, 5],
+ "fld_2": [1, np.nan, np.nan, np.nan, np.nan],
+ "fld_3": [1, 2, 3, 4, 5],
+ }
+ )
+ comp = SparkCompare(df1, df2, join_columns=["fld_1", "fld_2"])
+ assert comp.subset()
+
+
+@pandas_version
+@pytest.mark.parametrize(
+ "dataframe,expected",
+ [
+ (ps.DataFrame({"a": [1, 2, 3], "b": [1, 2, 3]}), ps.Series([0, 0, 0])),
+ (
+ ps.DataFrame({"a": ["a", "a", "DATACOMPY_NULL"], "b": [1, 1, 2]}),
+ ps.Series([0, 1, 0]),
+ ),
+ (ps.DataFrame({"a": [-999, 2, 3], "b": [1, 2, 3]}), ps.Series([0, 0, 0])),
+ (
+ ps.DataFrame({"a": [1, np.nan, np.nan], "b": [1, 2, 2]}),
+ ps.Series([0, 0, 1]),
+ ),
+ (
+ ps.DataFrame({"a": ["1", np.nan, np.nan], "b": ["1", "2", "2"]}),
+ ps.Series([0, 0, 1]),
+ ),
+ (
+ ps.DataFrame(
+ {"a": [datetime(2018, 1, 1), np.nan, np.nan], "b": ["1", "2", "2"]}
+ ),
+ ps.Series([0, 0, 1]),
+ ),
+ ],
+)
+def test_generate_id_within_group(dataframe, expected):
+ assert (generate_id_within_group(dataframe, ["a", "b"]) == expected).all()
+
+
+@pandas_version
+def test_lower():
+ """This function tests the toggle to use lower case for column names or not"""
+ # should match
+ df1 = ps.DataFrame({"a": [1, 2, 3], "b": [0, 1, 2]})
+ df2 = ps.DataFrame({"a": [1, 2, 3], "B": [0, 1, 2]})
+ compare = SparkCompare(df1, df2, join_columns=["a"])
+ assert compare.matches()
+ # should not match
+ df1 = ps.DataFrame({"a": [1, 2, 3], "b": [0, 1, 2]})
+ df2 = ps.DataFrame({"a": [1, 2, 3], "B": [0, 1, 2]})
+ compare = SparkCompare(df1, df2, join_columns=["a"], cast_column_names_lower=False)
+ assert not compare.matches()
+
+ # test join column
+ # should match
+ df1 = ps.DataFrame({"a": [1, 2, 3], "b": [0, 1, 2]})
+ df2 = ps.DataFrame({"A": [1, 2, 3], "B": [0, 1, 2]})
+ compare = SparkCompare(df1, df2, join_columns=["a"])
+ assert compare.matches()
+ # should fail because "a" is not found in df2
+ df1 = ps.DataFrame({"a": [1, 2, 3], "b": [0, 1, 2]})
+ df2 = ps.DataFrame({"A": [1, 2, 3], "B": [0, 1, 2]})
+ expected_message = "df2 must have all columns from join_columns"
+ with raises(ValueError, match=expected_message):
+ compare = SparkCompare(
+ df1, df2, join_columns=["a"], cast_column_names_lower=False
+ )
-def text_alignment_validator(
- report, section_start, section_end, left_indices, right_indices, column_regexes
-):
- r"""Check to make sure that report output columns are vertically aligned.
-
- Parameters
- ----------
- report: An iterable returning lines of report output to be validated.
- section_start: A string that represents the beginning of the section to be validated.
- section_end: A string that represents the end of the section to be validated.
- left_indices: The match group indexes (starting with 1) that should be left-aligned
- in the output column.
- right_indices: The match group indexes (starting with 1) that should be right-aligned
- in the output column.
- column_regexes: A list of regular expressions representing the expected output, with
- each column enclosed with parentheses to return a match. The regular expression will
- use the "X" flag, so it may contain whitespace, and any whitespace to be matched
- should be explicitly given with \s. The first line will represent the alignments
- that are expected in the following lines. The number of match groups should cover
- all of the indices given in left/right_indices.
-
- Runs assertions for every match group specified by left/right_indices to ensure that
- all lines past the first are either left- or right-aligned with the same match group
- on the first line.
- """
-
- at_column_section = False
- processed_first_line = False
- match_positions = [None] * (len(left_indices + right_indices) + 1)
-
- for line in report:
- if at_column_section:
- if line == section_end: # Detect end of section and stop
- break
-
- if (
- not processed_first_line
- ): # First line in section - capture text start/end positions
- matches = re.search(column_regexes[0], line, re.X)
- assert matches is not None # Make sure we found at least this...
-
- for n in left_indices:
- match_positions[n] = matches.start(n)
- for n in right_indices:
- match_positions[n] = matches.end(n)
- processed_first_line = True
- else: # Match the stuff after the header text
- match = None
- for regex in column_regexes[1:]:
- match = re.search(regex, line, re.X)
- if match:
- break
-
- if not match:
- raise AssertionError(f'Did not find a match for line: "{line}"')
-
- for n in left_indices:
- assert match_positions[n] == match.start(n)
- for n in right_indices:
- assert match_positions[n] == match.end(n)
-
- if not at_column_section and section_start in line:
- at_column_section = True
-
-
-def test_unicode_columns(spark_session):
- df1 = spark_session.createDataFrame(
- [
- (1, "foo", "test"),
- (2, "bar", "test"),
- ],
- ["id", "例", "予測対象日"],
+@pandas_version
+def test_integer_column_names():
+ """This function tests that integer column names would also work"""
+ df1 = ps.DataFrame({1: [1, 2, 3], 2: [0, 1, 2]})
+ df2 = ps.DataFrame({1: [1, 2, 3], 2: [0, 1, 2]})
+ compare = SparkCompare(df1, df2, join_columns=[1])
+ assert compare.matches()
+
+
+@pandas_version
+@mock.patch("datacompy.spark.render")
+def test_save_html(mock_render):
+ df1 = ps.DataFrame([{"a": 1, "b": 2}, {"a": 1, "b": 2}])
+ df2 = ps.DataFrame([{"a": 1, "b": 2}, {"a": 1, "b": 2}])
+ compare = SparkCompare(df1, df2, join_columns=["a"])
+
+ m = mock.mock_open()
+ with mock.patch("datacompy.spark.open", m, create=True):
+ # assert without HTML call
+ compare.report()
+ assert mock_render.call_count == 4
+ m.assert_not_called()
+
+ mock_render.reset_mock()
+ m = mock.mock_open()
+ with mock.patch("datacompy.spark.open", m, create=True):
+ # assert with HTML call
+ compare.report(html_file="test.html")
+ assert mock_render.call_count == 4
+ m.assert_called_with("test.html", "w")
+
+
+def test_pandas_version():
+ expected_message = "It seems like you are running Pandas 2+. Please note that Pandas 2+ will only be supported in Spark 4+. See: https://issues.apache.org/jira/browse/SPARK-44101. If you need to use Spark DataFrame with Pandas 2+ then consider using Fugue otherwise downgrade to Pandas 1.5.3"
+ df1 = ps.DataFrame([{"a": 1, "b": 2}, {"a": 1, "b": 2}])
+ df2 = ps.DataFrame([{"a": 1, "b": 2}, {"a": 1, "b": 2}])
+ with mock.patch("pandas.__version__", "2.0.0"):
+ with raises(Exception, match=re.escape(expected_message)):
+ SparkCompare(df1, df2, join_columns=["a"])
+
+ with mock.patch("pandas.__version__", "1.5.3"):
+ SparkCompare(df1, df2, join_columns=["a"])
+
+
+@pandas_version
+def test_unicode_columns():
+ df1 = ps.DataFrame(
+ [{"a": 1, "例": 2, "予測対象日": "test"}, {"a": 1, "例": 3, "予測対象日": "test"}]
)
- df2 = spark_session.createDataFrame(
- [
- (1, "foo", "test"),
- (2, "baz", "test"),
- ],
- ["id", "例", "予測対象日"],
+ df2 = ps.DataFrame(
+ [{"a": 1, "例": 2, "予測対象日": "test"}, {"a": 1, "例": 3, "予測対象日": "test"}]
)
- compare = SparkCompare(spark_session, df1, df2, join_columns=["例"])
+ compare = SparkCompare(df1, df2, join_columns=["例"])
+ assert compare.matches()
# Just render the report to make sure it renders.
- compare.report()
+ t = compare.report()