Skip to content

Commit

Permalink
test: add end-to-end tests for the different diagnostic approaches
Browse files Browse the repository at this point in the history
  • Loading branch information
alcarney committed Apr 3, 2024
1 parent f77ca27 commit 3c1064f
Show file tree
Hide file tree
Showing 5 changed files with 664 additions and 68 deletions.
97 changes: 97 additions & 0 deletions examples/servers/publish_diagnostics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
############################################################################
# Copyright(c) Open Law Library. All rights reserved. #
# See ThirdPartyNotices.txt in the project root for additional notices. #
# #
# Licensed under the Apache License, Version 2.0 (the "License") #
# you may not use this file except in compliance with the License. #
# You may obtain a copy of the License at #
# #
# http: // www.apache.org/licenses/LICENSE-2.0 #
# #
# Unless required by applicable law or agreed to in writing, software #
# distributed under the License is distributed on an "AS IS" BASIS, #
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. #
# See the License for the specific language governing permissions and #
# limitations under the License. #
############################################################################
import logging
import re

from lsprotocol import types

from pygls.server import LanguageServer
from pygls.workspace import TextDocument

ADDITION = re.compile(r"^\s*(\d+)\s*\+\s*(\d+)\s*=\s*(\d+)?$")


class PublishDiagnosticServer(LanguageServer):
"""Language server demonstrating "push-model" diagnostics."""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.diagnostics = {}

def parse(self, document: TextDocument):
diagnostics = []

for idx, line in enumerate(document.lines):
match = ADDITION.match(line)
if match is not None:
left = int(match.group(1))
right = int(match.group(2))

expected_answer = left + right
actual_answer = match.group(3)

if actual_answer is not None and expected_answer == int(actual_answer):
continue

if actual_answer is None:
message = "Missing answer"
severity = types.DiagnosticSeverity.Warning
else:
message = f"Incorrect answer: {actual_answer}"
severity = types.DiagnosticSeverity.Error

diagnostics.append(
types.Diagnostic(
message=message,
severity=severity,
range=types.Range(
start=types.Position(line=idx, character=0),
end=types.Position(line=idx, character=len(line) - 1),
),
)
)

self.diagnostics[document.uri] = (document.version, diagnostics)
# logging.info("%s", self.diagnostics)


server = PublishDiagnosticServer("diagnostic-server", "v1")


@server.feature(types.TEXT_DOCUMENT_DID_OPEN)
def did_open(ls: PublishDiagnosticServer, params: types.DidOpenTextDocumentParams):
"""Parse each document when it is opened"""
doc = ls.workspace.get_text_document(params.text_document.uri)
ls.parse(doc)

for uri, (version, diagnostics) in ls.diagnostics.items():
ls.publish_diagnostics(uri=uri, version=version, diagnostics=diagnostics)


@server.feature(types.TEXT_DOCUMENT_DID_CHANGE)
def did_change(ls: PublishDiagnosticServer, params: types.DidOpenTextDocumentParams):
"""Parse each document when it is changed"""
doc = ls.workspace.get_text_document(params.text_document.uri)
ls.parse(doc)

for uri, (version, diagnostics) in ls.diagnostics.items():
ls.publish_diagnostics(uri=uri, version=version, diagnostics=diagnostics)


if __name__ == "__main__":
logging.basicConfig(level=logging.INFO, format="%(message)s")
server.start_io()
150 changes: 150 additions & 0 deletions examples/servers/pull_diagnostics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
############################################################################
# Copyright(c) Open Law Library. All rights reserved. #
# See ThirdPartyNotices.txt in the project root for additional notices. #
# #
# Licensed under the Apache License, Version 2.0 (the "License") #
# you may not use this file except in compliance with the License. #
# You may obtain a copy of the License at #
# #
# http: // www.apache.org/licenses/LICENSE-2.0 #
# #
# Unless required by applicable law or agreed to in writing, software #
# distributed under the License is distributed on an "AS IS" BASIS, #
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. #
# See the License for the specific language governing permissions and #
# limitations under the License. #
############################################################################
import logging
import re

from lsprotocol import types

from pygls.server import LanguageServer
from pygls.workspace import TextDocument

ADDITION = re.compile(r"^\s*(\d+)\s*\+\s*(\d+)\s*=\s*(\d+)?$")


class PublishDiagnosticServer(LanguageServer):
"""Language server demonstrating "push-model" diagnostics."""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.diagnostics = {}

def parse(self, document: TextDocument):
_, previous = self.diagnostics.get(document.uri, (0, []))
diagnostics = []

for idx, line in enumerate(document.lines):
match = ADDITION.match(line)
if match is not None:
left = int(match.group(1))
right = int(match.group(2))

expected_answer = left + right
actual_answer = match.group(3)

if actual_answer is not None and expected_answer == int(actual_answer):
continue

if actual_answer is None:
message = "Missing answer"
severity = types.DiagnosticSeverity.Warning
else:
message = f"Incorrect answer: {actual_answer}"
severity = types.DiagnosticSeverity.Error

diagnostics.append(
types.Diagnostic(
message=message,
severity=severity,
range=types.Range(
start=types.Position(line=idx, character=0),
end=types.Position(line=idx, character=len(line) - 1),
),
)
)

# Only update if the list has changed
if previous != diagnostics:
self.diagnostics[document.uri] = (document.version, diagnostics)

# logging.info("%s", self.diagnostics)


server = PublishDiagnosticServer("diagnostic-server", "v1")


@server.feature(types.TEXT_DOCUMENT_DID_OPEN)
def did_open(ls: PublishDiagnosticServer, params: types.DidOpenTextDocumentParams):
"""Parse each document when it is opened"""
doc = ls.workspace.get_text_document(params.text_document.uri)
ls.parse(doc)


@server.feature(types.TEXT_DOCUMENT_DID_CHANGE)
def did_change(ls: PublishDiagnosticServer, params: types.DidOpenTextDocumentParams):
"""Parse each document when it is changed"""
doc = ls.workspace.get_text_document(params.text_document.uri)
ls.parse(doc)


@server.feature(
types.TEXT_DOCUMENT_DIAGNOSTIC,
types.DiagnosticOptions(
identifier="pull-diagnostics",
inter_file_dependencies=False,
workspace_diagnostics=True,
),
)
def document_diagnostic(
ls: PublishDiagnosticServer, params: types.DocumentDiagnosticParams
):
"""Return diagnostics for the requested document"""
# logging.info("%s", params)

if (uri := params.text_document.uri) not in ls.diagnostics:
return

version, diagnostics = ls.diagnostics[uri]
result_id = f"{uri}@{version}"

if result_id == params.previous_result_id:
return types.UnchangedDocumentDiagnosticReport(result_id)

return types.FullDocumentDiagnosticReport(items=diagnostics, result_id=result_id)


@server.feature(types.WORKSPACE_DIAGNOSTIC)
def workspace_diagnostic(
ls: PublishDiagnosticServer, params: types.WorkspaceDiagnosticParams
):
"""Return diagnostics for the workspace."""
# logging.info("%s", params)
items = []
previous_ids = {result.value for result in params.previous_result_ids}

for uri, (version, diagnostics) in ls.diagnostics.items():
result_id = f"{uri}@{version}"
if result_id in previous_ids:
items.append(
types.WorkspaceUnchangedDocumentDiagnosticReport(
uri=uri, result_id=result_id, version=version
)
)
else:
items.append(
types.WorkspaceFullDocumentDiagnosticReport(
uri=uri,
version=version,
items=diagnostics,
)
)

return types.WorkspaceDiagnosticReport(items=items)


if __name__ == "__main__":
logging.basicConfig(level=logging.INFO, format="%(message)s")
server.start_io()
146 changes: 146 additions & 0 deletions tests/e2e/test_publish_diagnostics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
############################################################################
# Copyright(c) Open Law Library. All rights reserved. #
# See ThirdPartyNotices.txt in the project root for additional notices. #
# #
# Licensed under the Apache License, Version 2.0 (the "License") #
# you may not use this file except in compliance with the License. #
# You may obtain a copy of the License at #
# #
# http: // www.apache.org/licenses/LICENSE-2.0 #
# #
# Unless required by applicable law or agreed to in writing, software #
# distributed under the License is distributed on an "AS IS" BASIS, #
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. #
# See the License for the specific language governing permissions and #
# limitations under the License. #
############################################################################
from __future__ import annotations

import asyncio
import typing

import pytest_asyncio
from lsprotocol import types

if typing.TYPE_CHECKING:
from typing import Tuple

from pygls.lsp.client import BaseLanguageClient


@pytest_asyncio.fixture()
async def push_diagnostics(get_client_for):
async for client, response in get_client_for("publish_diagnostics.py"):
# Setup a diagnostics handler
client.diagnostics = {}

@client.feature(types.TEXT_DOCUMENT_PUBLISH_DIAGNOSTICS)
def publish_diagnostics(params: types.PublishDiagnosticsParams):
client.diagnostics[params.uri] = params.diagnostics

yield client, response


async def test_publish_diagnostics(
push_diagnostics: Tuple[BaseLanguageClient, types.InitializeResult],
path_for,
uri_for,
):
"""Ensure that the publish diagnostics server is working as expected."""
client, initialize_result = push_diagnostics

test_uri = uri_for("sums.txt")
test_path = path_for("sums.txt")

client.text_document_did_open(
types.DidOpenTextDocumentParams(
types.TextDocumentItem(
uri=test_uri,
language_id="plaintext",
version=0,
text=test_path.read_text(),
)
)
)

await asyncio.sleep(0.5)
assert test_uri in client.diagnostics

expected = [
types.Diagnostic(
message="Missing answer",
severity=types.DiagnosticSeverity.Warning,
range=types.Range(
start=types.Position(line=0, character=0),
end=types.Position(line=0, character=7),
),
),
types.Diagnostic(
message="Missing answer",
severity=types.DiagnosticSeverity.Warning,
range=types.Range(
start=types.Position(line=3, character=0),
end=types.Position(line=3, character=7),
),
),
types.Diagnostic(
message="Missing answer",
severity=types.DiagnosticSeverity.Warning,
range=types.Range(
start=types.Position(line=6, character=0),
end=types.Position(line=6, character=7),
),
),
]

assert expected == client.diagnostics[test_uri]

# Write an incorrect answer...
client.text_document_did_change(
types.DidChangeTextDocumentParams(
text_document=types.VersionedTextDocumentIdentifier(
uri=test_uri, version=1
),
content_changes=[
types.TextDocumentContentChangeEvent_Type1(
text=" 12",
range=types.Range(
start=types.Position(line=0, character=7),
end=types.Position(line=0, character=7),
),
)
],
)
)

await asyncio.sleep(0.5)
assert test_uri in client.diagnostics

expected = [
types.Diagnostic(
message="Incorrect answer: 12",
severity=types.DiagnosticSeverity.Error,
range=types.Range(
start=types.Position(line=0, character=0),
end=types.Position(line=0, character=10),
),
),
types.Diagnostic(
message="Missing answer",
severity=types.DiagnosticSeverity.Warning,
range=types.Range(
start=types.Position(line=3, character=0),
end=types.Position(line=3, character=7),
),
),
types.Diagnostic(
message="Missing answer",
severity=types.DiagnosticSeverity.Warning,
range=types.Range(
start=types.Position(line=6, character=0),
end=types.Position(line=6, character=7),
),
),
]

assert expected == client.diagnostics[test_uri]
Loading

0 comments on commit 3c1064f

Please sign in to comment.