diff --git a/lua/luasnip/extras/_treesitter.lua b/lua/luasnip/extras/_treesitter.lua index e33c7afd9..107b443f0 100644 --- a/lua/luasnip/extras/_treesitter.lua +++ b/lua/luasnip/extras/_treesitter.lua @@ -281,41 +281,39 @@ function TSParser:get_node_at_pos(pos) return self.parser:named_node_for_range(range) end ----@param root TSNode? -function TSParser:get_range(root) - if root == nil then - -- try first tree's root - local first_tree = self.parser:trees()[1] - if first_tree then - root = first_tree:root() - end - end - if root == nil then - return +---Get the root for the smallest tree containing `pos`. +---@param pos { [1]: number, [2]: number } +---@return TSNode? +function TSParser:root_at(pos) + local tree = self.parser:tree_for_range({pos[1], pos[2], pos[1], pos[2]}, {ignore_injections = false}) + if not tree then + return nil end - local range = { root:range() } - - return { - root = root, - start = range[1], - stop = range[3] + 1, - } + return tree:root() end ---@param match_opts LuaSnip.extra.EffectiveMatchTSNodeOpts ---@param pos { [1]: number, [2]: number } ---@return LuaSnip.extra.NamedTSMatch?, TSNode? function TSParser:match_at(match_opts, pos) - local info = self:get_range() - if info == nil then + -- Since we want to find a match to the left of pos, and if we accept there + -- has to be at least one character (I assume), we should probably not look + -- for the tree containing `pos`, since that might be the wrong one (if + -- injected languages are in play). + local root = self:root_at({pos[1], pos[2]-1}) + if root == nil then return nil, nil end + local root_from_line, _, root_to_line, _ = root:range() local query = match_opts.query local selector = match_opts.selector() local next_ts_match = - query:iter_matches(info.root, self.source, info.start, info.stop) + -- end-line is excluded by iter_matches, if the column of root_to + -- greater than 0, we would erroneously ignore a line that could + -- contain our match. + query:iter_matches(root, self.source, root_from_line, root_to_line+1) for match, node in match_opts.generator(query, next_ts_match) do -- false: don't include bytes.