Skip to content

Commit

Permalink
Merge branch 'shuguang/main' of https://github.com/katanemo/arch into…
Browse files Browse the repository at this point in the history
… shuguang/main
  • Loading branch information
cotran2 committed Dec 11, 2024
2 parents 188b0a8 + 60dfb18 commit 2405fb3
Show file tree
Hide file tree
Showing 17 changed files with 312 additions and 1,244 deletions.
5 changes: 2 additions & 3 deletions arch/envoy.template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,7 @@ static_resources:
domains:
- "*"
routes:

{% for internal_clustrer in ["embeddings", "zeroshot", "guard", "arch_fc", "hallucination"] %}
{% for internal_clustrer in ["embeddings", "zeroshot", "guard", "arch_fc", "hallucination", "model_server"] %}
- match:
prefix: "/"
headers:
Expand Down Expand Up @@ -449,7 +448,7 @@ static_resources:
typed_config:
"@type": type.googleapis.com/envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext
sni: api.mistral.ai
{% for internal_clustrer in ["embeddings", "zeroshot", "guard", "arch_fc", "hallucination"] %}
{% for internal_clustrer in ["embeddings", "zeroshot", "guard", "arch_fc", "hallucination", "model_server"] %}
- name: {{ internal_clustrer }}
connect_timeout: 5s
type: STRICT_DNS
Expand Down
1 change: 1 addition & 0 deletions arch/tools/cli/config_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def validate_and_render_schema():

rendered = template.render(data)
print(ENVOY_CONFIG_FILE_RENDERED)
print(rendered)
with open(ENVOY_CONFIG_FILE_RENDERED, "w") as file:
file.write(rendered)

Expand Down
10 changes: 5 additions & 5 deletions crates/common/src/api/open_ai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ pub struct ChatCompletionsRequest {
pub metadata: Option<HashMap<String, String>>,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum ToolType {
#[serde(rename = "function")]
Function,
Expand Down Expand Up @@ -165,8 +165,8 @@ pub struct Message {

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Choice {
pub finish_reason: String,
pub index: usize,
pub finish_reason: Option<String>,
pub index: Option<usize>,
pub message: Message,
}

Expand Down Expand Up @@ -217,8 +217,8 @@ impl ChatCompletionsResponse {
tool_calls: None,
tool_call_id: None,
},
index: 0,
finish_reason: "done".to_string(),
index: Some(0),
finish_reason: Some("done".to_string()),
}],
usage: None,
model: ARCH_FC_MODEL_NAME.to_string(),
Expand Down
113 changes: 112 additions & 1 deletion crates/common/src/configuration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@ use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fmt::Display;

use crate::api::open_ai::{
ChatCompletionTool, FunctionDefinition, FunctionParameter, FunctionParameters, ParameterType,
};

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Configuration {
pub version: String,
Expand Down Expand Up @@ -231,11 +235,46 @@ pub struct PromptTarget {
pub auto_llm_dispatch_on_response: Option<bool>,
}

// convert PromptTarget to ChatCompletionTool
impl From<&PromptTarget> for ChatCompletionTool {
fn from(val: &PromptTarget) -> Self {
let properties: HashMap<String, FunctionParameter> = match val.parameters {
Some(ref entities) => {
let mut properties: HashMap<String, FunctionParameter> = HashMap::new();
for entity in entities.iter() {
let param = FunctionParameter {
parameter_type: ParameterType::from(
entity.parameter_type.clone().unwrap_or("str".to_string()),
),
description: entity.description.clone(),
required: entity.required,
enum_values: entity.enum_values.clone(),
default: entity.default.clone(),
};
properties.insert(entity.name.clone(), param);
}
properties
}
None => HashMap::new(),
};

ChatCompletionTool {
tool_type: crate::api::open_ai::ToolType::Function,
function: FunctionDefinition {
name: val.name.clone(),
description: val.description.clone(),
parameters: FunctionParameters { properties },
},
}
}
}

#[cfg(test)]
mod test {
use pretty_assertions::assert_eq;
use std::fs;

use crate::configuration::GuardType;
use crate::{api::open_ai::ToolType, configuration::GuardType};

#[test]
fn test_deserialize_configuration() {
Expand Down Expand Up @@ -307,4 +346,76 @@ mod test {
let mode = config.mode.as_ref().unwrap_or(&super::GatewayMode::Prompt);
assert_eq!(*mode, super::GatewayMode::Prompt);
}

#[test]
fn test_tool_conversion() {
let ref_config = fs::read_to_string(
"../../docs/source/resources/includes/arch_config_full_reference.yaml",
)
.expect("reference config file not found");
let config: super::Configuration = serde_yaml::from_str(&ref_config).unwrap();
let prompt_targets = &config.prompt_targets;
let prompt_target = prompt_targets
.as_ref()
.unwrap()
.iter()
.find(|p| p.name == "reboot_network_device")
.unwrap();
let chat_completion_tool: super::ChatCompletionTool = prompt_target.into();
assert_eq!(chat_completion_tool.tool_type, ToolType::Function);
assert_eq!(chat_completion_tool.function.name, "reboot_network_device");
assert_eq!(
chat_completion_tool.function.description,
"Reboot a specific network device"
);
assert_eq!(chat_completion_tool.function.parameters.properties.len(), 2);
assert_eq!(
chat_completion_tool
.function
.parameters
.properties
.contains_key("device_id"),
true
);
assert_eq!(
chat_completion_tool
.function
.parameters
.properties
.get("device_id")
.unwrap()
.parameter_type,
crate::api::open_ai::ParameterType::String
);
assert_eq!(
chat_completion_tool
.function
.parameters
.properties
.get("device_id")
.unwrap()
.description,
"Identifier of the network device to reboot.".to_string()
);
assert_eq!(
chat_completion_tool
.function
.parameters
.properties
.get("device_id")
.unwrap()
.required,
Some(true)
);
assert_eq!(
chat_completion_tool
.function
.parameters
.properties
.get("confirmation")
.unwrap()
.parameter_type,
crate::api::open_ai::ParameterType::Bool
);
}
}
2 changes: 1 addition & 1 deletion crates/common/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ pub mod embeddings;
pub mod errors;
pub mod http;
pub mod llm_providers;
pub mod path;
pub mod pii;
pub mod ratelimit;
pub mod routing;
pub mod stats;
pub mod tokenizer;
pub mod tracing;
pub mod path;
13 changes: 7 additions & 6 deletions crates/common/src/path.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use std::collections::HashMap;

pub fn replace_params_in_path(path: &str, params: &HashMap<String, String>) -> Result<String, String> {
pub fn replace_params_in_path(
path: &str,
params: &HashMap<String, String>,
) -> Result<String, String> {
let mut result = String::new();
let mut in_param = false;
let mut current_param = String::new();
Expand All @@ -17,12 +20,10 @@ pub fn replace_params_in_path(path: &str, params: &HashMap<String, String>) -> R
return Err(format!("Missing value for parameter `{}`", param_name));
}
current_param.clear();
} else if in_param {
current_param.push(c);
} else {
if in_param {
current_param.push(c);
} else {
result.push(c);
}
result.push(c);
}
}

Expand Down
59 changes: 0 additions & 59 deletions crates/prompt_gateway/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,70 +19,11 @@ impl Context for StreamContext {
.expect("invalid token_id");
self.metrics.active_http_calls.increment(-1);

/*
state transition
graph LR
on_http_request_body --> prompt received
prompt received --> get embeddings & arch guard
arch guard --> get embeddings
get embeddings --> zeroshot intent
┌──────────────────────┐ ┌─────────────────┐ ┌────────────────┐ ┌─────────────────┐
│ │ │ │ │ │ │ │
│ on_http_request_body ├──►│ prompt received ├──►│ get embeddings ├──►│ zeroshot intent │
│ │ │ │ │ │ │ │
└──────────────────────┘ └────────┬────────┘ └────────────────┘ └─────────────────┘
│ ▲
│ │
│ │
│ ┌────────┴───────┐
│ │ │
└───────────►│ arch guard │
│ │
└────────────────┘
continue from zeroshot intent
graph LR
zeroshot intent --> arch_fc
zeroshot intent --> default prompt target
arch_fc --> developer api call & hallucination check
hallucination check --> parameter gathering & developer api call
developer api call --> resume request to llm
┌─────────────────┐ ┌───────────────────────┐ ┌─────────────────────┐ ┌───────────────────────┐
│ │ │ │ │ │ │ │
│ zeroshot intent ├──►│ arch_fc ├──►│ developer api call ├──►│ resume request to llm │
│ │ │ │ │ │ │ │
└────────┬────────┘ └───────────┬───────────┘ └─────────────────────┘ └───────────────────────┘
│ │ ▲
│ └─────────────┐ │
│ │ │
│ ┌───────────────────────┐ │ ┌──────────┴──────────┐ ┌───────────────────────┐
│ │ │ │ │ │ │ │
└───────────►│ default prompt target │ └▲│ hallucination check ├──►│ parameter gathering │
│ │ │ │ │ │
└───────────────────────┘ └─────────────────────┘ └───────────────────────┘
using https://mermaid-ascii.art/
*/

if let Some(body) = self.get_http_call_response_body(0, body_size) {
#[cfg_attr(any(), rustfmt::skip)]
match callout_context.response_handler_type {
ResponseHandlerType::ArchGuard => self.arch_guard_handler(body, callout_context),
ResponseHandlerType::Embeddings => self.embeddings_handler(body, callout_context),
ResponseHandlerType::ZeroShotIntent => self.zero_shot_intent_detection_resp_handler(body, callout_context),
ResponseHandlerType::ArchFC => self.arch_fc_response_handler(body, callout_context),
ResponseHandlerType::Hallucination => self.hallucination_classification_resp_handler(body, callout_context),
ResponseHandlerType::FunctionCall => self.api_call_response_handler(body, callout_context),
ResponseHandlerType::DefaultTarget =>self.default_target_handler(body, callout_context),
}
} else {
self.send_server_error(
Expand Down
5 changes: 0 additions & 5 deletions crates/prompt_gateway/src/embeddings.rs

This file was deleted.

Loading

0 comments on commit 2405fb3

Please sign in to comment.