diff --git a/.gitignore b/.gitignore index 6f16c3d3..f35e8bca 100644 --- a/.gitignore +++ b/.gitignore @@ -22,3 +22,5 @@ demos/employee_details_copilot/ollama/models/ demos/employee_details_copilot_arch/ollama/models/ demos/network_copilot/ollama/models/ arch_log/ +arch/tools/*.egg-info +arch/tools/config diff --git a/arch/Dockerfile b/arch/Dockerfile index 0af5c62d..60526054 100644 --- a/arch/Dockerfile +++ b/arch/Dockerfile @@ -19,7 +19,7 @@ COPY --from=envoy /usr/local/bin/envoy /usr/local/bin/envoy WORKDIR /config COPY arch/requirements.txt . RUN pip install -r requirements.txt -COPY arch/config_generator.py . +COPY arch/tools/config_generator.py . COPY arch/envoy.template.yaml . COPY arch/arch_config_schema.yaml . diff --git a/arch/config_generator.py b/arch/config_generator.py deleted file mode 100644 index c3282f31..00000000 --- a/arch/config_generator.py +++ /dev/null @@ -1,88 +0,0 @@ -import os -from jinja2 import Environment, FileSystemLoader -import yaml -from jsonschema import validate - -ENVOY_CONFIG_TEMPLATE_FILE = os.getenv('ENVOY_CONFIG_TEMPLATE_FILE', 'envoy.template.yaml') -ARCH_CONFIG_FILE = os.getenv('ARCH_CONFIG_FILE', '/config/arch_config.yaml') -ENVOY_CONFIG_FILE_RENDERED = os.getenv('ENVOY_CONFIG_FILE_RENDERED', '/etc/envoy/envoy.yaml') -ARCH_CONFIG_SCHEMA_FILE = os.getenv('ARCH_CONFIG_SCHEMA_FILE', 'arch_config_schema.yaml') -OPENAI_API_KEY = os.getenv('OPENAI_API_KEY', False) -MISTRAL_API_KEY = os.getenv('MISTRAL_API_KEY', False) - -def add_secret_key_to_llm_providers(config_yaml) : - llm_providers = [] - for llm_provider in config_yaml.get("llm_providers", []): - if llm_provider['access_key'] == "$MISTRAL_ACCESS_KEY": - llm_provider['access_key'] = MISTRAL_API_KEY - elif llm_provider['access_key'] == "$OPENAI_ACCESS_KEY": - llm_provider['access_key'] = OPENAI_API_KEY - else: - llm_provider.pop('access_key') - llm_providers.append(llm_provider) - config_yaml["llm_providers"] = llm_providers - return config_yaml - -env = Environment(loader=FileSystemLoader('./')) -template = env.get_template('envoy.template.yaml') - -with open(ARCH_CONFIG_FILE, 'r') as file: - arch_config_string = file.read() - -with open(ARCH_CONFIG_SCHEMA_FILE, 'r') as file: - arch_config_schema = file.read() - -config_yaml = yaml.safe_load(arch_config_string) -config_schema_yaml = yaml.safe_load(arch_config_schema) - -try: - validate(config_yaml, config_schema_yaml) -except Exception as e: - print(f"Error validating arch_config file: {ARCH_CONFIG_FILE}, error: {e.message}") - exit(1) - -inferred_clusters = {} - -for prompt_target in config_yaml["prompt_targets"]: - name = prompt_target.get("endpoint", {}).get("name", "") - if name not in inferred_clusters: - inferred_clusters[name] = { - "name": name, - "port": 80, # default port - } - -print(inferred_clusters) - -endpoints = config_yaml.get("endpoints", {}) - -# override the inferred clusters with the ones defined in the config -for name, endpoint_details in endpoints.items(): - if name in inferred_clusters: - print("updating cluster", endpoint_details) - inferred_clusters[name].update(endpoint_details) - endpoint = inferred_clusters[name]['endpoint'] - if len(endpoint.split(':')) > 1: - inferred_clusters[name]['endpoint'] = endpoint.split(':')[0] - inferred_clusters[name]['port'] = int(endpoint.split(':')[1]) - else: - inferred_clusters[name] = endpoint_details - -print("updated clusters", inferred_clusters) - -config_yaml = add_secret_key_to_llm_providers(config_yaml) -arch_llm_providers = config_yaml["llm_providers"] -arch_config_string = yaml.dump(config_yaml) - -print("llm_providers:", arch_llm_providers) - -data = { - 'arch_config': arch_config_string, - 'arch_clusters': inferred_clusters, - 'arch_llm_providers': arch_llm_providers -} - -rendered = template.render(data) -print(rendered) -print(ENVOY_CONFIG_FILE_RENDERED) -with open(ENVOY_CONFIG_FILE_RENDERED, 'w') as file: - file.write(rendered) diff --git a/arch/docker-compose.yaml b/arch/docker-compose.yaml index 45417674..31d6db56 100644 --- a/arch/docker-compose.yaml +++ b/arch/docker-compose.yaml @@ -1,13 +1,11 @@ services: archgw: - build: - context: ../ - dockerfile: arch/Dockerfile + image: archgw:latest ports: - "10000:10000" - - "18080:9901" + - "19901:9901" volumes: - - ${ARCH_CONFIG_FILE}:/config/arch_config.yaml + - ${ARCH_CONFIG_FILE:-./demos/function_calling/arch_confg.yaml}:/config/arch_config.yaml - /etc/ssl/cert.pem:/etc/ssl/cert.pem - ./arch_log:/var/log/ depends_on: @@ -15,9 +13,7 @@ services: condition: service_healthy model_server: - build: - context: ../model_server - dockerfile: Dockerfile + image: model_server:latest ports: - "18081:80" healthcheck: @@ -26,3 +22,8 @@ services: retries: 20 volumes: - ~/.cache/huggingface:/root/.cache/huggingface + environment: + - OLLAMA_ENDPOINT=${OLLAMA_ENDPOINT:-host.docker.internal} + - OLLAMA_MODEL=Arch-Function-Calling-3B-Q4_K_M + - MODE=${MODE:-cloud} + - FC_URL=${FC_URL:-https://arch-fc-free-trial-4mzywewe.uc.gateway.dev/v1} diff --git a/arch/envoy.template.yaml b/arch/envoy.template.yaml index eb83f328..8aab7c6e 100644 --- a/arch/envoy.template.yaml +++ b/arch/envoy.template.yaml @@ -170,7 +170,8 @@ static_resources: {% else -%} connect_timeout: 5s {% endif -%} - type: STRICT_DNS + type: LOGICAL_DNS + dns_lookup_family: V4_ONLY lb_policy: ROUND_ROBIN load_assignment: cluster_name: {{ cluster.name }} diff --git a/arch/src/stream_context.rs b/arch/src/stream_context.rs index d67498b9..4ef74e55 100644 --- a/arch/src/stream_context.rs +++ b/arch/src/stream_context.rs @@ -328,7 +328,7 @@ impl StreamContext { if messages.len() >= 2 { let latest_assistant_message = &messages[messages.len() - 2]; if let Some(model) = latest_assistant_message.model.as_ref() { - if model.starts_with("Arch") { + if model.contains("Arch") { arch_assistant = true; } } diff --git a/arch/tools/MANIFEST.in b/arch/tools/MANIFEST.in new file mode 100644 index 00000000..f2af40b7 --- /dev/null +++ b/arch/tools/MANIFEST.in @@ -0,0 +1,2 @@ +include config/docker-compose.yaml +include config/arch_config_schema.yaml diff --git a/arch/tools/README.md b/arch/tools/README.md new file mode 100644 index 00000000..0a05a0a1 --- /dev/null +++ b/arch/tools/README.md @@ -0,0 +1,28 @@ +## Setup Instructions: archgw CLI + +This guide will walk you through the steps to set up the archgw cli on your local machine + +### Step 1: Create a Python virtual environment + +In the tools directory, create a Python virtual environment by running: + +```bash +python -m venv venv +``` + +### Step 2: Activate the virtual environment +* On Linux/MacOS: + +```bash +source venv/bin/activate +``` + +### Step 3: Run the build script +```bash +sh build-cli.sh +``` + +## Uninstall Instructions: archgw CLI +```bash +pip uninstall archgw +``` diff --git a/arch/tools/__init__.py b/arch/tools/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/arch/tools/build-cli.sh b/arch/tools/build-cli.sh new file mode 100644 index 00000000..4463ccf2 --- /dev/null +++ b/arch/tools/build-cli.sh @@ -0,0 +1,22 @@ +#!/bin/bash + +# Define paths +source_schema="../arch_config_schema.yaml" +source_compose="../docker-compose.yaml" +destination_dir="config" + +# Ensure the destination directory exists only if it doesn't already +if [ ! -d "$destination_dir" ]; then + mkdir -p "$destination_dir" + echo "Directory $destination_dir created." +fi + +# Copy the files +cp "$source_schema" "$destination_dir/arch_config_schema.yaml" +cp "$source_compose" "$destination_dir/docker-compose.yaml" + +# Print success message +echo "Files copied successfully!" + +echo "Building the cli" +pip install -e . diff --git a/arch/tools/cli.py b/arch/tools/cli.py new file mode 100644 index 00000000..ec3afeff --- /dev/null +++ b/arch/tools/cli.py @@ -0,0 +1,117 @@ +import click +from core import start_arch, stop_arch +import targets +import os +import config_generator +import pkg_resources +import sys +import subprocess + +logo = r""" + _ _ + / \ _ __ ___ | |__ + / _ \ | '__|/ __|| '_ \ + / ___ \ | | | (__ | | | | + /_/ \_\|_| \___||_| |_| + +""" +@click.group(invoke_without_command=True) +@click.pass_context +def main(ctx): + if ctx.invoked_subcommand is None: + click.echo( """Arch (The Intelligent Prompt Gateway) CLI""") + click.echo(logo) + click.echo(ctx.get_help()) + +# Command to build archgw and model_server Docker images +ARCHGW_DOCKERFILE = "./arch/Dockerfile" +MODEL_SERVER_DOCKERFILE = "./model_server/Dockerfile" + +@click.command() +def build(): + """Build Arch from source. Must be in root of cloned repo.""" + # Check if /arch/Dockerfile exists + if os.path.exists(ARCHGW_DOCKERFILE): + click.echo("Building archgw image...") + try: + subprocess.run(["docker", "build", "-f", ARCHGW_DOCKERFILE, "-t", "archgw:latest", "."], check=True) + click.echo("archgw image built successfully.") + except subprocess.CalledProcessError as e: + click.echo(f"Error building archgw image: {e}") + sys.exit(1) + else: + click.echo("Error: Dockerfile not found in /arch") + sys.exit(1) + + # Check if /model_server/Dockerfile exists + if os.path.exists(MODEL_SERVER_DOCKERFILE): + click.echo("Building model_server image...") + try: + subprocess.run(["docker", "build", "-f", MODEL_SERVER_DOCKERFILE, "-t", "model_server:latest", "./model_server"], check=True) + click.echo("model_server image built successfully.") + except subprocess.CalledProcessError as e: + click.echo(f"Error building model_server image: {e}") + sys.exit(1) + else: + click.echo("Error: Dockerfile not found in /model_server") + sys.exit(1) + + click.echo("All images built successfully.") + +@click.command() +@click.argument('file', required=False) # Optional file argument +@click.option('-path', default='.', help='Path to the directory containing arch_config.yml') +def up(file, path): + """Starts Arch.""" + if file: + # If a file is provided, process that file + arch_config_file = os.path.abspath(file) + else: + # If no file is provided, use the path and look for arch_config.yml + arch_config_file = os.path.abspath(os.path.join(path, "arch_config.yml")) + + # Check if the file exists + if not os.path.exists(arch_config_file): + print(f"Error: {arch_config_file} does not exist.") + return + + print(f"Processing config file: {arch_config_file}") + arch_schema_config = pkg_resources.resource_filename(__name__, "config/arch_config_schema.yaml") + + print(f"Validating {arch_config_file}") + + try: + config_generator.validate_prompt_config(arch_config_file=arch_config_file, arch_config_schema_file=arch_schema_config) + except Exception as e: + print("Exiting archgw up") + sys.exit(1) + + print("Starting Arch gateway and Arch model server services via docker ") + start_arch(arch_config_file) + +@click.command() +def down(): + """Stops Arch.""" + stop_arch() + +@click.command() +@click.option('-f', '--file', type=click.Path(exists=True), required=True, help="Path to the Python file") +def generate_prompt_targets(file): + """Generats prompt_targets from python methods. + Note: This works for simple data types like ['int', 'float', 'bool', 'str', 'list', 'tuple', 'set', 'dict']: + If you have a complex pydantic data type, you will have to flatten those manually until we add support for it.""" + + print(f"Processing file: {file}") + if not file.endswith(".py"): + print("Error: Input file must be a .py file") + sys.exit(1) + + targets.generate_prompt_targets(file) + +main.add_command(up) +main.add_command(down) +main.add_command(build) +main.add_command(generate_prompt_targets) + +if __name__ == '__main__': + main() diff --git a/arch/tools/config_generator.py b/arch/tools/config_generator.py new file mode 100644 index 00000000..c7759e71 --- /dev/null +++ b/arch/tools/config_generator.py @@ -0,0 +1,108 @@ +import os +from jinja2 import Environment, FileSystemLoader +import yaml +from jsonschema import validate + +ENVOY_CONFIG_TEMPLATE_FILE = os.getenv('ENVOY_CONFIG_TEMPLATE_FILE', 'envoy.template.yaml') +ARCH_CONFIG_FILE = os.getenv('ARCH_CONFIG_FILE', '/config/arch_config.yaml') +ENVOY_CONFIG_FILE_RENDERED = os.getenv('ENVOY_CONFIG_FILE_RENDERED', '/etc/envoy/envoy.yaml') +ARCH_CONFIG_SCHEMA_FILE = os.getenv('ARCH_CONFIG_SCHEMA_FILE', 'arch_config_schema.yaml') + +OPENAI_API_KEY = os.getenv('OPENAI_API_KEY', False) +MISTRAL_API_KEY = os.getenv('MISTRAL_API_KEY', False) + +def add_secret_key_to_llm_providers(config_yaml) : + llm_providers = [] + for llm_provider in config_yaml.get("llm_providers", []): + if llm_provider['access_key'] == "$MISTRAL_API_KEY": + llm_provider['access_key'] = MISTRAL_API_KEY + elif llm_provider['access_key'] == "$OPENAI_API_KEY": + llm_provider['access_key'] = OPENAI_API_KEY + else: + llm_provider.pop('access_key') + llm_providers.append(llm_provider) + config_yaml["llm_providers"] = llm_providers + return config_yaml + +def validate_and_render_schema(): + env = Environment(loader=FileSystemLoader('./')) + template = env.get_template('envoy.template.yaml') + + try: + validate_prompt_config(ARCH_CONFIG_FILE, ARCH_CONFIG_SCHEMA_FILE) + except Exception as e: + print(e) + exit(1) # validate_prompt_config failed. Exit + + with open(ARCH_CONFIG_FILE, 'r') as file: + arch_config = file.read() + + with open(ARCH_CONFIG_SCHEMA_FILE, 'r') as file: + arch_config_schema = file.read() + + config_yaml = yaml.safe_load(arch_config) + config_schema_yaml = yaml.safe_load(arch_config_schema) + inferred_clusters = {} + + for prompt_target in config_yaml["prompt_targets"]: + name = prompt_target.get("endpoint", {}).get("name", "") + if name not in inferred_clusters: + inferred_clusters[name] = { + "name": name, + "port": 80, # default port + } + + print(inferred_clusters) + endpoints = config_yaml.get("endpoints", {}) + + # override the inferred clusters with the ones defined in the config + for name, endpoint_details in endpoints.items(): + if name in inferred_clusters: + print("updating cluster", endpoint_details) + inferred_clusters[name].update(endpoint_details) + endpoint = inferred_clusters[name]['endpoint'] + if len(endpoint.split(':')) > 1: + inferred_clusters[name]['endpoint'] = endpoint.split(':')[0] + inferred_clusters[name]['port'] = int(endpoint.split(':')[1]) + else: + inferred_clusters[name] = endpoint_details + + + print("updated clusters", inferred_clusters) + + config_yaml = add_secret_key_to_llm_providers(config_yaml) + arch_llm_providers = config_yaml["llm_providers"] + arch_config_string = yaml.dump(config_yaml) + + print("llm_providers:", arch_llm_providers) + + data = { + 'arch_config': arch_config_string, + 'arch_clusters': inferred_clusters, + 'arch_llm_providers': arch_llm_providers + } + + rendered = template.render(data) + print(rendered) + print(ENVOY_CONFIG_FILE_RENDERED) + with open(ENVOY_CONFIG_FILE_RENDERED, 'w') as file: + file.write(rendered) + +def validate_prompt_config(arch_config_file, arch_config_schema_file): + with open(arch_config_file, 'r') as file: + arch_config = file.read() + + with open(arch_config_schema_file, 'r') as file: + arch_config_schema = file.read() + + config_yaml = yaml.safe_load(arch_config) + config_schema_yaml = yaml.safe_load(arch_config_schema) + + try: + validate(config_yaml, config_schema_yaml) + except Exception as e: + print(f"Error validating arch_config file: {arch_config_file}, error: {e.message}") + raise e + +if __name__ == '__main__': + validate_and_render_schema() diff --git a/arch/tools/core.py b/arch/tools/core.py new file mode 100644 index 00000000..9d970f30 --- /dev/null +++ b/arch/tools/core.py @@ -0,0 +1,101 @@ +import subprocess +import os +import time +import pkg_resources +import select +from utils import run_docker_compose_ps, print_service_status, check_services_state + +def start_arch(arch_config_file, log_timeout=120, check_interval=1): + """ + Start Docker Compose in detached mode and stream logs until services are healthy. + + Args: + path (str): The path where the prompt_confi.yml file is located. + log_timeout (int): Time in seconds to show logs before checking for healthy state. + check_interval (int): Time in seconds between health status checks. + """ + # Set the ARCH_CONFIG_FILE environment variable + env = os.environ.copy() + env['ARCH_CONFIG_FILE'] = arch_config_file + + compose_file = pkg_resources.resource_filename(__name__, 'docker-compose.yaml') + + try: + # Run the Docker Compose command in detached mode (-d) + subprocess.run( + ["docker-compose", "up", "-d"], + cwd=os.path.dirname(compose_file), # Ensure the Docker command runs in the correct path + env=env, # Pass the modified environment + check=True # Raise an exception if the command fails + ) + print(f"Arch docker-compose started in detached.") + print("Monitoring `docker-compose ps` logs...") + + start_time = time.time() + services_status = {} + services_running = False #assume that the services are not running at the moment + + while True: + current_time = time.time() + elapsed_time = current_time - start_time + + # Check if timeout is reached + if elapsed_time > log_timeout: + print(f"Stopping log monitoring after {log_timeout} seconds.") + break + + current_services_status = run_docker_compose_ps(compose_file=compose_file, env=env) + if not current_services_status: + print("Status for the services could not be detected. Something went wrong. Please run docker logs") + break + + if not services_status: + services_status = current_services_status #set the first time + print_service_status(services_status) #print the services status and proceed. + + #check if anyone service is failed or exited state, if so print and break out + unhealthy_states = ["unhealthy", "exit", "exited", "dead", "bad"] + running_states = ["running", "up"] + + if check_services_state(current_services_status, running_states): + print("Arch is up and running!") + break + + if check_services_state(current_services_status, unhealthy_states): + print("One or more Arch services are unhealthy. Please run `docker logs` for more information") + print_service_status(current_services_status) #print the services status and proceed. + break + + #check to see if the status of one of the services has changed from prior. Print and loop over until finish, or error + for service_name in services_status.item(): + if services_status[service_name]['status'] != current_services_status[service_name]['status']: + print("One or more Arch services have changed state. Printing current state") + print_service_status(current_services_status) + break + + services_status = current_services_status + + except subprocess.CalledProcessError as e: + print(f"Failed to start Arch: {str(e)}") + + +def stop_arch(): + """ + Shutdown all Docker Compose services by running `docker-compose down`. + + Args: + path (str): The path where the docker-compose.yml file is located. + """ + compose_file = pkg_resources.resource_filename(__name__, 'docker-compose.yaml') + + try: + # Run `docker-compose down` to shut down all services + subprocess.run( + ["docker-compose", "down"], + cwd=os.path.dirname(compose_file), + check=True, + ) + print("Successfully shut down all services.") + + except subprocess.CalledProcessError as e: + print(f"Failed to shut down services: {str(e)}") diff --git a/arch/tools/setup.py b/arch/tools/setup.py new file mode 100644 index 00000000..ca74ce36 --- /dev/null +++ b/arch/tools/setup.py @@ -0,0 +1,20 @@ +from setuptools import setup, find_packages + +setup( + name="archgw", + version="0.1.0", + description="Python-based CLI tool to manage Arch and generate targets.", + author="Katanemo Labs, Inc.", + packages=find_packages(), + py_modules = ['cli', 'core', 'targets', 'utils', 'config_generator'], + include_package_data=True, + package_data={ + '': ['config/docker-compose.yaml', 'config/arch_config_schema.yaml'] #Specify to include the docker-compose.yml file + }, + install_requires=['pyyaml', 'pydantic', 'click', 'jinja2','pyyaml','jsonschema'], # Add dependencies here, e.g., 'PyYAML' for YAML processing + entry_points={ + 'console_scripts': [ + 'archgw=cli:main', + ], + }, +) diff --git a/arch/tools/targets.py b/arch/tools/targets.py new file mode 100644 index 00000000..82cc770a --- /dev/null +++ b/arch/tools/targets.py @@ -0,0 +1,297 @@ +import ast +import sys +import yaml +from typing import Any +from pydantic import BaseModel + +FLASK_ROUTE_DECORATORS = ["route", "get", "post", "put", "delete", "patch"] +FASTAPI_ROUTE_DECORATORS = ["get", "post", "put", "delete", "patch"] + + +def detect_framework(tree: Any) -> str: + """Detect whether the file is using Flask or FastAPI based on imports.""" + for node in ast.walk(tree): + if isinstance(node, ast.ImportFrom): + if node.module == "flask": + return "flask" + elif node.module == "fastapi": + return "fastapi" + return "unknown" + +def get_route_decorators(node: Any, framework: str) -> list: + """Extract route decorators based on the framework.""" + decorators = [] + for decorator in node.decorator_list: + if isinstance(decorator, ast.Call) and isinstance(decorator.func, ast.Attribute): + if framework == "flask" and decorator.func.attr in FLASK_ROUTE_DECORATORS: + decorators.append(decorator.func.attr) + elif framework == "fastapi" and decorator.func.attr in FASTAPI_ROUTE_DECORATORS: + decorators.append(decorator.func.attr) + return decorators + + +def get_route_path(node: Any, framework: str) -> str: + """Extract route path based on the framework.""" + for decorator in node.decorator_list: + if isinstance(decorator, ast.Call) and decorator.args: + return decorator.args[0].s # Assuming it's a string literal + +def is_pydantic_model(annotation: ast.expr, tree: ast.AST) -> bool: + """Check if a given type annotation is a Pydantic model.""" + # We walk through the AST to find class definitions and check if they inherit from Pydantic's BaseModel + if isinstance(annotation, ast.Name): + for node in ast.walk(tree): + if isinstance(node, ast.ClassDef) and node.name == annotation.id: + for base in node.bases: + if isinstance(base, ast.Name) and base.id == "BaseModel": + return True + return False + +def get_pydantic_model_fields(model_name: str, tree: ast.AST) -> list: + """Extract fields from a Pydantic model, handling list, tuple, set, dict types, and direct default values.""" + fields = [] + + for node in ast.walk(tree): + if isinstance(node, ast.ClassDef) and node.name == model_name: + for stmt in node.body: + if isinstance(stmt, ast.AnnAssign): + # Initialize the default field description + field_type = "Unknown: Please Fix This!" + description = "Field, description not present. Please fix." + default_value = None + required = True # Assume the field is required initially + + # Check if the field uses Field() with required status and description + if stmt.value and isinstance(stmt.value, ast.Call) and isinstance(stmt.value.func, ast.Name) and stmt.value.func.id == 'Field': + # Extract the description argument inside the Field call + for keyword in stmt.value.keywords: + if keyword.arg == 'description' and isinstance(keyword.value, ast.Str): + description = keyword.value.s + if keyword.arg == 'default': + default_value = keyword.value + # If Ellipsis (...) is used, it means the field is required + if stmt.value.args and isinstance(stmt.value.args[0], ast.Constant) and stmt.value.args[0].value is Ellipsis: + required = True + else: + required = False + + # Handle direct default values (e.g., name: str = "John Doe") + elif stmt.value is not None: + if isinstance(stmt.value, ast.Constant): + # Set the default value from the assignment (e.g., name: str = "John Doe") + default_value = stmt.value.value + required = False # Not required since it has a default value + + # Always extract the field type, even if there's a default value + if isinstance(stmt.annotation, ast.Subscript): + # Get the base type (list, tuple, set, dict) + base_type = stmt.annotation.value.id if isinstance(stmt.annotation.value, ast.Name) else "Unknown" + + # Handle only list, tuple, set, dict and ignore the inner types + if base_type.lower() in ['list', 'tuple', 'set', 'dict']: + field_type = base_type.lower() + + # Handle the ellipsis '...' for required fields if no Field() call + elif isinstance(stmt.value, ast.Constant) and stmt.value.value is Ellipsis: + required = True + + # Handle simple types like str, int, etc. + if isinstance(stmt.annotation, ast.Name): + field_type = stmt.annotation.id + + field_info = { + "name": stmt.target.id, + "type": field_type, # Always set the field type + "description": description, + "default": default_value, # Handle direct default values + "required": required + } + fields.append(field_info) + + return fields + +def get_function_parameters(node: ast.FunctionDef, tree: ast.AST) -> list: + """Extract the parameters and their types from the function definition.""" + parameters = [] + + # Extract docstring to find descriptions + docstring = ast.get_docstring(node) + arg_descriptions = extract_arg_descriptions_from_docstring(docstring) + + # Extract default values + defaults = [None] * (len(node.args.args) - len(node.args.defaults)) + node.args.defaults # Align defaults with args + for arg, default in zip(node.args.args, defaults): + if arg.arg != "self": # Skip 'self' or 'cls' in class methods + param_info = {"name": arg.arg, "description": arg_descriptions.get(arg.arg, "[ADD DESCRIPTION]")} + + # Handle Pydantic model types + if hasattr(arg, 'annotation') and is_pydantic_model(arg.annotation, tree): + # Extract and flatten Pydantic model fields + pydantic_fields = get_pydantic_model_fields(arg.annotation.id, tree) + parameters.extend(pydantic_fields) # Flatten the model fields into the parameters list + continue # Skip adding the current param_info for the model since we expand the fields + + # Handle standard Python types (int, float, str, etc.) + elif hasattr(arg, 'annotation') and isinstance(arg.annotation, ast.Name): + if arg.annotation.id in ['int', 'float', 'bool', 'str', 'list', 'tuple', 'set', 'dict']: + param_info["type"] = arg.annotation.id + else: + param_info["type"] = "[UNKNOWN - PLEASE FIX]" + + # Handle generic subscript types (e.g., Optional, List[Type], etc.) + elif hasattr(arg, 'annotation') and isinstance(arg.annotation, ast.Subscript): + if isinstance(arg.annotation.value, ast.Name) and arg.annotation.value.id in ['list', 'tuple', 'set', 'dict']: + param_info["type"] = f"{arg.annotation.value.id}" # e.g., "List", "Tuple", etc. + else: + param_info["type"] = "[UNKNOWN - PLEASE FIX]" + + # Default for unknown types + else: + param_info["type"] = "[UNKNOWN - PLEASE FIX]" # If unable to detect type + + # Handle default values + if default is not None: + if isinstance(default, ast.Constant) or isinstance(default, ast.NameConstant): + param_info["default"] = default.value # Use the default value directly + else: + param_info["default"] = "[UNKNOWN DEFAULT]" # Unknown default type + param_info["required"] = False # Optional since it has a default value + else: + param_info["default"] = None + param_info["required"] = True # Required if no default value + + parameters.append(param_info) + + return parameters + +def get_function_docstring(node: Any) -> str: + """Extract the function's docstring description if present.""" + # Check if the first node is a docstring + if isinstance(node.body[0], ast.Expr) and isinstance(node.body[0].value, ast.Str): + # Get the entire docstring + full_docstring = node.body[0].value.s.strip() + + # Split the docstring by double newlines (to separate description from fields like Args) + description = full_docstring.split("\n\n")[0].strip() + + return description + + return "No description provided." + +def extract_arg_descriptions_from_docstring(docstring: str) -> dict: + """Extract descriptions for function parameters from the 'Args' section of the docstring.""" + descriptions = {} + if not docstring: + return descriptions + + in_args_section = False + current_param = None + for line in docstring.splitlines(): + line = line.strip() + + # Detect the start of the 'Args' section + if line.startswith("Args:"): + in_args_section = True + continue # Proceed to the next line after 'Args:' + + # End of 'Args' section if no indentation and no colon + if in_args_section and not line.startswith(" ") and ':' not in line: + break # Stop processing if we reach a new section + + # Process lines in the 'Args' section + if in_args_section: + if ':' in line: + # Extract parameter name and description + param_name, description = line.split(':', 1) + descriptions[param_name.strip()] = description.strip() + current_param = param_name.strip() + elif current_param and line.startswith(" "): + # Handle multiline descriptions (indented lines) + descriptions[current_param] += f" {line.strip()}" + + return descriptions + + +def generate_prompt_targets(input_file_path: str) -> None: + """Introspect routes and generate YAML for either Flask or FastAPI.""" + with open(input_file_path, "r") as source: + tree = ast.parse(source.read()) + + # Detect the framework (Flask or FastAPI) + framework = detect_framework(tree) + if framework == "unknown": + print("Could not detect Flask or FastAPI in the file.") + return + + # Extract routes + routes = [] + for node in ast.walk(tree): + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + route_decorators = get_route_decorators(node, framework) + if route_decorators: + route_path = get_route_path(node, framework) + function_params = get_function_parameters(node, tree) # Get parameters for the route + function_docstring = get_function_docstring(node) # Extract docstring + routes.append({ + 'name': node.name, + 'path': route_path, + 'methods': route_decorators, + 'parameters': function_params, # Add parameters to the route + 'description': function_docstring # Add the docstring as the description + }) + + # Generate YAML structure + output_structure = { + "prompt_targets": [] + } + + for route in routes: + target = { + "name": route['name'], + "endpoint": [ + { + "name": "app_server", + "path": route['path'], + } + ], + "description": route['description'], # Use extracted docstring + "parameters": [ + { + "name": param['name'], + "type": param['type'], + "description": f"{param['description']}", + **({"default": param['default']} if "default" in param and param['default'] is not None else {}), # Only add default if it's set + "required": param['required'] + } for param in route['parameters'] + ] + } + + if route['name'] == "default": + # Special case for `information_extraction` based on your YAML format + target["type"] = "default" + target["auto-llm-dispatch-on-response"] = True + + output_structure["prompt_targets"].append(target) + + # Output as YAML + print(yaml.dump(output_structure, sort_keys=False,default_flow_style=False, indent=3)) + +if __name__ == '__main__': + if len(sys.argv) != 2: + print("Usage: python targets.py ") + sys.exit(1) + + input_file = sys.argv[1] + + # Automatically generate the output file name + if input_file.endswith(".py"): + output_file = input_file.replace(".py", "_prompt_targets.yml") + else: + print("Error: Input file must be a .py file") + sys.exit(1) + + # Call the function with the input and generated output file names + generate_prompt_targets(input_file, output_file) + +# Example usage: +# python targets.py api.yaml diff --git a/arch/tools/test/fastapi_test.py b/arch/tools/test/fastapi_test.py new file mode 100644 index 00000000..1f25a0e1 --- /dev/null +++ b/arch/tools/test/fastapi_test.py @@ -0,0 +1,33 @@ +from fastapi import FastAPI +from pydantic import BaseModel +from typing import List, Dict, Set + +app = FastAPI() + +class User(BaseModel): + name: str = Field("John Doe", description="The name of the user.") # Default value and description for name + location: int = None + age: int = Field(30, description="The age of the user.") # Default value and description for age + tags: Set[str] = Field(default_factory=set, description="A set of tags associated with the user.") # Default empty set and description for tags + metadata: Dict[str, int] = Field(default_factory=dict, description="A dictionary storing metadata about the user, with string keys and integer values.") # Default empty dict and description for metadata + +@app.get("/agent/default") +async def default(request: User): + """ + This endpoint handles information extraction queries. + It can summarize, extract details, and perform various other information-related tasks. + """ + return {"info": f"Query: {request.name}, Count: {request.age}"} + +@app.post("/agent/action") +async def reboot_network_device(device_id: str, confirmation: str): + """ + This endpoint reboots a network device based on the device ID. + Confirmation is required to proceed with the reboot. + + Args: + device_id: The device_id that you want to reboot. + confirmation: The confirmation that the user wants to reboot. + metadata: Ignore this parameter + """ + return {"status": "Device rebooted", "device_id": device_id} diff --git a/arch/tools/test/fastapi_test_prompt_targets.yml b/arch/tools/test/fastapi_test_prompt_targets.yml new file mode 100644 index 00000000..7fb9d118 --- /dev/null +++ b/arch/tools/test/fastapi_test_prompt_targets.yml @@ -0,0 +1,33 @@ +prompt_targets: +- name: default + path: /agent/default + description: "This endpoint handles information extraction queries.\n It can\ + \ summarize, extract details, and perform various other information-related tasks." + parameters: + - name: query + type: str + description: Field from Pydantic model DefaultRequest + default_value: null + required: false + - name: count + type: int + description: Field from Pydantic model DefaultRequest + default_value: null + required: false + type: default + auto-llm-dispatch-on-response: true +- name: reboot_network_device + path: /agent/action + description: "This endpoint reboots a network device based on the device ID.\n \ + \ Confirmation is required to proceed with the reboot." + parameters: + - name: device_id + type: str + description: Description for device_id + default_value: '' + required: true + - name: confirmation + type: int + description: Description for confirmation + default_value: '' + required: true diff --git a/arch/tools/utils.py b/arch/tools/utils.py new file mode 100644 index 00000000..627279be --- /dev/null +++ b/arch/tools/utils.py @@ -0,0 +1,79 @@ +import subprocess +import os +import time +import select +import shlex + +def run_docker_compose_ps(compose_file, env): + """ + Check if all Docker Compose services are in a healthy state. + + Args: + path (str): The path where the docker-compose.yml file is located. + """ + try: + # Run `docker-compose ps` to get the health status of each service + ps_process = subprocess.Popen( + ["docker-compose", "ps"], + cwd=os.path.dirname(compose_file), + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + env=env + ) + # Capture the output of `docker-compose ps` + services_status, error_output = ps_process.communicate() + + # Check if there is any error output + if error_output: + print(f"Error while checking service status:\n{error_output}", file=os.sys.stderr) + return {} + + lines = services_status.strip().splitlines() + services = {} + + # Skip the header row and parse each service + for line in lines[1:]: + parts = shlex.split(line) + if len(parts) >= 5: + service_name = parts[0] # Service name + status_index = 3 # Status is typically at index 3, but may have multiple words + + # Check if the status has multiple words (e.g., "running (healthy)") + if '(' in parts[status_index+1] : + # Combine the status field if it's split over two parts + status = f"{parts[status_index]} {parts[status_index + 1]}" + ports = parts[status_index + 2] + else: + status = parts[status_index] + ports = parts[status_index + 1] + + # Store both status and ports in a dictionary for each service + services[service_name] = { + 'status': status, + 'ports': ports + } + + return services + + except subprocess.CalledProcessError as e: + print(f"Failed to check service status. Error:\n{e.stderr}") + return e + +#Helper method to print service status +def print_service_status(services): + print(f"{'Service Name':<25} {'Status':<20} {'Ports'}") + print("="*72) + for service_name, info in services.items(): + status = info['status'] + ports = info['ports'] + print(f"{service_name:<25} {status:<20} {ports}") + +#check for states based on the states passed in +def check_services_state(services, states): + for service_name, service_info in services.items(): + status = service_info['status'].lower() # Convert status to lowercase for easier comparison + if any(state in status for state in states): + return True + + return False diff --git a/demos/function_calling/README.md b/demos/function_calling/README.md index 86005388..9dcf3d67 100644 --- a/demos/function_calling/README.md +++ b/demos/function_calling/README.md @@ -28,7 +28,7 @@ This demo shows how you can use intelligent prompt gateway to do function callin - On this dashboard you can see reuqest latency and number of requests # Observability -Arch gateway publishes stats endpoint at http://localhost:19901/stats. In this demo we are using prometheus to pull stats from envoy and we are using grafan to visalize the stats in dashboard. To see grafana dashboard follow instructions below, +Arch gateway publishes stats endpoint at http://localhost:19901/stats. In this demo we are using prometheus to pull stats from arch and we are using grafana to visalize the stats in dashboard. To see grafana dashboard follow instructions below, 1. Start grafana and prometheus using following command ```yaml diff --git a/demos/function_calling/arch_config.yaml b/demos/function_calling/arch_config.yaml index 28ada761..ca550fbf 100644 --- a/demos/function_calling/arch_config.yaml +++ b/demos/function_calling/arch_config.yaml @@ -8,7 +8,7 @@ listener: endpoints: api_server: - endpoint: api_server:80 + endpoint: host.docker.internal:18083 connect_timeout: 0.005s overrides: @@ -17,16 +17,17 @@ overrides: llm_providers: - name: open-ai-gpt-4 - access_key: $OPENAI_ACCESS_KEY + access_key: $OPENAI_API_KEY provider: openai model: gpt-4 default: true - name: mistral-large-latest - access_key: $MISTRAL_ACCESS_KEY + access_key: $MISTRAL_API_KEY provider: mistral model: large-latest -system_prompt: You are a helpful assistant. +system_prompt: | + You are a helpful assistant. prompt_targets: - name: weather_forecast diff --git a/demos/function_calling/docker-compose.yaml b/demos/function_calling/docker-compose.yaml index 3cc689e5..8cbe0da8 100644 --- a/demos/function_calling/docker-compose.yaml +++ b/demos/function_calling/docker-compose.yaml @@ -1,54 +1,4 @@ - -x-variables: &common-vars - environment: - - MODE=${MODE:-cloud} # Set the default mode to 'cloud', others values are local-gpu, local-cpu - - services: - - arch: - build: - context: ../../ - dockerfile: arch/Dockerfile - ports: - - "10000:10000" - - "19901:9901" - volumes: - - /etc/ssl/cert.pem:/etc/ssl/cert.pem - - ./arch_log:/var/log/ - - ./arch_config.yaml:/config/arch_config.yaml - depends_on: - # config_generator: - # condition: service_completed_successfully - model_server: - condition: service_healthy - environment: - - LOG_LEVEL=debug - - OPENAI_API_KEY=${OPENAI_API_KEY:?error} - - MISTRAL_API_KEY=${MISTRAL_API_KEY:?error} - - model_server: - build: - context: ../../model_server - dockerfile: Dockerfile - ports: - - "18081:80" - healthcheck: - test: ["CMD", "curl" ,"http://localhost/healthz"] - interval: 5s - retries: 20 - volumes: - - ~/.cache/huggingface:/root/.cache/huggingface - - ./arch_config.yaml:/root/arch_config.yaml - << : *common-vars - environment: - - OLLAMA_ENDPOINT=${OLLAMA_ENDPOINT:-host.docker.internal} - - FC_URL=${FC_URL:-https://arch-fc-free-trial-4mzywewe.uc.gateway.dev/v1} - - OLLAMA_MODEL=Arch-Function-Calling-3B-Q4_K_M - - MODE=${MODE:-cloud} - # uncomment following line to use ollama endpoint that is hosted by docker - # - OLLAMA_ENDPOINT=ollama - # - OLLAMA_MODEL=Arch-Function-Calling-1.5B:Q4_K_M api_server: build: context: api_server @@ -60,45 +10,16 @@ services: interval: 5s retries: 20 - ollama: - image: ollama/ollama - container_name: ollama - volumes: - - ./ollama:/root/.ollama - restart: unless-stopped - ports: - - '11434:11434' - profiles: - - manual - - open_webui: - image: ghcr.io/open-webui/open-webui:${WEBUI_DOCKER_TAG-main} - container_name: open-webui - volumes: - - ./open-webui:/app/backend/data - # depends_on: - # - ollama - ports: - - 18090:8080 - environment: - - OLLAMA_BASE_URL=http://${OLLAMA_ENDPOINT:-host.docker.internal}:11434 - - WEBUI_AUTH=false - extra_hosts: - - host.docker.internal:host-gateway - restart: unless-stopped - profiles: - - monitoring - chatbot_ui: build: context: ../../chatbot_ui dockerfile: Dockerfile ports: - - "18080:8080" + - "18090:8080" environment: - OPENAI_API_KEY=${OPENAI_API_KEY:?error} - MISTRAL_API_KEY=${MISTRAL_API_KEY:?error} - - CHAT_COMPLETION_ENDPOINT=http://arch:10000/v1 + - CHAT_COMPLETION_ENDPOINT=http://host.docker.internal:10000/v1 prometheus: image: prom/prometheus diff --git a/demos/insurance_agent/README.md b/demos/insurance_agent/README.md new file mode 100644 index 00000000..8ee550a8 --- /dev/null +++ b/demos/insurance_agent/README.md @@ -0,0 +1 @@ +The following demo diff --git a/demos/insurance_agent/arch_confirg.yaml b/demos/insurance_agent/arch_confirg.yaml new file mode 100644 index 00000000..48267d0b --- /dev/null +++ b/demos/insurance_agent/arch_confirg.yaml @@ -0,0 +1,89 @@ +version: "0.1-beta" +listener: + address: 127.0.0.1 + port: 8080 #If you configure port 443, you'll need to update the listener with tls_certificates + message_format: huggingface + +system_prompt: | + You are an insurance assistant that just offers guidance related to car, boat, rental and home insurnace only. + +llm_providers: + - name: "OpenAI" + access_key: $OPEN_AI_KEY + model: gpt-4o + default: true + +# Arch creates a round-robin load balancing between different endpoints, managed via the cluster subsystem. +endpoints: + app_server: + # value could be ip address or a hostname with port + # this could also be a list of endpoints for load balancing + # for example endpoint: [ ip1:port, ip2:port ] + endpoint: "127.0.0.1:80" + # max time to wait for a connection to be established + connect_timeout: 500ms + +prompt_targets: + - name: policy_qa + endpoint: + name: app_server + path: /policy/qa + description: "This method handles Q/A related to general issues in insurance. It forwards the conversation to the OpenAI client via a local proxy and returns the response." + default: true + + - name: get_policy_coverage + description: Retrieve the coverage details for a given policy type (car, boat, house, motorcycle). + endpoint: + name: app_server + path: /policy/coverage + parameters: + - name: policy_type + type: str + description: The + default: 'car' + required: true + + - name: initiate_policy + endpoint: + name: app_server + path: /policy/initiate + description: Initiate policy coverage for a car, boat, house, or motorcycle. + parameters: + - name: policy_type + type: str + description: Field definition from Pydantic model. Requires fixes PolicyRequest + required: true + - name: details + type: Unknown + description: Field definition from Pydantic model. Requires fixes PolicyRequest + required: false + + - name: update_claim + endpoint: + name: app_server + path: /policy/claim + description: Update the status or details of a claim. + parameters: + - name: claim_id + type: int + description: Field definition from Pydantic model. Requires fixes ClaimUpdate + required: true + - name: update + type: str + description: Field definition from Pydantic model. Requires fixes ClaimUpdate + required: false + + - name: update_deductible + endpoint: + name: app_server + path: /policy/deductible + description: Update the deductible amount for a specific policy. + parameters: + - name: policy_id + type: int + description: Field definition from Pydantic model. Requires fixes DeductibleUpdate + required: true + - name: new_deductible + type: float + description: Field definition from Pydantic model. Requires fixes DeductibleUpdate + required: false diff --git a/demos/insurance_agent/insurance_agent_main.py b/demos/insurance_agent/insurance_agent_main.py new file mode 100644 index 00000000..65068adb --- /dev/null +++ b/demos/insurance_agent/insurance_agent_main.py @@ -0,0 +1,122 @@ +import openai +from fastapi import FastAPI, HTTPException +from pydantic import BaseModel +from typing import Optional + +app = FastAPI() +openai.api_base = "http://127.0.0.1:10000/v1" # Local proxy + +# Data models + +class PolicyCoverageRequest(BaseModel): + policy_type: str = Field(..., description="The type of a policy held by the customer For, e.g. car, boat, house, motorcycle)") + +class PolicyRequest(BaseModel): + policy_type: str = Field(..., description="The type of a policy held by the customer For, e.g. car, boat, house, motorcycle)") + details: str # Additional details like model, year, etc. + +class ClaimUpdate(BaseModel): + policy_id: int + claim_id: int + update: str # Status or details of the claim + +class DeductibleUpdate(BaseModel): + policy_id: int + new_deductible: float + +class CoverageResponse(BaseModel): + policy_type: str + coverage: str # Description of coverage + premium: float # The premium cost + +# Get information about policy coverage +@app.post("/policy/coverage", response_model=CoverageResponse) +async def get_policy_coverage(req: PolicyCoverageRequest): + """ + Retrieve the coverage details for a given policy type (car, boat, house, motorcycle). + """ + policy_coverage = { + "car": {"coverage": "Full car coverage with collision, liability", "premium": 500.0}, + "boat": {"coverage": "Full boat coverage including theft and storm damage", "premium": 700.0}, + "house": {"coverage": "Full house coverage including fire, theft, flood", "premium": 1000.0}, + "motorcycle": {"coverage": "Full motorcycle coverage with liability", "premium": 400.0}, + } + + if req.policy_type not in policy_coverage: + raise HTTPException(status_code=404, detail="Policy type not found") + + return CoverageResponse( + policy_type=req.policy_type, + coverage=policy_coverage[req.policy_type]["coverage"], + premium=policy_coverage[req.policy_type]["premium"] + ) + +# Initiate policy coverage +@app.post("/policy/initiate") +async def initiate_policy(policy_request: PolicyRequest): + """ + Initiate policy coverage for a car, boat, house, or motorcycle. + """ + if policy_request.policy_type not in ["car", "boat", "house", "motorcycle"]: + raise HTTPException(status_code=400, detail="Invalid policy type") + + return {"message": f"Policy initiated for {policy_request.policy_type}", "details": policy_request.details} + +# Update claim details +@app.post("/policy/claim") +async def update_claim(req: ClaimUpdate): + """ + Update the status or details of a claim. + """ + # For simplicity, this is a mock update response + return {"message": f"Claim {claim_update.claim_id} for policy {claim_update.policy_id} has been updated", + "update": claim_update.update} + +# Update deductible amount +@app.post("/policy/deductible") +async def update_deductible(deductible_update: DeductibleUpdate): + """ + Update the deductible amount for a specific policy. + """ + # For simplicity, this is a mock update response + return {"message": f"Deductible for policy {deductible_update.policy_id} has been updated", + "new_deductible": deductible_update.new_deductible} + +# Post method for policy Q/A +@app.post("/policy/qa") +async def policy_qa(): + """ + This method handles Q/A related to general issues in insurance. + It forwards the conversation to the OpenAI client via a local proxy and returns the response. + """ + try: + # Get the latest user message from the conversation + user_message = conversation.messages[-1].content # Assuming the last message is from the user + + # Call the OpenAI API through the Python client + response = openai.Completion.create( + model="gpt-4o", # Replace with the model you want to use + prompt=user_message, + max_tokens=150 + ) + + # Extract the response text from OpenAI + completion = response.choices[0].text.strip() + + # Build the assistant's response message + assistant_message = Message(role="assistant", content=completion) + + # Append the assistant's response to the conversation and return it + updated_conversation = Conversation( + messages=conversation.messages + [assistant_message] + ) + + return updated_conversation + + except openai.error.OpenAIError as e: + raise HTTPException(status_code=500, detail=f"LLM error: {str(e)}") + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error: {str(e)}") + +# Run the app using: +# uvicorn main:app --reload diff --git a/gateway.code-workspace b/gateway.code-workspace index 889d306a..e864caad 100644 --- a/gateway.code-workspace +++ b/gateway.code-workspace @@ -20,6 +20,10 @@ "name": "demos/function_calling", "path": "./demos/function_calling", }, + { + "name": "demos/insurance_agent", + "path": "./demos/insurance_agent", + }, { "name": "demos/function_calling/api_server", "path": "./demos/function_calling/api_server", diff --git a/model_server/app/arch_fc/arch_fc.py b/model_server/app/arch_fc/arch_fc.py index ec505eeb..a0216294 100644 --- a/model_server/app/arch_fc/arch_fc.py +++ b/model_server/app/arch_fc/arch_fc.py @@ -36,7 +36,8 @@ base_url=fc_url, api_key="EMPTY", ) - chosen_model = "fc-cloud" + models = client.models.list() + chosen_model = models.data[0].id endpoint = fc_url else: client = OpenAI( @@ -50,7 +51,6 @@ logger.info(f"using model: {chosen_model}") logger.info(f"using endpoint: {endpoint}") - async def chat_completion(req: ChatMessage, res: Response): logger.info("starting request") tools_encoded = handler._format_system(req.tools) diff --git a/model_server/app/main.py b/model_server/app/main.py index b9dda9d8..3c529f7a 100644 --- a/model_server/app/main.py +++ b/model_server/app/main.py @@ -26,8 +26,7 @@ with open("guard_model_config.yaml") as f: guard_model_config = yaml.safe_load(f) -with open('/root/arch_config.yaml') as f: - config = yaml.safe_load(f) + mode = os.getenv("MODE", "cloud") logger.info(f"Serving model mode: {mode}") if mode not in ['cloud', 'local-gpu', 'local-cpu']: @@ -37,20 +36,11 @@ else: hardware = "gpu" if torch.cuda.is_available() else "cpu" -if "prompt_guards" in config.keys(): - task = list(config["prompt_guards"]["input_guards"].keys())[0] - - hardware = "gpu" if torch.cuda.is_available() else "cpu" - jailbreak_model = load_guard_model( - guard_model_config["jailbreak"][hardware], hardware - ) - toxic_model = None - - guard_handler = GuardHandler(toxic_model=toxic_model, jailbreak_model=jailbreak_model) +jailbreak_model = load_guard_model(guard_model_config["jailbreak"][hardware], hardware) +guard_handler = GuardHandler(toxic_model=None, jailbreak_model=jailbreak_model) app = FastAPI() - class EmbeddingRequest(BaseModel): input: str model: str diff --git a/model_server/guard_model_config.yaml b/model_server/guard_model_config.yaml index 5c0d7802..590fafaa 100644 --- a/model_server/guard_model_config.yaml +++ b/model_server/guard_model_config.yaml @@ -1,3 +1,3 @@ jailbreak: - cpu: "katanemolabs/jailbreak_ovn_4bit" - gpu: "katanemolabs/Bolt-Guard-EEtq" + cpu: "katanemolabs/Arch-Guard-cpu" + gpu: "katanemolabs/Arch-Guard-gpu" diff --git a/model_server/openai_params.yaml b/model_server/openai_params.yaml index 342c3f41..6a5f8b2f 100644 --- a/model_server/openai_params.yaml +++ b/model_server/openai_params.yaml @@ -1,7 +1,6 @@ params: temperature: 0.01 top_p : 0.5 - repetition_penalty: 1.0 top_k: 50 max_tokens: 512 stop_token_ids: [151645, 151643] diff --git a/model_server/requirements.txt b/model_server/requirements.txt index 1320d843..79ec8e71 100644 --- a/model_server/requirements.txt +++ b/model_server/requirements.txt @@ -16,3 +16,4 @@ dateparser openai pandas tf-keras +onnx