diff --git a/pkg/compiler/compiler.go b/pkg/compiler/compiler.go index 4b636314..b008d661 100644 --- a/pkg/compiler/compiler.go +++ b/pkg/compiler/compiler.go @@ -68,6 +68,7 @@ func (c *Compiler) Compile(query string) (program *runtime.Program, err error) { program.Arguments = l.arguments program.Constants = l.constants program.Locations = l.locations + program.CatchTable = l.catchTable return program, err } diff --git a/pkg/compiler/compiler_func_test.go b/pkg/compiler/compiler_func_test.go index bf185bdd..316b6cff 100644 --- a/pkg/compiler/compiler_func_test.go +++ b/pkg/compiler/compiler_func_test.go @@ -1,26 +1,27 @@ package compiler_test import ( + . "github.com/smartystreets/goconvey/convey" "testing" ) func TestFunctionCall(t *testing.T) { RunUseCases(t, []UseCase{ - //{ - // "RETURN TYPENAME(1)", - // "int", - // nil, - //}, - //{ - // "WAIT(10) RETURN 1", - // 1, - // nil, - //}, - //{ - // "LET duration = 10 WAIT(duration) RETURN 1", - // 1, - // nil, - //}, + { + "RETURN TYPENAME(1)", + "int", + nil, + }, + { + "WAIT(10) RETURN 1", + 1, + nil, + }, + { + "LET duration = 10 WAIT(duration) RETURN 1", + 1, + nil, + }, { "RETURN (FALSE OR T::FAIL())?", nil, @@ -31,88 +32,23 @@ func TestFunctionCall(t *testing.T) { nil, nil, }, + { + `FOR i IN [1, 2, 3, 4] + LET duration = 10 + + WAIT(duration) + + RETURN i * 2`, + []int{2, 4, 6, 8}, + ShouldEqualJSON, + }, //{ - // `FOR i IN [1, 2, 3, 4] - // LET duration = 10 - // - // WAIT(duration) - // - // RETURN i * 2`, - // []int{2, 4, 6, 8}, - // ShouldEqualJSON, + // `RETURN FIRST((FOR i IN 1..10 RETURN i * 2))`, + // 2, + // nil, //}, }) - // - //Convey("Should handle errors when ? is used", t, func() { - // c := compiler.New() - // c.RegisterFunction("ERROR", func(ctx context.Context, args ...core.Value) (core.Value, error) { - // return values.None, errors.New("test error") - // }) - // - // p, err := c.Compile(` - // RETURN ERROR()? - // `) - // - // So(err, ShouldBeNil) - // - // out, err := p.Run(context.Background()) - // - // So(err, ShouldBeNil) - // - // So(string(out), ShouldEqual, `null`) - //}) - // - //Convey("Should handle errors when ? is used within a group", t, func() { - // c := compiler.New() - // - // p, err := c.Compile(` - // RETURN (FALSE OR T::FAIL())? - // `) - // - // So(err, ShouldBeNil) - // - // out, err := p.Run(context.Background()) - // - // So(err, ShouldBeNil) - // - // So(string(out), ShouldEqual, `null`) - //}) - // - //Convey("Should return NONE when error is handled", t, func() { - // c := compiler.New() - // c.RegisterFunction("ERROR", func(ctx context.Context, args ...core.Value) (core.Value, error) { - // return values.NewString("booo"), errors.New("test error") - // }) - // - // p, err := c.Compile(` - // RETURN ERROR()? - // `) - // - // So(err, ShouldBeNil) - // - // out, err := p.Run(context.Background()) - // - // So(err, ShouldBeNil) - // - // So(string(out), ShouldEqual, `null`) - //}) - // - //Convey("Should be able to use FOR as an argument", t, func() { - // c := compiler.New() - // - // p, err := c.Compile(` - // RETURN FIRST((FOR i IN 1..10 RETURN i * 2)) - // `) - // - // So(err, ShouldBeNil) - // - // out, err := p.Run(context.Background()) - // - // So(err, ShouldBeNil) - // - // So(string(out), ShouldEqual, `2`) - //}) // //Convey("Should be able to use FOR as arguments", t, func() { // c := compiler.New() diff --git a/pkg/compiler/visitor.go b/pkg/compiler/visitor.go index 2c51d11f..305030d3 100644 --- a/pkg/compiler/visitor.go +++ b/pkg/compiler/visitor.go @@ -37,6 +37,7 @@ type ( loops []*loopScope globals map[string]int locals []variable + catchTable [][2]int } ) @@ -63,6 +64,7 @@ func newVisitor(src string) *visitor { v.loops = make([]*loopScope, 0) v.globals = make(map[string]int) v.locals = make([]variable, 0) + v.catchTable = make([][2]int, 0) return v } @@ -721,7 +723,13 @@ func (v *visitor) VisitPredicate(ctx *fql.PredicateContext) interface{} { v.emit(runtime.OpLike) } } else if c := ctx.ExpressionAtom(); c != nil { + startCatch := len(v.bytecode) c.Accept(v) + + if c.ErrorOperator() != nil { + endCatch := len(v.bytecode) + v.catchTable = append(v.catchTable, [2]int{startCatch, endCatch}) + } } return nil diff --git a/pkg/runtime/program.go b/pkg/runtime/program.go index 691ce64c..e0702bd7 100644 --- a/pkg/runtime/program.go +++ b/pkg/runtime/program.go @@ -9,11 +9,12 @@ import ( ) type Program struct { - Source *core.Source - Locations []core.Location - Bytecode []Opcode - Arguments []int - Constants []core.Value + Source *core.Source + Locations []core.Location + Bytecode []Opcode + Arguments []int + Constants []core.Value + CatchTable [][2]int } func (program *Program) Disassemble() string { diff --git a/pkg/runtime/vm.go b/pkg/runtime/vm.go index 166edd7d..9004c520 100644 --- a/pkg/runtime/vm.go +++ b/pkg/runtime/vm.go @@ -2,7 +2,6 @@ package runtime import ( "context" - "errors" "github.com/MontFerret/ferret/pkg/runtime/core" "github.com/MontFerret/ferret/pkg/runtime/operators" "github.com/MontFerret/ferret/pkg/runtime/values" @@ -24,21 +23,16 @@ func NewVM(opts ...EnvironmentOption) *VM { return vm } -func (vm *VM) Run(ctx context.Context, program *Program) (res []byte, err error) { - defer func() { - if r := recover(); r != nil { - switch x := r.(type) { - case string: - err = errors.New(x) - case error: - err = x - default: - err = errors.New("unknown panic") +func (vm *VM) Run(ctx context.Context, program *Program) ([]byte, error) { + tryCatch := func(pos int) bool { + for _, pair := range program.CatchTable { + if pos >= pair[0] && pos <= pair[1] { + return true } - - program = nil } - }() + + return false + } // TODO: Add code analysis to calculate the number of operands and variables stack := NewStack(len(program.Bytecode), 8) @@ -289,7 +283,7 @@ loop: if err == nil { stack.Push(res) - } else if op == OpCallOptional { + } else if op == OpCallOptional || tryCatch(vm.ip) { stack.Push(values.None) } else { return nil, err @@ -302,7 +296,7 @@ loop: if err == nil { stack.Push(res) - } else if op == OpCall1Optional { + } else if op == OpCall1Optional || tryCatch(vm.ip) { stack.Push(values.None) } else { return nil, err @@ -316,7 +310,7 @@ loop: if err == nil { stack.Push(res) - } else if op == OpCall2Optional { + } else if op == OpCall2Optional || tryCatch(vm.ip) { stack.Push(values.None) } else { return nil, err @@ -331,7 +325,7 @@ loop: if err == nil { stack.Push(res) - } else if op == OpCall3Optional { + } else if op == OpCall3Optional || tryCatch(vm.ip) { stack.Push(values.None) } else { return nil, err @@ -347,7 +341,7 @@ loop: if err == nil { stack.Push(res) - } else if op == OpCall4Optional { + } else if op == OpCall4Optional || tryCatch(vm.ip) { stack.Push(values.None) } else { return nil, err @@ -371,7 +365,7 @@ loop: if err == nil { stack.Push(res) - } else if op == OpCallNOptional { + } else if op == OpCallNOptional || tryCatch(vm.ip) { stack.Push(values.None) } else { return nil, err