Skip to content

Commit

Permalink
adding a completer that sort of works :)
Browse files Browse the repository at this point in the history
  • Loading branch information
brifordwylie committed Nov 9, 2023
1 parent 19ed397 commit f55718b
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 5 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@ dash-bootstrap-templates >= 1.1.1
aws-cdk-lib >= 2.76.0
termcolor
tabulate
rich>=13.3.5
prompt_toolkit
22 changes: 18 additions & 4 deletions src/sageworks/cli/repl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from prompt_toolkit import PromptSession
from prompt_toolkit.history import InMemoryHistory
from prompt_toolkit.styles import Style
from prompt_toolkit.completion import Completer, Completion
from pprint import pprint

# SageWorks Imports
Expand All @@ -18,6 +19,18 @@
history = InMemoryHistory()


class SageWorksCompleter(Completer):
def __init__(self, globals):
self.globals = globals

def get_completions(self, document, complete_event):
word_before_cursor = document.get_word_before_cursor()

for name in self.globals:
if name.startswith(word_before_cursor):
yield Completion(name, start_position=-len(word_before_cursor))


# Command handler
class CommandHandler:
def __init__(self):
Expand Down Expand Up @@ -81,11 +94,11 @@ def handle_command(self, raw_text):
# For example, assignment (`foo = 5`) is a statement.
try:
exec(raw_text, self.session_globals, self.session_globals)
if raw_text.startswith('print'):
if raw_text.startswith("print"):
pass # Avoid printing None if it was a print statement
elif '=' in raw_text:
elif "=" in raw_text:
# If there was an assignment, print the assigned variable
left_hand_side = raw_text.split('=')[0].strip()
left_hand_side = raw_text.split("=")[0].strip()
# Evaluate the left hand side to print its new value
print(eval(left_hand_side, self.session_globals, self.session_globals))
except Exception as e:
Expand All @@ -98,8 +111,9 @@ def handle_command(self, raw_text):

# REPL loop
def repl():
session = PromptSession(history=history)
handler = CommandHandler()
completer = SageWorksCompleter(handler.session_globals) # Use the updated globals for the completer
session = PromptSession(completer=completer, history=history)

while True:
try:
Expand Down

0 comments on commit f55718b

Please sign in to comment.