diff --git a/LICENSE b/LICENSE new file mode 100644 index 000000000..44520b67c --- /dev/null +++ b/LICENSE @@ -0,0 +1,29 @@ +BSD 3-Clause License + +Copyright (c) 2023, Small Magellanic Cloud AI Ltd. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/refact_lsp/__init__.py b/refact_lsp/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/refact_lsp/__main__.py b/refact_lsp/__main__.py new file mode 100644 index 000000000..40fa77b48 --- /dev/null +++ b/refact_lsp/__main__.py @@ -0,0 +1,111 @@ +from refact_lsp import refact_lsp_server +from refact_lsp import refact_client +import asyncio +import aiohttp +import os +import termcolor +from typing import Dict, Optional + + +async def regular_code_completion( + files: Dict[str, str], + cursor_file: str, + cursor: int, + max_tokens: int, + multiline: bool, + temperature: Optional[float] = None, +): + if "SMALLCLOUD_API_KEY" not in os.environ: + raise ValueError("Please either set SMALLCLOUD_API_KEY environment variable or create requests session manually.") + sess = aiohttp.ClientSession(headers={ + "Authorization": "Bearer %s" % os.environ["SMALLCLOUD_API_KEY"], + }) + for fn, txt in files.items(): + if not txt.endswith("\n"): + # server side will add it anyway, add here for comparison to work correctly later in this function + files[fn] += "\n" + try: + ans = await refact_client.nlp_model_call( + "contrast", + "CONTRASTcode", + req_session=sess, + sources=files, + intent="Infill", + function="infill", + cursor_file=cursor_file, + cursor0=cursor, + cursor1=cursor, + max_tokens=max_tokens, + temperature=temperature, + stop=(["\n\n"] if multiline else ["\n"]), + verbose=2, + ) + finally: + await sess.close() + # print(ans) + # print(ans["choices"][0]["files"][cursor_file]) + + # Find an \n after any different char, when looking from the end. The goal is to find a line that's different, but a complete line. + stop_at = None + i = -1 + whole_file = files[cursor_file] + modif_file = ans["choices"][0]["files"][cursor_file] + length = min(len(whole_file), len(modif_file)) + any_different = False + while i > -length: + if whole_file[i] == "\n": + stop_at = i + 1 + if whole_file[i] != modif_file[i]: + any_different = True + break + i -= 1 + fail = cursor >= len(modif_file) + stop_at; + if fail or not any_different: + return None + # import pudb; pudb.set_trace() + return modif_file[cursor : len(modif_file) + stop_at] + + +async def test_multiline(no_newline_in_the_end: bool): + hello_world_py = "# This print hello world and does not do anything else\ndef hello_world():" + if not no_newline_in_the_end: + hello_world_py += "\n" + files = { + "hello_world.py": hello_world_py, + } + completion = await regular_code_completion( + files, + "hello_world.py", + len(hello_world_py), + 50, + multiline=True, + ) + print(termcolor.colored(hello_world_py, "yellow") + termcolor.colored(str(completion), "green")) + print("checking if correct \"%s\"" % str(completion).replace("\n", "\\n")) + assert completion.strip().lower().replace("!", "") in [ + "print('hello world')", + "print(\"hello world\")", + ] + + +async def test_everything(): + await test_multiline(False) + await test_multiline(True) + + +def main(): + # print("listening on 127.0.0.1:1337") + # refact_lsp_server.server.start_tcp("127.0.0.1", 1337) + + loop = asyncio.new_event_loop() + try: + loop.run_until_complete(test_everything()) + finally: + loop.close() + + +# TODO: +# * allow empty model +# * allow no temperature +# * /contrast should return mime type json + diff --git a/refact_lsp/refact_client.py b/refact_lsp/refact_client.py new file mode 100644 index 000000000..d1643b268 --- /dev/null +++ b/refact_lsp/refact_client.py @@ -0,0 +1,67 @@ +import aiohttp +import time +import os +import json +from typing import Optional, Tuple, Generator, Union + + +base_url = "https://inference.smallcloud.ai/v1/" + + +class APIConnectionError(Exception): + pass + + +async def nlp_model_call( + endpoint: str, + model: str, + *, + req_session: Optional[aiohttp.ClientSession]=None, + max_tokens: int, + temperature: Optional[float]=None, + top_p: Optional[float]=None, + top_n: Optional[int]=None, + verbose: int=0, + **pass_args +) -> Union[Tuple[str, str], Generator[Tuple[str, str], None, None]]: + """ + A simplified version without streaming + """ + req_session = req_session or aiohttp.ClientSession() + assert isinstance(req_session, aiohttp.ClientSession) + url = base_url + endpoint + data = { + "model": model, + "max_tokens": max_tokens, + "stream": False, + **pass_args, + } + if top_p is not None: + data["top_p"] = top_p + if top_n is not None: + data["top_n"] = top_n + if temperature is not None: + data["temperature"] = temperature + if verbose > 1: + print("POST %s" % (data,)) + resp = None + txt = "" + try: + t0 = time.time() + resp = await req_session.post(url, json=data) + t1 = time.time() + if verbose > 0: + print("%0.1fms %s" % (1000*(t1 - t0), url)) + txt = await resp.text() + except Exception as e: + raise APIConnectionError("completions() failed: %s" % str(e)) + + if resp.status != 200: + raise APIConnectionError("status=%i, server returned:\n%s" % (resp.status, txt)) + + try: + j = json.loads(txt) + except Exception as e: + raise APIConnectionError("completions() json parse failed: %s\n%s" % (str(e), txt)) + + return j diff --git a/refact_lsp/refact_lsp_server.py b/refact_lsp/refact_lsp_server.py new file mode 100644 index 000000000..d45458342 --- /dev/null +++ b/refact_lsp/refact_lsp_server.py @@ -0,0 +1,76 @@ +import asyncio +from pygls.server import LanguageServer +from refact_lsp import refact_client + + +from lsprotocol.types import ( + TEXT_DOCUMENT_COMPLETION, + CompletionItem, + CompletionList, + CompletionParams, + CompletionItemKind, + TextEdit, + Range, + Position, +) + + +server = LanguageServer("refact-lsp", "v0.1") + + +# def run_diff_call(func, src_py, src_txt, cursor, intent): +# j = inf.nlp_model_call( +# "contrast", +# MODEL, +# sources={src_py: src_txt}, +# intent=intent, +# function=func, +# cursor_file=src_py, +# cursor0=cursor, +# cursor1=cursor, +# temperature=TEMPERATURE, +# max_tokens=MAX_TOKENS, +# top_p=TOP_P, +# max_edits=1, +# verbose=1, +# ) +# if "status" not in j or j["status"] != "completed": +# log(str(j)) +# quit(1) +# return j["choices"][0] + + +@server.feature(TEXT_DOCUMENT_COMPLETION) +async def completions(params: CompletionParams): + items = [] + document = server.workspace.get_document(params.text_document.uri) + current_line = document.lines[params.position.line].strip() + print("\"%s\"" % current_line) + if current_line.endswith("hello."): + items = [ + CompletionItem(label="world"), + CompletionItem(label="friend"), + ] + if current_line.endswith("trigger_text"): + items = [ + CompletionItem( + label="trigger_text.line1\nline2\nline3", + text_edit=TextEdit( + range=Range( + start=Position(line=params.position.line, character=params.position.character - len("trigger_text")), + end=params.position + ), + new_text="trigger_text.line1\nline2\nline3" + ) + ) + ] + await asyncio.sleep(1) + + return CompletionList( + is_incomplete=False, + items=items, + ) + + +if __name__ == '__main__': + server.start_tcp("127.0.0.1", 1337) diff --git a/setup.py b/setup.py new file mode 100644 index 000000000..368c5b727 --- /dev/null +++ b/setup.py @@ -0,0 +1,27 @@ +from setuptools import setup + +setup( + name="refact-lsp", + py_modules=["refact_lsp"], + version="0.0.1", + url="https://github.com/smallcloudai/refact_lsp", + summary="LSP server for Refact, suitable for Sublime Text, and other editors", + description="Install, run refact_lsp, enter your custom server URL, or just an API Key", + license='BSD 3-Clause License', + install_requires=[ + "requests", + ], + author="Small Magellanic Cloud AI Ltd.", + author_email="info@smallcloud.tech", + entry_points={ + "console_scripts": ["refact_lsp = refact_lsp.__main__:main"], + }, + classifiers=[ + "Development Status :: 3 - Alpha", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Intended Audience :: Developers", + "License :: OSI Approved :: BSD License", + "Environment :: Console", + "Operating System :: OS Independent", + ] +)