diff --git a/lua/lspconfig/server_configurations/elmls.lua b/lua/lspconfig/server_configurations/elmls.lua index e3e63bcdb3..62be55516e 100644 --- a/lua/lspconfig/server_configurations/elmls.lua +++ b/lua/lspconfig/server_configurations/elmls.lua @@ -10,7 +10,7 @@ return { filetypes = { 'elm' }, root_dir = function(fname) local filetype = api.nvim_buf_get_option(0, 'filetype') - if filetype == 'elm' or (filetype == 'json' and fname:match 'elm%.json$') then + if util.ft_matches(filetype, 'elm') or (util.ft_matches(filetype, 'json') and fname:match 'elm%.json$') then return elm_root_pattern(fname) end end, diff --git a/lua/lspconfig/util.lua b/lua/lspconfig/util.lua index 2fb5e72ba5..1c0707cf5c 100644 --- a/lua/lspconfig/util.lua +++ b/lua/lspconfig/util.lua @@ -22,6 +22,21 @@ M.default_config = { -- global on_setup hook M.on_setup = nil +---@param filetype string the filetype to check (can be a compound, dot-separated filetype; see |'filetype'|) +---@param expected string|string[] the filetype(s) to match against +---@return boolean +function M.ft_matches(filetype, expected) + expected = type(expected) == 'table' and expected or { expected } + for ft in filetype:gmatch '([^.]+)' do + for _, expected_ft in ipairs(expected) do + if ft == expected_ft then + return true + end + end + end + return false +end + function M.bufname_valid(bufname) if bufname:match '^/' or bufname:match '^[a-zA-Z]:' or bufname:match '^zipfile://' or bufname:match '^tarfile:' then return true @@ -348,10 +363,8 @@ function M.get_active_clients_list_by_ft(filetype) local clients_list = {} for _, client in pairs(clients) do local filetypes = client.config.filetypes or {} - for _, ft in pairs(filetypes) do - if ft == filetype then - table.insert(clients_list, client.name) - end + if M.ft_matches(filetype, filetypes) then + table.insert(clients_list, client.name) end end return clients_list @@ -364,10 +377,8 @@ function M.get_other_matching_providers(filetype) for _, config in pairs(configs) do if not vim.tbl_contains(active_clients_list, config.name) then local filetypes = config.filetypes or {} - for _, ft in pairs(filetypes) do - if ft == filetype then - table.insert(other_matching_configs, config) - end + if M.ft_matches(filetype, filetypes) then + table.insert(other_matching_configs, config) end end end @@ -379,10 +390,8 @@ function M.get_config_by_ft(filetype) local matching_configs = {} for _, config in pairs(configs) do local filetypes = config.filetypes or {} - for _, ft in pairs(filetypes) do - if ft == filetype then - table.insert(matching_configs, config) - end + if M.ft_matches(filetype, filetypes) then + table.insert(matching_configs, config) end end return matching_configs