From d6468b6dc6735dc47ed80da995bef30e29b7786d Mon Sep 17 00:00:00 2001 From: mao3267 Date: Fri, 8 Nov 2024 21:40:47 +0800 Subject: [PATCH] fix: handle nested pydantic basemodel Signed-off-by: mao3267 --- .../pkg/compiler/validators/typing.go | 40 +++++++++++++++---- 1 file changed, 33 insertions(+), 7 deletions(-) diff --git a/flytepropeller/pkg/compiler/validators/typing.go b/flytepropeller/pkg/compiler/validators/typing.go index 44502ce504..43fbb61760 100644 --- a/flytepropeller/pkg/compiler/validators/typing.go +++ b/flytepropeller/pkg/compiler/validators/typing.go @@ -18,17 +18,36 @@ type trivialChecker struct { literalType *flyte.LiteralType } -func removeTitleFieldFromProperties(schema map[string]interface{}) { - properties, ok := schema["properties"].(*structpb.Value) +func removeTitleFieldFromProperties(schema map[string]*structpb.Value) { + properties, ok := schema["properties"] if !ok { return } for _, p := range properties.GetStructValue().Fields { + if _, ok := p.GetStructValue().Fields["properties"]; ok { + removeTitleFieldFromProperties(p.GetStructValue().Fields) + } delete(p.GetStructValue().Fields, "title") } } +func resolveRef(schema, defs map[string]*structpb.Value) { + // Schema from Pydantic BaseModel includes a $def field, which is a reference to the actual schema. + // We need to resolve the reference to compare the schema with those from marshumaro. + // https://github.com/flyteorg/flytekit/blob/3475ddc41f2ba31d23dd072362be704d7c2470a0/flytekit/core/type_engine.py#L632-L641 + for _, p := range schema["properties"].GetStructValue().Fields { + if _, ok := p.GetStructValue().Fields["$ref"]; ok { + propName := strings.TrimPrefix(p.GetStructValue().Fields["$ref"].GetStringValue(), "#/$defs/") + p.GetStructValue().Fields = defs[propName].GetStructValue().Fields + resolveRef(p.GetStructValue().Fields, defs) + delete(p.GetStructValue().Fields, "$ref") + } + } + + delete(schema, "$defs") +} + func isSuperTypeInJSON(sourceMetaData, targetMetaData *structpb.Struct) bool { // Since there are lots of field differences between draft-07 and draft 2020-12, // we only support json schema with 2020-12 draft, which is generated here: https://github.com/flyteorg/flytekit/blob/ff2d0da686c82266db4dbf764a009896cf062349/flytekit/core/type_engine.py#L630-L639 @@ -40,8 +59,8 @@ func isSuperTypeInJSON(sourceMetaData, targetMetaData *structpb.Struct) bool { return false } - copySrcSchema := make(map[string]interface{}) - copyTgtSchema := make(map[string]interface{}) + copySrcSchema := make(map[string]*structpb.Value) + copyTgtSchema := make(map[string]*structpb.Value) for k, v := range sourceMetaData.Fields { copySrcSchema[k] = v @@ -51,6 +70,13 @@ func isSuperTypeInJSON(sourceMetaData, targetMetaData *structpb.Struct) bool { copyTgtSchema[k] = v } + // For nested Pydantic BaseModel, we need to resolve the reference to compare the schema. + if _, ok := copySrcSchema["$defs"]; ok { + resolveRef(copySrcSchema, copySrcSchema["$defs"].GetStructValue().Fields) + } + if _, ok := copyTgtSchema["$defs"]; ok { + resolveRef(copyTgtSchema, copyTgtSchema["$defs"].GetStructValue().Fields) + } // The JSON schema generated by Pydantic.BaseModel includes a title field in its properties, repeatedly recording the property name. // Since this title field is absent in the JSON schema generated for dataclass, we need to remove the title field from the properties to ensure equivalence. removeTitleFieldFromProperties(copySrcSchema) @@ -63,7 +89,7 @@ func isSuperTypeInJSON(sourceMetaData, targetMetaData *structpb.Struct) bool { for _, p := range patch { // If additionalProperties is false, the field is not present in the schema from Pydantic.BaseModel. // We handle this case by checking the relationships by ourselves. - if p.Type != jsondiff.OperationAdd && p.Path == "/additionalProperties" { + if p.Type != jsondiff.OperationAdd && strings.Contains(p.Path, "additionalProperties") { if p.Type == jsondiff.OperationRemove || p.Type == jsondiff.OperationReplace { if p.OldValue != false { return false @@ -89,8 +115,8 @@ func isSameTypeInJSON(sourceMetaData, targetMetaData *structpb.Struct) bool { return false } - copySrcSchema := make(map[string]interface{}) - copyTgtSchema := make(map[string]interface{}) + copySrcSchema := make(map[string]*structpb.Value) + copyTgtSchema := make(map[string]*structpb.Value) for k, v := range sourceMetaData.Fields { copySrcSchema[k] = v