diff --git a/nemoguardrails/actions/llm/generation.py b/nemoguardrails/actions/llm/generation.py index 2a57e1c26..cd11e70a7 100644 --- a/nemoguardrails/actions/llm/generation.py +++ b/nemoguardrails/actions/llm/generation.py @@ -137,7 +137,7 @@ async def init(self): self._init_flows_index(), ) - def _extract_user_message_example(self, flow: Flow): + def _extract_user_message_example(self, flow: Flow) -> None: """Heuristic to extract user message examples from a flow.""" elements = [ item diff --git a/nemoguardrails/cli/__init__.py b/nemoguardrails/cli/__init__.py index 72ad4fc07..d170c83f7 100644 --- a/nemoguardrails/cli/__init__.py +++ b/nemoguardrails/cli/__init__.py @@ -16,7 +16,8 @@ import logging import os -from typing import List, Optional +from enum import Enum +from typing import Any, List, Literal, Optional import typer import uvicorn @@ -27,13 +28,24 @@ from nemoguardrails.cli.chat import run_chat from nemoguardrails.cli.migration import migrate from nemoguardrails.cli.providers import _list_providers, select_provider_with_type -from nemoguardrails.eval import cli +from nemoguardrails.eval import cli as eval_cli from nemoguardrails.logging.verbose import set_verbose from nemoguardrails.utils import init_random_seed + +class ColangVersions(str, Enum): + one = "1.0" + two_alpha = "2.0-alpha" + + +_COLANG_VERSIONS = [version.value for version in ColangVersions] + + app = typer.Typer() -app.add_typer(cli.app, name="eval", short_help="Evaluation a guardrail configuration.") +app.add_typer( + eval_cli.app, name="eval", short_help="Evaluation a guardrail configuration." +) app.pretty_exceptions_enable = False logging.getLogger().setLevel(logging.WARNING) @@ -44,7 +56,8 @@ def chat( config: List[str] = typer.Option( default=["config"], exists=True, - help="Path to a directory containing configuration files to use. Can also point to a single configuration file.", + help="Path to a directory containing configuration files to use. " + "Can also point to a single configuration file.", ), verbose: bool = typer.Option( default=False, @@ -60,7 +73,8 @@ def chat( ), debug_level: List[str] = typer.Option( default=[], - help="Enable debug mode which prints rich information about the flows execution. Available levels: WARNING, INFO, DEBUG", + help="Enable debug mode which prints rich information about the flows execution. " + "Available levels: WARNING, INFO, DEBUG", ), streaming: bool = typer.Option( default=False, @@ -77,7 +91,7 @@ def chat( ): """Start an interactive chat session.""" if len(config) > 1: - typer.secho(f"Multiple configurations are not supported.", fg=typer.colors.RED) + typer.secho("Multiple configurations are not supported.", fg=typer.colors.RED) typer.echo("Please provide a single folder.") raise typer.Exit(1) @@ -143,23 +157,27 @@ def server( if config: # We make sure there is no trailing separator, as that might break things in # single config mode. - api.app.rails_config_path = os.path.expanduser(config[0].rstrip(os.path.sep)) + setattr( + api.app, + "rails_config_path", + os.path.expanduser(config[0].rstrip(os.path.sep)), + ) else: # If we don't have a config, we try to see if there is a local config folder local_path = os.getcwd() local_configs_path = os.path.join(local_path, "config") if os.path.exists(local_configs_path): - api.app.rails_config_path = local_configs_path + setattr(api.app, "rails_config_path", local_configs_path) if verbose: logging.getLogger().setLevel(logging.INFO) if disable_chat_ui: - api.app.disable_chat_ui = True + setattr(api.app, "disable_chat_ui", True) if auto_reload: - api.app.auto_reload = True + setattr(api.app, "auto_reload", True) if prefix: server_app = FastAPI() @@ -173,17 +191,14 @@ def server( uvicorn.run(server_app, port=port, log_level="info", host="0.0.0.0") -_AVAILABLE_OPTIONS = ["1.0", "2.0-alpha"] - - @app.command() def convert( path: str = typer.Argument( ..., help="The path to the file or directory to migrate." ), - from_version: str = typer.Option( - default="1.0", - help=f"The version of the colang files to migrate from. Available options: {_AVAILABLE_OPTIONS}.", + from_version: ColangVersions = typer.Option( + default=ColangVersions.one, + help=f"The version of the colang files to migrate from. Available options: {_COLANG_VERSIONS}.", ), verbose: bool = typer.Option( default=False, @@ -209,11 +224,14 @@ def convert( absolute_path = os.path.abspath(path) + # Typer CLI args have to use an enum, not literal. Convert to Literal here + from_version_literal: Literal["1.0", "2.0-alpha"] = from_version.value + migrate( path=absolute_path, include_main_flow=include_main_flow, use_active_decorator=use_active_decorator, - from_version=from_version, + from_version=from_version_literal, validate=validate, ) diff --git a/nemoguardrails/cli/chat.py b/nemoguardrails/cli/chat.py index 97521c2a8..98ba83e0d 100644 --- a/nemoguardrails/cli/chat.py +++ b/nemoguardrails/cli/chat.py @@ -15,8 +15,8 @@ import asyncio import json import os -from dataclasses import dataclass, field -from typing import Dict, List, Optional, cast +from dataclasses import asdict, dataclass, field +from typing import Dict, List, Optional, Tuple, Union, cast import aiohttp from prompt_toolkit import HTML, PromptSession @@ -30,7 +30,11 @@ from nemoguardrails.colang.v2_x.runtime.runtime import RuntimeV2_x from nemoguardrails.logging import verbose from nemoguardrails.logging.verbose import console -from nemoguardrails.streaming import StreamingHandler +from nemoguardrails.rails.llm.options import ( + GenerationLog, + GenerationOptions, + GenerationResponse, +) from nemoguardrails.utils import get_or_create_event_loop, new_event_dict, new_uuid os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -61,6 +65,8 @@ async def _run_chat_v1_0( ) if not server_url: + if config_path is None: + raise RuntimeError("config_path cannot be None when server_url is None") rails_config = RailsConfig.from_path(config_path) rails_app = LLMRails(rails_config, verbose=verbose) if streaming and not rails_config.streaming_supported: @@ -82,7 +88,12 @@ async def _run_chat_v1_0( if not server_url: # If we have streaming from a locally loaded config, we initialize the handler. - if streaming and not server_url and rails_app.main_llm_supports_streaming: + if ( + streaming + and not server_url + and rails_app + and rails_app.main_llm_supports_streaming + ): bot_message_list = [] async for chunk in rails_app.stream_async(messages=history): if '{"event": "ABORT"' in chunk: @@ -101,11 +112,40 @@ async def _run_chat_v1_0( bot_message = {"role": "assistant", "content": bot_message_text} else: - bot_message = await rails_app.generate_async(messages=history) + if rails_app is None: + raise RuntimeError("Rails App is None") + response: Union[ + str, Dict, GenerationResponse, Tuple[Dict, Dict] + ] = await rails_app.generate_async(messages=history) + + # Handle different return types from generate_async + if isinstance(response, tuple) and len(response) == 2: + bot_message = ( + response[0] + if response + else {"role": "assistant", "content": ""} + ) + elif isinstance(response, GenerationResponse): + # GenerationResponse case + response_attr = getattr(response, "response", None) + if isinstance(response_attr, list) and len(response_attr) > 0: + bot_message = response_attr[0] + else: + bot_message = { + "role": "assistant", + "content": str(response_attr), + } + elif isinstance(response, dict): + # Direct dict case + bot_message = response + else: + # String or other fallback case + bot_message = {"role": "assistant", "content": str(response)} if not streaming or not rails_app.main_llm_supports_streaming: # We print bot messages in green. - console.print("[green]" + f"{bot_message['content']}" + "[/]") + content = bot_message.get("content", str(bot_message)) + console.print("[green]" + f"{content}" + "[/]") else: data = { "config_id": config_id, @@ -116,19 +156,19 @@ async def _run_chat_v1_0( async with session.post( f"{server_url}/v1/chat/completions", json=data, - ) as response: + ) as http_response: # If the response is streaming, we show each chunk as it comes - if response.headers.get("Transfer-Encoding") == "chunked": + if http_response.headers.get("Transfer-Encoding") == "chunked": bot_message_text = "" - async for chunk in response.content.iter_any(): - chunk = chunk.decode("utf-8") + async for chunk_bytes in http_response.content.iter_any(): + chunk = chunk_bytes.decode("utf-8") console.print("[green]" + f"{chunk}" + "[/]", end="") bot_message_text += chunk console.print("") bot_message = {"role": "assistant", "content": bot_message_text} else: - result = await response.json() + result = await http_response.json() bot_message = result["messages"][0] # We print bot messages in green. @@ -297,7 +337,8 @@ def _process_output(): else: console.print( "[black on magenta]" - + f"scene information (start): (title={event['title']}, action_uid={event['action_uid']}, content={event['content']})" + + f"scene information (start): (title={event['title']}, " + + f"action_uid={event['action_uid']}, content={event['content']})" + "[/]" ) @@ -333,7 +374,8 @@ def _process_output(): else: console.print( "[black on magenta]" - + f"scene form (start): (prompt={event['prompt']}, action_uid={event['action_uid']}, inputs={event['inputs']})" + + f"scene form (start): (prompt={event['prompt']}, " + + f"action_uid={event['action_uid']}, inputs={event['inputs']})" + "[/]" ) chat_state.input_events.append( @@ -370,7 +412,8 @@ def _process_output(): else: console.print( "[black on magenta]" - + f"scene choice (start): (prompt={event['prompt']}, action_uid={event['action_uid']}, options={event['options']})" + + f"scene choice (start): (prompt={event['prompt']}, " + + f"action_uid={event['action_uid']}, options={event['options']})" + "[/]" ) chat_state.input_events.append( @@ -452,12 +495,16 @@ async def _check_local_async_actions(): # We need to copy input events to prevent race condition input_events_copy = chat_state.input_events.copy() chat_state.input_events = [] - ( - chat_state.output_events, - chat_state.output_state, - ) = await rails_app.process_events_async( - input_events_copy, chat_state.state + + output_events, output_state = await rails_app.process_events_async( + input_events_copy, + asdict(chat_state.state) if chat_state.state else None, ) + chat_state.output_events = output_events + + # process_events_async returns a Dict `state`, need to convert to dataclass for ChatState object + if output_state: + chat_state.output_state = cast(State, State(**output_state)) # Process output_events and potentially generate new input_events _process_output() @@ -470,7 +517,8 @@ async def _check_local_async_actions(): # If there are no pending actions, we stop check_task.cancel() check_task = None - debugger.set_output_state(chat_state.output_state) + if chat_state.output_state is not None: + debugger.set_output_state(chat_state.output_state) chat_state.status.stop() enable_input.set() return @@ -485,13 +533,16 @@ async def _process_input_events(): # We need to copy input events to prevent race condition input_events_copy = chat_state.input_events.copy() chat_state.input_events = [] - ( - chat_state.output_events, - chat_state.output_state, - ) = await rails_app.process_events_async( - input_events_copy, chat_state.state + output_events, output_state = await rails_app.process_events_async( + input_events_copy, + asdict(chat_state.state) if chat_state.state else None, ) - debugger.set_output_state(chat_state.output_state) + chat_state.output_events = output_events + if output_state: + # process_events_async returns a Dict `state`, need to convert to dataclass for ChatState object + output_state_typed: State = cast(State, State(**output_state)) + chat_state.output_state = output_state_typed + debugger.set_output_state(output_state_typed) _process_output() # If we don't have a check task, we start it @@ -653,6 +704,8 @@ def run_chat( server_url (Optional[str]): The URL of the chat server. Defaults to None. config_id (Optional[str]): The configuration ID. Defaults to None. """ + if config_path is None: + raise RuntimeError("config_path cannot be None") rails_config = RailsConfig.from_path(config_path) if verbose and verbose_llm_calls: diff --git a/nemoguardrails/cli/debugger.py b/nemoguardrails/cli/debugger.py index b36428526..420260d9b 100644 --- a/nemoguardrails/cli/debugger.py +++ b/nemoguardrails/cli/debugger.py @@ -14,13 +14,13 @@ # limitations under the License. import shlex -from typing import Optional +from typing import TYPE_CHECKING, Dict, Optional, cast import typer from rich.table import Table from rich.tree import Tree -from nemoguardrails.colang.v2_x.lang.colang_ast import SpecOp, SpecType +from nemoguardrails.colang.v2_x.lang.colang_ast import Spec, SpecOp, SpecType from nemoguardrails.colang.v2_x.runtime.flows import ( FlowConfig, FlowState, @@ -31,8 +31,12 @@ from nemoguardrails.colang.v2_x.runtime.statemachine import is_active_flow from nemoguardrails.utils import console +if TYPE_CHECKING: + from nemoguardrails.cli.chat import ChatState + runtime: Optional[RuntimeV2_x] = None state: Optional[State] = None +chat_state: Optional["ChatState"] = None app = typer.Typer(name="!!!", no_args_is_help=True, add_completion=False) @@ -58,21 +62,24 @@ def set_output_state(_state: State): @app.command() def restart(): """Restart the current Colang script.""" - chat_state.state = None - chat_state.input_events = [] - chat_state.first_time = True + if chat_state is not None: + chat_state.state = None + chat_state.input_events = [] + chat_state.first_time = True @app.command() def pause(): """Pause current interaction.""" - chat_state.paused = True + if chat_state is not None: + chat_state.paused = True @app.command() def resume(): """Pause current interaction.""" - chat_state.paused = False + if chat_state is not None: + chat_state.paused = False @app.command() @@ -82,19 +89,24 @@ def flow( """Shows all details about a flow or flow instance.""" assert state - if flow_name in state.flow_configs: - flow_config = state.flow_configs[flow_name] - console.print(flow_config) - else: - matches = [ - (uid, item) for uid, item in state.flow_states.items() if flow_name in uid - ] - if matches: - flow_instance = matches[0][1] - console.print(flow_instance.__dict__) + if state is not None: + if flow_name in state.flow_configs: + flow_config = state.flow_configs[flow_name] + console.print(flow_config) else: - console.print(f"Flow '{flow_name}' does not exist.") - return + matches = [ + (uid, item) + for uid, item in state.flow_states.items() + if flow_name in uid + ] + if matches: + flow_instance = matches[0][1] + console.print(flow_instance.__dict__) + else: + console.print(f"Flow '{flow_name}' does not exist.") + return + else: + console.print("No state available.") @app.command() @@ -108,9 +120,8 @@ def flows( ), ): """Shows a table with all (active) flows ordered in terms of there interaction loop priority and name.""" - assert state - - """List the flows from the current state.""" + if state is None: + raise RuntimeError("No state available") table = Table(header_style="bold magenta") @@ -170,7 +181,8 @@ def get_loop_info(flow_config: FlowConfig) -> str: if order_by_name: rows.sort(key=lambda x: x[0]) else: - rows.sort(key=lambda x: (-state.flow_configs[x[0]].loop_priority, x[0])) + flow_configs: Dict[str, FlowConfig] = state.flow_configs + rows.sort(key=lambda x: (-flow_configs[x[0]].loop_priority, x[0])) for i, row in enumerate(rows): table.add_row(f"{i+1}", *row) @@ -186,6 +198,9 @@ def tree( ) ): """Lists the tree of all active flows.""" + if state is None or "main" not in state.flow_id_states: + raise RuntimeError("No main flow available.") + main_flow = state.flow_id_states["main"][0] root = Tree("main") @@ -231,13 +246,25 @@ def tree( # We also want to figure out if the flow is actually waiting on this child waiting_on = False - for head_id, head in flow_state.active_heads.items(): + for _, head in flow_state.active_heads.items(): head_element = elements[head.position] if isinstance(head_element, SpecOp): - if head_element.op == "match": - if head_element.spec.spec_type == SpecType.REFERENCE: - var_name = head_element.spec.var_name + head_element_spec_op = cast(SpecOp, head_element) + if head_element_spec_op.op == "match": + # Convert Spec to Spec object if it's a Dict + spec: Spec = ( + head_element_spec_op.spec + if isinstance(head_element_spec_op.spec, Spec) + else Spec(**cast(Dict, head_element_spec_op.spec)) + ) + + if ( + spec.spec_type + and spec.spec_type == SpecType.REFERENCE + and spec.var_name + ): + var_name = spec.var_name var = flow_state.context.get(var_name) if var == child_flow_state: @@ -271,6 +298,6 @@ def run_command(command: str): command = "--help" app(shlex.split(command)) - except SystemExit as e: + except SystemExit: # Prevent stopping the app pass diff --git a/nemoguardrails/cli/migration.py b/nemoguardrails/cli/migration.py index 25db24661..b58909d7f 100644 --- a/nemoguardrails/cli/migration.py +++ b/nemoguardrails/cli/migration.py @@ -177,9 +177,13 @@ def convert_colang_2alpha_syntax(lines: List[str]) -> List[str]: r"#\s*meta:\s*loop_id=(.*)", r'@loop("\1")', line.lstrip() ) else: + + def replace_meta(m): + return "@meta(" + m.group(1).replace(" ", "_") + "=True)" + meta_decorator = re.sub( r"#\s*meta:\s*(.*)", - lambda m: "@meta(" + m.group(1).replace(" ", "_") + "=True)", + replace_meta, line.lstrip(), ) meta_decorators.append(meta_decorator) @@ -216,7 +220,8 @@ def convert_colang_1_syntax(lines: List[str]) -> List[str]: # Check if the line matches the pattern $variable = ... # use of ellipsis in Colang 1.0 - # Based on https://github.com/NVIDIA/NeMo-Guardrails/blob/ff17a88efe70ed61580a36aaae5739f5aac6dccc/nemoguardrails/colang/v1_0/lang/coyml_parser.py#L610C1-L617C84 + # Based on https://github.com/NVIDIA/NeMo-Guardrails/blob/ff17a88efe70ed61580a36aaae5739f5aac6dccc/ + # nemoguardrails/colang/v1_0/lang/coyml_parser.py#L610C1-L617C84 if i > 0 and re.match(r"\s*\$\s*.*\s*=\s*\.\.\.", line): # Extract the variable name @@ -224,9 +229,14 @@ def convert_colang_1_syntax(lines: List[str]) -> List[str]: comment_match = re.search(r"# (.*)", lines[i - 1]) if variable_match and comment_match: variable = variable_match.group(1) - comment = comment_match.group(1) + comment = comment_match.group(1) or "" # Extract the leading whitespace - leading_whitespace = re.match(r"(\s*)", line).group(1) + leading_whitespace_match = re.match(r"(\s*)", line) + leading_whitespace = ( + leading_whitespace_match.group(1) + if leading_whitespace_match + else "" + ) # Replace the line, preserving the leading whitespace line = f'{leading_whitespace}${variable} = ... "{comment}"' @@ -256,7 +266,7 @@ def convert_colang_1_syntax(lines: List[str]) -> List[str]: if _is_anonymous_flow(line): # warnings.warn("Using anonymous flow is deprecated in Colang 2.0.") - line = _revise_anonymous_flow(line, next_line) + "\n" + line = _revise_anonymous_flow(line, next_line or "") + "\n" # We convert "define bot" to "flow bot" and set the flag if "define bot" in line: @@ -572,7 +582,7 @@ def _add_active_decorator(new_lines: List) -> List: def _get_raw_config(config_path: str): """read the yaml file and get rails key""" - + raw_config = None if config_path.endswith(".yaml") or config_path.endswith(".yml"): with open(config_path) as f: raw_config = yaml.safe_load(f.read()) @@ -1067,9 +1077,9 @@ def _process_sample_conversation_in_config(file_path: str): stripped_sample_lines = [line[sample_conv_indent:] for line in sample_lines] new_sample_lines = convert_sample_conversation_syntax(stripped_sample_lines) - # revert the indentation + # revert the indentation indented_new_sample_lines = [ - " " * sample_conv_indent + line for line in new_sample_lines + " " * (sample_conv_indent or 0) + line for line in new_sample_lines ] lines[sample_conv_line_idx + 1 : sample_conv_end_idx] = indented_new_sample_lines # Write back the modified lines