Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Codegen Lua 5.1 friendly string literals #762

Merged
merged 7 commits into from
Jul 22, 2024
2 changes: 1 addition & 1 deletion spec/api/process_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ describe("tl.process", function()
end)
it("should strip BOM from files", function()

local bom = "\xEF\xBB\xBF"
local bom = "\239\187\191"
local current_dir = lfs.currentdir()
local dir_name = util.write_tmp_dir(finally, {
["main.tl"] = bom .. [[
Expand Down
46 changes: 46 additions & 0 deletions spec/code_gen/string_compatability_spec.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
local util = require("spec.util")

describe("string literal code generation", function()
it("generates Lua 5.1 compatible escape sequences in string literals", util.gen([[
local _hex_bytes = "\xDe\xAD\xbE\xef\x05"
local _unicode = "hello \u{4e16}\u{754C}"
local _whitespace_removal = "hello\z

, world!"
local _source_new_lines_get_preserved = 0
]], [[
local _hex_bytes = "\222\173\190\239\005"
local _unicode = "hello \228\184\150\231\149\140"
local _whitespace_removal = "hello, world!"


local _source_new_lines_get_preserved = 0
]], "5.1"))

it("does not substitute escape sequences in [[strings]]", util.gen([==[
local _literal_string = [[
foo
\000\xee\u{ffffff}
bar
]]
]==], [==[
local _literal_string = [[
foo
\000\xee\u{ffffff}
bar
]]
]==], "5.1"))

for _, version in ipairs { "5.1", "5.3", "5.4" } do
local source = [[local _hex = "\xaa\xbb\xcc"]]
local expected = version == "5.1"
and [[local _hex = "\170\187\204"]]
or source
it(
version == "5.1"
and "does not make substitutions when target is 5.1"
or "does make substitutions when target is not 5.1",
util.gen(source, expected, version)
)
end
end)
6 changes: 3 additions & 3 deletions spec/util.lua
Original file line number Diff line number Diff line change
Expand Up @@ -548,13 +548,13 @@ local function gen(lax, code, expected, gen_target)
return function()
local ast, syntax_errors = tl.parse(code, "foo.tl")
assert.same({}, syntax_errors, "Code was not expected to have syntax errors")
local result = tl.type_check(ast, { filename = "foo.tl", lax = lax, gen_target = gen_target })
local result = assert(tl.type_check(ast, { filename = "foo.tl", lax = lax, gen_target = gen_target, gen_compat = gen_target == "5.4" and "off" or nil }))
assert.same({}, result.type_errors)
local output_code = tl.pretty_print_ast(ast)
local output_code = tl.pretty_print_ast(ast, gen_target)

local expected_ast, expected_errors = tl.parse(expected, "foo.tl")
assert.same({}, expected_errors, "Code was not expected to have syntax errors")
local expected_code = tl.pretty_print_ast(expected_ast)
local expected_code = tl.pretty_print_ast(expected_ast, gen_target)

assert.same(expected_code, output_code)
end
Expand Down
96 changes: 78 additions & 18 deletions tl.lua
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
local _tl_compat; if (tonumber((_VERSION or ''):match('[%d.]*$')) or 0) < 5.3 then local p, m = pcall(require, 'compat53.module'); if p then _tl_compat = m end end; local assert = _tl_compat and _tl_compat.assert or assert; local debug = _tl_compat and _tl_compat.debug or debug; local io = _tl_compat and _tl_compat.io or io; local ipairs = _tl_compat and _tl_compat.ipairs or ipairs; local load = _tl_compat and _tl_compat.load or load; local math = _tl_compat and _tl_compat.math or math; local _tl_math_maxinteger = math.maxinteger or math.pow(2, 53); local os = _tl_compat and _tl_compat.os or os; local package = _tl_compat and _tl_compat.package or package; local pairs = _tl_compat and _tl_compat.pairs or pairs; local string = _tl_compat and _tl_compat.string or string; local table = _tl_compat and _tl_compat.table or table; local _tl_table_unpack = unpack or table.unpack
local _tl_compat; if (tonumber((_VERSION or ''):match('[%d.]*$')) or 0) < 5.3 then local p, m = pcall(require, 'compat53.module'); if p then _tl_compat = m end end; local assert = _tl_compat and _tl_compat.assert or assert; local debug = _tl_compat and _tl_compat.debug or debug; local io = _tl_compat and _tl_compat.io or io; local ipairs = _tl_compat and _tl_compat.ipairs or ipairs; local load = _tl_compat and _tl_compat.load or load; local math = _tl_compat and _tl_compat.math or math; local _tl_math_maxinteger = math.maxinteger or math.pow(2, 53); local os = _tl_compat and _tl_compat.os or os; local package = _tl_compat and _tl_compat.package or package; local pairs = _tl_compat and _tl_compat.pairs or pairs; local string = _tl_compat and _tl_compat.string or string; local table = _tl_compat and _tl_compat.table or table; local _tl_table_unpack = unpack or table.unpack; local utf8 = _tl_compat and _tl_compat.utf8 or utf8
local VERSION = "0.15.3+dev"

local tl = {PrettyPrintOptions = {}, TypeCheckOptions = {}, Env = {}, Symbol = {}, Result = {}, Error = {}, TypeInfo = {}, TypeReport = {}, TypeReportEnv = {}, }
Expand Down Expand Up @@ -3851,6 +3851,14 @@ function tl.pretty_print_ast(ast, gen_target, mode)
["total"] = " <const>",
}

local function emit_exactly(node, _children)
local out = { y = node.y, h = 0 }
add_string(out, node.tk)
return out
end

local emit_exactly_visitor_cbs = { after = emit_exactly }

visit_node.cbs = {
["statements"] = {
after = function(node, children)
Expand Down Expand Up @@ -4259,13 +4267,6 @@ function tl.pretty_print_ast(ast, gen_target, mode)
return out
end,
},
["variable"] = {
after = function(node, _children)
local out = { y = node.y, h = 0 }
add_string(out, node.tk)
return out
end,
},
["newtype"] = {
after = function(node, _children)
local out = { y = node.y, h = 0 }
Expand Down Expand Up @@ -4296,6 +4297,74 @@ function tl.pretty_print_ast(ast, gen_target, mode)
return out
end,
},
["string"] = {
after = function(node, children)






if node.tk:sub(1, 1) == "[" or gen_target ~= "5.1" then
return emit_exactly(node, children)
end

local out = { y = node.y, h = 0 }

local replaced = node.tk
for _ in replaced:gmatch("\n") do
out.h = out.h + 1
end

replaced = replaced:gsub("()\\z(%s*)", function(index_in_disguise, ws)
local index = index_in_disguise - 1
if replaced:sub(index, index) == "\\" then
return "\\z" .. ws
end
for _ in ws:gmatch("\n") do
out.h = out.h - 1
end
return ""
end)

replaced = replaced:gsub("()\\x(..)", function(index_in_disguise, digits)
local index = index_in_disguise - 1
if replaced:sub(index, index) == "\\" then
return "\\x" .. digits
end
local byte = tonumber(digits, 16)
return byte and string.format("\\%03d", byte) or "\\x" .. digits
end)

replaced = replaced:gsub("()\\u{(.-)}", function(index_in_disguise, hex_digits)
local index = index_in_disguise - 1
if replaced:sub(index, index) == "\\" then
return "\\u{" .. hex_digits .. "}"
end
local codepoint = tonumber(hex_digits, 16)
if not codepoint then
return "\\000"
end
local sequence = utf8.char(codepoint)
return (sequence:gsub(".", function(c)
return ("\\%03d"):format(string.byte(c))
end))
end)

out[1] = replaced
return out
end,
},

["variable"] = emit_exactly_visitor_cbs,
["identifier"] = emit_exactly_visitor_cbs,
["number"] = emit_exactly_visitor_cbs,
["integer"] = emit_exactly_visitor_cbs,
["nil"] = emit_exactly_visitor_cbs,
["boolean"] = emit_exactly_visitor_cbs,
["..."] = emit_exactly_visitor_cbs,
["argument"] = emit_exactly_visitor_cbs,
["type_identifier"] = emit_exactly_visitor_cbs,
}

local visit_type = {}
Expand Down Expand Up @@ -4343,15 +4412,6 @@ function tl.pretty_print_ast(ast, gen_target, mode)

visit_node.cbs["expression_list"] = visit_node.cbs["variable_list"]
visit_node.cbs["argument_list"] = visit_node.cbs["variable_list"]
visit_node.cbs["identifier"] = visit_node.cbs["variable"]
visit_node.cbs["number"] = visit_node.cbs["variable"]
visit_node.cbs["integer"] = visit_node.cbs["variable"]
visit_node.cbs["string"] = visit_node.cbs["variable"]
visit_node.cbs["nil"] = visit_node.cbs["variable"]
visit_node.cbs["boolean"] = visit_node.cbs["variable"]
visit_node.cbs["..."] = visit_node.cbs["variable"]
visit_node.cbs["argument"] = visit_node.cbs["variable"]
visit_node.cbs["type_identifier"] = visit_node.cbs["variable"]

local out = recurse_node(ast, visit_node, visit_type)
if err then
Expand Down Expand Up @@ -10792,7 +10852,7 @@ end


local function read_full_file(fd)
local bom = "\xEF\xBB\xBF"
local bom = "\239\187\191"
local content, err = fd:read("*a")
if content:sub(1, bom:len()) == bom then
content = content:sub(bom:len() + 1)
Expand Down
94 changes: 77 additions & 17 deletions tl.tl
Original file line number Diff line number Diff line change
Expand Up @@ -3851,6 +3851,14 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean |
["total"] = " <const>",
}

local function emit_exactly(node: Node, _children: {Output}): Output
local out: Output = { y = node.y, h = 0 }
add_string(out, node.tk)
return out
end

local emit_exactly_visitor_cbs <const>: VisitorCallbacks<Node, Output> = { after = emit_exactly }

visit_node.cbs = {
["statements"] = {
after = function(node: Node, children: {Output}): Output
Expand Down Expand Up @@ -4259,13 +4267,6 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean |
return out
end,
},
["variable"] = {
after = function(node: Node, _children: {Output}): Output
local out: Output = { y = node.y, h = 0 }
add_string(out, node.tk)
return out
end,
},
["newtype"] = {
after = function(node: Node, _children: {Output}): Output
local out: Output = { y = node.y, h = 0 }
Expand Down Expand Up @@ -4296,6 +4297,74 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean |
return out
end,
},
["string"] = {
after = function(node: Node, children: {Output}): Output
-- translate escape sequences not supported by Lua 5.1
-- in particular:
-- - \z : removes trailing whitespace
-- - \xXX : hex byte
-- - \u{} : unicode

if node.tk:sub(1, 1) == "[" or gen_target ~= "5.1" then
euclidianAce marked this conversation as resolved.
Show resolved Hide resolved
return emit_exactly(node, children)
end

local out <const>: Output = { y = node.y, h = 0 }

local replaced = node.tk
for _ in replaced:gmatch("\n") do
out.h = out.h + 1
end

replaced = replaced:gsub("()\\z(%s*)", function(index_in_disguise: string, ws: string): string
local index <const> = index_in_disguise as integer - 1
if replaced:sub(index, index) == "\\" then
return "\\z" .. ws
end
for _ in ws:gmatch("\n") do
out.h = out.h - 1
end
return ""
end)

replaced = replaced:gsub("()\\x(..)", function(index_in_disguise: string, digits: string): string
local index <const> = index_in_disguise as integer - 1
if replaced:sub(index, index) == "\\" then
return "\\x" .. digits
end
local byte <const> = tonumber(digits, 16)
return byte and string.format("\\%03d", byte) or "\\x" .. digits
end)

replaced = replaced:gsub("()\\u{(.-)}", function(index_in_disguise: string, hex_digits: string): string
local index <const> = index_in_disguise as integer - 1
if replaced:sub(index, index) == "\\" then
return "\\u{" .. hex_digits .. "}"
end
local codepoint <const> = tonumber(hex_digits, 16)
if not codepoint then
return "\\000"
end
local sequence <const> = utf8.char(codepoint)
return (sequence:gsub(".", function(c: string): string
return ("\\%03d"):format(string.byte(c))
end))
end)

out[1] = replaced
return out
end,
},

["variable"] = emit_exactly_visitor_cbs,
["identifier"] = emit_exactly_visitor_cbs,
["number"] = emit_exactly_visitor_cbs,
["integer"] = emit_exactly_visitor_cbs,
["nil"] = emit_exactly_visitor_cbs,
["boolean"] = emit_exactly_visitor_cbs,
["..."] = emit_exactly_visitor_cbs,
["argument"] = emit_exactly_visitor_cbs,
["type_identifier"] = emit_exactly_visitor_cbs,
}

local visit_type: Visitor<TypeName, Type, Output> = {}
Expand Down Expand Up @@ -4343,15 +4412,6 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean |

visit_node.cbs["expression_list"] = visit_node.cbs["variable_list"]
visit_node.cbs["argument_list"] = visit_node.cbs["variable_list"]
visit_node.cbs["identifier"] = visit_node.cbs["variable"]
visit_node.cbs["number"] = visit_node.cbs["variable"]
visit_node.cbs["integer"] = visit_node.cbs["variable"]
visit_node.cbs["string"] = visit_node.cbs["variable"]
visit_node.cbs["nil"] = visit_node.cbs["variable"]
visit_node.cbs["boolean"] = visit_node.cbs["variable"]
visit_node.cbs["..."] = visit_node.cbs["variable"]
visit_node.cbs["argument"] = visit_node.cbs["variable"]
visit_node.cbs["type_identifier"] = visit_node.cbs["variable"]

local out = recurse_node(ast, visit_node, visit_type)
if err then
Expand Down Expand Up @@ -10792,7 +10852,7 @@ end
--------------------------------------------------------------------------------

local function read_full_file(fd: FILE): string, string
local bom <const> = "\xEF\xBB\xBF"
local bom <const> = "\239\187\191" -- "\xEF\xBB\xBF"
local content, err = fd:read("*a")
if content:sub(1, bom:len()) == bom then
content = content:sub(bom:len() + 1)
Expand Down
Loading