diff --git a/internal/ast/compiler/add_enum_value.go b/internal/ast/compiler/add_enum_value.go new file mode 100644 index 000000000..a0cba81c7 --- /dev/null +++ b/internal/ast/compiler/add_enum_value.go @@ -0,0 +1,105 @@ +package compiler + +import ( + "errors" + + "github.com/grafana/cog/internal/ast" +) + +type AddEnumValue struct { + ObjectRef ObjectReference + FieldRef FieldReference + Name string + Value any +} + +func (pass *AddEnumValue) Process(schemas []*ast.Schema) ([]*ast.Schema, error) { + if pass.Name == "" || pass.Value == nil { + return nil, errors.New("name and value are required") + } + + visitor := &Visitor{ + OnObject: pass.onObject, + OnEnum: pass.onEnum, + } + return visitor.VisitSchemas(schemas) +} + +func (pass *AddEnumValue) onObject(visitor *Visitor, schema *ast.Schema, object ast.Object) (ast.Object, error) { + if object.Type.IsEnum() && pass.ObjectRef.Matches(object) { + enum, err := visitor.OnEnum(visitor, schema, object.Type) + if err != nil { + return ast.Object{}, err + } + + object.Type = enum + object.AddToPassesTrail("AddEnumValue") + return object, nil + } + + if !object.Type.IsStruct() { + return object, nil + } + + for i, field := range object.Type.AsStruct().Fields { + if pass.FieldRef.Matches(object, field) { + if field.Type.IsEnum() { + updatedType, err := visitor.OnEnum(visitor, schema, field.Type) + if err != nil { + return ast.Object{}, err + } + object.Type.AsStruct().Fields[i].Type = updatedType + object.Type.AsStruct().Fields[i].AddToPassesTrail("AddEnumValue") + return object, nil + } + + if field.Type.IsRef() { + if enum, ok := pass.updateEnumObject(visitor, schema, field.Type); ok { + schema.Objects.Set(object.Name, enum) + } + return object, nil + } + } + } + + return object, nil +} + +func (pass *AddEnumValue) onEnum(_ *Visitor, _ *ast.Schema, def ast.Type) (ast.Type, error) { + enumString := def.AsEnum().Values[0].Type.AsScalar().ScalarKind == ast.KindString + _, isString := pass.Value.(string) + + if enumString && !isString { + return ast.Type{}, errors.New("enum value must be of type string") + } + if !enumString && isString { + return ast.Type{}, errors.New("enum value must be of type integer") + } + + def.Enum.Values = append(def.Enum.Values, ast.EnumValue{ + Name: pass.Name, + Type: def.AsEnum().Values[0].Type, + Value: pass.Value, + }) + + return def, nil +} + +func (pass *AddEnumValue) updateEnumObject(visitor *Visitor, schema *ast.Schema, def ast.Type) (ast.Object, bool) { + obj, ok := schema.LocateObject(def.AsRef().ReferredType) + if !ok { + return ast.Object{}, false + } + + if !obj.Type.IsEnum() { + return ast.Object{}, false + } + + enum, err := visitor.OnEnum(visitor, schema, obj.Type) + if err != nil { + return ast.Object{}, false + } + + obj.Type = enum + return obj, true +} diff --git a/internal/ast/compiler/add_enum_value_test.go b/internal/ast/compiler/add_enum_value_test.go new file mode 100644 index 000000000..0a69a166b --- /dev/null +++ b/internal/ast/compiler/add_enum_value_test.go @@ -0,0 +1,144 @@ +package compiler + +import ( + "testing" + + "github.com/grafana/cog/internal/ast" + "github.com/grafana/cog/internal/testutils" + "github.com/stretchr/testify/require" +) + +func TestAddEnumFieldValueReference(t *testing.T) { + schema := &ast.Schema{ + Package: "add_enum_value", + Objects: testutils.ObjectsMap( + ast.NewObject("test", "MyEnum", ast.NewEnum([]ast.EnumValue{ + {Value: "A", Name: "A", Type: ast.String()}, + {Value: "B", Name: "B", Type: ast.String()}, + })), + ast.NewObject("test", "MyStruct", ast.NewStruct( + ast.NewStructField("enum", ast.NewRef("test", "MyEnum")), + )), + ), + } + + expected := &ast.Schema{ + Package: "add_enum_value", + Objects: testutils.ObjectsMap( + ast.NewObject("test", "MyEnum", ast.NewEnum([]ast.EnumValue{ + {Value: "A", Name: "A", Type: ast.String()}, + {Value: "B", Name: "B", Type: ast.String()}, + {Value: "C", Name: "C", Type: ast.String()}, + })), + ast.NewObject("test", "MyStruct", ast.NewStruct( + ast.NewStructField("enum", ast.NewRef("test", "MyEnum")), + )), + ), + } + + pass := &AddEnumValue{ + FieldRef: FieldReference{ + Package: "test", + Object: "MyStruct", + Field: "enum", + }, + Name: "C", + Value: "C", + } + + runPassOnSchema(t, pass, schema, expected) +} + +func TestAddEnumFieldValueDirectEnum(t *testing.T) { + schema := &ast.Schema{ + Package: "add_enum_value", + Objects: testutils.ObjectsMap( + ast.NewObject("test", "MyStruct", ast.NewStruct( + ast.NewStructField("enum", ast.NewEnum([]ast.EnumValue{ + {Value: 1, Name: "A", Type: ast.NewScalar(ast.KindInt64)}, + {Value: 2, Name: "B", Type: ast.NewScalar(ast.KindInt64)}, + })), + )), + ), + } + + expected := &ast.Schema{ + Package: "add_enum_value", + Objects: testutils.ObjectsMap( + ast.NewObject("test", "MyStruct", ast.NewStruct( + ast.NewStructField("enum", ast.NewEnum([]ast.EnumValue{ + {Value: 1, Name: "A", Type: ast.NewScalar(ast.KindInt64)}, + {Value: 2, Name: "B", Type: ast.NewScalar(ast.KindInt64)}, + {Value: 3, Name: "C", Type: ast.NewScalar(ast.KindInt64)}, + }), ast.PassesTrail("AddEnumValue")), + )), + ), + } + + pass := &AddEnumValue{ + FieldRef: FieldReference{ + Package: "test", + Object: "MyStruct", + Field: "enum", + }, + Name: "C", + Value: 3, + } + + runPassOnSchema(t, pass, schema, expected) +} + +func TestAddEnumValueEnum(t *testing.T) { + schema := &ast.Schema{ + Package: "add_enum_value", + Objects: testutils.ObjectsMap(ast.NewObject("test", "MyEnum", ast.NewEnum([]ast.EnumValue{ + {Value: "A", Name: "A", Type: ast.String()}, + {Value: "B", Name: "B", Type: ast.String()}, + })), + ), + } + + expected := &ast.Schema{ + Package: "add_enum_value", + Objects: testutils.ObjectsMap(ast.NewObject("test", "MyEnum", ast.NewEnum([]ast.EnumValue{ + {Value: "A", Name: "A", Type: ast.String()}, + {Value: "B", Name: "B", Type: ast.String()}, + {Value: "C", Name: "C", Type: ast.String()}, + }), "AddEnumValue"), + ), + } + + pass := &AddEnumValue{ + ObjectRef: ObjectReference{ + Package: "test", + Object: "MyEnum", + }, + Name: "C", + Value: "C", + } + + runPassOnSchema(t, pass, schema, expected) +} + +func TestAddEnumValueInvalidValueKind(t *testing.T) { + schema := &ast.Schema{ + Package: "add_enum_value", + Objects: testutils.ObjectsMap(ast.NewObject("test", "MyEnum", ast.NewEnum([]ast.EnumValue{ + {Value: "A", Name: "A", Type: ast.String()}, + {Value: "B", Name: "B", Type: ast.String()}, + })), + ), + } + + pass := &AddEnumValue{ + ObjectRef: ObjectReference{ + Package: "test", + Object: "MyEnum", + }, + Name: "C", + Value: 1, + } + + _, err := pass.Process(ast.Schemas{schema}) + require.Error(t, err) +} diff --git a/internal/yaml/compilerpasses.go b/internal/yaml/compilerpasses.go index b14aa6c1d..71fe78ac1 100644 --- a/internal/yaml/compilerpasses.go +++ b/internal/yaml/compilerpasses.go @@ -29,6 +29,7 @@ type CompilerPass struct { DuplicateObject *DuplicateObject `yaml:"duplicate_object"` TrimEnumValues *TrimEnumValues `yaml:"trim_enum_values"` ConstantToEnum *ConstantToEnum `yaml:"constant_to_enum"` + AddEnumValue *AddEnumValue `yaml:"add_enum_value"` AnonymousStructsToNamed *AnonymousStructsToNamed `yaml:"anonymous_structs_to_named"` @@ -119,6 +120,9 @@ func (pass CompilerPass) AsCompilerPass() (compiler.Pass, error) { if pass.DisjunctionWithConstantToDefault != nil { return pass.DisjunctionWithConstantToDefault.AsCompilerPass() } + if pass.AddEnumValue != nil { + return pass.AddEnumValue.AsCompilerPass() + } return nil, fmt.Errorf("empty compiler pass") } @@ -494,3 +498,38 @@ func (pass ConstantToEnum) AsCompilerPass() (*compiler.ConstantToEnum, error) { return &compiler.ConstantToEnum{Objects: objectRefs}, nil } + +type AddEnumValue struct { + Enum string // Expected format: [package].[object] + Field string // Expected format: [package].[object].[field] + Name string + Value any +} + +func (pass AddEnumValue) AsCompilerPass() (*compiler.AddEnumValue, error) { + if pass.Enum != "" { + objectRef, err := compiler.ObjectReferenceFromString(pass.Enum) + if err != nil { + return nil, err + } + + return &compiler.AddEnumValue{ + ObjectRef: objectRef, + FieldRef: compiler.FieldReference{}, + Name: pass.Name, + Value: pass.Value, + }, nil + } + + fieldRef, err := compiler.FieldReferenceFromString(pass.Field) + if err != nil { + return nil, err + } + + return &compiler.AddEnumValue{ + ObjectRef: compiler.ObjectReference{}, + FieldRef: fieldRef, + Name: pass.Name, + Value: pass.Value, + }, nil +}