Skip to content

Commit

Permalink
Add the ability to use LLM Providers from the Arch config (#112)
Browse files Browse the repository at this point in the history
Signed-off-by: José Ulises Niño Rivera <[email protected]>
  • Loading branch information
junr03 authored Oct 3, 2024
1 parent 1b57a49 commit 8ea917a
Show file tree
Hide file tree
Showing 16 changed files with 293 additions and 208 deletions.
11 changes: 6 additions & 5 deletions arch/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion arch/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ governor = { version = "0.6.3", default-features = false, features = ["no_std"]}
tiktoken-rs = "0.5.9"
acap = "0.3.0"
rand = "0.8.5"
thiserror = "1.0.64"

[dev-dependencies]
proxy-wasm-test-framework = { git = "https://github.com/katanemo/test-framework.git", branch = "main" }
proxy-wasm-test-framework = { git = "https://github.com/katanemo/test-framework.git", branch = "new" }
serial_test = "3.1.1"
7 changes: 5 additions & 2 deletions arch/arch_config_schema.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ properties:
properties:
name:
type: string
provider:
type: string
access_key:
type: string
model:
Expand All @@ -47,6 +49,7 @@ properties:
additionalProperties: false
required:
- name
- provider
- access_key
- model
overrides:
Expand Down Expand Up @@ -112,7 +115,7 @@ properties:
items:
type: object
properties:
provider:
model:
type: string
selector:
type: object
Expand All @@ -138,7 +141,7 @@ properties:
- unit
additionalProperties: false
required:
- provider
- model
- selector
- limit
additionalProperties: false
Expand Down
30 changes: 26 additions & 4 deletions arch/config_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,32 @@
ARCH_CONFIG_FILE = os.getenv('ARCH_CONFIG_FILE', '/config/arch_config.yaml')
ENVOY_CONFIG_FILE_RENDERED = os.getenv('ENVOY_CONFIG_FILE_RENDERED', '/etc/envoy/envoy.yaml')
ARCH_CONFIG_SCHEMA_FILE = os.getenv('ARCH_CONFIG_SCHEMA_FILE', 'arch_config_schema.yaml')
OPENAI_API_KEY = os.getenv('OPENAI_API_KEY', False)
MISTRAL_API_KEY = os.getenv('MISTRAL_API_KEY', False)

def add_secret_key_to_llm_providers(config_yaml) :
llm_providers = []
for llm_provider in config_yaml.get("llm_providers", []):
if llm_provider['access_key'] == "$MISTRAL_ACCESS_KEY":
llm_provider['access_key'] = MISTRAL_API_KEY
elif llm_provider['access_key'] == "$OPENAI_ACCESS_KEY":
llm_provider['access_key'] = OPENAI_API_KEY
else:
llm_provider.pop('access_key')
llm_providers.append(llm_provider)
config_yaml["llm_providers"] = llm_providers
return config_yaml

env = Environment(loader=FileSystemLoader('./'))
template = env.get_template('envoy.template.yaml')

with open(ARCH_CONFIG_FILE, 'r') as file:
katanemo_config = file.read()
arch_config_string = file.read()

with open(ARCH_CONFIG_SCHEMA_FILE, 'r') as file:
arch_config_schema = file.read()

config_yaml = yaml.safe_load(katanemo_config)
config_yaml = yaml.safe_load(arch_config_string)
config_schema_yaml = yaml.safe_load(arch_config_schema)

try:
Expand Down Expand Up @@ -54,9 +69,16 @@

print("updated clusters", inferred_clusters)

config_yaml = add_secret_key_to_llm_providers(config_yaml)
arch_llm_providers = config_yaml["llm_providers"]
arch_config_string = yaml.dump(config_yaml)

print("llm_providers:", arch_llm_providers)

data = {
'katanemo_config': katanemo_config,
'arch_clusters': inferred_clusters
'arch_config': arch_config_string,
'arch_clusters': inferred_clusters,
'arch_llm_providers': arch_llm_providers
}

rendered = template.render(data)
Expand Down
23 changes: 6 additions & 17 deletions arch/envoy.template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,26 +34,18 @@ static_resources:
auto_host_rewrite: true
cluster: mistral_7b_instruct
timeout: 60s
{% for provider in arch_llm_providers %}
- match:
prefix: "/v1/chat/completions"
prefix: "/"
headers:
- name: "x-arch-llm-provider"
string_match:
exact: openai
exact: {{ provider.name }}
route:
auto_host_rewrite: true
cluster: openai
timeout: 60s
- match:
prefix: "/v1/chat/completions"
headers:
- name: "x-arch-llm-provider"
string_match:
exact: mistral
route:
auto_host_rewrite: true
cluster: mistral
cluster: {{ provider.provider }}
timeout: 60s
{% endfor %}
http_filters:
- name: envoy.filters.http.wasm
typed_config:
Expand All @@ -65,7 +57,7 @@ static_resources:
configuration:
"@type": "type.googleapis.com/google.protobuf.StringValue"
value: |
{{ katanemo_config | indent(30) }}
{{ arch_config | indent(30) }}
vm_config:
runtime: "envoy.wasm.runtime.v8"
code:
Expand All @@ -75,9 +67,6 @@ static_resources:
typed_config:
"@type": type.googleapis.com/envoy.extensions.filters.http.router.v3.Router
clusters:
# LLM Host
# Embedding Providers
# External LLM Providers
- name: openai
connect_timeout: 5s
dns_lookup_family: V4_ONLY
Expand Down
1 change: 1 addition & 0 deletions arch/src/consts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ pub const ARCH_FC_REQUEST_TIMEOUT_MS: u64 = 120000; // 2 minutes
pub const MODEL_SERVER_NAME: &str = "model_server";
pub const ARCH_ROUTING_HEADER: &str = "x-arch-llm-provider";
pub const ARCH_MESSAGES_KEY: &str = "arch_messages";
pub const ARCH_PROVIDER_HINT_HEADER: &str = "x-arch-llm-provider-hint";
64 changes: 32 additions & 32 deletions arch/src/filter_context.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::consts::{DEFAULT_EMBEDDING_MODEL, MODEL_SERVER_NAME};
use crate::llm_providers::LlmProviders;
use crate::ratelimit;
use crate::stats::{Counter, Gauge, RecordingMetric};
use crate::stream_context::StreamContext;
Expand Down Expand Up @@ -44,10 +45,11 @@ pub struct FilterContext {
metrics: Rc<WasmMetrics>,
// callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request.
callouts: HashMap<u32, CallContext>,
config: Option<Configuration>,
overrides: Rc<Option<Overrides>>,
prompt_targets: Rc<RwLock<HashMap<String, PromptTarget>>>,
// This should be Option<Rc<PromptGuards>>, because StreamContext::new() should get an Rc<PromptGuards> not Option<Rc<PromptGuards>>.
prompt_guards: Rc<Option<PromptGuards>>,
llm_providers: Option<Rc<LlmProviders>>,
}

pub fn embeddings_store() -> &'static RwLock<HashMap<String, EmbeddingTypeMap>> {
Expand All @@ -62,11 +64,11 @@ impl FilterContext {
pub fn new() -> FilterContext {
FilterContext {
callouts: HashMap::new(),
config: None,
metrics: Rc::new(WasmMetrics::new()),
prompt_targets: Rc::new(RwLock::new(HashMap::new())),
overrides: Rc::new(None),
prompt_guards: Rc::new(Some(PromptGuards::default())),
llm_providers: None,
}
}

Expand Down Expand Up @@ -219,42 +221,35 @@ impl Context for FilterContext {
// RootContext allows the Rust code to reach into the Envoy Config
impl RootContext for FilterContext {
fn on_configure(&mut self, _: usize) -> bool {
if let Some(config_bytes) = self.get_plugin_configuration() {
self.config = serde_yaml::from_slice(&config_bytes).unwrap();
let config_bytes = self
.get_plugin_configuration()
.expect("Arch config cannot be empty");

if let Some(overrides_config) = self
.config
.as_mut()
.and_then(|config| config.overrides.as_mut())
{
self.overrides = Rc::new(Some(std::mem::take(overrides_config)));
}
let config: Configuration = match serde_yaml::from_slice(&config_bytes) {
Ok(config) => config,
Err(err) => panic!("Invalid arch config \"{:?}\"", err),
};

for pt in self.config.clone().unwrap().prompt_targets {
self.prompt_targets
.write()
.unwrap()
.insert(pt.name.clone(), pt.clone());
}
self.overrides = Rc::new(config.overrides);

debug!("set configuration object");
for pt in config.prompt_targets {
self.prompt_targets
.write()
.unwrap()
.insert(pt.name.clone(), pt.clone());
}

if let Some(ratelimits_config) = self
.config
.as_mut()
.and_then(|config| config.ratelimits.as_mut())
{
ratelimit::ratelimits(Some(std::mem::take(ratelimits_config)));
}
ratelimit::ratelimits(config.ratelimits);

if let Some(prompt_guards) = self
.config
.as_mut()
.and_then(|config| config.prompt_guards.as_mut())
{
self.prompt_guards = Rc::new(Some(std::mem::take(prompt_guards)));
}
if let Some(prompt_guards) = config.prompt_guards {
self.prompt_guards = Rc::new(Some(prompt_guards))
}

match config.llm_providers.try_into() {
Ok(llm_providers) => self.llm_providers = Some(Rc::new(llm_providers)),
Err(err) => panic!("{err}"),
}

true
}

Expand All @@ -269,6 +264,11 @@ impl RootContext for FilterContext {
Rc::clone(&self.prompt_targets),
Rc::clone(&self.prompt_guards),
Rc::clone(&self.overrides),
Rc::clone(
self.llm_providers
.as_ref()
.expect("LLM Providers must exist when Streams are being created"),
),
)))
}

Expand Down
Loading

0 comments on commit 8ea917a

Please sign in to comment.