diff --git a/tl b/tl index 973c0039c..ac3693e12 100755 --- a/tl +++ b/tl @@ -892,7 +892,8 @@ do tlconfig["quiet"] = true tlconfig["gen_compat"] = "off" - local env = setup_env(tlconfig, args["file"][1]) + local filename = args["file"][1] + local env = setup_env(tlconfig, filename) env.keep_going = true env.report_types = true @@ -925,8 +926,9 @@ do local y, x = pos:match("^(%d+):?(%d*)") y = tonumber(y) or 1 x = tonumber(x) or 1 - json_out_table(io.stdout, tl.symbols_in_scope(tr, y, x)) + json_out_table(io.stdout, tl.symbols_in_scope(tr, filename, y, x)) else + tr.symbols = tr.symbols_by_file[filename] json_out_table(io.stdout, tr) end diff --git a/tl.lua b/tl.lua index c6aaae74e..0e153c951 100644 --- a/tl.lua +++ b/tl.lua @@ -628,6 +628,8 @@ local tl = {PrettyPrintOptions = {}, TypeCheckOptions = {}, Env = {}, Result = { + + @@ -5672,7 +5674,7 @@ function tl.new_type_reporter() tr = { by_pos = {}, types = {}, - symbols = mark_array({}), + symbols_by_file = {}, globals = {}, }, } @@ -5909,6 +5911,9 @@ function TypeReporter:store_result(collector, globals) end end + local symbols = mark_array({}) + tr.symbols_by_file[filename] = symbols + do local stack = {} @@ -5927,11 +5932,11 @@ function TypeReporter:store_result(collector, globals) else local other = stack[level] level = level - 1 - tr.symbols[other][4] = i + symbols[other][4] = i id = other - 1 end local sym = mark_array({ s.y, s.x, s.name, id }) - table.insert(tr.symbols, sym) + table.insert(symbols, sym) end end end @@ -5954,7 +5959,7 @@ end -function tl.symbols_in_scope(tr, y, x) +function tl.symbols_in_scope(tr, filename, y, x) local function find(symbols, at_y, at_x) local function le(a, b) return a[1] < b[1] or @@ -5965,9 +5970,13 @@ function tl.symbols_in_scope(tr, y, x) local ret = {} - local n = find(tr.symbols, y, x) + local symbols = tr.symbols_by_file[filename] + if not symbols then + return ret + end + + local n = find(symbols, y, x) - local symbols = tr.symbols while n >= 1 do local s = symbols[n] if s[3] == "@{" then diff --git a/tl.tl b/tl.tl index 3b31b4fd5..4a94f1db7 100644 --- a/tl.tl +++ b/tl.tl @@ -605,9 +605,11 @@ local record tl end record TypeReport + type Symbol = {integer, integer, string, integer} + by_pos: {string: {integer: {integer: integer}}} types: {integer: TypeInfo} - symbols: {{integer, integer, string, integer}} + symbols_by_file: {string: {Symbol}} globals: {string: integer} end @@ -5672,7 +5674,7 @@ function tl.new_type_reporter(): TypeReporter tr = { by_pos = {}, types = {}, - symbols = mark_array {}, + symbols_by_file = {}, globals = {}, }, } @@ -5909,6 +5911,9 @@ function TypeReporter:store_result(collector: TypeCollector, globals: {string:Va end end + local symbols: {TypeReport.Symbol} = mark_array {} + tr.symbols_by_file[filename] = symbols + -- resolve scope cross references, skipping unneeded scope blocks do local stack = {} @@ -5927,11 +5932,11 @@ function TypeReporter:store_result(collector: TypeCollector, globals: {string:Va else local other = stack[level] level = level - 1 - tr.symbols[other][4] = i -- overwrite id from @{ + symbols[other][4] = i -- overwrite id from @{ id = other - 1 end local sym = mark_array({ s.y, s.x, s.name, id }) - table.insert(tr.symbols, sym) + table.insert(symbols, sym) end end end @@ -5954,8 +5959,8 @@ end -- Report types -------------------------------------------------------------------------------- -function tl.symbols_in_scope(tr: TypeReport, y: integer, x: integer): {string:integer} - local function find(symbols: {{integer, integer, string, integer}}, at_y: integer, at_x: integer): integer +function tl.symbols_in_scope(tr: TypeReport, filename: string, y: integer, x: integer): {string:integer} + local function find(symbols: {TypeReport.Symbol}, at_y: integer, at_x: integer): integer local function le(a: {integer, integer}, b: {integer, integer}): boolean return a[1] < b[1] or (a[1] == b[1] and a[2] <= b[2]) @@ -5965,9 +5970,13 @@ function tl.symbols_in_scope(tr: TypeReport, y: integer, x: integer): {string:in local ret: {string:integer} = {} - local n = find(tr.symbols, y, x) + local symbols = tr.symbols_by_file[filename] + if not symbols then + return ret + end + + local n = find(symbols, y, x) - local symbols = tr.symbols while n >= 1 do local s = symbols[n] if s[3] == "@{" then