From 581cfea1635fc03c73ceeddf8b6080965cbbb7ce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?K=C3=A9vin=20Gomez?= Date: Tue, 19 Sep 2023 16:10:44 +0200 Subject: [PATCH] Properly handle required/nullable types and fields --- internal/ast/builder.go | 6 +- internal/ast/compiler/disjunctions.go | 82 +++++++----- internal/ast/compiler/disjunctions_test.go | 54 ++++---- .../ast/compiler/not_required_as_nullable.go | 87 ++++++++++++ .../compiler/not_required_as_nullable_test.go | 86 ++++++++++++ internal/ast/types.go | 124 ++++++++++++++---- internal/jennies/golang/builder.go | 8 +- internal/jennies/golang/jennies.go | 1 + internal/jennies/golang/rawtypes.go | 29 ++-- internal/jennies/golang/tmpl.go | 2 + ...tion_of_scalars.types.json_marshal.go.tmpl | 2 +- internal/veneers/option/actions.go | 6 +- .../struct_with_optional_fields.txtar | 2 +- 13 files changed, 375 insertions(+), 114 deletions(-) create mode 100644 internal/ast/compiler/not_required_as_nullable.go create mode 100644 internal/ast/compiler/not_required_as_nullable_test.go diff --git a/internal/ast/builder.go b/internal/ast/builder.go index 05b5083c9..fe8d0601d 100644 --- a/internal/ast/builder.go +++ b/internal/ast/builder.go @@ -49,7 +49,7 @@ type Assignment struct { Constraints []TypeConstraint // Some more context on the what - IntoOptionalField bool + IntoNullableField bool } type BuilderGenerator struct { @@ -106,7 +106,7 @@ func (generator *BuilderGenerator) structFieldToStaticInitialization(field Struc Path: field.Name, Value: field.Type.AsScalar().Value, ValueType: field.Type, - IntoOptionalField: !field.Required, + IntoNullableField: field.Type.Nullable, } } @@ -131,7 +131,7 @@ func (generator *BuilderGenerator) structFieldToOption(field StructField) Option ArgumentName: field.Name, ValueType: field.Type, Constraints: constraints, - IntoOptionalField: !field.Required, + IntoNullableField: field.Type.Nullable, }, }, } diff --git a/internal/ast/compiler/disjunctions.go b/internal/ast/compiler/disjunctions.go index 79f9e956f..338cdebc8 100644 --- a/internal/ast/compiler/disjunctions.go +++ b/internal/ast/compiler/disjunctions.go @@ -77,22 +77,22 @@ func (pass *DisjunctionToType) processObject(file *ast.File, object ast.Object) func (pass *DisjunctionToType) processType(file *ast.File, def ast.Type) (ast.Type, error) { if def.Kind == ast.KindArray { - return pass.processArray(file, def.AsArray()) + return pass.processArray(file, def) } if def.Kind == ast.KindStruct { - return pass.processStruct(file, def.AsStruct()) + return pass.processStruct(file, def) } if def.Kind == ast.KindDisjunction { - return pass.processDisjunction(file, def.AsDisjunction()) + return pass.processDisjunction(file, def) } return def, nil } -func (pass *DisjunctionToType) processArray(file *ast.File, def ast.ArrayType) (ast.Type, error) { - processedType, err := pass.processType(file, def.ValueType) +func (pass *DisjunctionToType) processArray(file *ast.File, def ast.Type) (ast.Type, error) { + processedType, err := pass.processType(file, def.AsArray().ValueType) if err != nil { return ast.Type{}, err } @@ -100,32 +100,33 @@ func (pass *DisjunctionToType) processArray(file *ast.File, def ast.ArrayType) ( return ast.NewArray(processedType), nil } -func (pass *DisjunctionToType) processStruct(file *ast.File, def ast.StructType) (ast.Type, error) { - processedFields := make([]ast.StructField, 0, len(def.Fields)) - for _, field := range def.Fields { +func (pass *DisjunctionToType) processStruct(file *ast.File, def ast.Type) (ast.Type, error) { + processedFields := make([]ast.StructField, 0, len(def.AsStruct().Fields)) + for _, field := range def.AsStruct().Fields { processedType, err := pass.processType(file, field.Type) if err != nil { return ast.Type{}, err } - processedFields = append(processedFields, ast.StructField{ - Name: field.Name, - Comments: field.Comments, - Type: processedType, - Required: field.Required, - Default: field.Default, - }) + newField := field + newField.Type = processedType + + processedFields = append(processedFields, newField) } - return ast.NewStruct(processedFields...), nil + newStruct := def + newStruct.Struct.Fields = processedFields + + return newStruct, nil } -func (pass *DisjunctionToType) processDisjunction(file *ast.File, def ast.DisjunctionType) (ast.Type, error) { +func (pass *DisjunctionToType) processDisjunction(file *ast.File, def ast.Type) (ast.Type, error) { + disjunction := def.AsDisjunction() + // Ex: type | null - if len(def.Branches) == 2 && def.Branches.HasNullType() { - finalType := def.Branches.NonNullTypes()[0] - // FIXME: this should be propagated - // finalType.Nullable = true + if len(disjunction.Branches) == 2 && disjunction.Branches.HasNullType() { + finalType := disjunction.Branches.NonNullTypes()[0] + finalType.Nullable = true return finalType, nil } @@ -133,12 +134,17 @@ func (pass *DisjunctionToType) processDisjunction(file *ast.File, def ast.Disjun // type | otherType | something (| null)? // generate a type with a nullable field for every branch of the disjunction, // add it to preprocessor.types, and use it instead. - newTypeName := pass.disjunctionTypeName(def) + newTypeName := pass.disjunctionTypeName(disjunction) // if we already generated a new object for this disjunction, let's return // a reference to it. if _, ok := pass.newObjects[newTypeName]; ok { - return ast.NewRef(newTypeName), nil + ref := ast.NewRef(newTypeName) + if disjunction.Branches.HasNullType() { + ref.Nullable = true + } + + return ref, nil } /* @@ -148,27 +154,30 @@ func (pass *DisjunctionToType) processDisjunction(file *ast.File, def ast.Disjun } */ - fields := make([]ast.StructField, 0, len(def.Branches)) - for _, branch := range def.Branches { - // FIXME: should ignore this completely. - // ie: if there was a nullable branch, the whole resulting type should be nullable. + fields := make([]ast.StructField, 0, len(disjunction.Branches)) + for _, branch := range disjunction.Branches { + // Handled below, by allowing the reference to the disjunction struct + // to be null. if branch.IsNull() { continue } + processedBranch := branch + processedBranch.Nullable = true + fields = append(fields, ast.StructField{ - Name: "Val" + tools.UpperCamelCase(pass.typeName(branch)), - Type: branch, + Name: "Val" + tools.UpperCamelCase(pass.typeName(processedBranch)), + Type: processedBranch, Required: false, }) } structType := ast.NewStruct(fields...) - if def.Branches.HasOnlyScalarOrArray() { - structType.Struct.Hint[ast.HintDisjunctionOfScalars] = def + if disjunction.Branches.HasOnlyScalarOrArray() { + structType.Struct.Hint[ast.HintDisjunctionOfScalars] = disjunction } - if def.Branches.HasOnlyRefs() { - newDisjunctionDef, err := pass.ensureDiscriminator(file, def) + if disjunction.Branches.HasOnlyRefs() { + newDisjunctionDef, err := pass.ensureDiscriminator(file, disjunction) if err != nil { return ast.Type{}, err } @@ -181,7 +190,12 @@ func (pass *DisjunctionToType) processDisjunction(file *ast.File, def ast.Disjun Type: structType, } - return ast.NewRef(newTypeName), nil + ref := ast.NewRef(newTypeName) + if disjunction.Branches.HasNullType() { + ref.Nullable = true + } + + return ref, nil } func (pass *DisjunctionToType) disjunctionTypeName(def ast.DisjunctionType) string { diff --git a/internal/ast/compiler/disjunctions_test.go b/internal/ast/compiler/disjunctions_test.go index fb08addf3..5e1be9bda 100644 --- a/internal/ast/compiler/disjunctions_test.go +++ b/internal/ast/compiler/disjunctions_test.go @@ -47,8 +47,8 @@ func TestDisjunctionToType_WithDisjunctionOfScalars_AsAnObject(t *testing.T) { // Prepare expected output disjunctionStructType := ast.NewStruct( - ast.NewStructField("ValString", ast.NewScalar(ast.KindString)), - ast.NewStructField("ValBool", ast.NewScalar(ast.KindBool)), + ast.NewStructField("ValString", ast.NewScalar(ast.KindString, ast.Nullable())), + ast.NewStructField("ValBool", ast.NewScalar(ast.KindBool, ast.Nullable())), ) // The original disjunction definition is preserved as a hint disjunctionStructType.Struct.Hint[ast.HintDisjunctionOfScalars] = objects[0].Type.AsDisjunction() @@ -76,8 +76,8 @@ func TestDisjunctionToType_WithDisjunctionOfScalars_AsAStructField(t *testing.T) // Prepare expected output disjunctionStructType := ast.NewStruct( - ast.NewStructField("ValString", ast.NewScalar(ast.KindString)), - ast.NewStructField("ValBool", ast.NewScalar(ast.KindBool)), + ast.NewStructField("ValString", ast.NewScalar(ast.KindString, ast.Nullable())), + ast.NewStructField("ValBool", ast.NewScalar(ast.KindBool, ast.Nullable())), ) // The original disjunction definition is preserved as a hint disjunctionStructType.Struct.Hint[ast.HintDisjunctionOfScalars] = disjunctionType.AsDisjunction() @@ -105,8 +105,8 @@ func TestDisjunctionToType_WithDisjunctionOfScalars_AsAnArrayValueType(t *testin // Prepare expected output disjunctionStructType := ast.NewStruct( - ast.NewStructField("ValString", ast.NewScalar(ast.KindString)), - ast.NewStructField("ValBool", ast.NewScalar(ast.KindBool)), + ast.NewStructField("ValString", ast.NewScalar(ast.KindString, ast.Nullable())), + ast.NewStructField("ValBool", ast.NewScalar(ast.KindBool, ast.Nullable())), ) // The original disjunction definition is preserved as a hint disjunctionStructType.Struct.Hint[ast.HintDisjunctionOfScalars] = disjunctionType.AsDisjunction() @@ -131,11 +131,11 @@ func TestDisjunctionToType_WithDisjunctionOfRefs_AsAnObject_NoDiscriminatorMetad })), ast.NewObject("SomeStruct", ast.NewStruct( - ast.NewStructField("Kind", ast.NewConcreteScalar(ast.KindString, "other-struct")), // No equivalent in OtherStruct + ast.NewStructField("Kind", ast.NewScalar(ast.KindString, ast.Value("some-struct"))), // No equivalent in OtherStruct ast.NewStructField("FieldFoo", ast.NewScalar(ast.KindString)), )), ast.NewObject("OtherStruct", ast.NewStruct( - ast.NewStructField("Type", ast.NewConcreteScalar(ast.KindString, "other-struct")), + ast.NewStructField("Type", ast.NewScalar(ast.KindString, ast.Value("other-struct"))), ast.NewStructField("FieldBar", ast.NewScalar(ast.KindBool)), )), } @@ -199,7 +199,7 @@ func TestDisjunctionToType_WithDisjunctionOfRefs_AsAnObject_NoDiscriminatorMetad ast.NewStructField("FieldFoo", ast.NewScalar(ast.KindString)), )), ast.NewObject("OtherStruct", ast.NewStruct( - ast.NewStructField("Type", ast.NewConcreteScalar(ast.KindString, "other-struct")), + ast.NewStructField("Type", ast.NewScalar(ast.KindString, ast.Value("other-struct"))), ast.NewStructField("FieldBar", ast.NewScalar(ast.KindBool)), )), } @@ -227,11 +227,11 @@ func TestDisjunctionToType_WithDisjunctionOfRefs_AsAnObject_NoDiscriminatorMetad ast.NewObject("ADisjunctionOfRefs", disjunctionType), ast.NewObject("SomeStruct", ast.NewStruct( - ast.NewStructField("Type", ast.NewConcreteScalar(ast.KindString, "some-struct")), + ast.NewStructField("Type", ast.NewScalar(ast.KindString, ast.Value("some-struct"))), ast.NewStructField("FieldFoo", ast.NewScalar(ast.KindString)), )), ast.NewObject("OtherStruct", ast.NewStruct( - ast.NewStructField("Type", ast.NewConcreteScalar(ast.KindString, "other-struct")), + ast.NewStructField("Type", ast.NewScalar(ast.KindString, ast.Value("other-struct"))), ast.NewStructField("FieldBar", ast.NewScalar(ast.KindBool)), )), } @@ -254,19 +254,19 @@ func TestDisjunctionToType_WithDisjunctionOfRefs_AsAnObject_NoDiscriminatorMetad })), ast.NewObject("SomeStruct", ast.NewStruct( - ast.NewStructField("Type", ast.NewConcreteScalar(ast.KindString, "some-struct")), + ast.NewStructField("Type", ast.NewScalar(ast.KindString, ast.Value("some-struct"))), ast.NewStructField("FieldFoo", ast.NewScalar(ast.KindString)), )), ast.NewObject("OtherStruct", ast.NewStruct( ast.NewStructField("FieldBar", ast.NewMap(ast.NewScalar(ast.KindString), ast.NewScalar(ast.KindString))), - ast.NewStructField("Type", ast.NewConcreteScalar(ast.KindString, "other-struct")), + ast.NewStructField("Type", ast.NewScalar(ast.KindString, ast.Value("other-struct"))), )), } // Prepare expected output disjunctionStructType := ast.NewStruct( - ast.NewStructField("ValSomeStruct", ast.NewRef("SomeStruct")), - ast.NewStructField("ValOtherStruct", ast.NewRef("OtherStruct")), + ast.NewStructField("ValSomeStruct", ast.NewRef("SomeStruct", ast.Nullable())), + ast.NewStructField("ValOtherStruct", ast.NewRef("OtherStruct", ast.Nullable())), ) // The original disjunction definition is preserved as a hint disjunctionTypeWithDiscriminatorMeta := objects[0].Type.AsDisjunction() @@ -305,21 +305,21 @@ func TestDisjunctionToType_WithDisjunctionOfRefs_AsAnObject_WithDiscriminatorFie ast.NewObject("ADisjunctionOfRefs", disjunctionType), ast.NewObject("SomeStruct", ast.NewStruct( - ast.NewStructField("Type", ast.NewConcreteScalar(ast.KindString, "some-struct")), - ast.NewStructField("Kind", ast.NewConcreteScalar(ast.KindString, "some-kind")), + ast.NewStructField("Type", ast.NewScalar(ast.KindString, ast.Value("some-struct"))), + ast.NewStructField("Kind", ast.NewScalar(ast.KindString, ast.Value("some-kind"))), ast.NewStructField("FieldFoo", ast.NewScalar(ast.KindString)), )), ast.NewObject("OtherStruct", ast.NewStruct( - ast.NewStructField("Type", ast.NewConcreteScalar(ast.KindString, "other-struct")), - ast.NewStructField("Kind", ast.NewConcreteScalar(ast.KindString, "other-kind")), + ast.NewStructField("Type", ast.NewScalar(ast.KindString, ast.Value("other-struct"))), + ast.NewStructField("Kind", ast.NewScalar(ast.KindString, ast.Value("other-kind"))), ast.NewStructField("FieldBar", ast.NewScalar(ast.KindBool)), )), } // Prepare expected output disjunctionStructType := ast.NewStruct( - ast.NewStructField("ValSomeStruct", ast.NewRef("SomeStruct")), - ast.NewStructField("ValOtherStruct", ast.NewRef("OtherStruct")), + ast.NewStructField("ValSomeStruct", ast.NewRef("SomeStruct", ast.Nullable())), + ast.NewStructField("ValOtherStruct", ast.NewRef("OtherStruct", ast.Nullable())), ) // The original disjunction definition is preserved as a hint disjunctionTypeWithDiscriminatorMeta := objects[0].Type.AsDisjunction() @@ -360,21 +360,21 @@ func TestDisjunctionToType_WithDisjunctionOfRefs_AsAnObject_WithDiscriminatorFie ast.NewObject("ADisjunctionOfRefs", disjunctionType), ast.NewObject("SomeStruct", ast.NewStruct( - ast.NewStructField("Type", ast.NewConcreteScalar(ast.KindString, "some-struct")), - ast.NewStructField("Kind", ast.NewConcreteScalar(ast.KindString, "some-kind")), + ast.NewStructField("Type", ast.NewScalar(ast.KindString, ast.Value("some-struct"))), + ast.NewStructField("Kind", ast.NewScalar(ast.KindString, ast.Value("some-kind"))), ast.NewStructField("FieldFoo", ast.NewScalar(ast.KindString)), )), ast.NewObject("OtherStruct", ast.NewStruct( - ast.NewStructField("Type", ast.NewConcreteScalar(ast.KindString, "other-struct")), - ast.NewStructField("Kind", ast.NewConcreteScalar(ast.KindString, "other-kind")), + ast.NewStructField("Type", ast.NewScalar(ast.KindString, ast.Value("other-struct"))), + ast.NewStructField("Kind", ast.NewScalar(ast.KindString, ast.Value("other-kind"))), ast.NewStructField("FieldBar", ast.NewScalar(ast.KindBool)), )), } // Prepare expected output disjunctionStructType := ast.NewStruct( - ast.NewStructField("ValSomeStruct", ast.NewRef("SomeStruct")), - ast.NewStructField("ValOtherStruct", ast.NewRef("OtherStruct")), + ast.NewStructField("ValSomeStruct", ast.NewRef("SomeStruct", ast.Nullable())), + ast.NewStructField("ValOtherStruct", ast.NewRef("OtherStruct", ast.Nullable())), ) // The original disjunction definition is preserved as a hint disjunctionTypeWithDiscriminatorMeta := objects[0].Type.AsDisjunction() diff --git a/internal/ast/compiler/not_required_as_nullable.go b/internal/ast/compiler/not_required_as_nullable.go new file mode 100644 index 000000000..da5334823 --- /dev/null +++ b/internal/ast/compiler/not_required_as_nullable.go @@ -0,0 +1,87 @@ +package compiler + +import ( + "github.com/grafana/cog/internal/ast" +) + +var _ Pass = (*NotRequiredFieldAsNullableType)(nil) + +type NotRequiredFieldAsNullableType struct { +} + +func (pass *NotRequiredFieldAsNullableType) Process(files []*ast.File) ([]*ast.File, error) { + newFiles := make([]*ast.File, 0, len(files)) + + for _, file := range files { + newFiles = append(newFiles, pass.processFile(file)) + } + + return newFiles, nil +} + +func (pass *NotRequiredFieldAsNullableType) processFile(file *ast.File) *ast.File { + processedObjects := make([]ast.Object, 0, len(file.Definitions)) + for _, object := range file.Definitions { + processedObjects = append(processedObjects, pass.processObject(object)) + } + + return &ast.File{ + Package: file.Package, + Definitions: processedObjects, + } +} + +func (pass *NotRequiredFieldAsNullableType) processObject(object ast.Object) ast.Object { + if object.Type.Kind != ast.KindStruct { + return object + } + + newObject := object + newObject.Type = pass.processType(object.Type) + + return newObject +} + +func (pass *NotRequiredFieldAsNullableType) processType(def ast.Type) ast.Type { + if def.Kind == ast.KindArray { + return pass.processArray(def.AsArray()) + } + + if def.Kind == ast.KindMap { + return pass.processMap(def.AsMap()) + } + + if def.Kind == ast.KindStruct { + return pass.processStruct(def.AsStruct()) + } + + return def +} + +func (pass *NotRequiredFieldAsNullableType) processArray(def ast.ArrayType) ast.Type { + return ast.NewArray(pass.processType(def.ValueType)) +} + +func (pass *NotRequiredFieldAsNullableType) processMap(def ast.MapType) ast.Type { + return ast.NewMap( + pass.processType(def.IndexType), + pass.processType(def.ValueType), + ) +} + +func (pass *NotRequiredFieldAsNullableType) processStruct(def ast.StructType) ast.Type { + processedFields := make([]ast.StructField, 0, len(def.Fields)) + for _, field := range def.Fields { + fieldType := pass.processType(field.Type) + if !field.Required { + fieldType.Nullable = true + } + + newField := field + newField.Type = fieldType + + processedFields = append(processedFields, newField) + } + + return ast.NewStruct(processedFields...) +} diff --git a/internal/ast/compiler/not_required_as_nullable_test.go b/internal/ast/compiler/not_required_as_nullable_test.go new file mode 100644 index 000000000..f4c1a93dd --- /dev/null +++ b/internal/ast/compiler/not_required_as_nullable_test.go @@ -0,0 +1,86 @@ +package compiler + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/grafana/cog/internal/ast" + "github.com/stretchr/testify/require" +) + +func TestNotRequiredFieldAsNullableType(t *testing.T) { + // Prepare test input + objects := []ast.Object{ + ast.NewObject("NotAStruct", ast.String()), + + ast.NewObject("AStruct", ast.NewStruct( + ast.NewStructField("RequiredString", ast.String(), ast.Required()), + ast.NewStructField("RequiredNullableString", ast.String(ast.Nullable()), ast.Required()), + ast.NewStructField("NotRequiredString", ast.String()), + + ast.NewStructField("RequiredRef", ast.NewRef("SomeStruct"), ast.Required()), + ast.NewStructField("RequiredNullableRef", ast.NewRef("SomeStruct", ast.Nullable()), ast.Required()), + ast.NewStructField("NotRequiredRef", ast.NewRef("SomeStruct")), + + ast.NewStructField("NotRequiredArray", ast.NewArray(ast.String())), + ast.NewStructField("RequiredArray", ast.NewArray(ast.String()), ast.Required()), + + ast.NewStructField("NotRequiredMap", ast.NewMap( + ast.String(), + ast.Bool(), + )), + ast.NewStructField("RequiredMap", ast.NewMap( + ast.String(), + ast.Bool(), + ), ast.Required()), + )), + } + + // Prepare expected output + expected := []ast.Object{ + ast.NewObject("NotAStruct", ast.String()), + + ast.NewObject("AStruct", ast.NewStruct( + ast.NewStructField("RequiredString", ast.String(), ast.Required()), + ast.NewStructField("RequiredNullableString", ast.String(ast.Nullable()), ast.Required()), + ast.NewStructField("NotRequiredString", ast.String(ast.Nullable())), // should become nullable + + ast.NewStructField("RequiredRef", ast.NewRef("SomeStruct"), ast.Required()), + ast.NewStructField("RequiredNullableRef", ast.NewRef("SomeStruct", ast.Nullable()), ast.Required()), + ast.NewStructField("NotRequiredRef", ast.NewRef("SomeStruct", ast.Nullable())), // should become nullable + + ast.NewStructField("NotRequiredArray", ast.NewArray(ast.String(), ast.Nullable())), // should become nullable + ast.NewStructField("RequiredArray", ast.NewArray(ast.String()), ast.Required()), + + ast.NewStructField("NotRequiredMap", ast.NewMap( + ast.String(), + ast.Bool(), + ast.Nullable(), // should become nullable + )), + ast.NewStructField("RequiredMap", ast.NewMap( + ast.String(), + ast.Bool(), + ), ast.Required()), + )), + } + + // Run the compiler pass + runNotRequiredAsNullablePass(t, objects, expected) +} + +func runNotRequiredAsNullablePass(t *testing.T, input []ast.Object, expectedOutput []ast.Object) { + t.Helper() + + req := require.New(t) + + compilerPass := &NotRequiredFieldAsNullableType{} + processedFiles, err := compilerPass.Process([]*ast.File{ + { + Package: "test", + Definitions: input, + }, + }) + req.NoError(err) + req.Len(processedFiles, 1) + req.Empty(cmp.Diff(expectedOutput, processedFiles[0].Definitions)) +} diff --git a/internal/ast/types.go b/internal/ast/types.go index 541ceb9aa..4d40405df 100644 --- a/internal/ast/types.go +++ b/internal/ast/types.go @@ -48,7 +48,8 @@ type TypeConstraint struct { // Bonus: in a way that can be (un)marshaled to/from JSON, // which is useful for unit tests. type Type struct { - Kind Kind + Kind Kind + Nullable bool Disjunction *DisjunctionType `json:",omitempty"` Array *ArrayType `json:",omitempty"` @@ -59,6 +60,24 @@ type Type struct { Scalar *ScalarType `json:",omitempty"` } +type TypeOption func(def *Type) + +func Nullable() TypeOption { + return func(def *Type) { + def.Nullable = true + } +} + +func Value(value any) TypeOption { + return func(def *Type) { + if def.Kind != KindScalar { + return + } + + def.Scalar.Value = value + } +} + func Any() Type { return NewScalar(KindAny) } @@ -67,54 +86,78 @@ func Null() Type { return NewScalar(KindNull) } -func Bool() Type { - return NewScalar(KindBool) +func Bool(opts ...TypeOption) Type { + return NewScalar(KindBool, opts...) } -func Bytes() Type { - return NewScalar(KindBytes) +func Bytes(opts ...TypeOption) Type { + return NewScalar(KindBytes, opts...) } -func String() Type { - return NewScalar(KindString) +func String(opts ...TypeOption) Type { + return NewScalar(KindString, opts...) } -func NewDisjunction(branches Types) Type { - return Type{ +func NewDisjunction(branches Types, opts ...TypeOption) Type { + def := Type{ Kind: KindDisjunction, Disjunction: &DisjunctionType{ Branches: branches, DiscriminatorMapping: make(map[string]any), }, } + + for _, opt := range opts { + opt(&def) + } + + return def } -func NewArray(valueType Type) Type { - return Type{ +func NewArray(valueType Type, opts ...TypeOption) Type { + def := Type{ Kind: KindArray, Array: &ArrayType{ ValueType: valueType, }, } + + for _, opt := range opts { + opt(&def) + } + + return def } -func NewEnum(values []EnumValue) Type { - return Type{ +func NewEnum(values []EnumValue, opts ...TypeOption) Type { + def := Type{ Kind: KindEnum, Enum: &EnumType{ Values: values, }, } + + for _, opt := range opts { + opt(&def) + } + + return def } -func NewMap(indexType Type, valueType Type) Type { - return Type{ +func NewMap(indexType Type, valueType Type, opts ...TypeOption) Type { + def := Type{ Kind: KindMap, Map: &MapType{ IndexType: indexType, ValueType: valueType, }, } + + for _, opt := range opts { + opt(&def) + } + + return def } func NewStruct(fields ...StructField) Type { @@ -127,32 +170,41 @@ func NewStruct(fields ...StructField) Type { } } -func NewRef(referredTypeName string) Type { - return Type{ +func NewNullableStruct(fields ...StructField) Type { + def := NewStruct(fields...) + def.Nullable = true + + return def +} + +func NewRef(referredTypeName string, opts ...TypeOption) Type { + def := Type{ Kind: KindRef, Ref: &RefType{ ReferredType: referredTypeName, }, } -} -func NewScalar(kind ScalarKind) Type { - return Type{ - Kind: KindScalar, - Scalar: &ScalarType{ - ScalarKind: kind, - }, + for _, opt := range opts { + opt(&def) } + + return def } -func NewConcreteScalar(kind ScalarKind, value any) Type { - return Type{ +func NewScalar(kind ScalarKind, opts ...TypeOption) Type { + def := Type{ Kind: KindScalar, Scalar: &ScalarType{ ScalarKind: kind, - Value: value, }, } + + for _, opt := range opts { + opt(&def) + } + + return def } func (t Type) IsNull() bool { @@ -344,11 +396,25 @@ type StructField struct { Default any } -func NewStructField(name string, fieldType Type) StructField { - return StructField{ +type StructFieldOption func(field *StructField) + +func Required() StructFieldOption { + return func(field *StructField) { + field.Required = true + } +} + +func NewStructField(name string, fieldType Type, opts ...StructFieldOption) StructField { + field := StructField{ Name: name, Type: fieldType, } + + for _, opt := range opts { + opt(&field) + } + + return field } type RefType struct { diff --git a/internal/jennies/golang/builder.go b/internal/jennies/golang/builder.go index 55889caf1..4bcaba56f 100644 --- a/internal/jennies/golang/builder.go +++ b/internal/jennies/golang/builder.go @@ -164,7 +164,7 @@ func (jenny *Builder) generateInitAssignment(builders ast.Builders, builder ast. asPointer := "" // FIXME: this condition is probably wrong - if valueType.Kind != ast.KindArray && valueType.Kind != ast.KindStruct && assignment.IntoOptionalField { + if valueType.Kind != ast.KindArray && valueType.Kind != ast.KindStruct && assignment.IntoNullableField { asPointer = "&" } @@ -229,7 +229,7 @@ func (jenny *Builder) typeHasBuilder(builders ast.Builders, builder ast.Builder, } func (jenny *Builder) generateArgument(builders ast.Builders, builder ast.Builder, arg ast.Argument) string { - typeName := formatType(arg.Type, true, "types") + typeName := strings.Trim(formatType(arg.Type, "types"), "*") if builderPkg, found := jenny.typeHasBuilder(builders, builder, arg.Type); found { return fmt.Sprintf(`opts ...%[1]s.Option`, builderPkg) @@ -246,7 +246,7 @@ func (jenny *Builder) generateAssignment(builders ast.Builders, builder ast.Buil if builderPkg, found := jenny.typeHasBuilder(builders, builder, assignment.ValueType); found { intoPointer := "*" - if assignment.IntoOptionalField { + if assignment.IntoNullableField { intoPointer = "" } @@ -267,7 +267,7 @@ func (jenny *Builder) generateAssignment(builders ast.Builders, builder ast.Buil asPointer := "" // FIXME: this condition is probably wrong - if valueType.Kind != ast.KindArray && valueType.Kind != ast.KindStruct && assignment.IntoOptionalField { + if valueType.Kind != ast.KindArray && valueType.Kind != ast.KindMap && assignment.IntoNullableField { asPointer = "&" } diff --git a/internal/jennies/golang/jennies.go b/internal/jennies/golang/jennies.go index c09f30520..4382c2bb1 100644 --- a/internal/jennies/golang/jennies.go +++ b/internal/jennies/golang/jennies.go @@ -34,6 +34,7 @@ func Jennies() *codejen.JennyList[[]*ast.File] { func CompilerPasses() []compiler.Pass { return []compiler.Pass{ &compiler.AnonymousEnumToExplicitType{}, + &compiler.NotRequiredFieldAsNullableType{}, &compiler.DisjunctionToType{}, } } diff --git a/internal/jennies/golang/rawtypes.go b/internal/jennies/golang/rawtypes.go index 336bd2d01..e8f08dd17 100644 --- a/internal/jennies/golang/rawtypes.go +++ b/internal/jennies/golang/rawtypes.go @@ -74,10 +74,10 @@ func (jenny RawTypes) formatObject(def ast.Object) ([]byte, error) { } else if scalarType.ScalarKind == ast.KindBytes { buffer.WriteString(fmt.Sprintf("type %s %s", defName, "[]byte")) } else { - buffer.WriteString(fmt.Sprintf("type %s %s", defName, formatType(def.Type, true, ""))) + buffer.WriteString(fmt.Sprintf("type %s %s", defName, formatType(def.Type, ""))) } case ast.KindMap, ast.KindRef, ast.KindArray, ast.KindStruct: - buffer.WriteString(fmt.Sprintf("type %s %s", defName, formatType(def.Type, true, ""))) + buffer.WriteString(fmt.Sprintf("type %s %s", defName, formatType(def.Type, ""))) default: return nil, fmt.Errorf("unhandled type def kind: %s", def.Type.Kind) } @@ -93,7 +93,7 @@ func (jenny RawTypes) formatEnumDef(def ast.Object) string { enumName := tools.UpperCamelCase(def.Name) enumType := def.Type.AsEnum() - buffer.WriteString(fmt.Sprintf("type %s %s\n", enumName, formatType(enumType.Values[0].Type, true, ""))) + buffer.WriteString(fmt.Sprintf("type %s %s\n", enumName, formatType(enumType.Values[0].Type, ""))) buffer.WriteString("const (\n") for _, val := range enumType.Values { @@ -188,7 +188,7 @@ func formatField(def ast.StructField, typesPkg string) string { buffer.WriteString(fmt.Sprintf( "%s %s `json:\"%s%s\"`\n", tools.UpperCamelCase(def.Name), - formatType(def.Type, def.Required, typesPkg), + formatType(def.Type, typesPkg), def.Name, jsonOmitEmpty, )) @@ -196,7 +196,7 @@ func formatField(def ast.StructField, typesPkg string) string { return buffer.String() } -func formatType(def ast.Type, fieldIsRequired bool, typesPkg string) string { +func formatType(def ast.Type, typesPkg string) string { if def.IsAny() { return "any" } @@ -211,7 +211,7 @@ func formatType(def ast.Type, fieldIsRequired bool, typesPkg string) string { if def.Kind == ast.KindScalar { typeName := def.AsScalar().ScalarKind - if !fieldIsRequired { + if def.Nullable { typeName = "*" + typeName } @@ -225,16 +225,21 @@ func formatType(def ast.Type, fieldIsRequired bool, typesPkg string) string { typeName = typesPkg + "." + typeName } - if !fieldIsRequired { + if def.Nullable { typeName = "*" + typeName } return typeName } - // anonymous struct + // anonymous struct or struct body if def.Kind == ast.KindStruct { - return formatStructBody(def.AsStruct(), typesPkg) + output := formatStructBody(def.AsStruct(), typesPkg) + if def.Nullable { + output = "*" + output + } + + return output } // FIXME: we should never be here @@ -242,14 +247,14 @@ func formatType(def ast.Type, fieldIsRequired bool, typesPkg string) string { } func formatArray(def ast.ArrayType, typesPkg string) string { - subTypeString := formatType(def.ValueType, true, typesPkg) + subTypeString := formatType(def.ValueType, typesPkg) return fmt.Sprintf("[]%s", subTypeString) } func formatMap(def ast.MapType, typesPkg string) string { - keyTypeString := formatType(def.IndexType, true, typesPkg) - valueTypeString := formatType(def.ValueType, true, typesPkg) + keyTypeString := formatType(def.IndexType, typesPkg) + valueTypeString := formatType(def.ValueType, typesPkg) return fmt.Sprintf("map[%s]%s", keyTypeString, valueTypeString) } diff --git a/internal/jennies/golang/tmpl.go b/internal/jennies/golang/tmpl.go index fdd3463b6..b2a85f75a 100644 --- a/internal/jennies/golang/tmpl.go +++ b/internal/jennies/golang/tmpl.go @@ -3,6 +3,7 @@ package golang import ( "embed" "html/template" + "strings" "github.com/grafana/cog/internal/tools" ) @@ -20,6 +21,7 @@ func init() { base.Funcs(map[string]any{ "formatIdentifier": tools.UpperCamelCase, "formatType": formatType, + "trimPrefix": strings.TrimPrefix, }) templates = template.Must(base.ParseFS(veneersFS, "veneers/*.tmpl")) } diff --git a/internal/jennies/golang/veneers/disjunction_of_scalars.types.json_marshal.go.tmpl b/internal/jennies/golang/veneers/disjunction_of_scalars.types.json_marshal.go.tmpl index 9a99a67a1..91f797869 100644 --- a/internal/jennies/golang/veneers/disjunction_of_scalars.types.json_marshal.go.tmpl +++ b/internal/jennies/golang/veneers/disjunction_of_scalars.types.json_marshal.go.tmpl @@ -16,7 +16,7 @@ func (resource *{{ .def.Name|formatIdentifier }}) UnmarshalJSON(raw []byte) erro var errList []error {{ range .def.Type.Struct.Fields }} // {{ .Name|formatIdentifier }} - var {{ .Name }} {{ formatType .Type true "" }} + var {{ .Name }} {{ trimPrefix (formatType .Type "") "*" }} if err := json.Unmarshal(raw, &{{ .Name }}); err != nil { errList = append(errList, err) resource.{{ .Name|formatIdentifier }} = nil diff --git a/internal/veneers/option/actions.go b/internal/veneers/option/actions.go index 8442e3ca2..d4fefcb7a 100644 --- a/internal/veneers/option/actions.go +++ b/internal/veneers/option/actions.go @@ -91,7 +91,7 @@ func StructFieldsAsArgumentsAction(explicitFields ...string) RewriteAction { ArgumentName: field.Name, ValueType: field.Type, Constraints: constraints, - IntoOptionalField: !field.Required, + IntoNullableField: field.Type.Nullable, }) } @@ -120,7 +120,7 @@ func UnfoldBooleanAction(unfoldOpts BooleanUnfold) RewriteAction { { Path: option.Assignments[0].Path, ValueType: option.Assignments[0].ValueType, - IntoOptionalField: option.Assignments[0].IntoOptionalField, + IntoNullableField: option.Assignments[0].IntoNullableField, Value: true, }, }, @@ -135,7 +135,7 @@ func UnfoldBooleanAction(unfoldOpts BooleanUnfold) RewriteAction { { Path: option.Assignments[0].Path, ValueType: option.Assignments[0].ValueType, - IntoOptionalField: option.Assignments[0].IntoOptionalField, + IntoNullableField: option.Assignments[0].IntoNullableField, Value: false, }, }, diff --git a/testdata/jennies/rawtypes/struct_with_optional_fields.txtar b/testdata/jennies/rawtypes/struct_with_optional_fields.txtar index 30060f451..b1fc42beb 100644 --- a/testdata/jennies/rawtypes/struct_with_optional_fields.txtar +++ b/testdata/jennies/rawtypes/struct_with_optional_fields.txtar @@ -134,7 +134,7 @@ type SomeStruct struct { FieldString *string `json:"FieldString,omitempty"` FieldAnonymousEnum *FieldAnonymousEnumEnum `json:"FieldAnonymousEnum,omitempty"` FieldArrayOfStrings []string `json:"FieldArrayOfStrings,omitempty"` - FieldAnonymousStruct struct { + FieldAnonymousStruct *struct { FieldAny any `json:"FieldAny"` } `json:"FieldAnonymousStruct,omitempty"` }