Skip to content

Commit

Permalink
Hallucination integration with rust (#122)
Browse files Browse the repository at this point in the history
  • Loading branch information
cotran2 authored Oct 8, 2024
1 parent 43dc2a0 commit b1fa127
Show file tree
Hide file tree
Showing 9 changed files with 278 additions and 56 deletions.
21 changes: 1 addition & 20 deletions arch/docker-compose.dev.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,9 @@ services:
- "10000:10000"
- "19901:9901"
volumes:
- ${ARCH_CONFIG_FILE:-./demos/function_calling/arch_config.yaml}:/config/arch_config.yaml
- ${ARCH_CONFIG_FILE:-../demos/function_calling/arch_config.yaml}:/config/arch_config.yaml
- /etc/ssl/cert.pem:/etc/ssl/cert.pem
- ./envoy.template.dev.yaml:/config/envoy.template.yaml
- ./target/wasm32-wasi/release/intelligent_prompt_gateway.wasm:/etc/envoy/proxy-wasm-plugins/intelligent_prompt_gateway.wasm
depends_on:
model_server:
condition: service_healthy
env_file:
- stage.env

model_server:
image: model_server:latest
ports:
- "18081:80"
healthcheck:
test: ["CMD", "curl" ,"http://localhost/healthz"]
interval: 5s
retries: 20
volumes:
- ~/.cache/huggingface:/root/.cache/huggingface
environment:
- OLLAMA_ENDPOINT=${OLLAMA_ENDPOINT:-host.docker.internal}
- OLLAMA_MODEL=Arch-Function-Calling-3B-Q4_K_M
- MODE=${MODE:-cloud}
- FC_URL=${FC_URL:-https://arch-fc-free-trial-4mzywewe.uc.gateway.dev/v1}
11 changes: 3 additions & 8 deletions arch/envoy.template.dev.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ static_resources:
typed_config:
"@type": type.googleapis.com/envoy.extensions.filters.network.http_connection_manager.v3.HttpConnectionManager
stat_prefix: arch_ingress_http
codec_type: HTTP1
codec_type: AUTO
scheme_header_transformation:
scheme_to_overwrite: https
access_log:
Expand Down Expand Up @@ -72,11 +72,6 @@ static_resources:
type: LOGICAL_DNS
dns_lookup_family: V4_ONLY
lb_policy: ROUND_ROBIN
# typed_extension_protocol_options:
# envoy.extensions.upstreams.http.v3.HttpProtocolOptions:
# "@type": type.googleapis.com/envoy.extensions.upstreams.http.v3.HttpProtocolOptions
# explicit_http_config:
# http2_protocol_options: {}
load_assignment:
cluster_name: openai
endpoints:
Expand Down Expand Up @@ -129,7 +124,7 @@ static_resources:
address:
socket_address:
address: host.docker.internal
port_value: 8000
port_value: 51000
hostname: "model_server"
- name: mistral_7b_instruct
connect_timeout: 5s
Expand Down Expand Up @@ -159,7 +154,7 @@ static_resources:
address:
socket_address:
address: host.docker.internal
port_value: 8000
port_value: 51000
hostname: "arch_fc"
{% for _, cluster in arch_clusters.items() %}
- name: {{ cluster.name }}
Expand Down
2 changes: 2 additions & 0 deletions arch/src/consts.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
pub const DEFAULT_EMBEDDING_MODEL: &str = "BAAI/bge-large-en-v1.5";
pub const DEFAULT_INTENT_MODEL: &str = "tasksource/deberta-base-long-nli";
pub const DEFAULT_PROMPT_TARGET_THRESHOLD: f64 = 0.8;
pub const DEFAULT_HALLUCINATED_THRESHOLD: f64 = 0.1;
pub const RATELIMIT_SELECTOR_HEADER_KEY: &str = "x-arch-ratelimit-selector";
pub const SYSTEM_ROLE: &str = "system";
pub const USER_ROLE: &str = "user";
Expand All @@ -13,3 +14,4 @@ pub const ARCH_MESSAGES_KEY: &str = "arch_messages";
pub const ARCH_PROVIDER_HINT_HEADER: &str = "x-arch-llm-provider-hint";
pub const CHAT_COMPLETIONS_PATH: &str = "v1/chat/completions";
pub const ARCH_STATE_HEADER: &str = "x-arch-state";
pub const ARCH_FC_MODEL_NAME: &str = "Arch-Function-1.5B";
182 changes: 169 additions & 13 deletions arch/src/stream_context.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use crate::consts::{
ARCH_FC_REQUEST_TIMEOUT_MS, ARCH_MESSAGES_KEY, ARCH_PROVIDER_HINT_HEADER, ARCH_ROUTING_HEADER,
ARCH_STATE_HEADER, ARC_FC_CLUSTER, CHAT_COMPLETIONS_PATH, DEFAULT_EMBEDDING_MODEL,
DEFAULT_INTENT_MODEL, DEFAULT_PROMPT_TARGET_THRESHOLD, GPT_35_TURBO, MODEL_SERVER_NAME,
ARCH_FC_MODEL_NAME, ARCH_FC_REQUEST_TIMEOUT_MS, ARCH_MESSAGES_KEY, ARCH_PROVIDER_HINT_HEADER,
ARCH_ROUTING_HEADER, ARCH_STATE_HEADER, ARC_FC_CLUSTER, CHAT_COMPLETIONS_PATH,
DEFAULT_EMBEDDING_MODEL, DEFAULT_HALLUCINATED_THRESHOLD, DEFAULT_INTENT_MODEL,
DEFAULT_PROMPT_TARGET_THRESHOLD, GPT_35_TURBO, MODEL_SERVER_NAME,
RATELIMIT_SELECTOR_HEADER_KEY, SYSTEM_ROLE, USER_ROLE,
};
use crate::filter_context::{EmbeddingsStore, WasmMetrics};
Expand All @@ -17,12 +18,13 @@ use proxy_wasm::traits::*;
use proxy_wasm::types::*;
use public_types::common_types::open_ai::{
ArchState, ChatCompletionChunkResponse, ChatCompletionTool, ChatCompletionsRequest,
ChatCompletionsResponse, FunctionDefinition, FunctionParameter, FunctionParameters, Message,
ParameterType, StreamOptions, ToolCall, ToolCallState, ToolType,
ChatCompletionsResponse, Choice, FunctionDefinition, FunctionParameter, FunctionParameters,
Message, ParameterType, StreamOptions, ToolCall, ToolCallState, ToolType,
};
use public_types::common_types::{
EmbeddingType, PromptGuardRequest, PromptGuardResponse, PromptGuardTask,
ZeroShotClassificationRequest, ZeroShotClassificationResponse,
EmbeddingType, HallucinationClassificationRequest, HallucinationClassificationResponse,
PromptGuardRequest, PromptGuardResponse, PromptGuardTask, ZeroShotClassificationRequest,
ZeroShotClassificationResponse,
};
use public_types::configuration::LlmProvider;
use public_types::configuration::{Overrides, PromptGuards, PromptTarget};
Expand All @@ -37,22 +39,24 @@ use std::num::NonZero;
use std::rc::Rc;
use std::time::Duration;

#[derive(Debug)]
#[derive(Debug, Clone)]
enum ResponseHandlerType {
GetEmbeddings,
FunctionResolver,
FunctionCall,
ZeroShotIntent,
HallucinationDetect,
ArchGuard,
DefaultTarget,
}

#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct StreamCallContext {
response_handler_type: ResponseHandlerType,
user_message: Option<String>,
prompt_target_name: Option<String>,
request_body: ChatCompletionsRequest,
tool_calls: Option<Vec<ToolCall>>,
similarity_scores: Option<Vec<(String, f64)>>,
upstream_cluster: Option<String>,
upstream_cluster_path: Option<String>,
Expand Down Expand Up @@ -310,6 +314,69 @@ impl StreamContext {
}
}

fn hallucination_classification_resp_handler(
&mut self,
body: Vec<u8>,
callout_context: StreamCallContext,
) {
let hallucination_response: HallucinationClassificationResponse =
match serde_json::from_slice(&body) {
Ok(hallucination_response) => hallucination_response,
Err(e) => {
return self.send_server_error(ServerError::Deserialization(e), None);
}
};
let mut keys_with_low_score: Vec<String> = Vec::new();
for (key, value) in &hallucination_response.params_scores {
if *value < DEFAULT_HALLUCINATED_THRESHOLD {
debug!(
"hallucination detected: score for {} : {} is less than threshold {}",
key, value, DEFAULT_HALLUCINATED_THRESHOLD
);
keys_with_low_score.push(key.clone().to_string());
}
}

if !keys_with_low_score.is_empty() {
let response =
"It seems I’m missing some information. Could you provide the following details: "
.to_string()
+ &keys_with_low_score.join(", ")
+ " ?";
let message = Message {
role: SYSTEM_ROLE.to_string(),
content: Some(response),
model: Some(ARCH_FC_MODEL_NAME.to_string()),
tool_calls: None,
};

let chat_completion_response = ChatCompletionsResponse {
choices: vec![Choice {
message,
index: 0,
finish_reason: "done".to_string(),
}],
usage: None,
model: ARCH_FC_MODEL_NAME.to_string(),
metadata: None,
};

debug!("hallucination response: {:?}", chat_completion_response);
self.send_http_response(
StatusCode::OK.as_u16().into(),
vec![("Powered-By", "Katanemo")],
Some(
serde_json::to_string(&chat_completion_response)
.unwrap()
.as_bytes(),
),
);
} else {
// not a hallucination, resume the flow
self.schedule_api_call_request(callout_context);
}
}

fn zero_shot_intent_detection_resp_handler(
&mut self,
body: Vec<u8>,
Expand Down Expand Up @@ -565,6 +632,9 @@ impl StreamContext {

let tool_calls = model_resp.message.tool_calls.as_ref().unwrap();

// TODO CO: pass nli check
// If hallucination, pass chat template to check parameters

// extract all tool names
let tool_names: Vec<String> = tool_calls
.iter()
Expand All @@ -581,17 +651,93 @@ impl StreamContext {
String::from(ARCH_MESSAGES_KEY),
serde_yaml::to_value(&callout_context.request_body.messages).unwrap(),
);

let tools_call_name = tool_calls[0].function.name.clone();
let tool_params_json_str = serde_json::to_string(&tool_params).unwrap();

let prompt_target = self.prompt_targets.get(&tools_call_name).unwrap().clone();
callout_context.tool_calls = Some(tool_calls.clone());

debug!(
"prompt_target_name: {}, tool_name(s): {:?}",
prompt_target.name, tool_names
);
debug!("tool_params: {}", tool_params_json_str);

if model_resp.message.tool_calls.is_some()
&& !model_resp.message.tool_calls.as_ref().unwrap().is_empty()
{
use serde_json::Value;
let v: Value = serde_json::from_str(&tool_params_json_str).unwrap();
let tool_params_dict: HashMap<String, String> = match v.as_object() {
Some(obj) => obj
.iter()
.filter_map(|(key, value)| {
value
.as_str()
.map(|str_value| (key.clone(), str_value.to_string()))
})
.collect(),
None => HashMap::new(), // Return an empty HashMap if v is not an object
};

let hallucination_classification_request = HallucinationClassificationRequest {
prompt: callout_context.user_message.as_ref().unwrap().clone(),
model: String::from(DEFAULT_INTENT_MODEL),
parameters: tool_params_dict,
};

let json_data: String =
match serde_json::to_string(&hallucination_classification_request) {
Ok(json_data) => json_data,
Err(error) => {
return self.send_server_error(ServerError::Serialization(error), None);
}
};
let call_args = CallArgs::new(
MODEL_SERVER_NAME,
"/hallucination",
vec![
(":method", "POST"),
(":path", "/hallucination"),
(":authority", MODEL_SERVER_NAME),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
("x-envoy-upstream-rq-timeout-ms", "60000"),
],
Some(json_data.as_bytes()),
vec![],
Duration::from_secs(5),
);
callout_context.response_handler_type = ResponseHandlerType::HallucinationDetect;

if let Err(e) = self.http_call(call_args, callout_context) {
self.send_server_error(ServerError::HttpDispatch(e), None);
}
} else {
self.schedule_api_call_request(callout_context);
}
}

fn schedule_api_call_request(&mut self, mut callout_context: StreamCallContext) {
let tools_call_name = callout_context.tool_calls.as_ref().unwrap()[0]
.function
.name
.clone();

let prompt_target = self.prompt_targets.get(&tools_call_name).unwrap().clone();

//HACK: for now we only support one tool call, we will support multiple tool calls in the future
let mut tool_params = callout_context.tool_calls.as_ref().unwrap()[0]
.function
.arguments
.clone();
tool_params.insert(
String::from(ARCH_MESSAGES_KEY),
serde_yaml::to_value(&callout_context.request_body.messages).unwrap(),
);

let tool_params_json_str = serde_json::to_string(&tool_params).unwrap();

let endpoint = prompt_target.endpoint.unwrap();
let path: String = endpoint.path.unwrap_or(String::from("/"));
let call_args = CallArgs::new(
Expand All @@ -612,8 +758,6 @@ impl StreamContext {
callout_context.upstream_cluster_path = Some(path.clone());
callout_context.response_handler_type = ResponseHandlerType::FunctionCall;

self.tool_calls = Some(tool_calls.clone());

if let Err(e) = self.http_call(call_args, callout_context) {
self.send_server_error(ServerError::HttpDispatch(e), Some(StatusCode::BAD_REQUEST));
}
Expand Down Expand Up @@ -806,6 +950,7 @@ impl StreamContext {
similarity_scores: None,
upstream_cluster: None,
upstream_cluster_path: None,
tool_calls: None,
};

if let Err(e) = self.http_call(call_args, call_context) {
Expand Down Expand Up @@ -1009,6 +1154,7 @@ impl HttpContext for StreamContext {
similarity_scores: None,
upstream_cluster: None,
upstream_cluster_path: None,
tool_calls: None,
};
self.get_embeddings(callout_context);
return Action::Pause;
Expand Down Expand Up @@ -1057,6 +1203,7 @@ impl HttpContext for StreamContext {
similarity_scores: None,
upstream_cluster: None,
upstream_cluster_path: None,
tool_calls: None,
};

if let Err(e) = self.http_call(call_args, call_context) {
Expand Down Expand Up @@ -1144,7 +1291,13 @@ impl HttpContext for StreamContext {
}
};

self.response_tokens += chat_completions_response.usage.completion_tokens;
if chat_completions_response.usage.is_some() {
self.response_tokens += chat_completions_response
.usage
.as_ref()
.unwrap()
.completion_tokens;
}

if let Some(tool_calls) = self.tool_calls.as_ref() {
if !tool_calls.is_empty() {
Expand Down Expand Up @@ -1239,6 +1392,9 @@ impl Context for StreamContext {
ResponseHandlerType::ZeroShotIntent => {
self.zero_shot_intent_detection_resp_handler(body, callout_context)
}
ResponseHandlerType::HallucinationDetect => {
self.hallucination_classification_resp_handler(body, callout_context)
}
ResponseHandlerType::FunctionResolver => {
self.function_resolver_handler(body, callout_context)
}
Expand Down
Loading

0 comments on commit b1fa127

Please sign in to comment.