Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 105 additions & 0 deletions internal/ast/compiler/add_enum_value.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
package compiler

import (
"errors"

"github.com/grafana/cog/internal/ast"
)

type AddEnumValue struct {
ObjectRef ObjectReference
FieldRef FieldReference
Name string
Value any
}

func (pass *AddEnumValue) Process(schemas []*ast.Schema) ([]*ast.Schema, error) {
if pass.Name == "" || pass.Value == nil {
return nil, errors.New("name and value are required")
}

visitor := &Visitor{
OnObject: pass.onObject,
OnEnum: pass.onEnum,
}
return visitor.VisitSchemas(schemas)
}

func (pass *AddEnumValue) onObject(visitor *Visitor, schema *ast.Schema, object ast.Object) (ast.Object, error) {
if object.Type.IsEnum() && pass.ObjectRef.Matches(object) {
enum, err := visitor.OnEnum(visitor, schema, object.Type)
if err != nil {
return ast.Object{}, err
}

object.Type = enum
object.AddToPassesTrail("AddEnumValue")
return object, nil
}

if !object.Type.IsStruct() {
return object, nil
}

for i, field := range object.Type.AsStruct().Fields {
if pass.FieldRef.Matches(object, field) {
if field.Type.IsEnum() {
updatedType, err := visitor.OnEnum(visitor, schema, field.Type)
if err != nil {
return ast.Object{}, err
}
object.Type.AsStruct().Fields[i].Type = updatedType
object.Type.AsStruct().Fields[i].AddToPassesTrail("AddEnumValue")
return object, nil
}

if field.Type.IsRef() {
if enum, ok := pass.updateEnumObject(visitor, schema, field.Type); ok {
schema.Objects.Set(object.Name, enum)
}
return object, nil
}
}
}

return object, nil
}

func (pass *AddEnumValue) onEnum(_ *Visitor, _ *ast.Schema, def ast.Type) (ast.Type, error) {
enumString := def.AsEnum().Values[0].Type.AsScalar().ScalarKind == ast.KindString
_, isString := pass.Value.(string)

if enumString && !isString {
return ast.Type{}, errors.New("enum value must be of type string")
}
if !enumString && isString {
return ast.Type{}, errors.New("enum value must be of type integer")
}

def.Enum.Values = append(def.Enum.Values, ast.EnumValue{
Name: pass.Name,
Type: def.AsEnum().Values[0].Type,
Value: pass.Value,
})

return def, nil
}

func (pass *AddEnumValue) updateEnumObject(visitor *Visitor, schema *ast.Schema, def ast.Type) (ast.Object, bool) {
obj, ok := schema.LocateObject(def.AsRef().ReferredType)
if !ok {
return ast.Object{}, false
}

if !obj.Type.IsEnum() {
return ast.Object{}, false
}

enum, err := visitor.OnEnum(visitor, schema, obj.Type)
if err != nil {
return ast.Object{}, false
}

obj.Type = enum
return obj, true
}
144 changes: 144 additions & 0 deletions internal/ast/compiler/add_enum_value_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
package compiler

import (
"testing"

"github.com/grafana/cog/internal/ast"
"github.com/grafana/cog/internal/testutils"
"github.com/stretchr/testify/require"
)

func TestAddEnumFieldValueReference(t *testing.T) {
schema := &ast.Schema{
Package: "add_enum_value",
Objects: testutils.ObjectsMap(
ast.NewObject("test", "MyEnum", ast.NewEnum([]ast.EnumValue{
{Value: "A", Name: "A", Type: ast.String()},
{Value: "B", Name: "B", Type: ast.String()},
})),
ast.NewObject("test", "MyStruct", ast.NewStruct(
ast.NewStructField("enum", ast.NewRef("test", "MyEnum")),
)),
),
}

expected := &ast.Schema{
Package: "add_enum_value",
Objects: testutils.ObjectsMap(
ast.NewObject("test", "MyEnum", ast.NewEnum([]ast.EnumValue{
{Value: "A", Name: "A", Type: ast.String()},
{Value: "B", Name: "B", Type: ast.String()},
{Value: "C", Name: "C", Type: ast.String()},
})),
ast.NewObject("test", "MyStruct", ast.NewStruct(
ast.NewStructField("enum", ast.NewRef("test", "MyEnum")),
)),
),
}

pass := &AddEnumValue{
FieldRef: FieldReference{
Package: "test",
Object: "MyStruct",
Field: "enum",
},
Name: "C",
Value: "C",
}

runPassOnSchema(t, pass, schema, expected)
}

func TestAddEnumFieldValueDirectEnum(t *testing.T) {
schema := &ast.Schema{
Package: "add_enum_value",
Objects: testutils.ObjectsMap(
ast.NewObject("test", "MyStruct", ast.NewStruct(
ast.NewStructField("enum", ast.NewEnum([]ast.EnumValue{
{Value: 1, Name: "A", Type: ast.NewScalar(ast.KindInt64)},
{Value: 2, Name: "B", Type: ast.NewScalar(ast.KindInt64)},
})),
)),
),
}

expected := &ast.Schema{
Package: "add_enum_value",
Objects: testutils.ObjectsMap(
ast.NewObject("test", "MyStruct", ast.NewStruct(
ast.NewStructField("enum", ast.NewEnum([]ast.EnumValue{
{Value: 1, Name: "A", Type: ast.NewScalar(ast.KindInt64)},
{Value: 2, Name: "B", Type: ast.NewScalar(ast.KindInt64)},
{Value: 3, Name: "C", Type: ast.NewScalar(ast.KindInt64)},
}), ast.PassesTrail("AddEnumValue")),
)),
),
}

pass := &AddEnumValue{
FieldRef: FieldReference{
Package: "test",
Object: "MyStruct",
Field: "enum",
},
Name: "C",
Value: 3,
}

runPassOnSchema(t, pass, schema, expected)
}

func TestAddEnumValueEnum(t *testing.T) {
schema := &ast.Schema{
Package: "add_enum_value",
Objects: testutils.ObjectsMap(ast.NewObject("test", "MyEnum", ast.NewEnum([]ast.EnumValue{
{Value: "A", Name: "A", Type: ast.String()},
{Value: "B", Name: "B", Type: ast.String()},
})),
),
}

expected := &ast.Schema{
Package: "add_enum_value",
Objects: testutils.ObjectsMap(ast.NewObject("test", "MyEnum", ast.NewEnum([]ast.EnumValue{
{Value: "A", Name: "A", Type: ast.String()},
{Value: "B", Name: "B", Type: ast.String()},
{Value: "C", Name: "C", Type: ast.String()},
}), "AddEnumValue"),
),
}

pass := &AddEnumValue{
ObjectRef: ObjectReference{
Package: "test",
Object: "MyEnum",
},
Name: "C",
Value: "C",
}

runPassOnSchema(t, pass, schema, expected)
}

func TestAddEnumValueInvalidValueKind(t *testing.T) {
schema := &ast.Schema{
Package: "add_enum_value",
Objects: testutils.ObjectsMap(ast.NewObject("test", "MyEnum", ast.NewEnum([]ast.EnumValue{
{Value: "A", Name: "A", Type: ast.String()},
{Value: "B", Name: "B", Type: ast.String()},
})),
),
}

pass := &AddEnumValue{
ObjectRef: ObjectReference{
Package: "test",
Object: "MyEnum",
},
Name: "C",
Value: 1,
}

_, err := pass.Process(ast.Schemas{schema})
require.Error(t, err)
}
39 changes: 39 additions & 0 deletions internal/yaml/compilerpasses.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ type CompilerPass struct {
DuplicateObject *DuplicateObject `yaml:"duplicate_object"`
TrimEnumValues *TrimEnumValues `yaml:"trim_enum_values"`
ConstantToEnum *ConstantToEnum `yaml:"constant_to_enum"`
AddEnumValue *AddEnumValue `yaml:"add_enum_value"`

AnonymousStructsToNamed *AnonymousStructsToNamed `yaml:"anonymous_structs_to_named"`

Expand Down Expand Up @@ -119,6 +120,9 @@ func (pass CompilerPass) AsCompilerPass() (compiler.Pass, error) {
if pass.DisjunctionWithConstantToDefault != nil {
return pass.DisjunctionWithConstantToDefault.AsCompilerPass()
}
if pass.AddEnumValue != nil {
return pass.AddEnumValue.AsCompilerPass()
}

return nil, fmt.Errorf("empty compiler pass")
}
Expand Down Expand Up @@ -494,3 +498,38 @@ func (pass ConstantToEnum) AsCompilerPass() (*compiler.ConstantToEnum, error) {

return &compiler.ConstantToEnum{Objects: objectRefs}, nil
}

type AddEnumValue struct {
Enum string // Expected format: [package].[object]
Field string // Expected format: [package].[object].[field]
Name string
Value any
}

func (pass AddEnumValue) AsCompilerPass() (*compiler.AddEnumValue, error) {
if pass.Enum != "" {
objectRef, err := compiler.ObjectReferenceFromString(pass.Enum)
if err != nil {
return nil, err
}

return &compiler.AddEnumValue{
ObjectRef: objectRef,
FieldRef: compiler.FieldReference{},
Name: pass.Name,
Value: pass.Value,
}, nil
}

fieldRef, err := compiler.FieldReferenceFromString(pass.Field)
if err != nil {
return nil, err
}

return &compiler.AddEnumValue{
ObjectRef: compiler.ObjectReference{},
FieldRef: fieldRef,
Name: pass.Name,
Value: pass.Value,
}, nil
}