diff --git a/ast/print.go b/ast/print.go index f5937715..b79048b2 100644 --- a/ast/print.go +++ b/ast/print.go @@ -45,13 +45,16 @@ func (n *ConstantNode) String() string { } func (n *UnaryNode) String() string { - op := "" + op := n.Operator if n.Operator == "not" { op = fmt.Sprintf("%s ", n.Operator) - } else { - op = fmt.Sprintf("%s", n.Operator) } - if _, ok := n.Node.(*BinaryNode); ok { + wrap := false + switch n.Node.(type) { + case *BinaryNode, *ConditionalNode: + wrap = true + } + if wrap { return fmt.Sprintf("%s(%s)", op, n.Node.String()) } return fmt.Sprintf("%s%s", op, n.Node.String()) diff --git a/ast/print_test.go b/ast/print_test.go index 51edd63f..a4b20b0a 100644 --- a/ast/print_test.go +++ b/ast/print_test.go @@ -77,6 +77,7 @@ func TestPrint(t *testing.T) { {`(nil ?? 1) > 0`, `(nil ?? 1) > 0`}, {`{("a" + "b"): 42}`, `{("a" + "b"): 42}`}, {`(One == 1 ? true : false) && Two == 2`, `(One == 1 ? true : false) && Two == 2`}, + {`not (a == 1 ? b > 1 : b < 2)`, `not (a == 1 ? b > 1 : b < 2)`}, } for _, tt := range tests {