diff --git a/internal/ast/builder.go b/internal/ast/builder.go index fc5681afe..ebdd5d9dc 100644 --- a/internal/ast/builder.go +++ b/internal/ast/builder.go @@ -63,31 +63,51 @@ type Assignment struct { type BuilderGenerator struct { } -func (generator *BuilderGenerator) FromAST(schemas []*Schema) []Builder { +func (generator *BuilderGenerator) FromAST(schemas Schemas) []Builder { builders := make([]Builder, 0, len(schemas)) for _, schema := range schemas { for _, object := range schema.Objects { - // we only want builders for structs - if object.Type.Kind != KindStruct { + // we only want builders for structs or references to structs + if object.Type.Kind == KindRef { + ref := object.Type.AsRef() + referredObj, found := schemas.LocateObject(ref.ReferredPkg, ref.ReferredType) + if !found { + continue + } + + if referredObj.Type.Kind != KindStruct { + continue + } + } + + if object.Type.Kind != KindStruct && object.Type.Kind != KindRef { continue } - builders = append(builders, generator.structObjectToBuilder(schema, object)) + builders = append(builders, generator.structObjectToBuilder(schemas, schema, object)) } } return builders } -func (generator *BuilderGenerator) structObjectToBuilder(schema *Schema, object Object) Builder { +func (generator *BuilderGenerator) structObjectToBuilder(schemas Schemas, schema *Schema, object Object) Builder { builder := Builder{ RootPackage: schema.Package, Package: object.Name, Schema: schema, For: object, } - structType := object.Type.AsStruct() + + var structType StructType + if object.Type.Kind == KindStruct { + structType = object.Type.AsStruct() + } else { + ref := object.Type.AsRef() + referredObj, _ := schemas.LocateObject(ref.ReferredPkg, ref.ReferredType) + structType = referredObj.Type.AsStruct() + } for _, field := range structType.Fields { if generator.fieldHasStaticValue(field) { diff --git a/internal/ast/compiler/disjunctions.go b/internal/ast/compiler/disjunctions.go index 575fb3bad..ba04c00a8 100644 --- a/internal/ast/compiler/disjunctions.go +++ b/internal/ast/compiler/disjunctions.go @@ -289,7 +289,7 @@ func (pass *DisjunctionToType) inferDiscriminatorField(schema *ast.Schema, def a for _, branch := range def.Branches { // FIXME: what if the definition is itself a reference? Resolve recursively? typeName := branch.AsRef().ReferredType - structType := schema.LocateDefinition(typeName).Type.AsStruct() + structType := schema.LocateObject(typeName).Type.AsStruct() candidates[typeName] = make(map[string]any) for _, field := range structType.Fields { @@ -343,7 +343,7 @@ func (pass *DisjunctionToType) buildDiscriminatorMapping(schema *ast.Schema, def for _, branch := range def.Branches { // FIXME: what if the definition is itself a reference? Resolve recursively? typeName := branch.AsRef().ReferredType - structType := schema.LocateDefinition(typeName).Type.AsStruct() + structType := schema.LocateObject(typeName).Type.AsStruct() field, found := structType.FieldByName(def.Discriminator) if !found { diff --git a/internal/ast/schema.go b/internal/ast/schema.go index d624b7fe7..565068de2 100644 --- a/internal/ast/schema.go +++ b/internal/ast/schema.go @@ -15,6 +15,21 @@ const ( type Schemas []*Schema +func (schemas Schemas) LocateObject(pkg string, name string) (Object, bool) { + for _, schema := range schemas { + if schema.Package != pkg { + continue + } + + obj := schema.LocateObject(name) + + // TODO: schema.LocateObject() should return a "found" boolean + return obj, obj.Name != "" + } + + return Object{}, false +} + func (schemas Schemas) DeepCopy() []*Schema { newSchemas := make([]*Schema, 0, len(schemas)) @@ -46,7 +61,7 @@ func (schema *Schema) DeepCopy() Schema { return newSchema } -func (schema *Schema) LocateDefinition(name string) Object { +func (schema *Schema) LocateObject(name string) Object { for _, def := range schema.Objects { if def.Name == name { return def diff --git a/internal/veneers/option/actions.go b/internal/veneers/option/actions.go index c7fd47ce2..100be0853 100644 --- a/internal/veneers/option/actions.go +++ b/internal/veneers/option/actions.go @@ -63,7 +63,7 @@ func StructFieldsAsArgumentsAction(explicitFields ...string) RewriteAction { firstArgType := option.Args[0].Type if firstArgType.Kind == ast.KindRef { - referredObject := builder.Schema.LocateDefinition(firstArgType.AsRef().ReferredType) + referredObject := builder.Schema.LocateObject(firstArgType.AsRef().ReferredType) firstArgType = referredObject.Type }