-
Notifications
You must be signed in to change notification settings - Fork 86
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Split arch wasm filter code into prompt and llm gateway filters (#190)
- Loading branch information
1 parent
8e54ac2
commit 21e7fe2
Showing
13 changed files
with
684 additions
and
2,789 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,13 @@ | ||
use filter_context::FilterContext; | ||
use llm_filter_context::LlmGatewayFilterContext; | ||
use proxy_wasm::traits::*; | ||
use proxy_wasm::types::*; | ||
|
||
mod filter_context; | ||
mod stream_context; | ||
mod llm_filter_context; | ||
mod llm_stream_context; | ||
|
||
proxy_wasm::main! {{ | ||
proxy_wasm::set_log_level(LogLevel::Trace); | ||
proxy_wasm::set_root_context(|_| -> Box<dyn RootContext> { | ||
Box::new(FilterContext::new()) | ||
Box::new(LlmGatewayFilterContext::new()) | ||
}); | ||
}} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
use crate::llm_stream_context::LlmGatewayStreamContext; | ||
use common::configuration::Configuration; | ||
use common::http::Client; | ||
use common::llm_providers::LlmProviders; | ||
use common::ratelimit; | ||
use common::stats::Counter; | ||
use common::stats::Gauge; | ||
use log::debug; | ||
use proxy_wasm::traits::*; | ||
use proxy_wasm::types::*; | ||
use std::cell::RefCell; | ||
use std::collections::HashMap; | ||
use std::rc::Rc; | ||
|
||
#[derive(Copy, Clone, Debug)] | ||
pub struct WasmMetrics { | ||
pub active_http_calls: Gauge, | ||
pub ratelimited_rq: Counter, | ||
} | ||
|
||
impl WasmMetrics { | ||
fn new() -> WasmMetrics { | ||
WasmMetrics { | ||
active_http_calls: Gauge::new(String::from("active_http_calls")), | ||
ratelimited_rq: Counter::new(String::from("ratelimited_rq")), | ||
} | ||
} | ||
} | ||
|
||
#[derive(Debug)] | ||
pub struct FilterCallContext {} | ||
|
||
#[derive(Debug)] | ||
pub struct LlmGatewayFilterContext { | ||
metrics: Rc<WasmMetrics>, | ||
// callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request. | ||
callouts: RefCell<HashMap<u32, FilterCallContext>>, | ||
llm_providers: Option<Rc<LlmProviders>>, | ||
} | ||
|
||
impl LlmGatewayFilterContext { | ||
pub fn new() -> LlmGatewayFilterContext { | ||
LlmGatewayFilterContext { | ||
callouts: RefCell::new(HashMap::new()), | ||
metrics: Rc::new(WasmMetrics::new()), | ||
llm_providers: None, | ||
} | ||
} | ||
} | ||
|
||
impl Client for LlmGatewayFilterContext { | ||
type CallContext = FilterCallContext; | ||
|
||
fn callouts(&self) -> &RefCell<HashMap<u32, Self::CallContext>> { | ||
&self.callouts | ||
} | ||
|
||
fn active_http_calls(&self) -> &Gauge { | ||
&self.metrics.active_http_calls | ||
} | ||
} | ||
|
||
impl Context for LlmGatewayFilterContext {} | ||
|
||
// RootContext allows the Rust code to reach into the Envoy Config | ||
impl RootContext for LlmGatewayFilterContext { | ||
fn on_configure(&mut self, _: usize) -> bool { | ||
let config_bytes = self | ||
.get_plugin_configuration() | ||
.expect("Arch config cannot be empty"); | ||
|
||
let config: Configuration = match serde_yaml::from_slice(&config_bytes) { | ||
Ok(config) => config, | ||
Err(err) => panic!("Invalid arch config \"{:?}\"", err), | ||
}; | ||
|
||
ratelimit::ratelimits(Some(config.ratelimits.unwrap_or_default())); | ||
|
||
match config.llm_providers.try_into() { | ||
Ok(llm_providers) => self.llm_providers = Some(Rc::new(llm_providers)), | ||
Err(err) => panic!("{err}"), | ||
} | ||
|
||
true | ||
} | ||
|
||
fn create_http_context(&self, context_id: u32) -> Option<Box<dyn HttpContext>> { | ||
debug!( | ||
"||| create_http_context called with context_id: {:?} |||", | ||
context_id | ||
); | ||
|
||
// No StreamContext can be created until the Embedding Store is fully initialized. | ||
Some(Box::new(LlmGatewayStreamContext::new( | ||
context_id, | ||
Rc::clone(&self.metrics), | ||
Rc::clone( | ||
self.llm_providers | ||
.as_ref() | ||
.expect("LLM Providers must exist when Streams are being created"), | ||
), | ||
))) | ||
} | ||
|
||
fn get_type(&self) -> Option<ContextType> { | ||
Some(ContextType::HttpContext) | ||
} | ||
} |
Oops, something went wrong.