Skip to content

Commit

Permalink
mkmatch: make allow to set custom name, otherwise use interface name
Browse files Browse the repository at this point in the history
  • Loading branch information
widmogrod committed Jul 13, 2024
1 parent 751ffdb commit 0eb1df7
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 24 deletions.
24 changes: 19 additions & 5 deletions example/tree.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,24 @@ func (s sumVisitor) VisitLeaf(v *Leaf[int]) any {
return v.Value
}

//go:tag mkmatch:"MyTriesMatch"
//go:tag mkmatch:"MyTriesMatch"
type MyTriesMatch[T0, T1 Tree[A], A any] interface {
MatchLeafs(*Leaf[A], *Leaf[A])
MatchBranches(*Branch[A], any)
//go:tag mkmatch:"MyName"
type MyTriesMatch[T0, T1 Tree[any]] interface {
MatchLeafs(*Leaf[any], *Leaf[any])
MatchBranches(*Branch[any], any)
MatchMixed(any, any)
}

func treeDoNumbers(a, b Tree[any]) int {
return MyNameR1(
a, b,
func(x0 *Leaf[any], x1 *Leaf[any]) int {
return x0.Value.(int) + x1.Value.(int)
},
func(x0 *Branch[any], x1 any) int {
return -1
},
func(x0 any, x1 any) int {
return -10
},
)
}
13 changes: 13 additions & 0 deletions example/tree_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,16 @@ func TestTreeSchema(t *testing.T) {
result := schema.ToGo[Tree[int]](sch)
assert.Equal(t, tree, result)
}

func TestMyNameMatch(t *testing.T) {
leaf1 := &Leaf[any]{Value: 1}
leaf2 := &Leaf[any]{Value: 2}

result := treeDoNumbers(leaf1, leaf2)
assert.Equal(t, 3, result)

branch1 := &Branch[any]{L: leaf1, R: leaf2}

result = treeDoNumbers(branch1, leaf1)
assert.Equal(t, -1, result)
}
28 changes: 17 additions & 11 deletions x/generators/mkmatch_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,36 +140,34 @@ func (b *MkMatchBuilder) extractPackageName(node ast.Node) {
}

func (b *MkMatchBuilder) SetName(name string) error {
if b.name == "" {
if b.name == "" || b.name == "-" {
b.name = name
} else {
return fmt.Errorf("match.SetName cannot declare name more than once")
}

return nil
}

func (b *MkMatchBuilder) SetInputs(types ...string) error {
if len(types) == 0 {
return fmt.Errorf("match.SetInputs is empty")
return fmt.Errorf("mkmatch: list of type parameters is required")
}

if b.inputTypes == nil {
b.inputTypes = types
} else {
return fmt.Errorf("match.SetInputs cannot declare inputs more than once")
return fmt.Errorf("mkmatch: cannot declare type parameters more than once")
}

return nil
}

func (b *MkMatchBuilder) AddCase(name string, inputs ...string) error {
if len(inputs) == 0 {
return fmt.Errorf("match.AddCase is empty; case name: %s", name)
return fmt.Errorf("mkmatch: matching case %s must have at least %d arguments", name, len(b.inputTypes))
}

if len(inputs) > len(b.inputTypes) {
return fmt.Errorf("match.AddCase function must have at least same number of arguments as number of type params; case name: %s", name)
if len(inputs) != len(b.inputTypes) {
return fmt.Errorf("mkmatch: matching case %s must have same number of function arguments as number of type params", name)
}

// check if there are no duplicates in other cases
Expand All @@ -181,15 +179,15 @@ func (b *MkMatchBuilder) AddCase(name string, inputs ...string) error {
}
}
if same == 0 {
return fmt.Errorf("match.AddCase cannot have duplicate; cases name: %s", b.names[cid])
return fmt.Errorf("mkmatch: matching case %s cannot have duplicate argument names", b.names[cid])
}
}
b.cases = append(b.cases, inputs)

// check if there are no duplicates in names
for _, caseName := range b.names {
if caseName == name {
return fmt.Errorf("match.AddCase cannot have duplicate; case name: %s", caseName)
return fmt.Errorf("mkmatch: cannot have duplicate; case name: %s", caseName)
}
}
b.names = append(b.names, name)
Expand All @@ -206,13 +204,21 @@ type MatchSpec struct {
}

func (b *MkMatchBuilder) Build() (*MatchSpec, error) {
if b.name == "" {
return nil, fmt.Errorf("mkmatch: type match must have name")
}

if len(b.cases) == 0 {
return nil, fmt.Errorf("mkmatch: type match must have at least one case")
}

pkgMap := make(PkgMap)

for pkgName := range b.usePackageNames {
if pkg, ok := b.knownPkgMap[pkgName]; ok {
pkgMap[pkgName] = pkg
} else {
return nil, fmt.Errorf("match.Build cannot find package %s", pkgName)
return nil, fmt.Errorf("mkmatch: cannot find package import path for name %s", pkgName)
}
}

Expand Down
11 changes: 4 additions & 7 deletions x/generators/mkmatch_visitor.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,13 @@ func (f *MkMatchTaggedNodeVisitor) Specs() []*MatchSpec {
return specs
}

func (f *MkMatchTaggedNodeVisitor) withNameWalk(value string, node ast.Node) {
func (f *MkMatchTaggedNodeVisitor) visitTaggedNode(node *shape.NodeAndTag) {
b := NewMkMatchBuilder()
b.InitPkgMap(f.pkgMap)
b.name = node.Tag.Value

ast.Walk(b, node)
f.matchBuilder[value] = b
}

func (f *MkMatchTaggedNodeVisitor) visitTaggedNode(node *shape.NodeAndTag) {
f.withNameWalk(node.Tag.Value, node.Node)
ast.Walk(b, node.Node)
f.matchBuilder[node.Tag.Value] = b
}

func typeToString(t ast.Expr) string {
Expand Down
2 changes: 1 addition & 1 deletion x/schema/location_typed.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func (location *TypedLocation) WrapLocation(loc []Location) ([]Location, error)
return loc, nil
}

//go:tag mkmatch:"MatchDifference"
//go:tag mkmatch:"-"
type MatchDifference[A, B shape.Shape] interface {
StructLikes(x *shape.StructLike, y *shape.StructLike)
UnionLikes(x *shape.UnionLike, y *shape.UnionLike)
Expand Down

0 comments on commit 0eb1df7

Please sign in to comment.