diff --git a/example/tree.go b/example/tree.go index 07d11f49..a99f2778 100644 --- a/example/tree.go +++ b/example/tree.go @@ -30,10 +30,24 @@ func (s sumVisitor) VisitLeaf(v *Leaf[int]) any { return v.Value } -//go:tag mkmatch:"MyTriesMatch" -//go:tag mkmatch:"MyTriesMatch" -type MyTriesMatch[T0, T1 Tree[A], A any] interface { - MatchLeafs(*Leaf[A], *Leaf[A]) - MatchBranches(*Branch[A], any) +//go:tag mkmatch:"MyName" +type MyTriesMatch[T0, T1 Tree[any]] interface { + MatchLeafs(*Leaf[any], *Leaf[any]) + MatchBranches(*Branch[any], any) MatchMixed(any, any) } + +func treeDoNumbers(a, b Tree[any]) int { + return MyNameR1( + a, b, + func(x0 *Leaf[any], x1 *Leaf[any]) int { + return x0.Value.(int) + x1.Value.(int) + }, + func(x0 *Branch[any], x1 any) int { + return -1 + }, + func(x0 any, x1 any) int { + return -10 + }, + ) +} diff --git a/example/tree_test.go b/example/tree_test.go index 04d3a948..1442032c 100644 --- a/example/tree_test.go +++ b/example/tree_test.go @@ -75,3 +75,16 @@ func TestTreeSchema(t *testing.T) { result := schema.ToGo[Tree[int]](sch) assert.Equal(t, tree, result) } + +func TestMyNameMatch(t *testing.T) { + leaf1 := &Leaf[any]{Value: 1} + leaf2 := &Leaf[any]{Value: 2} + + result := treeDoNumbers(leaf1, leaf2) + assert.Equal(t, 3, result) + + branch1 := &Branch[any]{L: leaf1, R: leaf2} + + result = treeDoNumbers(branch1, leaf1) + assert.Equal(t, -1, result) +} diff --git a/x/generators/mkmatch_builder.go b/x/generators/mkmatch_builder.go index f86fb446..a9083bfc 100644 --- a/x/generators/mkmatch_builder.go +++ b/x/generators/mkmatch_builder.go @@ -140,10 +140,8 @@ func (b *MkMatchBuilder) extractPackageName(node ast.Node) { } func (b *MkMatchBuilder) SetName(name string) error { - if b.name == "" { + if b.name == "" || b.name == "-" { b.name = name - } else { - return fmt.Errorf("match.SetName cannot declare name more than once") } return nil @@ -151,13 +149,13 @@ func (b *MkMatchBuilder) SetName(name string) error { func (b *MkMatchBuilder) SetInputs(types ...string) error { if len(types) == 0 { - return fmt.Errorf("match.SetInputs is empty") + return fmt.Errorf("mkmatch: list of type parameters is required") } if b.inputTypes == nil { b.inputTypes = types } else { - return fmt.Errorf("match.SetInputs cannot declare inputs more than once") + return fmt.Errorf("mkmatch: cannot declare type parameters more than once") } return nil @@ -165,11 +163,11 @@ func (b *MkMatchBuilder) SetInputs(types ...string) error { func (b *MkMatchBuilder) AddCase(name string, inputs ...string) error { if len(inputs) == 0 { - return fmt.Errorf("match.AddCase is empty; case name: %s", name) + return fmt.Errorf("mkmatch: matching case %s must have at least %d arguments", name, len(b.inputTypes)) } - if len(inputs) > len(b.inputTypes) { - return fmt.Errorf("match.AddCase function must have at least same number of arguments as number of type params; case name: %s", name) + if len(inputs) != len(b.inputTypes) { + return fmt.Errorf("mkmatch: matching case %s must have same number of function arguments as number of type params", name) } // check if there are no duplicates in other cases @@ -181,7 +179,7 @@ func (b *MkMatchBuilder) AddCase(name string, inputs ...string) error { } } if same == 0 { - return fmt.Errorf("match.AddCase cannot have duplicate; cases name: %s", b.names[cid]) + return fmt.Errorf("mkmatch: matching case %s cannot have duplicate argument names", b.names[cid]) } } b.cases = append(b.cases, inputs) @@ -189,7 +187,7 @@ func (b *MkMatchBuilder) AddCase(name string, inputs ...string) error { // check if there are no duplicates in names for _, caseName := range b.names { if caseName == name { - return fmt.Errorf("match.AddCase cannot have duplicate; case name: %s", caseName) + return fmt.Errorf("mkmatch: cannot have duplicate; case name: %s", caseName) } } b.names = append(b.names, name) @@ -206,13 +204,21 @@ type MatchSpec struct { } func (b *MkMatchBuilder) Build() (*MatchSpec, error) { + if b.name == "" { + return nil, fmt.Errorf("mkmatch: type match must have name") + } + + if len(b.cases) == 0 { + return nil, fmt.Errorf("mkmatch: type match must have at least one case") + } + pkgMap := make(PkgMap) for pkgName := range b.usePackageNames { if pkg, ok := b.knownPkgMap[pkgName]; ok { pkgMap[pkgName] = pkg } else { - return nil, fmt.Errorf("match.Build cannot find package %s", pkgName) + return nil, fmt.Errorf("mkmatch: cannot find package import path for name %s", pkgName) } } diff --git a/x/generators/mkmatch_visitor.go b/x/generators/mkmatch_visitor.go index e3ec76dc..f3e82a04 100644 --- a/x/generators/mkmatch_visitor.go +++ b/x/generators/mkmatch_visitor.go @@ -36,16 +36,13 @@ func (f *MkMatchTaggedNodeVisitor) Specs() []*MatchSpec { return specs } -func (f *MkMatchTaggedNodeVisitor) withNameWalk(value string, node ast.Node) { +func (f *MkMatchTaggedNodeVisitor) visitTaggedNode(node *shape.NodeAndTag) { b := NewMkMatchBuilder() b.InitPkgMap(f.pkgMap) + b.name = node.Tag.Value - ast.Walk(b, node) - f.matchBuilder[value] = b -} - -func (f *MkMatchTaggedNodeVisitor) visitTaggedNode(node *shape.NodeAndTag) { - f.withNameWalk(node.Tag.Value, node.Node) + ast.Walk(b, node.Node) + f.matchBuilder[node.Tag.Value] = b } func typeToString(t ast.Expr) string { diff --git a/x/schema/location_typed.go b/x/schema/location_typed.go index 8a6acb5e..bf523b36 100644 --- a/x/schema/location_typed.go +++ b/x/schema/location_typed.go @@ -75,7 +75,7 @@ func (location *TypedLocation) WrapLocation(loc []Location) ([]Location, error) return loc, nil } -//go:tag mkmatch:"MatchDifference" +//go:tag mkmatch:"-" type MatchDifference[A, B shape.Shape] interface { StructLikes(x *shape.StructLike, y *shape.StructLike) UnionLikes(x *shape.UnionLike, y *shape.UnionLike)