Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ArchFC endpoint integration #94

Merged
merged 14 commits into from
Oct 1, 2024
Merged
2 changes: 1 addition & 1 deletion arch/config_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from jsonschema import validate

ENVOY_CONFIG_TEMPLATE_FILE = os.getenv('ENVOY_CONFIG_TEMPLATE_FILE', 'envoy.template.yaml')
ARCH_CONFIG_FILE = os.getenv('ARCH_CONFIG_FILE', '/config/arch_config.yaml')
ARCH_CONFIG_FILE = os.getenv('ARCH_CONFIG_FILE', '/root/arch_config.yaml')
ENVOY_CONFIG_FILE_RENDERED = os.getenv('ENVOY_CONFIG_FILE_RENDERED', '/etc/envoy/envoy.yaml')
ARCH_CONFIG_SCHEMA_FILE = os.getenv('ARCH_CONFIG_SCHEMA_FILE', 'arch_config_schema.yaml')

Expand Down
10 changes: 3 additions & 7 deletions arch/src/stream_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,9 @@ impl StreamContext {

let model_resp = &arch_fc_response.choices[0];

if model_resp.message.tool_calls.is_none() {
if model_resp.message.tool_calls.is_none()
|| model_resp.message.tool_calls.as_ref().unwrap().is_empty()
{
// This means that Arch FC did not have enough information to resolve the function call
// Arch FC probably responded with a message asking for more information.
// Let's send the response back to the user to initalize lightweight dialog for parameter collection
Expand All @@ -488,12 +490,6 @@ impl StreamContext {
}

let tool_calls = model_resp.message.tool_calls.as_ref().unwrap();
if tool_calls.is_empty() {
return self.send_server_error(
"No tool calls found in function resolver response".to_string(),
Some(StatusCode::BAD_REQUEST),
);
}

debug!("tool_call_details: {:?}", tool_calls);
// extract all tool names
Expand Down
14 changes: 12 additions & 2 deletions demos/function_calling/docker-compose.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@

x-variables: &common-vars
environment:
- MODE=${MODE:-cloud} # Set the default mode to 'cloud', others values are local-gpu, local-cpu


services:

arch:
Expand All @@ -11,7 +17,10 @@ services:
- ./generated/envoy.yaml:/etc/envoy/envoy.yaml
- /etc/ssl/cert.pem:/etc/ssl/cert.pem
- ./arch_log:/var/log/
- ./arch_config.yaml:/root/arch_config.yaml
depends_on:
# config_generator:
# condition: service_completed_successfully
model_server:
condition: service_healthy
environment:
Expand All @@ -30,14 +39,15 @@ services:
volumes:
- ~/.cache/huggingface:/root/.cache/huggingface
- ./arch_config.yaml:/root/arch_config.yaml
<< : *common-vars
environment:
- OLLAMA_ENDPOINT=${OLLAMA_ENDPOINT:-host.docker.internal}
- FC_URL=${FC_URL:-empty}
- OLLAMA_MODEL=Arch-Function-Calling-3B-Q4_K_M
# use ollama endpoint that is hosted by host machine (no virtualization)
- MODE=${MODE:-cloud}
# uncomment following line to use ollama endpoint that is hosted by docker
# - OLLAMA_ENDPOINT=ollama
# - OLLAMA_MODEL=Arch-Function-Calling-1.5B:Q4_K_M

api_server:
build:
context: api_server
Expand Down
71 changes: 49 additions & 22 deletions model_server/app/arch_fc/arch_fc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,30 +5,52 @@
from app.arch_fc.bolt_handler import BoltHandler
from app.arch_fc.common import ChatMessage
import logging
import yaml
from openai import OpenAI
import os


with open("openai_params.yaml") as f:
params = yaml.safe_load(f)

ollama_endpoint = os.getenv("OLLAMA_ENDPOINT", "localhost")
ollama_model = os.getenv("OLLAMA_MODEL", "Arch-Function-Calling-1.5B-Q4_K_M")
logger = logging.getLogger('uvicorn.error')
fc_url = os.getenv("FC_URL", ollama_endpoint)
mode = os.getenv("MODE", "cloud")
if mode not in ["cloud", "local-gpu", "local-cpu"]:
raise ValueError(f"Invalid mode: {mode}")
arch_api_key = os.getenv("ARCH_API_KEY", "")
logger = logging.getLogger("uvicorn.error")

handler = None
if ollama_model.startswith("Arch"):
handler = ArchHandler()
handler = ArchHandler()
else:
handler = BoltHandler()

logger.info(f"using model: {ollama_model}")
logger.info(f"using ollama endpoint: {ollama_endpoint}")

# app = FastAPI()

client = OpenAI(
base_url='http://{}:11434/v1/'.format(ollama_endpoint),
if mode == "cloud":
client = OpenAI(
base_url=fc_url,
api_key="EMPTY",
)
models = client.models.list()
model = models.data[0].id
chosen_model = model
endpoint = fc_url
else:
client = OpenAI(
base_url="http://{}:11434/v1/".format(ollama_endpoint),
api_key="ollama",
)
chosen_model = ollama_model
endpoint = ollama_endpoint
logger.info(f"serving mode: {mode}")
logger.info(f"using model: {chosen_model}")
logger.info(f"using endpoint: {endpoint}")

# required but ignored
api_key='ollama',
)


async def chat_completion(req: ChatMessage, res: Response):
Expand All @@ -38,23 +60,28 @@ async def chat_completion(req: ChatMessage, res: Response):
messages = [{"role": "system", "content": tools_encoded}]
for message in req.messages:
messages.append({"role": message.role, "content": message.content})
logger.info(f"request model: {ollama_model}, messages: {json.dumps(messages)}")
resp = client.chat.completions.create(messages=messages, model=ollama_model, stream=False)
logger.info(f"request model: {chosen_model}, messages: {json.dumps(messages)}")
completions_params = params["params"]
resp = client.chat.completions.create(
messages=messages,
model=chosen_model,
stream=False,
extra_body=completions_params,
)
tools = handler.extract_tools(resp.choices[0].message.content)
tool_calls = []
for tool in tools:
for tool_name, tool_args in tool.items():
tool_calls.append({
"id": f"call_{random.randint(1000, 10000)}",
"type": "function",
"function": {
"name": tool_name,
"arguments": tool_args
}
})
for tool_name, tool_args in tool.items():
tool_calls.append(
{
"id": f"call_{random.randint(1000, 10000)}",
"type": "function",
"function": {"name": tool_name, "arguments": tool_args},
}
)
if tools:
resp.choices[0].message.tool_calls = tool_calls
resp.choices[0].message.content = None
resp.choices[0].message.tool_calls = tool_calls
resp.choices[0].message.content = None
logger.info(f"response (tools): {json.dumps(tools)}")
logger.info(f"response: {json.dumps(resp.to_dict())}")
return resp
4 changes: 3 additions & 1 deletion model_server/app/arch_fc/bolt_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,9 @@ def extract_tools(self, content, executable=False):
if isinstance(tool_call, dict):
try:
if not executable:
extracted_tools.append({tool_call["name"]: tool_call["arguments"]})
extracted_tools.append(
{tool_call["name"]: tool_call["arguments"]}
)
else:
name, arguments = (
tool_call.get("name", ""),
Expand Down
2 changes: 2 additions & 0 deletions model_server/app/arch_fc/common.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from typing import Any, Dict, List
from pydantic import BaseModel


class Message(BaseModel):
role: str
content: str


class ChatMessage(BaseModel):
messages: list[Message]
tools: List[Dict[str, Any]]
2 changes: 0 additions & 2 deletions model_server/app/install.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,5 @@
load_transformers()
print("installing ner models")
load_ner_models()
print("installing toxic models")
load_toxic_model()
print("installing jailbreak models")
load_jailbreak_model()
28 changes: 21 additions & 7 deletions model_server/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
load_guard_model,
load_zero_shot_models,
)
import os
from app.utils import GuardHandler, split_text_into_chunks
import torch
import yaml
Expand All @@ -25,14 +26,27 @@

with open("guard_model_config.yaml") as f:
guard_model_config = yaml.safe_load(f)
with open('/root/arch_config.yaml') as f:
config = yaml.safe_load(f)
mode = os.getenv("MODE", "cloud")
logger.info(f"Serving model mode: {mode}")
if mode not in ['cloud', 'local-gpu', 'local-cpu']:
raise ValueError(f"Invalid mode: {mode}")
if mode == 'local-cpu':
hardware = 'cpu'
else:
hardware = "gpu" if torch.cuda.is_available() else "cpu"

if "prompt_guards" in config.keys():
task = list(config["prompt_guards"]["input_guards"].keys())[0]

hardware = "gpu" if torch.cuda.is_available() else "cpu"
jailbreak_model = load_guard_model(
guard_model_config["jailbreak"][hardware], hardware
)
toxic_model = None

task = "both"
hardware = "gpu" if torch.cuda.is_available() else "cpu"
jailbreak_model = load_guard_model(
guard_model_config["jailbreak"][hardware], hardware
)

guard_handler = GuardHandler(toxic_model=None, jailbreak_model=jailbreak_model)
guard_handler = GuardHandler(toxic_model=toxic_model, jailbreak_model=jailbreak_model)

app = FastAPI()

Expand Down
8 changes: 8 additions & 0 deletions model_server/openai_params.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
params:
temperature: 0.0001
top_p : 0.5
repetition_penalty: 1.0
top_k: 50
max_tokens: 128
stop: ["<|im_start|>", "<|im_end|>"]
stop_token_ids: [151645, 151643]
Loading