diff --git a/ast.go b/ast.go index 43d8d43..b6008db 100644 --- a/ast.go +++ b/ast.go @@ -122,6 +122,18 @@ func (l *StringSliceLiteral) Type() string { return "string slice" } +type DurationLiteral struct { + Val time.Duration +} + +func (l *DurationLiteral) String() string { + return l.Val.String() +} + +func (l *DurationLiteral) Type() string { + return "duration" +} + type WalkFunc func(expr Expr, err error) error func Walk(expr Expr, fn WalkFunc) error { diff --git a/evaluator.go b/evaluator.go index 2d61bdd..2486310 100644 --- a/evaluator.go +++ b/evaluator.go @@ -138,6 +138,16 @@ func computeEQ(lhs, rhs Expr) (*BoolLiteral, error) { return &BoolLiteral{Val: (l.Val == dt)}, nil } + case *DurationLiteral: + rv, ok := rhs.(*StringLiteral) + if ok { + v, err := time.ParseDuration(rv.Val) + if err != nil { + return nil, err + } + + return &BoolLiteral{Val: l.Val == v}, nil + } } return nil, fmt.Errorf(`cannot convert "%s" to %s`, rhs.String(), lhs.Type()) @@ -175,6 +185,16 @@ func computeNEQ(lhs, rhs Expr) (*BoolLiteral, error) { return &BoolLiteral{Val: (l.Val != dt)}, nil } + case *DurationLiteral: + rv, ok := rhs.(*StringLiteral) + if ok { + v, err := time.ParseDuration(rv.Val) + if err != nil { + return nil, err + } + + return &BoolLiteral{Val: l.Val != v}, nil + } } return nil, fmt.Errorf(`cannot convert "%s" to %s`, rhs.String(), lhs.Type()) @@ -207,6 +227,16 @@ func computeLT(lhs, rhs Expr) (*BoolLiteral, error) { return &BoolLiteral{Val: l.Val.Before(dt)}, nil } + case *DurationLiteral: + rv, ok := rhs.(*StringLiteral) + if ok { + v, err := time.ParseDuration(rv.Val) + if err != nil { + return nil, err + } + + return &BoolLiteral{Val: l.Val < v}, nil + } } return nil, fmt.Errorf(`cannot convert "%s" to %s`, rhs.String(), lhs.Type()) @@ -239,6 +269,16 @@ func computeLTE(lhs, rhs Expr) (*BoolLiteral, error) { return &BoolLiteral{Val: l.Val.Before(dt)}, nil } + case *DurationLiteral: + rv, ok := rhs.(*StringLiteral) + if ok { + v, err := time.ParseDuration(rv.Val) + if err != nil { + return nil, err + } + + return &BoolLiteral{Val: l.Val <= v}, nil + } } return nil, fmt.Errorf(`cannot convert "%s" to %s`, rhs.String(), lhs.Type()) @@ -271,6 +311,16 @@ func computeGT(lhs, rhs Expr) (*BoolLiteral, error) { return &BoolLiteral{Val: l.Val.After(dt)}, nil } + case *DurationLiteral: + rv, ok := rhs.(*StringLiteral) + if ok { + v, err := time.ParseDuration(rv.Val) + if err != nil { + return nil, err + } + + return &BoolLiteral{Val: l.Val > v}, nil + } } return nil, fmt.Errorf(`cannot convert "%s" to %s`, rhs.String(), lhs.Type()) @@ -303,6 +353,16 @@ func computeGTE(lhs, rhs Expr) (*BoolLiteral, error) { return &BoolLiteral{Val: l.Val.After(dt)}, nil } + case *DurationLiteral: + rv, ok := rhs.(*StringLiteral) + if ok { + v, err := time.ParseDuration(rv.Val) + if err != nil { + return nil, err + } + + return &BoolLiteral{Val: l.Val >= v}, nil + } } return nil, fmt.Errorf(`cannot convert "%s" to %s`, rhs.String(), lhs.Type()) diff --git a/evaluator_test.go b/evaluator_test.go index 866397f..6f01757 100644 --- a/evaluator_test.go +++ b/evaluator_test.go @@ -25,6 +25,15 @@ func TestEvaluator(t *testing.T) { dt, err := time.Parse(time.RFC3339, "2019-03-28T11:39:43+07:00") require.NoError(t, err) + dur2m, err := time.ParseDuration("2m") + require.NoError(t, err) + + dur1m30s, err := time.ParseDuration("1m30s") + require.NoError(t, err) + + dur45s, err := time.ParseDuration("45s") + require.NoError(t, err) + tests := []TestCase{ { `{ "comparator": "||", "rules": [ { "comparator": "&&", "rules": [ { "var": "a", "op": "==", "val": 1 }, { "var": "b", "op": "==", "val": 2 } ] }, { "comparator": "&&", "rules": [ { "var": "c", "op": "==", "val": 3 }, { "var": "d", "op": "==", "val": 4 } ] } ] }`, @@ -342,6 +351,127 @@ func TestEvaluator(t *testing.T) { }, }, }, + { + `{ "var": "a", "op": "==", "val": "1m30s" }`, + []Evaluation{ + { + map[string]interface{}{ + "a": dur2m, + }, + false, + false, + }, + { + map[string]interface{}{ + "a": dur1m30s, + }, + true, + false, + }, + { + map[string]interface{}{ + "a": 1, + }, + false, + true, + }, + }, + }, + { + `{ "var": "a", "op": "!=", "val": "1m30s" }`, + []Evaluation{ + { + map[string]interface{}{ + "a": dur2m, + }, + true, + false, + }, + { + map[string]interface{}{ + "a": dur1m30s, + }, + false, + false, + }, + }, + }, + { + `{ "var": "a", "op": ">", "val": "1m30s" }`, + []Evaluation{ + { + map[string]interface{}{ + "a": dur2m, + }, + true, + false, + }, + { + map[string]interface{}{ + "a": dur1m30s, + }, + false, + false, + }, + }, + }, + { + `{ "var": "a", "op": ">=", "val": "1m30s" }`, + []Evaluation{ + { + map[string]interface{}{ + "a": dur2m, + }, + true, + false, + }, + { + map[string]interface{}{ + "a": dur1m30s, + }, + true, + false, + }, + }, + }, + { + `{ "var": "a", "op": "<", "val": "1m30s" }`, + []Evaluation{ + { + map[string]interface{}{ + "a": dur45s, + }, + true, + false, + }, + { + map[string]interface{}{ + "a": dur2m, + }, + false, + false, + }, + }, + }, + { + `{ "var": "a", "op": "<=", "val": "1m30s" }`, + []Evaluation{ + { + map[string]interface{}{ + "a": dur1m30s, + }, + true, + false, + }, + { + map[string]interface{}{ + "a": dur2m, + }, + false, + false, + }, + }, + }, } for _, test := range tests { diff --git a/parser.go b/parser.go index 27e3013..6af81bc 100644 --- a/parser.go +++ b/parser.go @@ -116,6 +116,11 @@ func toLiteral(i interface{}) (Expr, error) { case reflect.Int32: return &NumberLiteral{Val: float64(i.(int32))}, nil case reflect.Int64: + dur, isDuration := i.(time.Duration) + if isDuration { + return &DurationLiteral{Val: dur}, nil + } + return &NumberLiteral{Val: float64(i.(int64))}, nil case reflect.Float32: return &NumberLiteral{Val: float64(i.(float32))}, nil