From add0bd711a6443f2988143ad1fd3eb3e4b788f6c Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Fri, 19 Jul 2024 15:33:19 -0300 Subject: [PATCH] accept 'or' as a boolean in if/while/repeat expressions --- spec/statement/if_spec.lua | 9 +++++++++ spec/statement/repeat_spec.lua | 23 +++++++++++++++++++++++ spec/statement/while_spec.lua | 26 ++++++++++++++++++++++++++ tl.lua | 9 +++++++++ tl.tl | 9 +++++++++ 5 files changed, 76 insertions(+) create mode 100644 spec/statement/while_spec.lua diff --git a/spec/statement/if_spec.lua b/spec/statement/if_spec.lua index 99f009e12..7930942d1 100644 --- a/spec/statement/if_spec.lua +++ b/spec/statement/if_spec.lua @@ -16,6 +16,15 @@ describe("if", function() end ]])) + it("if expression propagates a boolean context", util.check([[ + local n = 123 + local s = "hello" + if n or s then + local ns: number | string = n or s + print(ns) + end + ]])) + it("accepts boolean expressions", util.check([[ local s = "Hallo, Welt" if string.match(s, "world") or s == "Hallo, Welt" then diff --git a/spec/statement/repeat_spec.lua b/spec/statement/repeat_spec.lua index fcfa43d6a..9bb2160d7 100644 --- a/spec/statement/repeat_spec.lua +++ b/spec/statement/repeat_spec.lua @@ -1,6 +1,29 @@ local util = require("spec.util") describe("repeat", function() + it("accepts a boolean", util.check([[ + local b = true + repeat + print(b) + until b + ]])) + + it("accepts a non-boolean", util.check([[ + local n = 123 + repeat + print(n) + until n + ]])) + + it("until expression propagates a boolean context", util.check([[ + local n = 123 + local s = "hello" + repeat + local ns: number | string = n or s + print(ns) + until n or s + ]])) + it("only closes scope after until", util.check([[ repeat local type R = record diff --git a/spec/statement/while_spec.lua b/spec/statement/while_spec.lua new file mode 100644 index 000000000..07023f603 --- /dev/null +++ b/spec/statement/while_spec.lua @@ -0,0 +1,26 @@ +local util = require("spec.util") + +describe("while", function() + it("accepts a boolean", util.check([[ + local b = true + while b do + print(b) + end + ]])) + + it("accepts a non-boolean", util.check([[ + local n = 123 + while n do + print(n) + end + ]])) + + it("while expression propagates a boolean context", util.check([[ + local n = 123 + local s = "hello" + while n or s do + local ns: number | string = n or s + print(ns) + end + ]])) +end) diff --git a/tl.lua b/tl.lua index a8f68ad2a..471bf88cc 100644 --- a/tl.lua +++ b/tl.lua @@ -10824,6 +10824,9 @@ self:expand_type(node, values, elements) }) if node.if_block_n > 1 then self:infer_negation_of_if_blocks(node, node.if_parent, node.if_block_n - 1) end + if node.exp then + node.exp.expected = a_type(node, "boolean", {}) + end end, before_statements = function(self, node) if node.exp then @@ -10844,6 +10847,7 @@ self:expand_type(node, values, elements) }) before = function(self, node) self:widen_all_unions(node) + node.exp.expected = a_type(node, "boolean", {}) end, before_statements = function(self, node) self:begin_scope(node) @@ -10907,6 +10911,7 @@ self:expand_type(node, values, elements) }) before = function(self, node) self:widen_all_unions(node) + node.exp.expected = a_type(node, "boolean", {}) end, after = end_scope_and_none_type, @@ -11767,6 +11772,10 @@ self:expand_type(node, values, elements) }) end t = drop_constant_value(t) end + + if expected and expected.typename == "boolean" then + t = a_type(node, "boolean", {}) + end end if t then diff --git a/tl.tl b/tl.tl index 34de82b3b..0c61ce5b7 100644 --- a/tl.tl +++ b/tl.tl @@ -10824,6 +10824,9 @@ do if node.if_block_n > 1 then self:infer_negation_of_if_blocks(node, node.if_parent, node.if_block_n - 1) end + if node.exp then + node.exp.expected = a_type(node, "boolean", {}) + end end, before_statements = function(self: TypeChecker, node: Node) if node.exp then @@ -10844,6 +10847,7 @@ do before = function(self: TypeChecker, node: Node) -- widen all narrowed variables because we don't calculate a fixpoint yet self:widen_all_unions(node) + node.exp.expected = a_type(node, "boolean", {}) end, before_statements = function(self: TypeChecker, node: Node) self:begin_scope(node) @@ -10907,6 +10911,7 @@ do before = function(self: TypeChecker, node: Node) -- widen all narrowed variables because we don't calculate a fixpoint yet self:widen_all_unions(node) + node.exp.expected = a_type(node, "boolean", {}) end, -- only end scope after checking `until`, `statements` in repeat body has is_repeat == true after = end_scope_and_none_type, @@ -11767,6 +11772,10 @@ do end t = drop_constant_value(t) end + + if expected and expected is BooleanType then + t = a_type(node, "boolean", {}) + end end if t then