diff --git a/spec/statement/if_spec.lua b/spec/statement/if_spec.lua index 99f009e12..7930942d1 100644 --- a/spec/statement/if_spec.lua +++ b/spec/statement/if_spec.lua @@ -16,6 +16,15 @@ describe("if", function() end ]])) + it("if expression propagates a boolean context", util.check([[ + local n = 123 + local s = "hello" + if n or s then + local ns: number | string = n or s + print(ns) + end + ]])) + it("accepts boolean expressions", util.check([[ local s = "Hallo, Welt" if string.match(s, "world") or s == "Hallo, Welt" then diff --git a/spec/statement/repeat_spec.lua b/spec/statement/repeat_spec.lua index fcfa43d6a..9bb2160d7 100644 --- a/spec/statement/repeat_spec.lua +++ b/spec/statement/repeat_spec.lua @@ -1,6 +1,29 @@ local util = require("spec.util") describe("repeat", function() + it("accepts a boolean", util.check([[ + local b = true + repeat + print(b) + until b + ]])) + + it("accepts a non-boolean", util.check([[ + local n = 123 + repeat + print(n) + until n + ]])) + + it("until expression propagates a boolean context", util.check([[ + local n = 123 + local s = "hello" + repeat + local ns: number | string = n or s + print(ns) + until n or s + ]])) + it("only closes scope after until", util.check([[ repeat local type R = record diff --git a/spec/statement/while_spec.lua b/spec/statement/while_spec.lua new file mode 100644 index 000000000..07023f603 --- /dev/null +++ b/spec/statement/while_spec.lua @@ -0,0 +1,26 @@ +local util = require("spec.util") + +describe("while", function() + it("accepts a boolean", util.check([[ + local b = true + while b do + print(b) + end + ]])) + + it("accepts a non-boolean", util.check([[ + local n = 123 + while n do + print(n) + end + ]])) + + it("while expression propagates a boolean context", util.check([[ + local n = 123 + local s = "hello" + while n or s do + local ns: number | string = n or s + print(ns) + end + ]])) +end) diff --git a/tl.lua b/tl.lua index 558be5f49..7d4f0b643 100644 --- a/tl.lua +++ b/tl.lua @@ -10764,6 +10764,9 @@ self:expand_type(node, values, elements) }) if node.if_block_n > 1 then self:infer_negation_of_if_blocks(node, node.if_parent, node.if_block_n - 1) end + if node.exp then + node.exp.expected = a_type(node, "boolean", {}) + end end, before_statements = function(self, node) if node.exp then @@ -10784,6 +10787,7 @@ self:expand_type(node, values, elements) }) before = function(self, node) self:widen_all_unions(node) + node.exp.expected = a_type(node, "boolean", {}) end, before_statements = function(self, node) self:begin_scope(node) @@ -10847,6 +10851,7 @@ self:expand_type(node, values, elements) }) before = function(self, node) self:widen_all_unions(node) + node.exp.expected = a_type(node, "boolean", {}) end, after = end_scope_and_none_type, @@ -11707,6 +11712,10 @@ self:expand_type(node, values, elements) }) end t = drop_constant_value(t) end + + if expected and expected.typename == "boolean" then + t = a_type(node, "boolean", {}) + end end if t then diff --git a/tl.tl b/tl.tl index af1c8ccc9..e7e0fa663 100644 --- a/tl.tl +++ b/tl.tl @@ -10764,6 +10764,9 @@ do if node.if_block_n > 1 then self:infer_negation_of_if_blocks(node, node.if_parent, node.if_block_n - 1) end + if node.exp then + node.exp.expected = a_type(node, "boolean", {}) + end end, before_statements = function(self: TypeChecker, node: Node) if node.exp then @@ -10784,6 +10787,7 @@ do before = function(self: TypeChecker, node: Node) -- widen all narrowed variables because we don't calculate a fixpoint yet self:widen_all_unions(node) + node.exp.expected = a_type(node, "boolean", {}) end, before_statements = function(self: TypeChecker, node: Node) self:begin_scope(node) @@ -10847,6 +10851,7 @@ do before = function(self: TypeChecker, node: Node) -- widen all narrowed variables because we don't calculate a fixpoint yet self:widen_all_unions(node) + node.exp.expected = a_type(node, "boolean", {}) end, -- only end scope after checking `until`, `statements` in repeat body has is_repeat == true after = end_scope_and_none_type, @@ -11707,6 +11712,10 @@ do end t = drop_constant_value(t) end + + if expected and expected is BooleanType then + t = a_type(node, "boolean", {}) + end end if t then