Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize subsetting #220

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
140 changes: 77 additions & 63 deletions lua/r/format/brackets.lua
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,7 @@ local parsers = require("nvim-treesitter.parsers")

-- Define the Treesitter query for capturing nodes
local query = [[
(extract_operator
(identifier)
(extract_operator
(identifier)
(extract_operator
(identifier)
)*
)*
) @dollar_operator
(extract_operator) @extract_operator

(subset
(identifier)*
Expand All @@ -28,41 +20,57 @@ local query = [[
(_) )) @single_bracket)
]]

--- Build a replacement string for a given node by traversing its child nodes
---@param node userdata: The Treesitter node to traverse
--- Build a replacement string for extract_operator nodes.
--- This function formats subsetting expressions using the $ operator.
---@param node TSNode: The Treesitter node to process
---@param bufnr number: The buffer number
---@return string: The constructed replacement string
---@return table: Replacement information for the node
local function build_extract_operator_replacement(node, bufnr)
local identifiers = {}

-- Function to recursively collect identifier text
local function collect_identifiers(inner_node)
if inner_node:type() == "identifier" then
local text = vim.treesitter.get_node_text(inner_node, bufnr)
if text ~= "" then table.insert(identifiers, text) end
else
local child_count = inner_node:named_child_count()
for i = 0, child_count - 1 do
local child_node = inner_node:named_child(i)
collect_identifiers(child_node)
end
end
end
local rhs_node = node:field("rhs")[1]

if not rhs_node then return {} end

-- Start collecting identifiers from the node
collect_identifiers(node)
local rhs_text = vim.treesitter.get_node_text(rhs_node, bufnr)
local replacement_rhs = string.format('[["%s"]]', rhs_text)

local start_row_rhs, start_col_rhs, end_row_rhs, end_col_rhs = rhs_node:range()

return {
start_row = start_row_rhs,
start_col = start_col_rhs - 1,
end_row = end_row_rhs,
end_col = end_col_rhs,
text = replacement_rhs,
}
end

-- Construct the replacement string
local replacement = table.remove(identifiers, 1)
for _, id in ipairs(identifiers) do
replacement = replacement .. string.format('[["%s"]]', id)
--- Format subset subsetting expressions
---@param node TSNode: The Treesitter node to process
---@param bufnr number: The buffer number
---@return table: Replacement information for the node
local function build_subset_replacement(node, bufnr)
local value_node = node:named_child(0)

if not value_node then return {} end

if node:named_child_count() == 1 then
local value = vim.treesitter.get_node_text(value_node, bufnr)
local replacement = string.format("[[%s]]", value)
local start_row, start_col, end_row, end_col = node:range()
return {
start_row = start_row,
start_col = start_col,
end_row = end_row,
end_col = end_col,
text = replacement,
}
end
return replacement

return {}
end

--- Formats subsetting expressions in the current buffer using Treesitter and
--- parses the buffer to find and replace specific patterns defined in a
--- Treesitter query
--- Formats subsetting expressions in the current buffer using Treesitter.
--- Parses the buffer to find and replace specific patterns defined in a Treesitter query.
---@param bufnr number: (optional) The buffer number to operate on; defaults to the current buffer if not provided
M.formatsubsetting = function(bufnr)
bufnr = bufnr or vim.api.nvim_get_current_buf()
Expand All @@ -85,38 +93,44 @@ M.formatsubsetting = function(bufnr)

local replacements = {}

-- Process extract_operator nodes
for id, node, _ in query_obj:iter_captures(root, bufnr, 0, -1) do
local replacement

if query_obj.captures[id] == "dollar_operator" then
replacement = build_extract_operator_replacement(node, bufnr)
elseif query_obj.captures[id] == "single_bracket" then
local value_node = node:named_child(0)

if not value_node then return end

-- Process only if the value is not a comma. This prevents
-- processing when the brackets are used for subsetting a matrix.
-- We can verify this by checking if the node has a single child.
if node:named_child_count() == 1 then
local value = vim.treesitter.get_node_text(value_node, bufnr)
replacement = string.format("[[%s]]", value)
end
if query_obj.captures[id] == "extract_operator" then
local replacement = build_extract_operator_replacement(node, bufnr)
if next(replacement) then table.insert(replacements, replacement) end
end
end

if replacement then
local start_row, start_col, end_row, end_col = node:range()
table.insert(replacements, {
start_row = start_row,
start_col = start_col,
end_row = end_row,
end_col = end_col,
text = replacement,
})
-- Sort replacements to apply the farthest right first
table.sort(replacements, function(a, b)
if a.start_row == b.start_row then return a.start_col > b.start_col end
return a.start_row > b.start_row
end)

-- Apply the replacements for extract_operator
for i = 1, #replacements do
local r = replacements[i]
vim.api.nvim_buf_set_text(
bufnr,
r.start_row,
r.start_col,
r.end_row,
r.end_col,
{ r.text }
)
end

-- Clear replacements and handle subset nodes
replacements = {}

for id, node, _ in query_obj:iter_captures(root, bufnr, 0, -1) do
if query_obj.captures[id] == "single_bracket" then
local replacement = build_subset_replacement(node, bufnr)
if next(replacement) then table.insert(replacements, replacement) end
end
end

-- Apply replacements in reverse order
-- Apply the replacements for subset
for i = #replacements, 1, -1 do
local r = replacements[i]
vim.api.nvim_buf_set_text(
Expand Down
Loading