From 7b5f72bce989cb6329123ccb9ecfa333915588f3 Mon Sep 17 00:00:00 2001 From: Anton Medvedev Date: Fri, 3 Feb 2023 15:53:43 +0100 Subject: [PATCH] Add nil coalescing operator --- checker/checker.go | 18 +++++++++++++--- checker/checker_test.go | 2 ++ compiler/compiler.go | 7 +++++++ compiler/compiler_test.go | 19 +++++++++++++++++ docs/Language-Definition.md | 11 +++++++++- expr_test.go | 36 ++++++++++++++++++++++++++++++++ parser/lexer/lexer_test.go | 9 ++++++++ parser/lexer/state.go | 2 +- parser/parser.go | 30 +++++++++++++++++++-------- parser/parser_test.go | 41 +++++++++++++++++++++++++++++++++++++ vm/opcodes.go | 1 + vm/program.go | 3 +++ vm/vm.go | 5 +++++ 13 files changed, 170 insertions(+), 14 deletions(-) diff --git a/checker/checker.go b/checker/checker.go index 9b778d76c..76e3d0be8 100644 --- a/checker/checker.go +++ b/checker/checker.go @@ -38,9 +38,6 @@ func Check(tree *parser.Tree, config *conf.Config) (t reflect.Type, err error) { } default: if t != nil { - if t.Kind() == reflect.Interface { - t = t.Elem() - } if t.Kind() == v.config.Expect { return t, nil } @@ -358,6 +355,21 @@ func (v *visitor) BinaryNode(node *ast.BinaryNode) (reflect.Type, info) { return ret, info{} } + case "??": + if l == nil && r != nil { + return r, info{} + } + if l != nil && r == nil { + return l, info{} + } + if l == nil && r == nil { + return nilType, info{} + } + if r.AssignableTo(l) { + return l, info{} + } + return anyType, info{} + default: return v.error(node, "unknown operator (%v)", node.Operator) diff --git a/checker/checker_test.go b/checker/checker_test.go index d00c74320..ba612a4c4 100644 --- a/checker/checker_test.go +++ b/checker/checker_test.go @@ -121,6 +121,8 @@ var successTests = []string{ "Duration + Any == Time", "Any + Duration == Time", "Any.A?.B == nil", + "(Any.Bool ?? Bool) > 0", + "Bool ?? Bool", } func TestCheck(t *testing.T) { diff --git a/compiler/compiler.go b/compiler/compiler.go index aeefb5289..3cd32af0f 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -424,6 +424,13 @@ func (c *compiler) BinaryNode(node *ast.BinaryNode) { c.compile(node.Right) c.emit(OpRange) + case "??": + c.compile(node.Left) + end := c.emit(OpJumpIfNotNil, placeholder) + c.emit(OpPop) + c.compile(node.Right) + c.patchJump(end) + default: panic(fmt.Sprintf("unknown operator (%v)", node.Operator)) diff --git a/compiler/compiler_test.go b/compiler/compiler_test.go index 858a2b340..81da17316 100644 --- a/compiler/compiler_test.go +++ b/compiler/compiler_test.go @@ -230,6 +230,25 @@ func TestCompile(t *testing.T) { Arguments: []int{0, 1, 0, 2}, }, }, + { + `A ?? 1`, + vm.Program{ + Constants: []interface{}{ + &runtime.Field{ + Index: []int{0}, + Path: []string{"A"}, + }, + 1, + }, + Bytecode: []vm.Opcode{ + vm.OpLoadField, + vm.OpJumpIfNotNil, + vm.OpPop, + vm.OpPush, + }, + Arguments: []int{0, 2, 0, 1}, + }, + }, } for _, test := range tests { diff --git a/docs/Language-Definition.md b/docs/Language-Definition.md index 17dffdf4b..450c351e4 100644 --- a/docs/Language-Definition.md +++ b/docs/Language-Definition.md @@ -79,7 +79,7 @@ d> Conditional - ?: (ternary) + ?: (ternary), ?? (nil coalescing) @@ -147,6 +147,15 @@ without checking if the struct or the map is `nil`. If the struct or the map is author?.User?.Name ``` +#### Nil coalescing + +The `??` operator can be used to return the left-hand side if it is not `nil`, +otherwise the right-hand side is returned. + +```c++ +author?.User?.Name ?? "Anonymous" +``` + ### Slice Operator The slice operator `[:]` can be used to access a slice of an array. diff --git a/expr_test.go b/expr_test.go index f474f76fc..b80010fd5 100644 --- a/expr_test.go +++ b/expr_test.go @@ -1710,6 +1710,42 @@ func TestFunction(t *testing.T) { assert.Equal(t, 20, out) } +// Nil coalescing operator +func TestRun_NilCoalescingOperator(t *testing.T) { + env := map[string]interface{}{ + "foo": map[string]interface{}{ + "bar": "value", + }, + } + + t.Run("value", func(t *testing.T) { + p, err := expr.Compile(`foo.bar ?? "default"`, expr.Env(env)) + assert.NoError(t, err) + + out, err := expr.Run(p, env) + assert.NoError(t, err) + assert.Equal(t, "value", out) + }) + + t.Run("default", func(t *testing.T) { + p, err := expr.Compile(`foo.baz ?? "default"`, expr.Env(env)) + assert.NoError(t, err) + + out, err := expr.Run(p, env) + assert.NoError(t, err) + assert.Equal(t, "default", out) + }) + + t.Run("default with chain", func(t *testing.T) { + p, err := expr.Compile(`foo?.bar ?? "default"`, expr.Env(env)) + assert.NoError(t, err) + + out, err := expr.Run(p, map[string]interface{}{}) + assert.NoError(t, err) + assert.Equal(t, "default", out) + }) +} + // Mock types type mockEnv struct { diff --git a/parser/lexer/lexer_test.go b/parser/lexer/lexer_test.go index d5f2cc2c7..03ccbd14f 100644 --- a/parser/lexer/lexer_test.go +++ b/parser/lexer/lexer_test.go @@ -180,6 +180,15 @@ var lexTests = []lexTest{ {Kind: EOF}, }, }, + { + `foo ?? bar`, + []Token{ + {Kind: Identifier, Value: "foo"}, + {Kind: Operator, Value: "??"}, + {Kind: Identifier, Value: "bar"}, + {Kind: EOF}, + }, + }, } func compareTokens(i1, i2 []Token) bool { diff --git a/parser/lexer/state.go b/parser/lexer/state.go index 41885fe95..1212aa321 100644 --- a/parser/lexer/state.go +++ b/parser/lexer/state.go @@ -156,7 +156,7 @@ func not(l *lexer) stateFn { } func questionMark(l *lexer) stateFn { - l.accept(".") + l.accept(".?") l.emit(Operator) return root } diff --git a/parser/parser.go b/parser/parser.go index f9584858b..fd26fe18b 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -58,6 +58,7 @@ var binaryOperators = map[string]operator{ "%": {60, left}, "**": {100, right}, "^": {100, right}, + "??": {500, left}, } var builtins = map[string]builtin{ @@ -113,9 +114,13 @@ func Parse(input string) (*Tree, error) { } func (p *parser) error(format string, args ...interface{}) { + p.errorAt(p.current, format, args...) +} + +func (p *parser) errorAt(token Token, format string, args ...interface{}) { if p.err == nil { // show first error p.err = &file.Error{ - Location: p.current.Location, + Location: token.Location, Message: fmt.Sprintf(format, args...), } } @@ -143,22 +148,28 @@ func (p *parser) expect(kind Kind, values ...string) { func (p *parser) parseExpression(precedence int) Node { nodeLeft := p.parsePrimary() - token := p.current - for token.Is(Operator) && p.err == nil { + lastOperator := "" + opToken := p.current + for opToken.Is(Operator) && p.err == nil { negate := false var notToken Token - if token.Is(Operator, "not") { + if opToken.Is(Operator, "not") { p.next() notToken = p.current negate = true - token = p.current + opToken = p.current } - if op, ok := binaryOperators[token.Value]; ok { + if op, ok := binaryOperators[opToken.Value]; ok { if op.precedence >= precedence { p.next() + if lastOperator == "??" && opToken.Value != "??" && !opToken.Is(Bracket, "(") { + p.errorAt(opToken, "Operator (%v) and coalesce expressions (??) cannot be mixed. Wrap either by parentheses.", opToken.Value) + break + } + var nodeRight Node if op.associativity == left { nodeRight = p.parseExpression(op.precedence + 1) @@ -167,11 +178,11 @@ func (p *parser) parseExpression(precedence int) Node { } nodeLeft = &BinaryNode{ - Operator: token.Value, + Operator: opToken.Value, Left: nodeLeft, Right: nodeRight, } - nodeLeft.SetLocation(token.Location) + nodeLeft.SetLocation(opToken.Location) if negate { nodeLeft = &UnaryNode{ @@ -181,7 +192,8 @@ func (p *parser) parseExpression(precedence int) Node { nodeLeft.SetLocation(notToken.Location) } - token = p.current + lastOperator = opToken.Value + opToken = p.current continue } } diff --git a/parser/parser_test.go b/parser/parser_test.go index 71ce323c4..a93ecdc2f 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -408,6 +408,42 @@ func TestParse(t *testing.T) { "[]", &ArrayNode{}, }, + { + "foo ?? bar", + &BinaryNode{Operator: "??", + Left: &IdentifierNode{Value: "foo"}, + Right: &IdentifierNode{Value: "bar"}}, + }, + { + "foo ?? bar ?? baz", + &BinaryNode{Operator: "??", + Left: &BinaryNode{Operator: "??", + Left: &IdentifierNode{Value: "foo"}, + Right: &IdentifierNode{Value: "bar"}}, + Right: &IdentifierNode{Value: "baz"}}, + }, + { + "foo ?? (bar || baz)", + &BinaryNode{Operator: "??", + Left: &IdentifierNode{Value: "foo"}, + Right: &BinaryNode{Operator: "||", + Left: &IdentifierNode{Value: "bar"}, + Right: &IdentifierNode{Value: "baz"}}}, + }, + { + "foo || bar ?? baz", + &BinaryNode{Operator: "||", + Left: &IdentifierNode{Value: "foo"}, + Right: &BinaryNode{Operator: "??", + Left: &IdentifierNode{Value: "bar"}, + Right: &IdentifierNode{Value: "baz"}}}, + }, + { + "foo ?? bar()", + &BinaryNode{Operator: "??", + Left: &IdentifierNode{Value: "foo"}, + Right: &CallNode{Callee: &IdentifierNode{Value: "bar"}}}, + }, } for _, test := range parseTests { actual, err := parser.Parse(test.input) @@ -479,6 +515,11 @@ a map key must be a quoted string, a number, a identifier, or an expression encl unexpected token Operator(",") (1:16) | {foo:1, bar:2, ,} | ...............^ + +foo ?? bar || baz +Operator (||) and coalesce expressions (??) cannot be mixed. Wrap either by parentheses. (1:12) + | foo ?? bar || baz + | ...........^ ` func TestParse_error(t *testing.T) { diff --git a/vm/opcodes.go b/vm/opcodes.go index 8e47f1d3c..b3117e73c 100644 --- a/vm/opcodes.go +++ b/vm/opcodes.go @@ -26,6 +26,7 @@ const ( OpJumpIfTrue OpJumpIfFalse OpJumpIfNil + OpJumpIfNotNil OpJumpIfEnd OpJumpBackward OpIn diff --git a/vm/program.go b/vm/program.go index f7da12374..7a417903c 100644 --- a/vm/program.go +++ b/vm/program.go @@ -138,6 +138,9 @@ func (program *Program) Disassemble() string { case OpJumpIfNil: jump("OpJumpIfNil") + case OpJumpIfNotNil: + jump("OpJumpIfNotNil") + case OpJumpIfEnd: jump("OpJumpIfEnd") diff --git a/vm/vm.go b/vm/vm.go index a8bda2f30..46f7628f2 100644 --- a/vm/vm.go +++ b/vm/vm.go @@ -174,6 +174,11 @@ func (vm *VM) Run(program *Program, env interface{}) (_ interface{}, err error) vm.ip += arg } + case OpJumpIfNotNil: + if !runtime.IsNil(vm.current()) { + vm.ip += arg + } + case OpJumpIfEnd: scope := vm.Scope() if scope.It >= scope.Len {