diff --git a/chispa/schema_comparer.py b/chispa/schema_comparer.py index 472444d..1d29ef2 100644 --- a/chispa/schema_comparer.py +++ b/chispa/schema_comparer.py @@ -1,7 +1,7 @@ from chispa.prettytable import PrettyTable from chispa.bcolors import * import chispa.six as six -from chispa.structfield_comparer import are_structfields_equal +from chispa.structfield_comparer import are_structfields_equal, check_type_equal_ignore_nullable class SchemasNotEqualError(Exception): @@ -52,17 +52,4 @@ def are_schemas_equal_ignore_nullable(s1, s2): return True -def check_type_equal_ignore_nullable(sf1, sf2): - """Checks StructField data types ignoring nullables. - Handles array element types also. - """ - dt1, dt2 = sf1.dataType, sf2.dataType - if dt1.typeName() == dt2.typeName(): - # Account for array types by inspecting elementType. - if dt1.typeName() == 'array': - return dt1.elementType == dt2.elementType - else: - return True - else: - return False diff --git a/chispa/structfield_comparer.py b/chispa/structfield_comparer.py index b76adea..c5f340f 100644 --- a/chispa/structfield_comparer.py +++ b/chispa/structfield_comparer.py @@ -1,10 +1,27 @@ + +def check_type_equal_ignore_nullable(sf1, sf2): + """Checks StructField data types ignoring nullables. + + Handles array element types also. + """ + dt1, dt2 = sf1.dataType, sf2.dataType + if dt1.typeName() == dt2.typeName(): + # Account for array types by inspecting elementType. + if dt1.typeName() == 'array': + return dt1.elementType == dt2.elementType + else: + return True + else: + return False + + def are_structfields_equal(sf1, sf2, ignore_nullability=False): if ignore_nullability: if sf1 is None and sf2 is not None: return False elif sf1 is not None and sf2 is None: return False - elif sf1.name != sf2.name or sf1.dataType != sf2.dataType: + elif sf1.name != sf2.name or not check_type_equal_ignore_nullable(sf1, sf2): return False else: return True diff --git a/tests/test_structfield_comparer.py b/tests/test_structfield_comparer.py index 920ac72..a53b728 100644 --- a/tests/test_structfield_comparer.py +++ b/tests/test_structfield_comparer.py @@ -24,3 +24,8 @@ def it_can_perform_nullability_insensitive_comparisons(): sf1 = StructField("hi", IntegerType(), False) sf2 = StructField("hi", IntegerType(), True) assert are_structfields_equal(sf1, sf2, ignore_nullability=True) == True + + def it_can_perform_nullability_insensitive_comparisons_with_arrays(): + sf1 = StructField("hi", ArrayType(IntegerType(), True), False) + sf2 = StructField("hi", ArrayType(IntegerType(), False), True) + assert are_structfields_equal(sf1, sf2, ignore_nullability=True) == True