diff --git a/flytepropeller/pkg/compiler/validators/typing.go b/flytepropeller/pkg/compiler/validators/typing.go index 388cb32123..669bca0e5d 100644 --- a/flytepropeller/pkg/compiler/validators/typing.go +++ b/flytepropeller/pkg/compiler/validators/typing.go @@ -356,9 +356,27 @@ func getTypeChecker(t *flyte.LiteralType) typeChecker { } } -func AreTypesCastable(upstreamType, downstreamType *flyte.LiteralType) bool { +func isTypeAny(t *flyte.LiteralType) bool { + if t.GetMetadata() != nil { + if t.GetMetadata().GetFields() != nil { + pythonClassName := t.GetMetadata().GetFields()["python_class_name"] + if pythonClassName != nil { + if strVal, ok := pythonClassName.GetKind().(*structpb.Value_StringValue); ok && strVal.StringValue == "typing.Any" { + return true + } + } + } + } + return false +} + +func AreTypesCastable(upstreamType *flyte.LiteralType, downstreamType *flyte.LiteralType) bool { typeChecker := getTypeChecker(downstreamType) + if isTypeAny(upstreamType) || isTypeAny(downstreamType) { + return true + } + // if upstream is a singular union we check if the downstream type is castable from the union variant if upstreamType.GetUnionType() != nil && len(upstreamType.GetUnionType().GetVariants()) == 1 { variants := upstreamType.GetUnionType().GetVariants()