Skip to content

Commit

Permalink
Merge pull request #580 from boozallen/568-spark-schema-required-fiel…
Browse files Browse the repository at this point in the history
…d-null

#568 Spark/PySpark schema validation should not fail on non required fields
  • Loading branch information
carter-cundiff authored Feb 14, 2025
2 parents 9976fbe + ea0c771 commit f93c103
Show file tree
Hide file tree
Showing 11 changed files with 530 additions and 52 deletions.
7 changes: 7 additions & 0 deletions DRAFT_RELEASE_NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,13 @@
# Breaking Changes
_Note: instructions for adapting to these changes are outlined in the upgrade instructions below._

- PySpark will no longer throw an exception when a required field is `None` but instead filter it out. See Changes in Spark/PySpark Schema Behavior below for more details.
- Spark/PySpark will no longer filter out records with `null`/`None` fields that are not required and have validation. See Changes in Spark/PySpark Schema Behavior below for more details.

## Changes in Spark/PySpark Schema Behavior
- When creating a data frame from a record schema with [required fields](https://boozallen.github.io/aissemble/aissemble/current/record-metamodel.html#_record_field_options) using PySpark, creation of the data frame (`spark_session.createDataFrame()`) will no longer throw an exception if a required field is `None` but instead filter out the record from the data frame as part of validation (`record_schema.validate_dataset()`).
- When validating a data frame from a record schema with [non-required fields](https://boozallen.github.io/aissemble/aissemble/current/record-metamodel.html#_record_field_options) and [dictionary validation](https://boozallen.github.io/aissemble/aissemble/current/dictionary-metamodel.html#_validation_options) using Spark/PySpark, validation (`recordSchema.validateDataFrame()/record_schema.validate_dataset()`) will no longer mistakenly filter out a record from the data frame if the field value is `None`/`null`.

# Known Issues

## Docker Module Build Failures
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,31 +32,36 @@ class ${record.capitalizedName}SchemaBase(ABC):
Generated from: ${templateName}
"""

#set($columnVars = {})
#foreach ($field in $record.fields)
${field.upperSnakecaseName}_COLUMN: str = '${field.sparkAttributes.columnName}'
#set ($columnVars[$field.name] = "${field.upperSnakecaseName}_COLUMN")
${columnVars[$field.name]}: str = '${field.sparkAttributes.columnName}'
#end
#set($relationVars = {})
#foreach ($relation in $record.relations)
${relation.upperSnakecaseName}_COLUMN: str = '${relation.columnName}'
#set ($relationVars[$relation.name] = "${relation.upperSnakecaseName}_COLUMN")
${relationVars[$relation.name]}: str = '${relation.columnName}'
#end


def __init__(self):
self._schema = StructType()

## Setting the nullable parameter to True for every column due to inconsistencies in the behavior from different data sources/toolings (Spark vs Pyspark)
## This allows all data to be read in, and None values will be filtered out as part of the validate_dataset method if the field is required
## Previously Pyspark would throw an exception if it encountered a None value with nullable set to False, resulting in the all previous data processed being lost
#foreach ($field in $record.fields)
#set ($nullable = "#if($field.sparkAttributes.isNullable())True#{else}False#end")
#if ($field.sparkAttributes.isDecimalType())
self.add(${record.capitalizedName}SchemaBase.${field.upperSnakecaseName}_COLUMN, ${field.shortType}(${field.sparkAttributes.defaultDecimalPrecision}, ${field.sparkAttributes.decimalScale}), ${nullable})
self.add(self.${columnVars[$field.name]}, ${field.shortType}(${field.sparkAttributes.defaultDecimalPrecision}, ${field.sparkAttributes.decimalScale}), True)
#else
self.add(${record.capitalizedName}SchemaBase.${field.upperSnakecaseName}_COLUMN, ${field.shortType}(), ${nullable})
self.add(self.${columnVars[$field.name]}, ${field.shortType}(), True)
#end
#end
#foreach ($relation in $record.relations)
#set ($nullable = "#if($relation.isNullable())True#{else}False#end")
#if ($relation.isOneToManyRelation())
self.add(${record.capitalizedName}SchemaBase.${relation.upperSnakecaseName}_COLUMN, ArrayType(${relation.capitalizedName}Schema().struct_type), ${nullable})
self.add(self.${relationVars[$relation.name]}, ArrayType(${relation.capitalizedName}Schema().struct_type), True)
#else
self.add(${record.capitalizedName}SchemaBase.${relation.upperSnakecaseName}_COLUMN, ${relation.capitalizedName}Schema().struct_type, ${nullable})
self.add(self.${relationVars[$relation.name]}, ${relation.capitalizedName}Schema().struct_type, True)
#end
#end

Expand All @@ -66,18 +71,18 @@ class ${record.capitalizedName}SchemaBase(ABC):
Returns the given dataset cast to this schema.
"""
#foreach ($field in $record.fields)
${field.snakeCaseName}_type = self.get_data_type(${record.capitalizedName}SchemaBase.${field.upperSnakecaseName}_COLUMN)
${field.snakeCaseName}_type = self.get_data_type(self.${columnVars[$field.name]})
#end
#foreach ($relation in $record.relations)
${relation.snakeCaseName}_type = self.get_data_type(${record.capitalizedName}SchemaBase.${relation.upperSnakecaseName}_COLUMN)
${relation.snakeCaseName}_type = self.get_data_type(self.${relationVars[$relation.name]})
#end

return dataset \
#foreach ($field in $record.fields)
.withColumn(${record.capitalizedName}SchemaBase.${field.upperSnakecaseName}_COLUMN, dataset[${record.capitalizedName}SchemaBase.${field.upperSnakecaseName}_COLUMN].cast(${field.snakeCaseName}_type))#if ($foreach.hasNext || $record.hasRelations()) \\#end
.withColumn(self.${columnVars[$field.name]}, dataset[self.${columnVars[$field.name]}].cast(${field.snakeCaseName}_type))#if ($foreach.hasNext || $record.hasRelations()) \\#end
#end
#foreach ($relation in $record.relations)
.withColumn(${record.capitalizedName}SchemaBase.${relation.upperSnakecaseName}_COLUMN, dataset[${record.capitalizedName}SchemaBase.${relation.upperSnakecaseName}_COLUMN].cast(${relation.snakeCaseName}_type))#if ($foreach.hasNext) \\#end
.withColumn(self.${relationVars[$relation.name]}, dataset[self.${relationVars[$relation.name]}].cast(${relation.snakeCaseName}_type))#if ($foreach.hasNext) \\#end
#end
#end

Expand Down Expand Up @@ -137,31 +142,32 @@ class ${record.capitalizedName}SchemaBase(ABC):
"""
data_with_validations = ingest_dataset
#foreach ($field in $record.fields)
#set ( $columnName = "#if($field.column)$field.column#{else}$field.upperSnakecaseName#end" )
#if (${field.isRequired()})
data_with_validations = data_with_validations.withColumn("${field.upperSnakecaseName}_IS_NOT_NULL", col(column_prefix + "${columnName}").isNotNull())
data_with_validations = data_with_validations.withColumn(self.${columnVars[$field.name]} + "_IS_NOT_NULL", col(column_prefix + self.${columnVars[$field.name]}).isNotNull())
#else
data_with_validations = data_with_validations.withColumn(self.${columnVars[$field.name]} + "_IS_NULL", col(column_prefix + self.${columnVars[$field.name]}).isNull())
#end
#if (${field.getValidation().getMinValue()})
data_with_validations = data_with_validations.withColumn("${field.upperSnakecaseName}_GREATER_THAN_MIN", col(column_prefix + "${columnName}").cast('double') >= ${field.getValidation().getMinValue()})
data_with_validations = data_with_validations.withColumn(self.${columnVars[$field.name]} + "_GREATER_THAN_MIN", col(column_prefix + self.${columnVars[$field.name]}).cast('double') >= ${field.getValidation().getMinValue()})
#end
#if (${field.getValidation().getMaxValue()})
data_with_validations = data_with_validations.withColumn("${field.upperSnakecaseName}_LESS_THAN_MAX", col(column_prefix + "${columnName}").cast('double') <= ${field.getValidation().getMaxValue()})
data_with_validations = data_with_validations.withColumn(self.${columnVars[$field.name]} + "_LESS_THAN_MAX", col(column_prefix + self.${columnVars[$field.name]}).cast('double') <= ${field.getValidation().getMaxValue()})
#end
#if (${field.getValidation().getScale()})
data_with_validations = data_with_validations.withColumn("${field.upperSnakecaseName}_MATCHES_SCALE", col(column_prefix + "${columnName}").cast(StringType()).rlike(r"^[0-9]*(?:\.[0-9]{0,${field.getValidation().getScale()}})?$"))
data_with_validations = data_with_validations.withColumn(self.${columnVars[$field.name]} + "_MATCHES_SCALE", col(column_prefix + self.${columnVars[$field.name]}).cast(StringType()).rlike(r"^[0-9]*(?:\.[0-9]{0,${field.getValidation().getScale()}})?$"))
#end
#if (${field.getValidation().getMinLength()})
data_with_validations = data_with_validations.withColumn("${field.upperSnakecaseName}_GREATER_THAN_OR_EQUAL_TO_MIN_LENGTH", col(column_prefix + "${columnName}").rlike("^.{${field.getValidation().getMinLength()},}"))
data_with_validations = data_with_validations.withColumn(self.${columnVars[$field.name]} + "_GREATER_THAN_OR_EQUAL_TO_MIN_LENGTH", col(column_prefix + self.${columnVars[$field.name]}).rlike("^.{${field.getValidation().getMinLength()},}"))
#end
#if (${field.getValidation().getMaxLength()})
#set($max = ${field.getValidation().getMaxLength()} + 1)
data_with_validations = data_with_validations.withColumn("${field.upperSnakecaseName}_LESS_THAN_OR_EQUAL_TO_MAX_LENGTH", col(column_prefix + "${columnName}").rlike("^.{$max,}").eqNullSafe(False))
data_with_validations = data_with_validations.withColumn(self.${columnVars[$field.name]} + "_LESS_THAN_OR_EQUAL_TO_MAX_LENGTH", col(column_prefix + self.${columnVars[$field.name]}).rlike("^.{$max,}").eqNullSafe(False))
#end
#foreach ($format in $field.getValidation().getFormats())
#if ($foreach.first)
data_with_validations = data_with_validations.withColumn("${field.upperSnakecaseName}_MATCHES_FORMAT", col(column_prefix + "${columnName}").rlike("$format.replace("\","\\")")#if($foreach.last))#end
data_with_validations = data_with_validations.withColumn(self.${columnVars[$field.name]} + "_MATCHES_FORMAT", col(column_prefix + self.${columnVars[$field.name]}).rlike("$format.replace("\","\\")")#if($foreach.last))#end
#else
| col(column_prefix + "${columnName}").rlike("$format.replace("\","\\")")#if($foreach.last))#end
| col(column_prefix + self.${columnVars[$field.name]}).rlike("$format.replace("\","\\")")#if($foreach.last))#end
#end
#end
#end
Expand All @@ -170,28 +176,63 @@ class ${record.capitalizedName}SchemaBase(ABC):
#if (false)
#foreach($relation in $record.relations)
#if($relation.isOneToManyRelation())
data_with_validations = data_with_validations.withColumn(self.${relation.upperSnakecaseName}_COLUMN + "_VALID", lit(self._validate_with_${relation.snakeCaseName}_schema(data_with_validations.select(col(self.${relation.upperSnakecaseName}_COLUMN)))))
data_with_validations = data_with_validations.withColumn(self.${relationVars[$relation.name]} + "_VALID", lit(self._validate_with_${relation.snakeCaseName}_schema(data_with_validations.select(col(self.${relationVars[$relation.name]})))))
#else
${relation.snakeCaseName}_schema = ${relation.name}Schema()
data_with_validations = data_with_validations.withColumn(self.${relation.upperSnakecaseName}_COLUMN + "_VALID", lit(not ${relation.snakeCaseName}_schema.validate_dataset_with_prefix(data_with_validations.select(col(self.${relation.upperSnakecaseName}_COLUMN)), '${relation.columnName}.').isEmpty()))
data_with_validations = data_with_validations.withColumn(self.${relationVars[$relation.name]} + "_VALID", lit(not ${relation.snakeCaseName}_schema.validate_dataset_with_prefix(data_with_validations.select(col(self.${relationVars[$relation.name]})), '${relation.columnName}.').isEmpty()))
#end
#end
#end

validation_columns = [x for x in data_with_validations.columns if x not in ingest_dataset.columns]
column_filter_schemas = []
validation_columns = [col for col in data_with_validations.columns if col not in ingest_dataset.columns]

# Schema for filtering for valid data
filter_schema = None
for column_name in validation_columns:
if isinstance(filter_schema, Column):
filter_schema = filter_schema & col(column_name).eqNullSafe(True)
else:
filter_schema = col(column_name).eqNullSafe(True)
# Separate columns into groups based on their field name
columns_grouped_by_field = []

#foreach ($field in $record.fields)
columns_grouped_by_field.append([col for col in validation_columns if col.startswith(self.${columnVars[$field.name]})])
#end

# Create a schema filter for each field represented as a column group
for column_group in columns_grouped_by_field:
column_group_filter_schema = None

# This column tracks if a non-required field is None. This enables
# non-required validated fields to still pass filtering when they are None
nullable_column = None

for column_name in column_group:
if column_name.endswith("_IS_NULL"):
nullable_column = col(column_name).eqNullSafe(True)
elif column_group_filter_schema is not None:
column_group_filter_schema = column_group_filter_schema & col(column_name).eqNullSafe(True)
else:
column_group_filter_schema = col(column_name).eqNullSafe(True)

# Add the nullable column filter as a OR statement at the end of the given field schema
# If there is no other schema filters for the field, then it can be ignored
if nullable_column is not None and column_group_filter_schema is not None:
column_group_filter_schema = nullable_column | column_group_filter_schema

if column_group_filter_schema is not None:
column_filter_schemas.append(column_group_filter_schema)

valid_data = data_with_validations
# Isolate the valid data and drop validation columns
if isinstance(filter_schema, Column):
valid_data = data_with_validations.filter(filter_schema)
valid_data = data_with_validations
if column_filter_schemas:

# Combine all the field filter schemas into one final schema for the row
final_column_filter_schemas = None

for column_group_filter_schema in column_filter_schemas:
if final_column_filter_schemas is not None:
final_column_filter_schemas = final_column_filter_schemas & column_group_filter_schema
else:
final_column_filter_schemas = column_group_filter_schema

valid_data = data_with_validations.filter(final_column_filter_schemas)

valid_data = valid_data.drop(*validation_columns)
return valid_data

Expand Down
Loading

0 comments on commit f93c103

Please sign in to comment.