diff --git a/ast/print.go b/ast/print.go index 063e9eb2..394e0c1a 100644 --- a/ast/print.go +++ b/ast/print.go @@ -65,8 +65,7 @@ func (n *BinaryNode) String() string { var lhs, rhs string var lwrap, rwrap bool - lb, ok := n.Left.(*BinaryNode) - if ok { + if lb, ok := n.Left.(*BinaryNode); ok { if operator.Less(lb.Operator, n.Operator) { lwrap = true } @@ -77,9 +76,7 @@ func (n *BinaryNode) String() string { lwrap = true } } - - rb, ok := n.Right.(*BinaryNode) - if ok { + if rb, ok := n.Right.(*BinaryNode); ok { if operator.Less(rb.Operator, n.Operator) { rwrap = true } @@ -88,6 +85,13 @@ func (n *BinaryNode) String() string { } } + if _, ok := n.Left.(*ConditionalNode); ok { + lwrap = true + } + if _, ok := n.Right.(*ConditionalNode); ok { + rwrap = true + } + if lwrap { lhs = fmt.Sprintf("(%s)", n.Left.String()) } else { @@ -108,20 +112,25 @@ func (n *ChainNode) String() string { } func (n *MemberNode) String() string { + node := n.Node.String() + if _, ok := n.Node.(*BinaryNode); ok { + node = fmt.Sprintf("(%s)", node) + } + if n.Optional { if str, ok := n.Property.(*StringNode); ok && utils.IsValidIdentifier(str.Value) { - return fmt.Sprintf("%s?.%s", n.Node.String(), str.Value) + return fmt.Sprintf("%s?.%s", node, str.Value) } else { - return fmt.Sprintf("%s?.[%s]", n.Node.String(), n.Property.String()) + return fmt.Sprintf("%s?.[%s]", node, n.Property.String()) } } if str, ok := n.Property.(*StringNode); ok && utils.IsValidIdentifier(str.Value) { if _, ok := n.Node.(*PointerNode); ok { return fmt.Sprintf(".%s", str.Value) } - return fmt.Sprintf("%s.%s", n.Node.String(), str.Value) + return fmt.Sprintf("%s.%s", node, str.Value) } - return fmt.Sprintf("%s[%s]", n.Node.String(), n.Property.String()) + return fmt.Sprintf("%s[%s]", node, n.Property.String()) } func (n *SliceNode) String() string { diff --git a/ast/print_test.go b/ast/print_test.go index d9e55c2e..373030a4 100644 --- a/ast/print_test.go +++ b/ast/print_test.go @@ -72,6 +72,7 @@ func TestPrint(t *testing.T) { {`a[:]`, `a[:]`}, {`(nil ?? 1) > 0`, `(nil ?? 1) > 0`}, {`{("a" + "b"): 42}`, `{("a" + "b"): 42}`}, + {`(One == 1 ? true : false) && Two == 2`, `(One == 1 ? true : false) && Two == 2`}, } for _, tt := range tests { diff --git a/expr_test.go b/expr_test.go index df4eba50..4b4e2edf 100644 --- a/expr_test.go +++ b/expr_test.go @@ -1298,6 +1298,15 @@ func TestExpr(t *testing.T) { require.NoError(t, err, "eval") assert.Equal(t, tt.want, got, "eval") } + { + program, err := expr.Compile(tt.code, expr.Env(mock.Env{}), expr.Optimize(false)) + require.NoError(t, err) + + code := program.Node().String() + got, err := expr.Eval(code, env) + require.NoError(t, err, code) + assert.Equal(t, tt.want, got, code) + } }) } }