Skip to content

Commit

Permalink
fix(language_server): improved completions for builtins and globals
Browse files Browse the repository at this point in the history
  • Loading branch information
TheNuclearNexus committed Jan 31, 2025
1 parent b42e876 commit a6a785e
Show file tree
Hide file tree
Showing 6 changed files with 148 additions and 57 deletions.
120 changes: 81 additions & 39 deletions language_server/server/features/completion.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import builtins
from functools import reduce
import inspect
import logging
import types
from typing import Any

from bolt import AstAttribute, AstIdentifier, UndefinedIdentifier, Variable
from bolt import AstAttribute, AstIdentifier, Runtime, UndefinedIdentifier, Variable
from lsprotocol import types as lsp
from mecha import (
AstItemSlot,
Expand All @@ -20,6 +25,7 @@
from ..indexing import get_type_annotation
from ..utils.reflection import (
UNKNOWN_TYPE,
FunctionInfo,
format_function_hints,
get_name_of_type,
get_type_info,
Expand Down Expand Up @@ -107,7 +113,7 @@ def get_completions(

items = []
if len(diagnostics) > 0:
items = get_diag_completions(pos, mecha, diagnostics)
items = get_diag_completions(pos, mecha, ctx.inject(Runtime), diagnostics)
elif ast is not None:
current_node = get_node_at_position(ast, pos)

Expand Down Expand Up @@ -149,6 +155,30 @@ def add_registry_items(
)


def get_variable_description(name: str, value: Any):
doc_string = "\n---\n" + value.__doc__ if value.__doc__ is not None else ""
return f"```python\n(variable) {name}: {get_name_of_type(type(value))}\n```{doc_string}"


def get_class_description(name: str, value: type):
doc_string = "\n---\n" + value.__doc__ if value.__doc__ is not None else ""

return f"```python\nclass {name}()\n```{doc_string}"


def get_function_description(name: str, function: Any):
function_info = None
if isinstance(function, FunctionInfo):
function_info = function
else:
function_info = FunctionInfo.extract(function)

doc_string = "\n---\n" + function_info.doc if function_info.doc is not None else ""


return f"```py\n{format_function_hints(name, function_info)}\n```{doc_string}"


def get_bolt_completions(
node: AstNode,
items: list[lsp.CompletionItem],
Expand All @@ -166,22 +196,7 @@ def get_bolt_completions(
logging.debug(type_info)

for name, type in type_info.fields.items():
kind = (
lsp.CompletionItemKind.Property
if not name.isupper()
else lsp.CompletionItemKind.Constant
)

items.append(
lsp.CompletionItem(
name,
kind=kind,
documentation=lsp.MarkupContent(
kind=lsp.MarkupKind.Markdown,
value=f"```py\n{name}: {get_name_of_type(type)}\n```\n{type.__doc__ or ''}",
),
)
)
add_variable_completion(items, name, type)

for name, function_info in type_info.functions.items():
items.append(
Expand All @@ -190,14 +205,14 @@ def get_bolt_completions(
kind=lsp.CompletionItemKind.Function,
documentation=lsp.MarkupContent(
kind=lsp.MarkupKind.Markdown,
value=f"```py\n(function) {format_function_hints(name, function_info)}\n```\n{function_info.doc or ''}",
value=get_function_description(name, function_info),
),
)
)


def get_diag_completions(
pos: lsp.Position, mecha: Mecha, diagnostics: list[InvalidSyntax]
pos: lsp.Position, mecha: Mecha, runtime: Runtime, diagnostics: list[InvalidSyntax]
):
items = []
for diagnostic in diagnostics:
Expand All @@ -219,35 +234,62 @@ def get_diag_completions(

if isinstance(diagnostic, UndefinedIdentifier):
for name, variable in diagnostic.lexical_scope.variables.items():
add_variable(items, name, variable)
add_variable_definition(items, name, variable)

for name, value in runtime.globals.items():
add_raw_definition(items, name, value)

for name in runtime.builtins:
add_raw_definition(items, name, getattr(builtins, name))

break
return items


def add_variable(items: list[lsp.CompletionItem], name: str, variable: Variable):

def add_variable_definition(
items: list[lsp.CompletionItem], name: str, variable: Variable
):
possible_types = set()
documentation = None

for binding in variable.bindings:
origin = binding.origin
if type_annotations := origin.__dict__.get("type_annotations"):
for t in type_annotations:
match t:
case AstIdentifier() as identifer:
possible_types.add(identifer.value)
case type() as _type:
possible_types.add(_type.__name__)
case _:
possible_types.add(str(type_annotations))
if annotation := get_type_annotation(origin):
possible_types.add(annotation)

if len(possible_types) > 0:
description = f"```python\n{name}: {' | '.join(possible_types)}\n```"
documentation = lsp.MarkupContent(lsp.MarkupKind.Markdown, description)
_type = reduce(lambda a, b: a | b, possible_types)
add_variable_completion(items, name, _type)

items.append(
lsp.CompletionItem(
name, documentation=documentation, kind=lsp.CompletionItemKind.Variable
)

def add_raw_definition(items: list[lsp.CompletionItem], name: str, value: Any):
if inspect.isclass(value):
add_class_completion(items, name, value)
elif inspect.isfunction(value) or inspect.isbuiltin(value):
add_function_completion(items, name, value)
else:
add_variable_completion(items, name, type(value))


def add_class_completion(items: list[lsp.CompletionItem], name: str, _type):
description = get_class_description(name, _type)
documentation = lsp.MarkupContent(lsp.MarkupKind.Markdown, description)

items.append(lsp.CompletionItem(name, documentation=documentation, kind=lsp.CompletionItemKind.Class))

def add_function_completion(items: list[lsp.CompletionItem], name: str, function):
description = get_function_description(name, function)
documentation = lsp.MarkupContent(lsp.MarkupKind.Markdown, description)

items.append(lsp.CompletionItem(name, documentation=documentation, kind=lsp.CompletionItemKind.Function))

def add_variable_completion(items: list[lsp.CompletionItem], name: str, _type):
kind = (
lsp.CompletionItemKind.Property
if not name.isupper()
else lsp.CompletionItemKind.Constant
)

description = get_variable_description(name, _type)
documentation = lsp.MarkupContent(lsp.MarkupKind.Markdown, description)

items.append(lsp.CompletionItem(name, documentation=documentation, kind=kind))
6 changes: 4 additions & 2 deletions language_server/server/features/semantics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
from dataclasses import dataclass, field
from types import NoneType
from typing import Any, Literal, Union, get_args
from typing import Any, Literal, Union, get_args, get_origin

from beet import Context
from beet.core.utils import required_field
Expand Down Expand Up @@ -258,8 +258,10 @@ def assignment(self, assignment: AstAssignment):
def generic_identifier(self, identifier: Any):
annotation = get_type_annotation(identifier)

if annotation is not None and inspect.isfunction(annotation):
if annotation is not None and (inspect.isfunction(annotation) or inspect.isbuiltin(annotation)):
self.nodes.append((identifier, TOKEN_TYPES["function"], 0))
elif annotation is not None and get_origin(annotation) is type:
self.nodes.append((identifier, TOKEN_TYPES["class"], 0))
else:
kind = TOKEN_TYPES["variable"]
modifiers = 0
Expand Down
6 changes: 3 additions & 3 deletions language_server/server/features/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,17 +205,17 @@ def parse_function(
dependents = set()

compiled_module = None
runtime = ctx.inject(Runtime)
if location in COMPILATION_RESULTS:
prev_compilation = COMPILATION_RESULTS[location]
dependents = prev_compilation.dependents
compiled_module = prev_compilation.compiled_module

if len(diagnostics) == 0 and Module in ctx.data.extend_namespace:
runtime = ctx.inject(Runtime)

if fresh_module := runtime.modules.get(function):
fresh_module.ast = index_function_ast(
fresh_module.ast, location, fresh_module
fresh_module.ast, location, runtime, fresh_module
)

for dependency in fresh_module.dependencies:
Expand All @@ -224,7 +224,7 @@ def parse_function(

compiled_module = fresh_module

ast = index_function_ast(ast, location, module=compiled_module)
ast = index_function_ast(ast, location, runtime=runtime, module=compiled_module)
logging.debug(compiled_module)

for dependent in dependents:
Expand Down
58 changes: 49 additions & 9 deletions language_server/server/indexing.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import builtins
import inspect
import logging
from dataclasses import dataclass, field
from functools import reduce
from typing import Any, Optional, TypeVar

from beet.core.utils import extra_field
from beet.core.utils import extra_field, required_field
from bolt import (
AstAssignment,
AstAttribute,
Expand All @@ -18,6 +19,7 @@
AstTuple,
AstValue,
CompiledModule,
Runtime,
)
from mecha import (
AstBlock,
Expand Down Expand Up @@ -85,16 +87,43 @@ def was_referenced(references: list[AstNode], identifier: AstNode):
return True


def get_referenced_type(module: CompiledModule, identifier: AstIdentifier):
def is_builtin(identifier: AstIdentifier):
if identifier.value.startswith("_"):
return None

return (
getattr(builtins, identifier.value)
if hasattr(builtins, identifier.value)
else None
)


def annotate_types(annotation):
if isinstance(annotation, type):
return type[annotation]
elif inspect.isfunction(annotation) or inspect.isbuiltin(annotation):
return annotation
else:
return type(annotation)


def get_referenced_type(
runtime: Runtime, module: CompiledModule, identifier: AstIdentifier
):
var_name = identifier.value
defined_variables = module.lexical_scope.variables

if variable := defined_variables.get(var_name):
for binding in variable.bindings:
if was_referenced(binding.references, identifier) and (
type_annotation := get_type_annotation(binding.origin)
annotation := get_type_annotation(binding.origin)
):
return type_annotation
return annotation
elif identifier.value in module.globals:
return annotate_types(runtime.globals[identifier.value])

elif annotation := is_builtin(identifier):
return annotate_types(annotation)

return UNKNOWN_TYPE

Expand All @@ -103,12 +132,15 @@ def get_referenced_type(module: CompiledModule, identifier: AstIdentifier):


def index_function_ast(
ast: T, function_location: str, module: CompiledModule | None = None
ast: T,
function_location: str,
runtime: Runtime | None = None,
module: CompiledModule | None = None,
) -> T:
resolve_paths(ast, path="/".join(function_location.split(":")[1:]))
try:
initial_values = InitialValues()
bindings = Bindings(module=module)
bindings = Bindings(module=module, runtime=runtime)

return bindings(initial_values(ast))
except Exception as e:
Expand Down Expand Up @@ -173,16 +205,24 @@ def value(self, value: AstValue):

@dataclass
class Bindings(Reducer):
module: Optional[CompiledModule] = extra_field(default=None)
module: Optional[CompiledModule] = required_field()
runtime: Optional[Runtime] = required_field()

@rule(AstIdentifier)
def identifier(self, identifier):
logging.debug(identifier)
logging.debug(self.module)
if get_type_annotation(identifier) or self.module is None:
logging.debug(self.runtime)
if (
get_type_annotation(identifier)
or self.module is None
or self.runtime is None
):
return identifier

set_type_annotation(identifier, get_referenced_type(self.module, identifier))
set_type_annotation(
identifier, get_referenced_type(self.runtime, self.module, identifier)
)

return identifier

Expand Down
11 changes: 7 additions & 4 deletions language_server/server/utils/reflection.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,19 +150,22 @@ def get_name_of_type(annotation):
return annotation.__name__ if hasattr(annotation, "__name__") else annotation


def format_function_hints(name: str, signature: FunctionInfo):
hint = f"def {name}("
def format_function_hints(name: str, signature: FunctionInfo, keyword: str = "def"):
hint = f"{keyword} {name}("

return_type = signature.return_annotation

parameters = []

for name, parameter in signature.parameters:
annotation = get_name_of_type(parameter.annotation or parameter.default)
annotation = get_name_of_type(parameter.annotation)

if annotation is None and parameter.default is not inspect.Parameter.empty:
annotation = get_name_of_type(type(parameter.default))

annotation_string = ": " + str(annotation) if annotation else ""
default_string = (
" = " + str(parameter.default)
" = " + parameter.default.__repr__()
if parameter.default is not inspect.Parameter.empty
else ""
)
Expand Down
4 changes: 4 additions & 0 deletions tests/data/test/modules/use.bolt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,7 @@ TEST = @s

TEST.positioned((1,1,1)).tag("foo").tag("bar", False)

print()
Exception
ctx
Selector

0 comments on commit a6a785e

Please sign in to comment.