From ea0c77140d232e4555c050d4f887715ea3f0ece6 Mon Sep 17 00:00:00 2001 From: Carter Cundiff Date: Fri, 14 Feb 2025 12:06:02 -0500 Subject: [PATCH] #568 Spark/PySpark schema validation should not fail on non required fields --- DRAFT_RELEASE_NOTES.md | 7 ++ .../pyspark.schema.base.py.vm | 109 +++++++++++----- .../spark.schema.base.java.vm | 75 +++++++++-- .../RecordWithNonRequiredValidation.json | 32 +++++ .../records/RecordWithRequiredValidation.json | 35 ++++++ ...s.feature => pyspark_spark_schema.feature} | 19 ++- ...steps.py => pyspark_spark_schema_steps.py} | 119 +++++++++++++++++- .../RecordWithNonRequiredValidation.json | 29 +++++ .../records/RecordWithRequiredValidation.json | 31 +++++ .../aiops/mda/pattern/SparkSchemaTest.java | 108 +++++++++++++++- .../specifications/sparkSchema.feature | 18 ++- 11 files changed, 530 insertions(+), 52 deletions(-) create mode 100644 test/test-mda-models/aissemble-test-data-delivery-pyspark-model/src/aissemble_test_data_delivery_pyspark_model/resources/records/RecordWithNonRequiredValidation.json create mode 100644 test/test-mda-models/aissemble-test-data-delivery-pyspark-model/src/aissemble_test_data_delivery_pyspark_model/resources/records/RecordWithRequiredValidation.json rename test/test-mda-models/aissemble-test-data-delivery-pyspark-model/tests/features/{pyspark_schema_relations.feature => pyspark_spark_schema.feature} (79%) rename test/test-mda-models/aissemble-test-data-delivery-pyspark-model/tests/features/steps/{pyspark_schema_relation_steps.py => pyspark_spark_schema_steps.py} (65%) create mode 100644 test/test-mda-models/test-data-delivery-spark-model/src/main/resources/records/RecordWithNonRequiredValidation.json create mode 100644 test/test-mda-models/test-data-delivery-spark-model/src/main/resources/records/RecordWithRequiredValidation.json diff --git a/DRAFT_RELEASE_NOTES.md b/DRAFT_RELEASE_NOTES.md index e7948f587..f26bc3580 100644 --- a/DRAFT_RELEASE_NOTES.md +++ b/DRAFT_RELEASE_NOTES.md @@ -3,6 +3,13 @@ # Breaking Changes _Note: instructions for adapting to these changes are outlined in the upgrade instructions below._ +- PySpark will no longer throw an exception when a required field is `None` but instead filter it out. See Changes in Spark/PySpark Schema Behavior below for more details. +- Spark/PySpark will no longer filter out records with `null`/`None` fields that are not required and have validation. See Changes in Spark/PySpark Schema Behavior below for more details. + +## Changes in Spark/PySpark Schema Behavior +- When creating a data frame from a record schema with [required fields](https://boozallen.github.io/aissemble/aissemble/current/record-metamodel.html#_record_field_options) using PySpark, creation of the data frame (`spark_session.createDataFrame()`) will no longer throw an exception if a required field is `None` but instead filter out the record from the data frame as part of validation (`record_schema.validate_dataset()`). +- When validating a data frame from a record schema with [non-required fields](https://boozallen.github.io/aissemble/aissemble/current/record-metamodel.html#_record_field_options) and [dictionary validation](https://boozallen.github.io/aissemble/aissemble/current/dictionary-metamodel.html#_validation_options) using Spark/PySpark, validation (`recordSchema.validateDataFrame()/record_schema.validate_dataset()`) will no longer mistakenly filter out a record from the data frame if the field value is `None`/`null`. + # Known Issues ## Docker Module Build Failures diff --git a/foundation/foundation-mda/src/main/resources/templates/data-delivery-data-records/pyspark.schema.base.py.vm b/foundation/foundation-mda/src/main/resources/templates/data-delivery-data-records/pyspark.schema.base.py.vm index 5b1e0f754..2b05a9758 100644 --- a/foundation/foundation-mda/src/main/resources/templates/data-delivery-data-records/pyspark.schema.base.py.vm +++ b/foundation/foundation-mda/src/main/resources/templates/data-delivery-data-records/pyspark.schema.base.py.vm @@ -32,31 +32,36 @@ class ${record.capitalizedName}SchemaBase(ABC): Generated from: ${templateName} """ +#set($columnVars = {}) #foreach ($field in $record.fields) - ${field.upperSnakecaseName}_COLUMN: str = '${field.sparkAttributes.columnName}' + #set ($columnVars[$field.name] = "${field.upperSnakecaseName}_COLUMN") + ${columnVars[$field.name]}: str = '${field.sparkAttributes.columnName}' #end +#set($relationVars = {}) #foreach ($relation in $record.relations) - ${relation.upperSnakecaseName}_COLUMN: str = '${relation.columnName}' + #set ($relationVars[$relation.name] = "${relation.upperSnakecaseName}_COLUMN") + ${relationVars[$relation.name]}: str = '${relation.columnName}' #end def __init__(self): self._schema = StructType() +## Setting the nullable parameter to True for every column due to inconsistencies in the behavior from different data sources/toolings (Spark vs Pyspark) +## This allows all data to be read in, and None values will be filtered out as part of the validate_dataset method if the field is required +## Previously Pyspark would throw an exception if it encountered a None value with nullable set to False, resulting in the all previous data processed being lost #foreach ($field in $record.fields) - #set ($nullable = "#if($field.sparkAttributes.isNullable())True#{else}False#end") #if ($field.sparkAttributes.isDecimalType()) - self.add(${record.capitalizedName}SchemaBase.${field.upperSnakecaseName}_COLUMN, ${field.shortType}(${field.sparkAttributes.defaultDecimalPrecision}, ${field.sparkAttributes.decimalScale}), ${nullable}) + self.add(self.${columnVars[$field.name]}, ${field.shortType}(${field.sparkAttributes.defaultDecimalPrecision}, ${field.sparkAttributes.decimalScale}), True) #else - self.add(${record.capitalizedName}SchemaBase.${field.upperSnakecaseName}_COLUMN, ${field.shortType}(), ${nullable}) + self.add(self.${columnVars[$field.name]}, ${field.shortType}(), True) #end #end #foreach ($relation in $record.relations) - #set ($nullable = "#if($relation.isNullable())True#{else}False#end") #if ($relation.isOneToManyRelation()) - self.add(${record.capitalizedName}SchemaBase.${relation.upperSnakecaseName}_COLUMN, ArrayType(${relation.capitalizedName}Schema().struct_type), ${nullable}) + self.add(self.${relationVars[$relation.name]}, ArrayType(${relation.capitalizedName}Schema().struct_type), True) #else - self.add(${record.capitalizedName}SchemaBase.${relation.upperSnakecaseName}_COLUMN, ${relation.capitalizedName}Schema().struct_type, ${nullable}) + self.add(self.${relationVars[$relation.name]}, ${relation.capitalizedName}Schema().struct_type, True) #end #end @@ -66,18 +71,18 @@ class ${record.capitalizedName}SchemaBase(ABC): Returns the given dataset cast to this schema. """ #foreach ($field in $record.fields) - ${field.snakeCaseName}_type = self.get_data_type(${record.capitalizedName}SchemaBase.${field.upperSnakecaseName}_COLUMN) + ${field.snakeCaseName}_type = self.get_data_type(self.${columnVars[$field.name]}) #end #foreach ($relation in $record.relations) - ${relation.snakeCaseName}_type = self.get_data_type(${record.capitalizedName}SchemaBase.${relation.upperSnakecaseName}_COLUMN) + ${relation.snakeCaseName}_type = self.get_data_type(self.${relationVars[$relation.name]}) #end return dataset \ #foreach ($field in $record.fields) - .withColumn(${record.capitalizedName}SchemaBase.${field.upperSnakecaseName}_COLUMN, dataset[${record.capitalizedName}SchemaBase.${field.upperSnakecaseName}_COLUMN].cast(${field.snakeCaseName}_type))#if ($foreach.hasNext || $record.hasRelations()) \\#end + .withColumn(self.${columnVars[$field.name]}, dataset[self.${columnVars[$field.name]}].cast(${field.snakeCaseName}_type))#if ($foreach.hasNext || $record.hasRelations()) \\#end #end #foreach ($relation in $record.relations) - .withColumn(${record.capitalizedName}SchemaBase.${relation.upperSnakecaseName}_COLUMN, dataset[${record.capitalizedName}SchemaBase.${relation.upperSnakecaseName}_COLUMN].cast(${relation.snakeCaseName}_type))#if ($foreach.hasNext) \\#end + .withColumn(self.${relationVars[$relation.name]}, dataset[self.${relationVars[$relation.name]}].cast(${relation.snakeCaseName}_type))#if ($foreach.hasNext) \\#end #end #end @@ -137,31 +142,32 @@ class ${record.capitalizedName}SchemaBase(ABC): """ data_with_validations = ingest_dataset #foreach ($field in $record.fields) - #set ( $columnName = "#if($field.column)$field.column#{else}$field.upperSnakecaseName#end" ) #if (${field.isRequired()}) - data_with_validations = data_with_validations.withColumn("${field.upperSnakecaseName}_IS_NOT_NULL", col(column_prefix + "${columnName}").isNotNull()) + data_with_validations = data_with_validations.withColumn(self.${columnVars[$field.name]} + "_IS_NOT_NULL", col(column_prefix + self.${columnVars[$field.name]}).isNotNull()) + #else + data_with_validations = data_with_validations.withColumn(self.${columnVars[$field.name]} + "_IS_NULL", col(column_prefix + self.${columnVars[$field.name]}).isNull()) #end #if (${field.getValidation().getMinValue()}) - data_with_validations = data_with_validations.withColumn("${field.upperSnakecaseName}_GREATER_THAN_MIN", col(column_prefix + "${columnName}").cast('double') >= ${field.getValidation().getMinValue()}) + data_with_validations = data_with_validations.withColumn(self.${columnVars[$field.name]} + "_GREATER_THAN_MIN", col(column_prefix + self.${columnVars[$field.name]}).cast('double') >= ${field.getValidation().getMinValue()}) #end #if (${field.getValidation().getMaxValue()}) - data_with_validations = data_with_validations.withColumn("${field.upperSnakecaseName}_LESS_THAN_MAX", col(column_prefix + "${columnName}").cast('double') <= ${field.getValidation().getMaxValue()}) + data_with_validations = data_with_validations.withColumn(self.${columnVars[$field.name]} + "_LESS_THAN_MAX", col(column_prefix + self.${columnVars[$field.name]}).cast('double') <= ${field.getValidation().getMaxValue()}) #end #if (${field.getValidation().getScale()}) - data_with_validations = data_with_validations.withColumn("${field.upperSnakecaseName}_MATCHES_SCALE", col(column_prefix + "${columnName}").cast(StringType()).rlike(r"^[0-9]*(?:\.[0-9]{0,${field.getValidation().getScale()}})?$")) + data_with_validations = data_with_validations.withColumn(self.${columnVars[$field.name]} + "_MATCHES_SCALE", col(column_prefix + self.${columnVars[$field.name]}).cast(StringType()).rlike(r"^[0-9]*(?:\.[0-9]{0,${field.getValidation().getScale()}})?$")) #end #if (${field.getValidation().getMinLength()}) - data_with_validations = data_with_validations.withColumn("${field.upperSnakecaseName}_GREATER_THAN_OR_EQUAL_TO_MIN_LENGTH", col(column_prefix + "${columnName}").rlike("^.{${field.getValidation().getMinLength()},}")) + data_with_validations = data_with_validations.withColumn(self.${columnVars[$field.name]} + "_GREATER_THAN_OR_EQUAL_TO_MIN_LENGTH", col(column_prefix + self.${columnVars[$field.name]}).rlike("^.{${field.getValidation().getMinLength()},}")) #end #if (${field.getValidation().getMaxLength()}) #set($max = ${field.getValidation().getMaxLength()} + 1) - data_with_validations = data_with_validations.withColumn("${field.upperSnakecaseName}_LESS_THAN_OR_EQUAL_TO_MAX_LENGTH", col(column_prefix + "${columnName}").rlike("^.{$max,}").eqNullSafe(False)) + data_with_validations = data_with_validations.withColumn(self.${columnVars[$field.name]} + "_LESS_THAN_OR_EQUAL_TO_MAX_LENGTH", col(column_prefix + self.${columnVars[$field.name]}).rlike("^.{$max,}").eqNullSafe(False)) #end #foreach ($format in $field.getValidation().getFormats()) #if ($foreach.first) - data_with_validations = data_with_validations.withColumn("${field.upperSnakecaseName}_MATCHES_FORMAT", col(column_prefix + "${columnName}").rlike("$format.replace("\","\\")")#if($foreach.last))#end + data_with_validations = data_with_validations.withColumn(self.${columnVars[$field.name]} + "_MATCHES_FORMAT", col(column_prefix + self.${columnVars[$field.name]}).rlike("$format.replace("\","\\")")#if($foreach.last))#end #else - | col(column_prefix + "${columnName}").rlike("$format.replace("\","\\")")#if($foreach.last))#end + | col(column_prefix + self.${columnVars[$field.name]}).rlike("$format.replace("\","\\")")#if($foreach.last))#end #end #end #end @@ -170,28 +176,63 @@ class ${record.capitalizedName}SchemaBase(ABC): #if (false) #foreach($relation in $record.relations) #if($relation.isOneToManyRelation()) - data_with_validations = data_with_validations.withColumn(self.${relation.upperSnakecaseName}_COLUMN + "_VALID", lit(self._validate_with_${relation.snakeCaseName}_schema(data_with_validations.select(col(self.${relation.upperSnakecaseName}_COLUMN))))) + data_with_validations = data_with_validations.withColumn(self.${relationVars[$relation.name]} + "_VALID", lit(self._validate_with_${relation.snakeCaseName}_schema(data_with_validations.select(col(self.${relationVars[$relation.name]}))))) #else ${relation.snakeCaseName}_schema = ${relation.name}Schema() - data_with_validations = data_with_validations.withColumn(self.${relation.upperSnakecaseName}_COLUMN + "_VALID", lit(not ${relation.snakeCaseName}_schema.validate_dataset_with_prefix(data_with_validations.select(col(self.${relation.upperSnakecaseName}_COLUMN)), '${relation.columnName}.').isEmpty())) + data_with_validations = data_with_validations.withColumn(self.${relationVars[$relation.name]} + "_VALID", lit(not ${relation.snakeCaseName}_schema.validate_dataset_with_prefix(data_with_validations.select(col(self.${relationVars[$relation.name]})), '${relation.columnName}.').isEmpty())) #end #end #end - validation_columns = [x for x in data_with_validations.columns if x not in ingest_dataset.columns] + column_filter_schemas = [] + validation_columns = [col for col in data_with_validations.columns if col not in ingest_dataset.columns] - # Schema for filtering for valid data - filter_schema = None - for column_name in validation_columns: - if isinstance(filter_schema, Column): - filter_schema = filter_schema & col(column_name).eqNullSafe(True) - else: - filter_schema = col(column_name).eqNullSafe(True) + # Separate columns into groups based on their field name + columns_grouped_by_field = [] + + #foreach ($field in $record.fields) + columns_grouped_by_field.append([col for col in validation_columns if col.startswith(self.${columnVars[$field.name]})]) + #end + + # Create a schema filter for each field represented as a column group + for column_group in columns_grouped_by_field: + column_group_filter_schema = None + + # This column tracks if a non-required field is None. This enables + # non-required validated fields to still pass filtering when they are None + nullable_column = None + + for column_name in column_group: + if column_name.endswith("_IS_NULL"): + nullable_column = col(column_name).eqNullSafe(True) + elif column_group_filter_schema is not None: + column_group_filter_schema = column_group_filter_schema & col(column_name).eqNullSafe(True) + else: + column_group_filter_schema = col(column_name).eqNullSafe(True) + + # Add the nullable column filter as a OR statement at the end of the given field schema + # If there is no other schema filters for the field, then it can be ignored + if nullable_column is not None and column_group_filter_schema is not None: + column_group_filter_schema = nullable_column | column_group_filter_schema + + if column_group_filter_schema is not None: + column_filter_schemas.append(column_group_filter_schema) - valid_data = data_with_validations # Isolate the valid data and drop validation columns - if isinstance(filter_schema, Column): - valid_data = data_with_validations.filter(filter_schema) + valid_data = data_with_validations + if column_filter_schemas: + + # Combine all the field filter schemas into one final schema for the row + final_column_filter_schemas = None + + for column_group_filter_schema in column_filter_schemas: + if final_column_filter_schemas is not None: + final_column_filter_schemas = final_column_filter_schemas & column_group_filter_schema + else: + final_column_filter_schemas = column_group_filter_schema + + valid_data = data_with_validations.filter(final_column_filter_schemas) + valid_data = valid_data.drop(*validation_columns) return valid_data diff --git a/foundation/foundation-mda/src/main/resources/templates/data-delivery-data-records/spark.schema.base.java.vm b/foundation/foundation-mda/src/main/resources/templates/data-delivery-data-records/spark.schema.base.java.vm index 42d2fbbf0..a6736c4e2 100644 --- a/foundation/foundation-mda/src/main/resources/templates/data-delivery-data-records/spark.schema.base.java.vm +++ b/foundation/foundation-mda/src/main/resources/templates/data-delivery-data-records/spark.schema.base.java.vm @@ -8,6 +8,7 @@ import ${import}; import ${import}; #end +import java.util.stream.Collectors; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -53,18 +54,21 @@ public abstract class ${record.capitalizedName}SchemaBase extends SparkSchema { public ${record.capitalizedName}SchemaBase() { super(); +## Setting the nullable parameter to true for every column due to inconsistencies in the behavior from different data sources/toolings (Spark vs Pyspark) +## This allows all data to be read in, and null values will be filtered out as part of the validateDataFrame method if the field is required +## Spark will still allow null values through even when nullable is set to false, so setting to true here to be consistent with Pyspark implementation #foreach ($field in $record.fields) #if ($field.sparkAttributes.isDecimalType()) - add(${columnVars[$field.name]}, new ${field.shortType}(${field.sparkAttributes.defaultDecimalPrecision}, ${field.sparkAttributes.decimalScale}), ${field.sparkAttributes.isNullable()}, "${field.description}"); + add(${columnVars[$field.name]}, new ${field.shortType}(${field.sparkAttributes.defaultDecimalPrecision}, ${field.sparkAttributes.decimalScale}), true, "${field.description}"); #else - add(${columnVars[$field.name]}, ${field.shortType}, ${field.sparkAttributes.isNullable()}, "${field.description}"); + add(${columnVars[$field.name]}, ${field.shortType}, true, "${field.description}"); #end #end #foreach ($relation in $record.relations) #if ($relation.isOneToManyRelation()) - add(${relationVars[$relation.name]}, DataTypes.createArrayType(new ${relation.name}Schema().getStructType()), ${relation.isNullable()}, "${relation.documentation}"); + add(${relationVars[$relation.name]}, DataTypes.createArrayType(new ${relation.name}Schema().getStructType()), true, "${relation.documentation}"); #else - add(${relationVars[$relation.name]}, new ${relation.name}Schema().getStructType(), ${relation.isNullable()}, "${relation.documentation}"); + add(${relationVars[$relation.name]}, new ${relation.name}Schema().getStructType(), true, "${relation.documentation}"); #end #end } @@ -135,6 +139,8 @@ public abstract class ${record.capitalizedName}SchemaBase extends SparkSchema { #foreach ($field in $record.fields) #if (${field.isRequired()}) .withColumn(${columnVars[$field.name]} + "_IS_NOT_NULL", col(columnPrefix + ${columnVars[$field.name]}).isNotNull()) + #else + .withColumn(${columnVars[$field.name]} + "_IS_NULL", col(columnPrefix + ${columnVars[$field.name]}).isNull()) #end #if (${field.getValidation().getMinValue()}) .withColumn(${columnVars[$field.name]} + "_GREATER_THAN_MIN", col(columnPrefix + ${columnVars[$field.name]}).gt(lit(${field.getValidation().getMinValue()})).or(col(columnPrefix + ${columnVars[$field.name]}).equalTo(lit(${field.getValidation().getMinValue()})))) @@ -176,21 +182,66 @@ public abstract class ${record.capitalizedName}SchemaBase extends SparkSchema { #end #end - Column filterSchema = null; + List columnFilterSchemas = new ArrayList<>(); List validationColumns = new ArrayList<>(); Collections.addAll(validationColumns, dataWithValidations.columns()); validationColumns.removeAll(Arrays.asList(data.columns())); - for (String columnName : validationColumns) { - if (filterSchema == null) { - filterSchema = col(columnName).equalTo(lit(true)); - } else { - filterSchema = filterSchema.and(col(columnName).equalTo(lit(true))); + + // Separate columns into groups based on their field name + List> columnsGroupedByField = new ArrayList<>(); + + #foreach ($field in $record.fields) + columnsGroupedByField.add(validationColumns.stream() + .filter(col -> col.startsWith(${columnVars[$field.name]})) + .collect(Collectors.toList())); + + #end + + // Create a schema filter for each field represented as a column group + for (List columnGroup: columnsGroupedByField) { + Column columnGroupFilterSchema = null; + + // This column tracks if a non-required field is null. This enables + // non-required validated fields to still pass filtering when they are null + Column nullableColumn = null; + + for (String columnName : columnGroup) { + if (columnName.endsWith("_IS_NULL")) { + nullableColumn = col(columnName).equalTo(lit(true)); + } else if (columnGroupFilterSchema == null) { + columnGroupFilterSchema = col(columnName).equalTo(lit(true)); + } else { + columnGroupFilterSchema = columnGroupFilterSchema.and(col(columnName).equalTo(lit(true))); + } + } + + // Add the nullable column filter as a OR statement at the end of the given field schema + // If there is no other schema filters for the field, then it can be ignored + if (nullableColumn != null && columnGroupFilterSchema != null) { + columnGroupFilterSchema = nullableColumn.or(columnGroupFilterSchema); + } + + if (columnGroupFilterSchema != null) { + columnFilterSchemas.add(columnGroupFilterSchema); } } + // Isolate the valid data Dataset validData = data; - if (filterSchema != null) { - validData = dataWithValidations.filter(filterSchema); + if (!columnFilterSchemas.isEmpty()) { + + // Combine all the field filter schemas into one final schema for the row + Column finalColumnsFilterSchema = null; + + for (Column columnGroupFilterSchema: columnFilterSchemas) { + if (finalColumnsFilterSchema == null) { + finalColumnsFilterSchema = columnGroupFilterSchema; + } else { + finalColumnsFilterSchema = finalColumnsFilterSchema.and(columnGroupFilterSchema); + } + } + + validData = dataWithValidations.filter(finalColumnsFilterSchema); } // Remove validation columns from valid data diff --git a/test/test-mda-models/aissemble-test-data-delivery-pyspark-model/src/aissemble_test_data_delivery_pyspark_model/resources/records/RecordWithNonRequiredValidation.json b/test/test-mda-models/aissemble-test-data-delivery-pyspark-model/src/aissemble_test_data_delivery_pyspark_model/resources/records/RecordWithNonRequiredValidation.json new file mode 100644 index 000000000..9b7fd1b50 --- /dev/null +++ b/test/test-mda-models/aissemble-test-data-delivery-pyspark-model/src/aissemble_test_data_delivery_pyspark_model/resources/records/RecordWithNonRequiredValidation.json @@ -0,0 +1,32 @@ +{ + "name": "RecordWithNonRequiredValidation", + "package": "com.boozallen.aiops.mda.pattern.record", + "description": "Example record with non required field that has validation", + "frameworks": [{ + "name": "pyspark" + }], + "fields": [ + { + "name": "integerValidation", + "type": { + "name": "integerWithValidation", + "package": "com.boozallen.aiops.mda.pattern.dictionary" + } + }, + { + "name": "stringValidation", + "type": { + "name": "stringWithValidation", + "package": "com.boozallen.aiops.mda.pattern.dictionary" + } + }, + { + "name": "stringSimple", + "type": { + "name": "string", + "package": "com.boozallen.aiops.mda.pattern.dictionary" + } + } + ] + } + \ No newline at end of file diff --git a/test/test-mda-models/aissemble-test-data-delivery-pyspark-model/src/aissemble_test_data_delivery_pyspark_model/resources/records/RecordWithRequiredValidation.json b/test/test-mda-models/aissemble-test-data-delivery-pyspark-model/src/aissemble_test_data_delivery_pyspark_model/resources/records/RecordWithRequiredValidation.json new file mode 100644 index 000000000..8dfccdebf --- /dev/null +++ b/test/test-mda-models/aissemble-test-data-delivery-pyspark-model/src/aissemble_test_data_delivery_pyspark_model/resources/records/RecordWithRequiredValidation.json @@ -0,0 +1,35 @@ +{ + "name": "RecordWithRequiredValidation", + "package": "com.boozallen.aiops.mda.pattern.record", + "description": "Example record with a required field that has validation", + "frameworks": [{ + "name": "pyspark" + }], + "fields": [ + { + "name": "integerValidation", + "type": { + "name": "integerWithValidation", + "package": "com.boozallen.aiops.mda.pattern.dictionary" + }, + "required": true + }, + { + "name": "stringValidation", + "type": { + "name": "stringWithValidation", + "package": "com.boozallen.aiops.mda.pattern.dictionary" + }, + "required": true + }, + { + "name": "stringSimple", + "type": { + "name": "string", + "package": "com.boozallen.aiops.mda.pattern.dictionary" + }, + "required": true + } + ] + } + \ No newline at end of file diff --git a/test/test-mda-models/aissemble-test-data-delivery-pyspark-model/tests/features/pyspark_schema_relations.feature b/test/test-mda-models/aissemble-test-data-delivery-pyspark-model/tests/features/pyspark_spark_schema.feature similarity index 79% rename from test/test-mda-models/aissemble-test-data-delivery-pyspark-model/tests/features/pyspark_schema_relations.feature rename to test/test-mda-models/aissemble-test-data-delivery-pyspark-model/tests/features/pyspark_spark_schema.feature index 8108916a6..1a92e685e 100644 --- a/test/test-mda-models/aissemble-test-data-delivery-pyspark-model/tests/features/pyspark_schema_relations.feature +++ b/test/test-mda-models/aissemble-test-data-delivery-pyspark-model/tests/features/pyspark_spark_schema.feature @@ -1,5 +1,5 @@ -@pyspark_schema_relation -Feature: Pyspark schema functionality works for relations +@pyspark_schema +Feature: Pyspark schema functionality works for records Background: Given the record "City" exists with the following relations @@ -57,3 +57,18 @@ Feature: Pyspark schema functionality works for relations | 0 | 1 | | 1 | 1 | + Scenario Outline: Records with fields with validation rules can be validated using the spark schema + Given a record with a "" field with validation rules + And the field is set to a "" value + And a dataSet containing the record + And the dataset contains one valid record + When the generated spark schema validation is performed on the dataSet + Then the resulting dataSet contains row(s) + Examples: + | requirement | validity | num | + | required | valid | 2 | + | required | invalid | 1 | + | required | null | 1 | + | non-required | valid | 2 | + | non-required | invalid | 1 | + | non-required | null | 2 | diff --git a/test/test-mda-models/aissemble-test-data-delivery-pyspark-model/tests/features/steps/pyspark_schema_relation_steps.py b/test/test-mda-models/aissemble-test-data-delivery-pyspark-model/tests/features/steps/pyspark_spark_schema_steps.py similarity index 65% rename from test/test-mda-models/aissemble-test-data-delivery-pyspark-model/tests/features/steps/pyspark_schema_relation_steps.py rename to test/test-mda-models/aissemble-test-data-delivery-pyspark-model/tests/features/steps/pyspark_spark_schema_steps.py index 2e821e059..65433daa8 100644 --- a/test/test-mda-models/aissemble-test-data-delivery-pyspark-model/tests/features/steps/pyspark_schema_relation_steps.py +++ b/test/test-mda-models/aissemble-test-data-delivery-pyspark-model/tests/features/steps/pyspark_spark_schema_steps.py @@ -6,6 +6,12 @@ from aissemble_test_data_delivery_pyspark_model.dictionary.integer_with_validation import ( IntegerWithValidation, ) +from aissemble_test_data_delivery_pyspark_model.dictionary.string_with_validation import ( + StringWithValidation, +) +from aissemble_test_data_delivery_pyspark_model.dictionary.state_address import ( + StateAddress, +) from aissemble_test_data_delivery_pyspark_model.dictionary.zipcode import Zipcode from aissemble_test_data_delivery_pyspark_model.record.address import Address from aissemble_test_data_delivery_pyspark_model.record.city import ( @@ -26,6 +32,13 @@ from aissemble_test_data_delivery_pyspark_model.record.street import ( Street, ) +from aissemble_test_data_delivery_pyspark_model.record.record_with_required_validation import ( + RecordWithRequiredValidation, +) + +from aissemble_test_data_delivery_pyspark_model.record.record_with_non_required_validation import ( + RecordWithNonRequiredValidation, +) from aissemble_test_data_delivery_pyspark_model.schema.city_schema import ( CitySchema, ) @@ -35,8 +48,11 @@ from aissemble_test_data_delivery_pyspark_model.schema.person_with_one_to_one_relation_schema import ( PersonWithOneToOneRelationSchema, ) -from aissemble_test_data_delivery_pyspark_model.dictionary.state_address import ( - StateAddress, +from aissemble_test_data_delivery_pyspark_model.schema.record_with_required_validation_schema import ( + RecordWithRequiredValidationSchema, +) +from aissemble_test_data_delivery_pyspark_model.schema.record_with_non_required_validation_schema import ( + RecordWithNonRequiredValidationSchema, ) @@ -132,6 +148,75 @@ def step_impl(context, validity): ) +@given('a record with a "{requirement}" field with validation rules') +def step_impl(context, requirement): + context.record_with_validated_field_requirement = requirement + + context.record_with_requirement_validation = ( + RecordWithRequiredValidation() + if requirement == "required" + else RecordWithNonRequiredValidation() + ) + + +@given('the field is set to a "{validity}" value') +def step_impl(context, validity): + # set valid fields to verify validation still works with multiple fields + context.record_with_requirement_validation.string_validation = StringWithValidation( + "Test123" + ) + context.record_with_requirement_validation.string_simple = "Test123" + + if validity == "valid": + context.record_with_requirement_validation.integer_validation = ( + IntegerWithValidation(150) + ) + elif validity == "invalid": + context.record_with_requirement_validation.integer_validation = ( + IntegerWithValidation(50) + ) + else: + pass # Do nothing to keep the field None + + +@given("a dataSet containing the record") +def step_impl(context): + if context.record_with_validated_field_requirement == "required": + row = RecordWithRequiredValidation.as_row( + context.record_with_requirement_validation + ) + else: + row = RecordWithNonRequiredValidation.as_row( + context.record_with_requirement_validation + ) + + context.record_with_requirement_validation_rows = [row] + + +@given("the dataset contains one valid record") +def step_impl(context): + if context.record_with_validated_field_requirement == "required": + # Create other valid row for the data frame to test filtering + valid_record_with_requirement_validation = RecordWithRequiredValidation() + valid_record_with_requirement_validation.integer_validation = ( + IntegerWithValidation(150) + ) + valid_record_with_requirement_validation.string_validation = ( + StringWithValidation("Test123") + ) + valid_record_with_requirement_validation.string_simple = "Test123" + valid_row = RecordWithRequiredValidation.as_row( + valid_record_with_requirement_validation + ) + else: + valid_record_with_non_requirement_validation = RecordWithNonRequiredValidation() + valid_row = RecordWithNonRequiredValidation.as_row( + valid_record_with_non_requirement_validation + ) + + context.record_with_requirement_validation_rows.append(valid_row) + + @when('spark schema validation is performed on the "PersonWithMToOneRelation" dataSet') def step_impl(context): person_with_many_to_one_relation_schema = PersonWithMToOneRelationSchema() @@ -169,6 +254,27 @@ def step_impl(context): context.exc = e +@when("the generated spark schema validation is performed on the dataSet") +def step_impl(context): + if context.record_with_validated_field_requirement == "required": + record_with_requirement_validation_schema = RecordWithRequiredValidationSchema() + else: + record_with_requirement_validation_schema = ( + RecordWithNonRequiredValidationSchema() + ) + + record_with_validated_field_dataset = context.test_spark_session.createDataFrame( + context.record_with_requirement_validation_rows, + schema=record_with_requirement_validation_schema.struct_type, + ) + + context.validated_dataframe = ( + record_with_requirement_validation_schema.validate_dataset( + record_with_validated_field_dataset + ) + ) + + @then('the schema data type for "{record}" is "{type}"') def step_impl(context, record, type): nt.assert_equal(str(context.schema.get_data_type(record.upper())), type) @@ -213,6 +319,15 @@ def step_impl(context): ) +@then("the resulting dataSet contains {numRows} row(s)") +def step_impl(context, numRows): + nt.assert_equal( + int(numRows), + context.validated_dataframe.count(), + "The validated dataSet contained the incorrect number of rows", + ) + + def _create_city() -> City: streets: List[Street] = [] street = Street() diff --git a/test/test-mda-models/test-data-delivery-spark-model/src/main/resources/records/RecordWithNonRequiredValidation.json b/test/test-mda-models/test-data-delivery-spark-model/src/main/resources/records/RecordWithNonRequiredValidation.json new file mode 100644 index 000000000..dfa008ba0 --- /dev/null +++ b/test/test-mda-models/test-data-delivery-spark-model/src/main/resources/records/RecordWithNonRequiredValidation.json @@ -0,0 +1,29 @@ +{ + "name": "RecordWithNonRequiredValidation", + "package": "com.boozallen.aiops.mda.pattern.record", + "description": "Example record with non required field that has validation", + "fields": [ + { + "name": "integerValidation", + "type": { + "name": "integerWithValidation", + "package": "com.boozallen.aiops.mda.pattern.dictionary" + } + }, + { + "name": "stringValidation", + "type": { + "name": "stringWithValidation", + "package": "com.boozallen.aiops.mda.pattern.dictionary" + } + }, + { + "name": "stringSimple", + "type": { + "name": "string", + "package": "com.boozallen.aiops.mda.pattern.dictionary" + } + } + ] + } + \ No newline at end of file diff --git a/test/test-mda-models/test-data-delivery-spark-model/src/main/resources/records/RecordWithRequiredValidation.json b/test/test-mda-models/test-data-delivery-spark-model/src/main/resources/records/RecordWithRequiredValidation.json new file mode 100644 index 000000000..ebccf4d33 --- /dev/null +++ b/test/test-mda-models/test-data-delivery-spark-model/src/main/resources/records/RecordWithRequiredValidation.json @@ -0,0 +1,31 @@ +{ + "name": "RecordWithRequiredValidation", + "package": "com.boozallen.aiops.mda.pattern.record", + "description": "Example record with a required field that has validation", + "fields": [ + { + "name": "integerValidation", + "type": { + "name": "integerWithValidation", + "package": "com.boozallen.aiops.mda.pattern.dictionary" + }, + "required": true + }, + { + "name": "stringValidation", + "type": { + "name": "stringWithValidation", + "package": "com.boozallen.aiops.mda.pattern.dictionary" + }, + "required": true + }, + { + "name": "stringSimple", + "type": { + "name": "string", + "package": "com.boozallen.aiops.mda.pattern.dictionary" + }, + "required": true + } + ] +} diff --git a/test/test-mda-models/test-data-delivery-spark-model/src/test/java/com/boozallen/aiops/mda/pattern/SparkSchemaTest.java b/test/test-mda-models/test-data-delivery-spark-model/src/test/java/com/boozallen/aiops/mda/pattern/SparkSchemaTest.java index 95c12b2a0..079325cc6 100644 --- a/test/test-mda-models/test-data-delivery-spark-model/src/test/java/com/boozallen/aiops/mda/pattern/SparkSchemaTest.java +++ b/test/test-mda-models/test-data-delivery-spark-model/src/test/java/com/boozallen/aiops/mda/pattern/SparkSchemaTest.java @@ -16,12 +16,12 @@ import static org.junit.Assert.assertTrue; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Map; import com.boozallen.aiops.mda.pattern.dictionary.Zipcode; -import com.boozallen.aiops.mda.pattern.record.PersonWithOneToMRelation; import com.boozallen.aiops.mda.pattern.record.PersonWithOneToMRelationSchema; import org.apache.commons.lang.StringUtils; import org.apache.commons.lang3.NotImplementedException; @@ -30,6 +30,7 @@ import org.apache.spark.sql.SparkSession; import com.boozallen.aiops.mda.pattern.dictionary.IntegerWithValidation; +import com.boozallen.aiops.mda.pattern.dictionary.StringWithValidation; import com.boozallen.aiops.mda.pattern.record.Address; import com.boozallen.aiops.mda.pattern.record.City; import com.boozallen.aiops.mda.pattern.record.CitySchema; @@ -38,6 +39,10 @@ import com.boozallen.aiops.mda.pattern.record.PersonWithMToOneRelationSchema; import com.boozallen.aiops.mda.pattern.record.PersonWithOneToOneRelation; import com.boozallen.aiops.mda.pattern.record.PersonWithOneToOneRelationSchema; +import com.boozallen.aiops.mda.pattern.record.RecordWithNonRequiredValidation; +import com.boozallen.aiops.mda.pattern.record.RecordWithNonRequiredValidationSchema; +import com.boozallen.aiops.mda.pattern.record.RecordWithRequiredValidation; +import com.boozallen.aiops.mda.pattern.record.RecordWithRequiredValidationSchema; import com.boozallen.aiops.mda.pattern.record.State; import com.boozallen.aiops.mda.pattern.record.Street; @@ -47,21 +52,29 @@ import io.cucumber.java.en.When; public class SparkSchemaTest { + String recordWithValidatedFieldRequirement; CitySchema citySchema; PersonWithOneToOneRelationSchema personWithOneToOneRelationSchema; PersonWithMToOneRelationSchema personWithMToOneRelationSchema; PersonWithOneToMRelationSchema personWithOneToMRelationSchema; + RecordWithRequiredValidationSchema recordWithRequiredValidationSchema; + RecordWithNonRequiredValidationSchema recordWithNonRequiredValidationSchema; + RecordWithNonRequiredValidation recordWithNonRequiredValidation; + RecordWithRequiredValidation recordWithRequiredValidation; + List recordWithRequirementValidationRows; SparkSession spark; Dataset cityDataSet; Dataset personWithOneToOneRelationDataSet; Dataset personWithMToOneRelationDataSet; Dataset personWithOneToMRelationDataSet; + Dataset recordWithValidatedFieldDataSet; Dataset validatedDataSet; Exception exception; @Before("@SparkSchema") public void setUp() { this.spark = SparkTestHarness.getSparkSession(); + this.recordWithRequirementValidationRows = new ArrayList<>(); } @Given("the record \"City\" exists with the following relations") @@ -130,6 +143,71 @@ public void aCityDataSetWithAnInvalidRelationExists() { this.cityDataSet = spark.createDataFrame(rows, this.citySchema.getStructType()); } + @Given("a record with a {string} field with validation rules") + public void aRecordWithAFieldWithValidationRules(String requirement) { + this.recordWithValidatedFieldRequirement = requirement; + + if (requirement.equals("required")) { + this.recordWithRequiredValidation = new RecordWithRequiredValidation(); + } else { + this.recordWithNonRequiredValidation = new RecordWithNonRequiredValidation(); + } + } + + @Given("the field is set to a {string} value") + public void theFieldIsSetToAValue(String validity) { + if (this.recordWithValidatedFieldRequirement.equals("required")) { + // set valid fields to verify validation still works with multiple fields + this.recordWithRequiredValidation.setStringValidation(new StringWithValidation("Test123")); + this.recordWithRequiredValidation.setStringSimple("Test123"); + + if (validity.equals("valid")) { + this.recordWithRequiredValidation.setIntegerValidation(new IntegerWithValidation(150)); + } else if(validity.equals("invalid")) { + this.recordWithRequiredValidation.setIntegerValidation(new IntegerWithValidation(50)); + } else { + // Do nothing to keep the field null + } + } else { + // set valid fields to verify validation still works with multiple fields + this.recordWithNonRequiredValidation.setStringValidation(new StringWithValidation("Test123")); + this.recordWithNonRequiredValidation.setStringSimple("Test123"); + + if (validity.equals("valid")) { + this.recordWithNonRequiredValidation.setIntegerValidation(new IntegerWithValidation(150)); + } else if(validity.equals("invalid")) { + this.recordWithNonRequiredValidation.setIntegerValidation(new IntegerWithValidation(50)); + } else { + // Do nothing to keep the field null + } + } + } + + @Given("a dataSet containing the record") + public void aDataSetContainingTheRecord() { + if (this.recordWithValidatedFieldRequirement.equals("required")) { + this.recordWithRequirementValidationRows.add(RecordWithRequiredValidationSchema.asRow(this.recordWithRequiredValidation)); + } else { + this.recordWithRequirementValidationRows.add(RecordWithNonRequiredValidationSchema.asRow(this.recordWithNonRequiredValidation)); + } + } + + @Given("the dataset contains one valid record") + public void theDataSetContainsOneValidRecord() { + if (this.recordWithValidatedFieldRequirement.equals("required")) { + RecordWithRequiredValidation validRecordWithRequiredValidation = new RecordWithRequiredValidation(); + validRecordWithRequiredValidation.setIntegerValidation(new IntegerWithValidation(150)); + validRecordWithRequiredValidation.setStringValidation(new StringWithValidation("Test123")); + validRecordWithRequiredValidation.setStringSimple("Test123"); + + this.recordWithRequirementValidationRows.add(RecordWithRequiredValidationSchema.asRow(validRecordWithRequiredValidation)); + } else { + RecordWithNonRequiredValidation validRecordWithNonRequiredValidation = new RecordWithNonRequiredValidation(); + + this.recordWithRequirementValidationRows.add(RecordWithNonRequiredValidationSchema.asRow(validRecordWithNonRequiredValidation)); + } + } + @When("the spark schema is generated for the \"City\" record") public void theSparkSchemaIsGeneratedForTheCityRecord() { this.citySchema = new CitySchema(); @@ -171,6 +249,29 @@ public void sparkSchemaValidationIsPerformedOnTheCityDataSet() { } } + @When("the generated spark schema validation is performed on the dataSet") + public void theGeneratedSparkSchemaValidationIsPerformedOnTheDataSet() { + if (this.recordWithValidatedFieldRequirement.equals("required")) { + this.recordWithRequiredValidationSchema = new RecordWithRequiredValidationSchema(); + + this.recordWithValidatedFieldDataSet = this.spark.createDataFrame( + this.recordWithRequirementValidationRows, + this.recordWithRequiredValidationSchema.getStructType() + ); + + this.validatedDataSet = this.recordWithRequiredValidationSchema.validateDataFrame(this.recordWithValidatedFieldDataSet); + } else { + this.recordWithNonRequiredValidationSchema = new RecordWithNonRequiredValidationSchema(); + + this.recordWithValidatedFieldDataSet = this.spark.createDataFrame( + this.recordWithRequirementValidationRows, + this.recordWithNonRequiredValidationSchema.getStructType() + ); + + this.validatedDataSet = this.recordWithNonRequiredValidationSchema.validateDataFrame(this.recordWithValidatedFieldDataSet); + } + } + @Then("the validation fails with NotYetImplementedException") public void theValidationFailsWithNotYetImplementedException() { assertNotNull("No exception was thrown", this.exception); @@ -210,6 +311,11 @@ public void theDataSetValidationIsSuccessful(String succeed) { } } + @Then("the resulting dataSet contains {int} row\\(s)") + public void theResultingDataSetContainsRows(int numRows) { + assertEquals("The validated dataSet contained the incorrect number of rows", numRows, this.validatedDataSet.count()); + } + private City createCity(){ IntegerWithValidation integerWithValidation = new IntegerWithValidation(100); diff --git a/test/test-mda-models/test-data-delivery-spark-model/src/test/resources/specifications/sparkSchema.feature b/test/test-mda-models/test-data-delivery-spark-model/src/test/resources/specifications/sparkSchema.feature index 6127eac2b..b7ea6d6af 100644 --- a/test/test-mda-models/test-data-delivery-spark-model/src/test/resources/specifications/sparkSchema.feature +++ b/test/test-mda-models/test-data-delivery-spark-model/src/test/resources/specifications/sparkSchema.feature @@ -1,5 +1,5 @@ @SparkSchema -Feature: Records with relations are generated correctly and function as expected +Feature: Record spark schemas are generated correctly and function as expected Background: Given the record "City" exists with the following relations @@ -50,3 +50,19 @@ Feature: Records with relations are generated correctly and function as expected And a valid "City" dataSet exists When spark schema validation is performed on the "City" dataSet Then the dataSet validation "passes" + + Scenario Outline: Records with fields with validation rules can be validated using the spark schema + Given a record with a "" field with validation rules + And the field is set to a "" value + And a dataSet containing the record + And the dataset contains one valid record + When the generated spark schema validation is performed on the dataSet + Then the resulting dataSet contains row(s) + Examples: + | requirement | validity | num | + | required | valid | 2 | + | required | invalid | 1 | + | required | null | 1 | + | non-required | valid | 2 | + | non-required | invalid | 1 | + | non-required | null | 2 |