Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ArchFC endpoint integration #94

Merged
merged 14 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions arch/src/stream_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,9 @@ impl StreamContext {

let model_resp = &arch_fc_response.choices[0];

if model_resp.message.tool_calls.is_none() {
if model_resp.message.tool_calls.is_none()
|| model_resp.message.tool_calls.as_ref().unwrap().is_empty()
{
// This means that Arch FC did not have enough information to resolve the function call
// Arch FC probably responded with a message asking for more information.
// Let's send the response back to the user to initalize lightweight dialog for parameter collection
Expand All @@ -494,12 +496,6 @@ impl StreamContext {
}

let tool_calls = model_resp.message.tool_calls.as_ref().unwrap();
if tool_calls.is_empty() {
return self.send_server_error(
"No tool calls found in function resolver response".to_string(),
Some(StatusCode::BAD_REQUEST),
);
}

debug!("tool_call_details: {:?}", tool_calls);
// extract all tool names
Expand Down
9 changes: 8 additions & 1 deletion demos/function_calling/docker-compose.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
x-variables: &common-vars
environment:
- MODE=${MODE:-cloud} # Set the default mode to 'cloud', others values are local-gpu, local-cpu

services:

config_generator:
Expand Down Expand Up @@ -34,6 +38,7 @@ services:
dockerfile: Dockerfile
ports:
- "18081:80"
<<: *common-vars
healthcheck:
test: ["CMD", "curl" ,"http://localhost/healthz"]
interval: 5s
Expand All @@ -48,6 +53,7 @@ services:
dockerfile: Dockerfile
ports:
- "18082:80"
<<: *common-vars
healthcheck:
test: ["CMD", "curl" ,"http://localhost:80/healthz"]
interval: 5s
Expand All @@ -57,11 +63,12 @@ services:
environment:
# use ollama endpoint that is hosted by host machine (no virtualization)
- OLLAMA_ENDPOINT=${OLLAMA_ENDPOINT:-host.docker.internal}
- FC_URL=${FC_URL:}
- OLLAMA_MODEL=Arch-Function-Calling-3B-Q4_K_M
- MODE=${MODE:-cloud}
# uncomment following line to use ollama endpoint that is hosted by docker
# - OLLAMA_ENDPOINT=ollama
# - OLLAMA_MODEL=Arch-Function-Calling-1.5B:Q4_K_M

api_server:
build:
context: api_server
Expand Down
4 changes: 3 additions & 1 deletion function_resolver/app/bolt_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,9 @@ def extract_tools(self, content, executable=False):
if isinstance(tool_call, dict):
try:
if not executable:
extracted_tools.append({tool_call["name"]: tool_call["arguments"]})
extracted_tools.append(
{tool_call["name"]: tool_call["arguments"]}
)
else:
name, arguments = (
tool_call.get("name", ""),
Expand Down
2 changes: 2 additions & 0 deletions function_resolver/app/common.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from typing import Any, Dict, List
from pydantic import BaseModel


class Message(BaseModel):
role: str
content: str


class ChatMessage(BaseModel):
messages: list[Message]
tools: List[Dict[str, Any]]
75 changes: 50 additions & 25 deletions function_resolver/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,36 +5,56 @@
from bolt_handler import BoltHandler
from common import ChatMessage
import logging
import yaml
from openai import OpenAI
import os


with open("openai_params.yaml") as f:
params = yaml.safe_load(f)

ollama_endpoint = os.getenv("OLLAMA_ENDPOINT", "localhost")
ollama_model = os.getenv("OLLAMA_MODEL", "Arch-Function-Calling-1.5B-Q4_K_M")
logger = logging.getLogger('uvicorn.error')
fc_url = os.getenv("FC_URL", ollama_endpoint)
mode = os.getenv("MODE", "cloud")
if mode not in ["cloud", "local-gpu", "local-cpu"]:
raise ValueError(f"Invalid mode: {mode}")
arch_api_key = os.getenv("ARCH_API_KEY", "")
logger = logging.getLogger("uvicorn.error")

handler = None
if ollama_model.startswith("Arch"):
handler = ArchHandler()
handler = ArchHandler()
else:
handler = BoltHandler()

logger.info(f"using model: {ollama_model}")
logger.info(f"using ollama endpoint: {ollama_endpoint}")

app = FastAPI()

client = OpenAI(
base_url='http://{}:11434/v1/'.format(ollama_endpoint),
if mode == "cloud":
client = OpenAI(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if its cloud ensure that api_key is set

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for now api_key is not required

base_url=fc_url,
api_key=arch_api_key,
)
models = client.models.list()
model = models.data[0].id
chosen_model = model
endpoint = fc_url
else:
client = OpenAI(
base_url="http://{}:11434/v1/".format(ollama_endpoint),
api_key="ollama",
)
chosen_model = ollama_model
endpoint = ollama_endpoint
logger.info(f"serving mode: {mode}")
logger.info(f"using model: {chosen_model}")
logger.info(f"using endpoint: {endpoint}")

# required but ignored
api_key='ollama',
)

@app.get("/healthz")
async def healthz():
return {
"status": "ok"
}
return {"status": "ok"}


@app.post("/v1/chat/completions")
Expand All @@ -45,23 +65,28 @@ async def chat_completion(req: ChatMessage, res: Response):
messages = [{"role": "system", "content": tools_encoded}]
for message in req.messages:
messages.append({"role": message.role, "content": message.content})
logger.info(f"request model: {ollama_model}, messages: {json.dumps(messages)}")
resp = client.chat.completions.create(messages=messages, model=ollama_model, stream=False)
logger.info(f"request model: {chosen_model}, messages: {json.dumps(messages)}")
completions_params = params["params"]
resp = client.chat.completions.create(
messages=messages,
model=chosen_model,
stream=False,
extra_body=completions_params,
)
tools = handler.extract_tools(resp.choices[0].message.content)
tool_calls = []
for tool in tools:
for tool_name, tool_args in tool.items():
tool_calls.append({
"id": f"call_{random.randint(1000, 10000)}",
"type": "function",
"function": {
"name": tool_name,
"arguments": tool_args
}
})
for tool_name, tool_args in tool.items():
tool_calls.append(
{
"id": f"call_{random.randint(1000, 10000)}",
"type": "function",
"function": {"name": tool_name, "arguments": tool_args},
}
)
if tools:
resp.choices[0].message.tool_calls = tool_calls
resp.choices[0].message.content = None
resp.choices[0].message.tool_calls = tool_calls
resp.choices[0].message.content = None
logger.info(f"response (tools): {json.dumps(tools)}")
logger.info(f"response: {json.dumps(resp.to_dict())}")
return resp
8 changes: 8 additions & 0 deletions function_resolver/app/openai_params.yaml
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it should be per model - like openai_params_arch_guard_1.5b - or sth similar

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the model name is the path file in vm docker instance

Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
params:
temperature: 0.0001
top_p : 0.5
repetition_penalty: 1.0
top_k: 50
max_tokens: 128
stop: ["<|im_start|>", "<|im_end|>"]
stop_token_ids: [151645, 151643]
2 changes: 0 additions & 2 deletions model_server/app/install.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,5 @@
load_transformers()
print("installing ner models")
load_ner_models()
print("installing toxic models")
load_toxic_model()
print("installing jailbreak models")
load_jailbreak_model()
39 changes: 15 additions & 24 deletions model_server/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
load_guard_model,
load_zero_shot_models,
)
import os
from utils import GuardHandler, split_text_into_chunks
import torch
import yaml
Expand All @@ -26,33 +27,23 @@
config = yaml.safe_load(file)
with open("guard_model_config.yaml") as f:
guard_model_config = yaml.safe_load(f)
mode = os.getenv("MODE", "cloud")
logger.info(f"Serving model mode: {mode}")
if mode not in ['cloud', 'local-gpu', 'local-cpu']:
raise ValueError(f"Invalid mode: {mode}")
if mode == 'local-cpu':
hardware = 'cpu'
else:
hardware = "gpu" if torch.cuda.is_available() else "cpu"

if "prompt_guards" in config.keys():
if len(config["prompt_guards"]["input_guards"]) == 2:
task = "both"
jailbreak_hardware = "gpu" if torch.cuda.is_available() else "cpu"
toxic_hardware = "gpu" if torch.cuda.is_available() else "cpu"
toxic_model = load_guard_model(
guard_model_config["toxic"][jailbreak_hardware], toxic_hardware
)
jailbreak_model = load_guard_model(
guard_model_config["jailbreak"][toxic_hardware], jailbreak_hardware
)
task = list(config["prompt_guards"]["input_guards"].keys())[0]

else:
task = list(config["prompt_guards"]["input_guards"].keys())[0]

hardware = "gpu" if torch.cuda.is_available() else "cpu"
if task == "toxic":
toxic_model = load_guard_model(
guard_model_config["toxic"][hardware], hardware
)
jailbreak_model = None
elif task == "jailbreak":
jailbreak_model = load_guard_model(
guard_model_config["jailbreak"][hardware], hardware
)
toxic_model = None
hardware = "gpu" if torch.cuda.is_available() else "cpu"
jailbreak_model = load_guard_model(
guard_model_config["jailbreak"][hardware], hardware
)
toxic_model = None


guard_handler = GuardHandler(toxic_model, jailbreak_model)
Expand Down
Loading