From 11bbfc12efdc8039adba227c3abf597fa4b15b52 Mon Sep 17 00:00:00 2001 From: spinillos Date: Wed, 16 Jul 2025 13:11:14 +0200 Subject: [PATCH 1/2] Add pass to add a value to enum field --- internal/ast/compiler/add_enum_value.go | 103 +++++++++++++ internal/ast/compiler/add_enum_value_test.go | 144 +++++++++++++++++++ internal/yaml/compilerpasses.go | 39 +++++ 3 files changed, 286 insertions(+) create mode 100644 internal/ast/compiler/add_enum_value.go create mode 100644 internal/ast/compiler/add_enum_value_test.go diff --git a/internal/ast/compiler/add_enum_value.go b/internal/ast/compiler/add_enum_value.go new file mode 100644 index 000000000..509589f41 --- /dev/null +++ b/internal/ast/compiler/add_enum_value.go @@ -0,0 +1,103 @@ +package compiler + +import ( + "errors" + + "github.com/grafana/cog/internal/ast" +) + +type AddEnumValue struct { + ObjectRef ObjectReference + FieldRef FieldReference + Name string + Value any +} + +func (pass *AddEnumValue) Process(schemas []*ast.Schema) ([]*ast.Schema, error) { + if pass.Name == "" || pass.Value == nil { + return nil, errors.New("name and value are required") + } + + visitor := &Visitor{ + OnObject: pass.onObject, + OnEnum: pass.onEnum, + } + return visitor.VisitSchemas(schemas) +} + +func (pass *AddEnumValue) onObject(visitor *Visitor, schema *ast.Schema, object ast.Object) (ast.Object, error) { + if object.Type.IsEnum() && pass.ObjectRef.Matches(object) { + enum, err := visitor.OnEnum(visitor, schema, object.Type) + if err != nil { + return ast.Object{}, err + } + + object.Type = enum + return object, nil + } + + if !object.Type.IsStruct() { + return object, nil + } + + for i, field := range object.Type.AsStruct().Fields { + if pass.FieldRef.Matches(object, field) { + if field.Type.IsEnum() { + updatedType, err := visitor.OnEnum(visitor, schema, field.Type) + if err != nil { + return ast.Object{}, err + } + object.Type.AsStruct().Fields[i].Type = updatedType + return object, nil + } + + if field.Type.IsRef() { + if enum, ok := pass.updateEnumObject(visitor, schema, field.Type); ok { + schema.Objects.Set(object.Name, enum) + } + return object, nil + } + } + } + + return object, nil +} + +func (pass *AddEnumValue) onEnum(_ *Visitor, _ *ast.Schema, def ast.Type) (ast.Type, error) { + enumString := def.AsEnum().Values[0].Type.AsScalar().ScalarKind == ast.KindString + _, isString := pass.Value.(string) + + if enumString && !isString { + return ast.Type{}, errors.New("enum value must be of type string") + } + if !enumString && isString { + return ast.Type{}, errors.New("enum value must be of type integer") + } + + def.Enum.Values = append(def.Enum.Values, ast.EnumValue{ + Name: pass.Name, + Type: def.AsEnum().Values[0].Type, + Value: pass.Value, + }) + + return def, nil +} + +func (pass *AddEnumValue) updateEnumObject(visitor *Visitor, schema *ast.Schema, def ast.Type) (ast.Object, bool) { + obj, ok := schema.LocateObject(def.AsRef().ReferredType) + if !ok { + return ast.Object{}, false + } + + if !obj.Type.IsEnum() { + return ast.Object{}, false + } + + enum, err := visitor.OnEnum(visitor, schema, obj.Type) + if err != nil { + return ast.Object{}, false + } + + obj.Type = enum + return obj, true +} diff --git a/internal/ast/compiler/add_enum_value_test.go b/internal/ast/compiler/add_enum_value_test.go new file mode 100644 index 000000000..97885d1b6 --- /dev/null +++ b/internal/ast/compiler/add_enum_value_test.go @@ -0,0 +1,144 @@ +package compiler + +import ( + "testing" + + "github.com/grafana/cog/internal/ast" + "github.com/grafana/cog/internal/testutils" + "github.com/stretchr/testify/require" +) + +func TestAddEnumFieldValueReference(t *testing.T) { + schema := &ast.Schema{ + Package: "add_enum_value", + Objects: testutils.ObjectsMap( + ast.NewObject("test", "MyEnum", ast.NewEnum([]ast.EnumValue{ + {Value: "A", Name: "A", Type: ast.String()}, + {Value: "B", Name: "B", Type: ast.String()}, + })), + ast.NewObject("test", "MyStruct", ast.NewStruct( + ast.NewStructField("enum", ast.NewRef("test", "MyEnum")), + )), + ), + } + + expected := &ast.Schema{ + Package: "add_enum_value", + Objects: testutils.ObjectsMap( + ast.NewObject("test", "MyEnum", ast.NewEnum([]ast.EnumValue{ + {Value: "A", Name: "A", Type: ast.String()}, + {Value: "B", Name: "B", Type: ast.String()}, + {Value: "C", Name: "C", Type: ast.String()}, + })), + ast.NewObject("test", "MyStruct", ast.NewStruct( + ast.NewStructField("enum", ast.NewRef("test", "MyEnum")), + )), + ), + } + + pass := &AddEnumValue{ + FieldRef: FieldReference{ + Package: "test", + Object: "MyStruct", + Field: "enum", + }, + Name: "C", + Value: "C", + } + + runPassOnSchema(t, pass, schema, expected) +} + +func TestAddEnumFieldValueDirectEnum(t *testing.T) { + schema := &ast.Schema{ + Package: "add_enum_value", + Objects: testutils.ObjectsMap( + ast.NewObject("test", "MyStruct", ast.NewStruct( + ast.NewStructField("enum", ast.NewEnum([]ast.EnumValue{ + {Value: 1, Name: "A", Type: ast.NewScalar(ast.KindInt64)}, + {Value: 2, Name: "B", Type: ast.NewScalar(ast.KindInt64)}, + })), + )), + ), + } + + expected := &ast.Schema{ + Package: "add_enum_value", + Objects: testutils.ObjectsMap( + ast.NewObject("test", "MyStruct", ast.NewStruct( + ast.NewStructField("enum", ast.NewEnum([]ast.EnumValue{ + {Value: 1, Name: "A", Type: ast.NewScalar(ast.KindInt64)}, + {Value: 2, Name: "B", Type: ast.NewScalar(ast.KindInt64)}, + {Value: 3, Name: "C", Type: ast.NewScalar(ast.KindInt64)}, + })), + )), + ), + } + + pass := &AddEnumValue{ + FieldRef: FieldReference{ + Package: "test", + Object: "MyStruct", + Field: "enum", + }, + Name: "C", + Value: 3, + } + + runPassOnSchema(t, pass, schema, expected) +} + +func TestAddEnumValueEnum(t *testing.T) { + schema := &ast.Schema{ + Package: "add_enum_value", + Objects: testutils.ObjectsMap(ast.NewObject("test", "MyEnum", ast.NewEnum([]ast.EnumValue{ + {Value: "A", Name: "A", Type: ast.String()}, + {Value: "B", Name: "B", Type: ast.String()}, + })), + ), + } + + expected := &ast.Schema{ + Package: "add_enum_value", + Objects: testutils.ObjectsMap(ast.NewObject("test", "MyEnum", ast.NewEnum([]ast.EnumValue{ + {Value: "A", Name: "A", Type: ast.String()}, + {Value: "B", Name: "B", Type: ast.String()}, + {Value: "C", Name: "C", Type: ast.String()}, + })), + ), + } + + pass := &AddEnumValue{ + ObjectRef: ObjectReference{ + Package: "test", + Object: "MyEnum", + }, + Name: "C", + Value: "C", + } + + runPassOnSchema(t, pass, schema, expected) +} + +func TestAddEnumValueInvalidValueKind(t *testing.T) { + schema := &ast.Schema{ + Package: "add_enum_value", + Objects: testutils.ObjectsMap(ast.NewObject("test", "MyEnum", ast.NewEnum([]ast.EnumValue{ + {Value: "A", Name: "A", Type: ast.String()}, + {Value: "B", Name: "B", Type: ast.String()}, + })), + ), + } + + pass := &AddEnumValue{ + ObjectRef: ObjectReference{ + Package: "test", + Object: "MyEnum", + }, + Name: "C", + Value: 1, + } + + _, err := pass.Process(ast.Schemas{schema}) + require.Error(t, err) +} diff --git a/internal/yaml/compilerpasses.go b/internal/yaml/compilerpasses.go index b14aa6c1d..5f336d079 100644 --- a/internal/yaml/compilerpasses.go +++ b/internal/yaml/compilerpasses.go @@ -29,6 +29,7 @@ type CompilerPass struct { DuplicateObject *DuplicateObject `yaml:"duplicate_object"` TrimEnumValues *TrimEnumValues `yaml:"trim_enum_values"` ConstantToEnum *ConstantToEnum `yaml:"constant_to_enum"` + AddEnumValue *AddEnumValue `yaml:"add_enum_value"` AnonymousStructsToNamed *AnonymousStructsToNamed `yaml:"anonymous_structs_to_named"` @@ -119,6 +120,9 @@ func (pass CompilerPass) AsCompilerPass() (compiler.Pass, error) { if pass.DisjunctionWithConstantToDefault != nil { return pass.DisjunctionWithConstantToDefault.AsCompilerPass() } + if pass.AddEnumValue != nil { + return pass.AddEnumValue.AsCompilerPass() + } return nil, fmt.Errorf("empty compiler pass") } @@ -494,3 +498,38 @@ func (pass ConstantToEnum) AsCompilerPass() (*compiler.ConstantToEnum, error) { return &compiler.ConstantToEnum{Objects: objectRefs}, nil } + +type AddEnumValue struct { + InObject string // Expected format: [package].[object] + InField string // Expected format: [package].[object].[field] + Name string + Value any +} + +func (pass AddEnumValue) AsCompilerPass() (*compiler.AddEnumValue, error) { + if pass.InObject != "" { + objectRef, err := compiler.ObjectReferenceFromString(pass.InObject) + if err != nil { + return nil, err + } + + return &compiler.AddEnumValue{ + ObjectRef: objectRef, + FieldRef: compiler.FieldReference{}, + Name: pass.Name, + Value: pass.Value, + }, nil + } + + fieldRef, err := compiler.FieldReferenceFromString(pass.InField) + if err != nil { + return nil, err + } + + return &compiler.AddEnumValue{ + ObjectRef: compiler.ObjectReference{}, + FieldRef: fieldRef, + Name: pass.Name, + Value: pass.Value, + }, nil +} From 5a3404982a92a177e7a49a7b99a245cafddbbb85 Mon Sep 17 00:00:00 2001 From: spinillos Date: Wed, 16 Jul 2025 13:58:27 +0200 Subject: [PATCH 2/2] Fixes --- internal/ast/compiler/add_enum_value.go | 2 ++ internal/ast/compiler/add_enum_value_test.go | 4 ++-- internal/yaml/compilerpasses.go | 14 +++++++------- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/internal/ast/compiler/add_enum_value.go b/internal/ast/compiler/add_enum_value.go index 509589f41..a0cba81c7 100644 --- a/internal/ast/compiler/add_enum_value.go +++ b/internal/ast/compiler/add_enum_value.go @@ -33,6 +33,7 @@ func (pass *AddEnumValue) onObject(visitor *Visitor, schema *ast.Schema, object } object.Type = enum + object.AddToPassesTrail("AddEnumValue") return object, nil } @@ -48,6 +49,7 @@ func (pass *AddEnumValue) onObject(visitor *Visitor, schema *ast.Schema, object return ast.Object{}, err } object.Type.AsStruct().Fields[i].Type = updatedType + object.Type.AsStruct().Fields[i].AddToPassesTrail("AddEnumValue") return object, nil } diff --git a/internal/ast/compiler/add_enum_value_test.go b/internal/ast/compiler/add_enum_value_test.go index 97885d1b6..0a69a166b 100644 --- a/internal/ast/compiler/add_enum_value_test.go +++ b/internal/ast/compiler/add_enum_value_test.go @@ -70,7 +70,7 @@ func TestAddEnumFieldValueDirectEnum(t *testing.T) { {Value: 1, Name: "A", Type: ast.NewScalar(ast.KindInt64)}, {Value: 2, Name: "B", Type: ast.NewScalar(ast.KindInt64)}, {Value: 3, Name: "C", Type: ast.NewScalar(ast.KindInt64)}, - })), + }), ast.PassesTrail("AddEnumValue")), )), ), } @@ -104,7 +104,7 @@ func TestAddEnumValueEnum(t *testing.T) { {Value: "A", Name: "A", Type: ast.String()}, {Value: "B", Name: "B", Type: ast.String()}, {Value: "C", Name: "C", Type: ast.String()}, - })), + }), "AddEnumValue"), ), } diff --git a/internal/yaml/compilerpasses.go b/internal/yaml/compilerpasses.go index 5f336d079..71fe78ac1 100644 --- a/internal/yaml/compilerpasses.go +++ b/internal/yaml/compilerpasses.go @@ -500,15 +500,15 @@ func (pass ConstantToEnum) AsCompilerPass() (*compiler.ConstantToEnum, error) { } type AddEnumValue struct { - InObject string // Expected format: [package].[object] - InField string // Expected format: [package].[object].[field] - Name string - Value any + Enum string // Expected format: [package].[object] + Field string // Expected format: [package].[object].[field] + Name string + Value any } func (pass AddEnumValue) AsCompilerPass() (*compiler.AddEnumValue, error) { - if pass.InObject != "" { - objectRef, err := compiler.ObjectReferenceFromString(pass.InObject) + if pass.Enum != "" { + objectRef, err := compiler.ObjectReferenceFromString(pass.Enum) if err != nil { return nil, err } @@ -521,7 +521,7 @@ func (pass AddEnumValue) AsCompilerPass() (*compiler.AddEnumValue, error) { }, nil } - fieldRef, err := compiler.FieldReferenceFromString(pass.InField) + fieldRef, err := compiler.FieldReferenceFromString(pass.Field) if err != nil { return nil, err }