diff --git a/lib/ruby_indexer/lib/ruby_indexer/reference_finder.rb b/lib/ruby_indexer/lib/ruby_indexer/reference_finder.rb index 8f8b54b54f..c0d3bd90b3 100644 --- a/lib/ruby_indexer/lib/ruby_indexer/reference_finder.rb +++ b/lib/ruby_indexer/lib/ruby_indexer/reference_finder.rb @@ -35,10 +35,14 @@ class InstanceVariableTarget < Target #: String attr_reader :name - #: (String name) -> void - def initialize(name) + #: Array[String] + attr_reader :receiver_ancestors + + #: (String name, Array[String] receiver_ancestors) -> void + def initialize(name, receiver_ancestors) super() @name = name + @receiver_ancestors = receiver_ancestors end end @@ -321,7 +325,10 @@ def collect_constant_references(name, location) def collect_instance_variable_references(name, location, declaration) return unless @target.is_a?(InstanceVariableTarget) && name == @target.name - @references << Reference.new(name, location, declaration: declaration) + receiver_type = @stack.join("::") + if @target.receiver_ancestors.include?(receiver_type) + @references << Reference.new(name, location, declaration: declaration) + end end end end diff --git a/lib/ruby_indexer/test/reference_finder_test.rb b/lib/ruby_indexer/test/reference_finder_test.rb index 4f11f3be4f..31a7521697 100644 --- a/lib/ruby_indexer/test/reference_finder_test.rb +++ b/lib/ruby_indexer/test/reference_finder_test.rb @@ -216,22 +216,43 @@ def foo assert_equal(11, refs[2].location.start_line) end - def test_finds_instance_variable_read_references - refs = find_instance_variable_references("@foo", <<~RUBY) + def test_finds_instance_variable_references + refs = find_instance_variable_references("@name", ["Foo"], <<~RUBY) class Foo - def foo - @foo + def initialize + @name = "foo" + end + def name + @name + end + def name_capital + @name[0] + end + end + + class Bar + def initialize + @name = "foo" + end + def name + @name end end RUBY - assert_equal(1, refs.size) + assert_equal(3, refs.size) - assert_equal("@foo", refs[0].name) + assert_equal("@name", refs[0].name) assert_equal(3, refs[0].location.start_line) + + assert_equal("@name", refs[1].name) + assert_equal(6, refs[1].location.start_line) + + assert_equal("@name", refs[2].name) + assert_equal(9, refs[2].location.start_line) end def test_finds_instance_variable_write_references - refs = find_instance_variable_references("@foo", <<~RUBY) + refs = find_instance_variable_references("@foo", ["Foo"], <<~RUBY) class Foo def write @foo = 1 @@ -252,26 +273,61 @@ def write assert_equal(7, refs[4].location.start_line) end - def test_finds_instance_variable_references_ignore_context - refs = find_instance_variable_references("@name", <<~RUBY) - class Foo + def test_finds_instance_variable_references_in_receiver_ancestors + refs = find_instance_variable_references("@name", ["Foo", "Base", "Parent"], <<~RUBY) + module Base + def change_name(name) + @name = name + end def name + @name + end + end + + class Parent + def initialize + @name = "parent" + end + def name_capital + @name[0] + end + end + + class Foo < Parent + include Base + def initialize @name = "foo" end + def name + @name + end end + class Bar def name @name = "bar" end end RUBY - assert_equal(2, refs.size) + assert_equal(6, refs.size) assert_equal("@name", refs[0].name) assert_equal(3, refs[0].location.start_line) assert_equal("@name", refs[1].name) - assert_equal(8, refs[1].location.start_line) + assert_equal(6, refs[1].location.start_line) + + assert_equal("@name", refs[2].name) + assert_equal(12, refs[2].location.start_line) + + assert_equal("@name", refs[3].name) + assert_equal(15, refs[3].location.start_line) + + assert_equal("@name", refs[4].name) + assert_equal(22, refs[4].location.start_line) + + assert_equal("@name", refs[5].name) + assert_equal(25, refs[5].location.start_line) end def test_accounts_for_reopened_classes @@ -310,8 +366,8 @@ def find_method_references(method_name, source) find_references(target, source) end - def find_instance_variable_references(instance_variable_name, source) - target = ReferenceFinder::InstanceVariableTarget.new(instance_variable_name) + def find_instance_variable_references(instance_variable_name, receiver_ancestors, source) + target = ReferenceFinder::InstanceVariableTarget.new(instance_variable_name, receiver_ancestors) find_references(target, source) end diff --git a/lib/ruby_lsp/requests/references.rb b/lib/ruby_lsp/requests/references.rb index a2593bcf0c..1271ad0546 100644 --- a/lib/ruby_lsp/requests/references.rb +++ b/lib/ruby_lsp/requests/references.rb @@ -115,7 +115,11 @@ def create_reference_target(target_node, node_context) Prism::InstanceVariableReadNode, Prism::InstanceVariableTargetNode, Prism::InstanceVariableWriteNode - RubyIndexer::ReferenceFinder::InstanceVariableTarget.new(target_node.name.to_s) + receiver_type = @global_state.type_inferrer.infer_receiver_type(node_context) + return unless receiver_type + + ancestors = @global_state.index.linearized_ancestors_of(receiver_type.name) + RubyIndexer::ReferenceFinder::InstanceVariableTarget.new(target_node.name.to_s, ancestors) when Prism::CallNode, Prism::DefNode RubyIndexer::ReferenceFinder::MethodTarget.new(target_node.name.to_s) end