diff --git a/cmd/avrogo/generate.go b/cmd/avrogo/generate.go index 46a3683..1093a69 100644 --- a/cmd/avrogo/generate.go +++ b/cmd/avrogo/generate.go @@ -29,14 +29,13 @@ const nullType = "avrotypegen.Null" // schema.RecordDefinition by looking at their match within given parsed namespace func shouldImportAvroTypeGen(namespace *parser.Namespace, definitions []schema.QualifiedName) bool { for _, def := range namespace.Definitions { - defToGenerateIdx := sort.Search(len(definitions), func(i int) bool { - return definitions[i].Name == def.AvroName().Name + searchName := def.AvroName().Name + _, found := sort.Find(len(definitions), func(i int) int { + return strings.Compare(searchName, definitions[i].Name) }) - if defToGenerateIdx < len(definitions) && def.AvroName().Name == definitions[defToGenerateIdx].Name { - if _, ok := def.(*schema.RecordDefinition); ok { - return true - } - if _, ok := def.(*schema.FixedDefinition); ok { + if found { + switch def.(type) { + case *schema.RecordDefinition, *schema.FixedDefinition: return true } }