Skip to content

Commit

Permalink
accept 'or' as a boolean in if/while/repeat expressions
Browse files Browse the repository at this point in the history
  • Loading branch information
hishamhm committed Jul 23, 2024
1 parent 32d2c3b commit add0bd7
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 0 deletions.
9 changes: 9 additions & 0 deletions spec/statement/if_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,15 @@ describe("if", function()
end
]]))

it("if expression propagates a boolean context", util.check([[
local n = 123
local s = "hello"
if n or s then
local ns: number | string = n or s
print(ns)
end
]]))

it("accepts boolean expressions", util.check([[
local s = "Hallo, Welt"
if string.match(s, "world") or s == "Hallo, Welt" then
Expand Down
23 changes: 23 additions & 0 deletions spec/statement/repeat_spec.lua
Original file line number Diff line number Diff line change
@@ -1,6 +1,29 @@
local util = require("spec.util")

describe("repeat", function()
it("accepts a boolean", util.check([[
local b = true
repeat
print(b)
until b
]]))

it("accepts a non-boolean", util.check([[
local n = 123
repeat
print(n)
until n
]]))

it("until expression propagates a boolean context", util.check([[
local n = 123
local s = "hello"
repeat
local ns: number | string = n or s
print(ns)
until n or s
]]))

it("only closes scope after until", util.check([[
repeat
local type R = record
Expand Down
26 changes: 26 additions & 0 deletions spec/statement/while_spec.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
local util = require("spec.util")

describe("while", function()
it("accepts a boolean", util.check([[
local b = true
while b do
print(b)
end
]]))

it("accepts a non-boolean", util.check([[
local n = 123
while n do
print(n)
end
]]))

it("while expression propagates a boolean context", util.check([[
local n = 123
local s = "hello"
while n or s do
local ns: number | string = n or s
print(ns)
end
]]))
end)
9 changes: 9 additions & 0 deletions tl.lua
Original file line number Diff line number Diff line change
Expand Up @@ -10824,6 +10824,9 @@ self:expand_type(node, values, elements) })
if node.if_block_n > 1 then
self:infer_negation_of_if_blocks(node, node.if_parent, node.if_block_n - 1)
end
if node.exp then
node.exp.expected = a_type(node, "boolean", {})
end
end,
before_statements = function(self, node)
if node.exp then
Expand All @@ -10844,6 +10847,7 @@ self:expand_type(node, values, elements) })
before = function(self, node)

self:widen_all_unions(node)
node.exp.expected = a_type(node, "boolean", {})
end,
before_statements = function(self, node)
self:begin_scope(node)
Expand Down Expand Up @@ -10907,6 +10911,7 @@ self:expand_type(node, values, elements) })
before = function(self, node)

self:widen_all_unions(node)
node.exp.expected = a_type(node, "boolean", {})
end,

after = end_scope_and_none_type,
Expand Down Expand Up @@ -11767,6 +11772,10 @@ self:expand_type(node, values, elements) })
end
t = drop_constant_value(t)
end

if expected and expected.typename == "boolean" then
t = a_type(node, "boolean", {})
end
end

if t then
Expand Down
9 changes: 9 additions & 0 deletions tl.tl
Original file line number Diff line number Diff line change
Expand Up @@ -10824,6 +10824,9 @@ do
if node.if_block_n > 1 then
self:infer_negation_of_if_blocks(node, node.if_parent, node.if_block_n - 1)
end
if node.exp then
node.exp.expected = a_type(node, "boolean", {})
end
end,
before_statements = function(self: TypeChecker, node: Node)
if node.exp then
Expand All @@ -10844,6 +10847,7 @@ do
before = function(self: TypeChecker, node: Node)
-- widen all narrowed variables because we don't calculate a fixpoint yet
self:widen_all_unions(node)
node.exp.expected = a_type(node, "boolean", {})
end,
before_statements = function(self: TypeChecker, node: Node)
self:begin_scope(node)
Expand Down Expand Up @@ -10907,6 +10911,7 @@ do
before = function(self: TypeChecker, node: Node)
-- widen all narrowed variables because we don't calculate a fixpoint yet
self:widen_all_unions(node)
node.exp.expected = a_type(node, "boolean", {})
end,
-- only end scope after checking `until`, `statements` in repeat body has is_repeat == true
after = end_scope_and_none_type,
Expand Down Expand Up @@ -11767,6 +11772,10 @@ do
end
t = drop_constant_value(t)
end

if expected and expected is BooleanType then
t = a_type(node, "boolean", {})
end
end

if t then
Expand Down

0 comments on commit add0bd7

Please sign in to comment.