Skip to content

Commit

Permalink
Fix duration multiplication by integer
Browse files Browse the repository at this point in the history
  • Loading branch information
antonmedv committed Aug 30, 2023
1 parent f32da1e commit 5be37aa
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 20 deletions.
12 changes: 12 additions & 0 deletions expr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -976,6 +976,18 @@ func TestExpr(t *testing.T) {
`duration("1h") + duration("1m")`,
time.Hour + time.Minute,
},
{
`7 * duration("1h")`,
7 * time.Hour,
},
{
`duration("1h") * 7`,
7 * time.Hour,
},
{
`duration("1s") * .5`,
5e8,
},
{
`1 /* one */ + 2 // two`,
3,
Expand Down
53 changes: 53 additions & 0 deletions vm/runtime/generated.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

64 changes: 44 additions & 20 deletions vm/runtime/helpers/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"fmt"
"go/format"
"os"
"strings"
"text/template"
)
Expand All @@ -13,11 +14,14 @@ func main() {
err := template.Must(
template.New("helpers").
Funcs(template.FuncMap{
"cases": func(op string) string { return cases(op, false) },
"cases_int_only": func(op string) string { return cases(op, true) },
"cases": func(op string) string { return cases(op, uints, ints, floats) },
"cases_int_only": func(op string) string { return cases(op, uints, ints) },
"cases_with_duration": func(op string) string {
return cases(op, uints, ints, floats, []string{"time.Duration"})
},
}).
Parse(helpers),
).Execute(&b, types)
).Execute(&b, nil)
if err != nil {
panic(err)
}
Expand All @@ -29,40 +33,48 @@ func main() {
fmt.Print(string(formatted))
}

var types = []string{
"uint",
"uint8",
"uint16",
"uint32",
"uint64",
var ints = []string{
"int",
"int8",
"int16",
"int32",
"int64",
}

var uints = []string{
"uint",
"uint8",
"uint16",
"uint32",
"uint64",
}

var floats = []string{
"float32",
"float64",
}

func cases(op string, noFloat bool) string {
func cases(op string, xs ...[]string) string {
var types []string
for _, x := range xs {
types = append(types, x...)
}

_, _ = fmt.Fprintf(os.Stderr, "Generating %s cases for %v\n", op, types)

var out string
echo := func(s string, xs ...any) {
out += fmt.Sprintf(s, xs...) + "\n"
}
for _, a := range types {
aIsFloat := strings.HasPrefix(a, "float")
if noFloat && aIsFloat {
continue
}
echo(`case %v:`, a)
echo(`switch y := b.(type) {`)
for _, b := range types {
bIsFloat := strings.HasPrefix(b, "float")
if noFloat && bIsFloat {
continue
}
t := "int"
if aIsFloat || bIsFloat {
if isDuration(a) || isDuration(b) {
t = "time.Duration"
}
if isFloat(a) || isFloat(b) {
t = "float64"
}
echo(`case %v:`, b)
Expand All @@ -77,6 +89,18 @@ func cases(op string, noFloat bool) string {
return strings.TrimRight(out, "\n")
}

func isFloat(t string) bool {
return strings.HasPrefix(t, "float")
}

func isInt(t string) bool {
return strings.HasPrefix(t, "int")
}

func isDuration(t string) bool {
return t == "time.Duration"
}

const helpers = `// Code generated by vm/runtime/helpers/main.go. DO NOT EDIT.
package runtime
Expand Down Expand Up @@ -245,7 +269,7 @@ func Subtract(a, b interface{}) interface{} {
func Multiply(a, b interface{}) interface{} {
switch x := a.(type) {
{{ cases "*" }}
{{ cases_with_duration "*" }}
}
panic(fmt.Sprintf("invalid operation: %T * %T", a, b))
}
Expand Down

0 comments on commit 5be37aa

Please sign in to comment.