diff --git a/tools/chapel-py/src/method-tables/uast-methods.h b/tools/chapel-py/src/method-tables/uast-methods.h index 2d6b1aa6b004..2cb39e794396 100644 --- a/tools/chapel-py/src/method-tables/uast-methods.h +++ b/tools/chapel-py/src/method-tables/uast-methods.h @@ -66,6 +66,8 @@ CLASS_BEGIN(AstNode) std::optional, return ScopeObject::tryCreate(contextObject, resolution::scopeForId(context, node->id()))) + PLAIN_GETTER(AstNode, creates_scope, "Returns true if this AST node creates a scope", + bool, return chpl::resolution::createsScope(node->tag())) PLAIN_GETTER(AstNode, type, "Get the type of this AST node, as a 3-tuple of (kind, type, param).", std::optional, diff --git a/tools/chpl-language-server/src/chpl-language-server.py b/tools/chpl-language-server/src/chpl-language-server.py index b526143459da..8b1041f65368 100755 --- a/tools/chpl-language-server/src/chpl-language-server.py +++ b/tools/chpl-language-server/src/chpl-language-server.py @@ -143,6 +143,12 @@ import argparse import configargparse +import sys +import functools + + +def log(*args, **kwargs): + print(*args, **kwargs, file=sys.stderr) class ChplcheckProxy: @@ -229,12 +235,12 @@ def decl_kind(decl: chapel.NamedDecl) -> Optional[SymbolKind]: return SymbolKind.Method else: return SymbolKind.Function - elif isinstance(decl, chapel.Variable): + elif isinstance(decl, chapel.VarLikeDecl): if decl.intent() == "type": return SymbolKind.TypeParameter elif decl.intent() == "param": return SymbolKind.Constant - elif decl.is_field(): + elif isinstance(decl, chapel.Variable) and decl.is_field(): return SymbolKind.Field elif decl.intent() == "": return SymbolKind.Constant @@ -264,28 +270,26 @@ def decl_kind_to_completion_kind(kind: SymbolKind) -> CompletionItemKind: def completion_item_for_decl( - decl: chapel.NamedDecl, override_name: Optional[str] = None + decl: chapel.NamedDecl, + override_name: Optional[str] = None, + override_sort: Optional[str] = None, ) -> Optional[CompletionItem]: kind = decl_kind(decl) if not kind: return None - # For now, we show completion for global symbols (not x.), - # so it seems like we ought to rule out methods. - if kind == SymbolKind.Method: - return None - # We don't want to show operators in completion lists, as they're # not really useful to the user in this context. if kind == SymbolKind.Operator: return None name_to_use = override_name if override_name else decl.name() + sort_text = override_sort if override_sort else name_to_use return CompletionItem( label=name_to_use, kind=decl_kind_to_completion_kind(kind), insert_text=name_to_use, - sort_text=name_to_use, + sort_text=sort_text, ) @@ -427,6 +431,27 @@ class ResolvedPair: resolved_to: NodeAndRange +@dataclass +class ScopedNodeAndRange: + node: chapel.AstNode + scopes: List[chapel.Scope] = field(default_factory=list) + + @staticmethod + def create(node: chapel.AstNode) -> Optional["ScopedNodeAndRange"]: + scopes = [] + scope = node.scope() + while scope: + scopes.append(scope) + scope = scope.parent_scope() + if len(scopes) == 0: + return None + return ScopedNodeAndRange(node, scopes) + + @property + def rng(self): + return location_to_range(self.node.location()) + + @dataclass class References: in_file: "FileInfo" @@ -579,6 +604,7 @@ class FileInfo: use_resolver: bool use_segments: PositionList[ResolvedPair] = field(init=False) def_segments: PositionList[NodeAndRange] = field(init=False) + scope_segments: PositionList[ScopedNodeAndRange] = field(init=False) instantiation_segments: PositionList[ Tuple[NodeAndRange, chapel.TypedSignature] ] = field(init=False) @@ -588,11 +614,11 @@ class FileInfo: Dict[chapel.TypedSignature, CallsInTypeContext], ] = field(init=False) siblings: chapel.SiblingMap = field(init=False) - visible_decls: List[Tuple[str, chapel.AstNode]] = field(init=False) def __post_init__(self): self.use_segments = PositionList(lambda x: x.ident.rng) self.def_segments = PositionList(lambda x: x.rng) + self.scope_segments = PositionList(lambda x: x.rng) self.instantiation_segments = PositionList(lambda x: x[0].rng) self.uses_here = {} self.rebuild_index() @@ -637,6 +663,18 @@ def _note_reference(self, node: Union[chapel.Dot, chapel.Identifier]): ResolvedPair(NodeAndRange(node), NodeAndRange(to)) ) + def _note_scope(self, node: chapel.AstNode): + if not node.creates_scope(): + return + s = ScopedNodeAndRange.create(node) + if not s: + return + self.scope_segments.append(s) + + @enter + def _enter_AstNode(self, node: chapel.AstNode): + self._note_scope(node) + @enter def _enter_Identifier(self, node: chapel.Identifier): self._note_reference(node) @@ -651,6 +689,7 @@ def _enter_Module(self, node: chapel.Module): _ = node.scope_resolve() self.def_segments.append(NodeAndRange(node)) + self._note_scope(node) @enter def _enter_Function(self, node: chapel.Function): @@ -658,66 +697,143 @@ def _enter_Function(self, node: chapel.Function): _ = node.scope_resolve() self.def_segments.append(NodeAndRange(node)) + self._note_scope(node) @enter def _enter_NamedDecl(self, node: chapel.NamedDecl): self.def_segments.append(NodeAndRange(node)) + self._note_scope(node) - def _collect_possibly_visible_decls(self, asts: List[chapel.AstNode]): - self.visible_decls = [] - for ast in asts: - if isinstance(ast, chapel.Comment): - continue + def get_visible_nodes( + self, pos: Position + ) -> List[Tuple[str, chapel.AstNode, int]]: + """ + Returns the visible nodes at a given position. + """ - scope = ast.scope() - if not scope: - continue + def visible_nodes_for_scope( + name: str, nodes: List[chapel.AstNode], in_bundled_module: bool + ) -> Optional[Tuple[str, chapel.AstNode]]: + """ + Narrow the list of visible nodes to those that are actually visible + + The heuristic here is to avoid showing internal symbols to the user, + i.e. those that start with 'chpl_' or '_'. We also avoid showing nodes + with the @chpldoc.nodoc attribute. + """ + # Don't show internal symbols to the user, even if they + # are technically in scope. The exception is if we're currently + # editing a standard file. + skip_prefixes = ["chpl_", "chpldev_", "_"] + if any(name.startswith(prefix) for prefix in skip_prefixes): + if not in_bundled_module: + return None + + # Only show nodes without @chpldoc.nodoc. The exception + # about standard files applies here too. + documented_nodes = [] + for node in nodes: + # apply aforementioned exception + if in_bundled_module: + documented_nodes.append(node) + continue + + # avoid nodes with nodoc attribute. + ag = node.attribute_group() + show = False + if not ag or not ag.get_attribute_named("chpldoc.nodoc"): + show = True + elif name in _ALLOWED_NODOC_DECLS: + # If users declare variables like 'here' themselves, + # we will not show them if they're @chpldoc.nodoc, + # since they're not special. + decl_file = node.location().path() + is_standard_decl = self.context.context.is_bundled_path( + decl_file + ) + show = is_standard_decl - file = ast.location().path() - in_bundled_module = self.context.context.is_bundled_path(file) + if show: + documented_nodes.append(node) - for name, nodes in scope.visible_nodes(): - # Don't show internal symbols to the user, even if they - # are technically in scope. The exception is if we're currently - # editing a standard file. - skip_prefixes = ["chpl_", "chpldev_", "_"] - if any(name.startswith(prefix) for prefix in skip_prefixes): - if not in_bundled_module: - continue + if len(documented_nodes) == 0: + return None - # Only show nodes without @chpldoc.nodoc. The exception - # about standard files applies here too. - documented_nodes = [] - for node in nodes: - # apply aforementioned exception - if in_bundled_module: - documented_nodes.append(node) + # Just take the first value to avoid showing N entries for + # overloaded functions. + return name, documented_nodes[0] + + @functools.cache + def files_named_in_use_or_import(scope: chapel.Scope) -> Set[str]: + files = set() + for m in scope.modules_named_in_use_or_import(): + files.add(m.location().path()) + return files + + def apply_depth_heuristic( + scope: chapel.Scope, + name: str, + node: chapel.AstNode, + original_depth: int, + cur_file: str, + ) -> Tuple[str, chapel.AstNode, int]: + """ + Heuristic to provide results in a more useful order, since + most clients will sort alphabetically. We can provide a + depth that is used to sort the results, so that the most + relevant results are shown first. + """ + depth = original_depth + vis_path = node.location().path() + if vis_path != cur_file: + # if from a different file, increase the depth by 1 + depth += 1 + # if from a bundled path increase the depth by 1 + depth += int(self.context.context.is_bundled_path(vis_path)) + # if not explicitly used, increase the depth by 1 + files_named_in_use = files_named_in_use_or_import(scope) + depth += int(vis_path not in files_named_in_use) + return (name, node, depth) + + def visible_nodes_for_scopes( + node: chapel.AstNode, scopes: List[chapel.Scope] + ): + visible_nodes = [] + cur_file = node.location().path() + in_bundled_module = self.context.context.is_bundled_path(cur_file) + # for each scope of the node + for depth, scope in enumerate(scopes): + # for all of the visible nodes in the scope + for name, nodes in scope.visible_nodes(): + # narrow the list of visible nodes to those that are + # actually visible to the user (i.e. not nodoc/internal) + visible_node = visible_nodes_for_scope( + name, nodes, in_bundled_module + ) + if visible_node is None: continue + vn = apply_depth_heuristic( + scope, *visible_node, depth, cur_file + ) + visible_nodes.append(vn) + return visible_nodes - # avoid nodes with nodoc attribute. - ag = node.attribute_group() - show = False - if not ag or not ag.get_attribute_named("chpldoc.nodoc"): - show = True - elif name in _ALLOWED_NODOC_DECLS: - # If users declare variables like 'here' themselves, - # we will not show them if they're @chpldoc.nodoc, - # since they're not special. - decl_file = node.location().path() - is_standard_decl = self.context.context.is_bundled_path( - decl_file - ) - show = is_standard_decl - - if show: - documented_nodes.append(node) - - if len(documented_nodes) == 0: + visible_nodes = [] + segment = self.scope_at_position(pos) + + if segment: + vns = visible_nodes_for_scopes(segment.node, segment.scopes) + visible_nodes.extend(vns) + else: + # no segment found, use the top level nodes + for a in self.get_asts(): + if isinstance(a, chapel.Comment): continue + s = a.scope() + if s: + visible_nodes.extend(visible_nodes_for_scopes(a, [s])) - # Just take the first value to avoid showing N entries for - # overloaded functions. - self.visible_decls.append((name, documented_nodes[0])) + return visible_nodes def _search_instantiations( self, @@ -769,12 +885,12 @@ def rebuild_index(self): refs.clear() self.use_segments.clear() self.def_segments.clear() + self.scope_segments.clear() self.visit(asts) self.use_segments.sort() self.def_segments.sort() self.siblings = chapel.SiblingMap(asts) - self._collect_possibly_visible_decls(asts) if self.use_resolver: # TODO: suppress resolution errors due to false-positives @@ -863,6 +979,20 @@ def get_use_or_def_segment_at_position( return None + def scope_at_position( + self, position: Position + ) -> Optional[ScopedNodeAndRange]: + """ + Given a position, return the scope that contains it. + """ + found = None + for s in self.scope_segments.elts: + if s.rng.start <= position <= s.rng.end: + found = s + if s.rng.start > position: + break + return found + def file_lines(self) -> List[str]: file_text = self.context.context.get_file_text( self.uri[len("file://") :] @@ -1723,18 +1853,32 @@ async def hover(ls: ChapelLanguageServer, params: HoverParams): content = MarkupContent(MarkupKind.Markdown, text) return Hover(content, range=segment.get_location().range) + # TODO: can we make use of 'trigger_character' to provide completions? + # since we can't parse 'foo.', can we use the presence of a trigger '.' to + # read a identifier from the file buffer, lookup the scope for that name, + # and provide completions based on that scope? @server.feature(TEXT_DOCUMENT_COMPLETION, CompletionOptions()) async def complete(ls: ChapelLanguageServer, params: CompletionParams): text_doc = ls.workspace.get_text_document(params.text_document.uri) fi, _ = ls.get_file_info(text_doc.uri) + names = set() items = [] - for name, decl in fi.visible_decls: - if isinstance(decl, chapel.NamedDecl): - items.append(completion_item_for_decl(decl, override_name=name)) - - items = [item for item in items if item] + for name, node, depth in fi.get_visible_nodes(params.position): + if not isinstance(node, chapel.NamedDecl): + continue + # if name is already suggested, skip it + if name in names: + continue + # use the depth to sort the suggestions, lower depths first + sort_name = f"{depth:03d}{name}" + item = completion_item_for_decl( + node, override_name=name, override_sort=sort_name + ) + if item: + items.append(item) + names.add(name) return CompletionList(is_incomplete=False, items=items) diff --git a/tools/chpl-language-server/test/completion.py b/tools/chpl-language-server/test/completion.py new file mode 100644 index 000000000000..c795a2821730 --- /dev/null +++ b/tools/chpl-language-server/test/completion.py @@ -0,0 +1,196 @@ +""" +Test that completion works properly +""" + +import sys + +from lsprotocol.types import ClientCapabilities +from lsprotocol.types import Position, TextDocumentIdentifier +from lsprotocol.types import CompletionItem, CompletionList, CompletionParams +from lsprotocol.types import InitializeParams +import pytest +import pytest_lsp +from pytest_lsp import ClientServerConfig, LanguageClient + +from util.utils import * +from util.config import CLS_PATH + + +@pytest_lsp.fixture( + config=ClientServerConfig( + server_command=[sys.executable, CLS_PATH()], + client_factory=get_base_client, + ) +) +async def client(lsp_client: LanguageClient): + # Setup + params = InitializeParams(capabilities=ClientCapabilities()) + await lsp_client.initialize_session(params) + + yield + + # Teardown + await lsp_client.shutdown_session() + + +async def check_completion_items( + client: LanguageClient, + doc: TextDocumentIdentifier, + pos: Position, + expected: typing.List[str], +) -> typing.List[CompletionItem]: + """ + Check that the names returned by the completion items match what is expected. + + Expected is a list strings of length N that should match the first N completion items. + """ + items = await client.text_document_completion_async( + params=CompletionParams(text_document=doc, position=pos) + ) + assert items is not None + items = items.items if isinstance(items, CompletionList) else items + assert len(items) >= len(expected) + + sorted_items = sorted(items, key=lambda x: (x.sort_text or x.label).lower()) + print(" ") + for expect, item in zip(expected, sorted_items): + print(f"Expected: {expect}, Actual: {item.label}") + assert item.label == expect + return sorted_items + + +@pytest.mark.asyncio +async def test_empty(client: LanguageClient): + """ + Test that an empty file returns something + """ + test = ";" + async with source_file(client, test) as doc: + items = await check_completion_items(client, doc, pos((0, 0)), []) + assert len(items) > 0 + + +@pytest.mark.asyncio +async def test_basic_completion(client: LanguageClient): + """ + Test basic features + """ + file = """ + var myGlobal: int; + record R { + var abc: int; + proc foo() { } + proc bar(myFormal) { + var localVar = 2; + forall myIndex in 1..10 + with (var myTaskPrivate = abc, var myGlobal = 1) { + + } + begin { + + } + } + } + var a = new R(); + """ + + async with source_file(client, file) as doc: + global_scope = (pos((0, 0)), ["a", "myGlobal", "R"]) + r_scope = (pos((2, 0)), ["abc", "bar", "foo"] + global_scope[1]) + bar_scope = (pos((5, 0)), ["localVar", "myFormal", "this"] + r_scope[1]) + + # forall_scope changes where myGlobal is + forall_parent_scope_elms = bar_scope[1].copy() + forall_parent_scope_elms.remove("myGlobal") + forall_scope = ( + pos((8, 0)), + ["myGlobal", "myIndex", "myTaskPrivate"] + forall_parent_scope_elms, + ) + + begin_scope = (pos((11, 0)), bar_scope[1]) + + expected = [global_scope, r_scope, bar_scope, forall_scope, begin_scope] + for p, exp in expected: + await check_completion_items(client, doc, p, exp) + + +@pytest.mark.asyncio +async def test_std_lib(client: LanguageClient): + """ + Test that modules can be imported and used + """ + A = """ + use IO; + use B; + var mySymbol = 1; + """ + B = """ + proc foo() { } + """ + + async with source_files(client, A=A, B=B) as docs: + a_scope = (pos((0, 0)), ["mySymbol", "foo"]) + b_scope = (pos((1, 0)), ["foo"]) + + a_completion = await check_completion_items( + client, docs("A"), a_scope[0], a_scope[1] + ) + # a should have IO and writef somewhere in the list + assert any( + [x.label == "IO" or x.label == "writef" for x in a_completion] + ) + + await check_completion_items(client, docs("B"), b_scope[0], b_scope[1]) + + await save_file(client, docs("A"), docs("B")) + assert len(client.diagnostics[docs("A").uri]) == 0 + assert len(client.diagnostics[docs("B").uri]) == 0 + + +@pytest.mark.asyncio +async def test_use_in_module(client: LanguageClient): + """ + Test that modules can be imported and used inside of a module + """ + + file = """ + module file { + var mySymbol = 1; + module M { + use Random; + } + } + """ + + async with source_file(client, file) as doc: + file_scope = (pos((0, 0)), ["M", "mySymbol"]) + m_scope = (pos((3, 0)), ["M", "mySymbol"]) + + await check_completion_items(client, doc, file_scope[0], file_scope[1]) + m_scope_items = await check_completion_items( + client, doc, m_scope[0], m_scope[1] + ) + assert any([x.label == "Random" for x in m_scope_items]) + + +@pytest.mark.asyncio +async def test_use_in_scope(client: LanguageClient): + """ + Test that modules can be imported and used inside of an arbitrary scope + """ + + file = """ + proc bar() { + use Sort; + } + ; + """ + + async with source_file(client, file) as doc: + # bar and Sort are at the same "depth", so we can't rely on ordering + items = await check_completion_items(client, doc, pos((0, 0)), []) + assert any([x.label == "bar" or x.label == "sort" for x in items]) + + # outside of the scope, we should only see bar + items = await check_completion_items(client, doc, pos((3, 0)), ["bar"]) + assert all([x.label != "sort" for x in items])