diff --git a/requirements.txt b/requirements.txt index 1449371c5..6f6bac0be 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,4 +13,4 @@ dash-bootstrap-templates >= 1.1.1 aws-cdk-lib >= 2.76.0 termcolor tabulate -rich>=13.3.5 +prompt_toolkit diff --git a/src/sageworks/cli/repl.py b/src/sageworks/cli/repl.py index 2d2256da6..5b7e21841 100644 --- a/src/sageworks/cli/repl.py +++ b/src/sageworks/cli/repl.py @@ -1,6 +1,7 @@ from prompt_toolkit import PromptSession from prompt_toolkit.history import InMemoryHistory from prompt_toolkit.styles import Style +from prompt_toolkit.completion import Completer, Completion from pprint import pprint # SageWorks Imports @@ -18,6 +19,18 @@ history = InMemoryHistory() +class SageWorksCompleter(Completer): + def __init__(self, globals): + self.globals = globals + + def get_completions(self, document, complete_event): + word_before_cursor = document.get_word_before_cursor() + + for name in self.globals: + if name.startswith(word_before_cursor): + yield Completion(name, start_position=-len(word_before_cursor)) + + # Command handler class CommandHandler: def __init__(self): @@ -81,11 +94,11 @@ def handle_command(self, raw_text): # For example, assignment (`foo = 5`) is a statement. try: exec(raw_text, self.session_globals, self.session_globals) - if raw_text.startswith('print'): + if raw_text.startswith("print"): pass # Avoid printing None if it was a print statement - elif '=' in raw_text: + elif "=" in raw_text: # If there was an assignment, print the assigned variable - left_hand_side = raw_text.split('=')[0].strip() + left_hand_side = raw_text.split("=")[0].strip() # Evaluate the left hand side to print its new value print(eval(left_hand_side, self.session_globals, self.session_globals)) except Exception as e: @@ -98,8 +111,9 @@ def handle_command(self, raw_text): # REPL loop def repl(): - session = PromptSession(history=history) handler = CommandHandler() + completer = SageWorksCompleter(handler.session_globals) # Use the updated globals for the completer + session = PromptSession(completer=completer, history=history) while True: try: