diff --git a/arch/src/filter_context.rs b/arch/src/filter_context.rs index 20bf1849..89a2c3df 100644 --- a/arch/src/filter_context.rs +++ b/arch/src/filter_context.rs @@ -128,13 +128,15 @@ impl FilterContext { embedding_type: EmbeddingType, prompt_target_name: String, ) { - let prompt_target = self.prompt_targets.get(&prompt_target_name).expect( - format!( - "Received embeddings response for unknown prompt target name={}", - prompt_target_name - ) - .as_str(), - ); + let prompt_target = self + .prompt_targets + .get(&prompt_target_name) + .unwrap_or_else(|| { + panic!( + "Received embeddings response for unknown prompt target name={}", + prompt_target_name + ) + }); let body = self .get_http_call_response_body(0, body_size) @@ -153,7 +155,7 @@ impl FilterContext { }; let embeddings = embedding_response.data.remove(0).embedding; - log::info!( + debug!( "Adding embeddings for prompt target name: {:?}, description: {:?}, embedding type: {:?}", prompt_target.name, prompt_target.description, @@ -164,13 +166,13 @@ impl FilterContext { match entry { Entry::Occupied(_) => { entry.and_modify(|e| { - if e.contains_key(&embedding_type) { + if let Entry::Vacant(e) = e.entry(embedding_type) { + e.insert(embeddings); + } else { panic!( "Duplicate {:?} for prompt target with name=\"{}\"", &embedding_type, prompt_target.name ) - } else { - e.insert(embedding_type, embeddings); } }); } @@ -246,7 +248,6 @@ impl RootContext for FilterContext { prompt_targets.insert(pt.name.clone(), pt.clone()); } self.prompt_targets = Rc::new(prompt_targets); - debug!("Setting prompt target config value"); ratelimit::ratelimits(config.ratelimits); @@ -269,9 +270,7 @@ impl RootContext for FilterContext { ); // No StreamContext can be created until the Embedding Store is fully initialized. - if self.embeddings_store.is_none() { - return None; - } + self.embeddings_store.as_ref()?; Some(Box::new(StreamContext::new( context_id, @@ -285,8 +284,7 @@ impl RootContext for FilterContext { .expect("LLM Providers must exist when Streams are being created"), ), Rc::clone( - &self - .embeddings_store + self.embeddings_store .as_ref() .expect("Embeddings Store must exist when StreamContext is being constructed"), ), diff --git a/arch/src/http.rs b/arch/src/http.rs index c5ed179a..eeae74d3 100644 --- a/arch/src/http.rs +++ b/arch/src/http.rs @@ -1,6 +1,7 @@ use crate::stats::{Gauge, IncrementingMetric}; +use log::debug; use proxy_wasm::traits::Context; -use std::{cell::RefCell, collections::HashMap, time::Duration}; +use std::{cell::RefCell, collections::HashMap, fmt::Debug, time::Duration}; #[derive(Debug)] pub struct CallArgs<'a> { @@ -30,9 +31,13 @@ impl<'a> CallArgs<'a> { } pub trait Client: Context { - type CallContext; + type CallContext: Debug; fn http_call(&self, call_args: CallArgs, call_context: Self::CallContext) { + debug!( + "dispatching http call with args={:?} context={:?}", + call_args, call_context + ); let id = self .dispatch_http_call( call_args.upstream, @@ -54,7 +59,7 @@ pub trait Client: Context { fn add_call_context(&self, id: u32, call_context: Self::CallContext) { let callouts = self.callouts(); - if let Some(_) = callouts.borrow_mut().insert(id, call_context) { + if callouts.borrow_mut().insert(id, call_context).is_some() { panic!("Duplicate http call with id={}", id); } self.active_http_calls().increment(1); diff --git a/arch/src/llm_providers.rs b/arch/src/llm_providers.rs index 75d57817..65cd0d04 100644 --- a/arch/src/llm_providers.rs +++ b/arch/src/llm_providers.rs @@ -18,7 +18,7 @@ impl LlmProviders { } pub fn get(&self, name: &str) -> Option> { - self.providers.get(name).map(|rc| rc.clone()) + self.providers.get(name).cloned() } } diff --git a/arch/tests/integration.rs b/arch/tests/integration.rs index ca467734..6a1fc3be 100644 --- a/arch/tests/integration.rs +++ b/arch/tests/integration.rs @@ -6,9 +6,9 @@ use proxy_wasm_test_framework::types::{ use public_types::common_types::open_ai::{ChatCompletionsResponse, Choice, Message, Usage}; use public_types::common_types::open_ai::{FunctionCallDetail, ToolCall, ToolType}; use public_types::common_types::PromptGuardResponse; -use public_types::embeddings::embedding::Object; use public_types::embeddings::{ - create_embedding_response, CreateEmbeddingResponse, CreateEmbeddingResponseUsage, Embedding, + create_embedding_response, embedding, CreateEmbeddingResponse, CreateEmbeddingResponseUsage, + Embedding, }; use public_types::{common_types::ZeroShotClassificationResponse, configuration::Configuration}; use serde_yaml::Value; @@ -156,7 +156,7 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) { data: vec![Embedding { index: 0, embedding: vec![], - object: Object::default(), + object: embedding::Object::default(), }], model: String::from("test"), object: create_embedding_response::Object::default(), @@ -239,8 +239,130 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) { .unwrap(); } -fn default_config() -> Configuration { - let config: &str = r#" +fn setup_filter(module: &mut Tester, config: &str) -> i32 { + let filter_context = 1; + + module + .call_proxy_on_context_create(filter_context, 0) + .expect_metric_creation(MetricType::Gauge, "active_http_calls") + .expect_metric_creation(MetricType::Counter, "ratelimited_rq") + .execute_and_expect(ReturnType::None) + .unwrap(); + + module + .call_proxy_on_configure(filter_context, config.len() as i32) + .expect_get_buffer_bytes(Some(BufferType::PluginConfiguration)) + .returning(Some(&config)) + .execute_and_expect(ReturnType::Bool(true)) + .unwrap(); + + module + .call_proxy_on_tick(filter_context) + .expect_log(Some(LogLevel::Debug), None) + .expect_http_call( + Some("model_server"), + Some(vec![ + (":method", "POST"), + (":path", "/embeddings"), + (":authority", "model_server"), + ("content-type", "application/json"), + ("x-envoy-upstream-rq-timeout-ms", "60000"), + ]), + None, + None, + None, + ) + .returning(Some(101)) + .expect_metric_increment("active_http_calls", 1) + .expect_log(Some(LogLevel::Debug), None) + .expect_http_call( + Some("model_server"), + Some(vec![ + (":method", "POST"), + (":path", "/embeddings"), + (":authority", "model_server"), + ("content-type", "application/json"), + ("x-envoy-upstream-rq-timeout-ms", "60000"), + ]), + None, + None, + None, + ) + .returning(Some(102)) + .expect_metric_increment("active_http_calls", 1) + .expect_set_tick_period_millis(Some(0)) + .execute_and_expect(ReturnType::None) + .unwrap(); + + let embedding_response = CreateEmbeddingResponse { + data: vec![Embedding { + embedding: vec![], + index: 0, + object: embedding::Object::default(), + }], + model: String::from("test"), + object: create_embedding_response::Object::default(), + usage: Box::new(CreateEmbeddingResponseUsage { + prompt_tokens: 0, + total_tokens: 0, + }), + }; + let embedding_response_str = serde_json::to_string(&embedding_response).unwrap(); + module + .call_proxy_on_http_call_response( + filter_context, + 101, + 0, + embedding_response_str.len() as i32, + 0, + ) + .expect_log( + Some(LogLevel::Debug), + Some( + format!( + "filter_context: on_http_call_response called with token_id: {:?}", + 101 + ) + .as_str(), + ), + ) + .expect_metric_increment("active_http_calls", -1) + .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) + .returning(Some(&embedding_response_str)) + .expect_log(Some(LogLevel::Debug), None) + .execute_and_expect(ReturnType::None) + .unwrap(); + + module + .call_proxy_on_http_call_response( + filter_context, + 102, + 0, + embedding_response_str.len() as i32, + 0, + ) + .expect_log( + Some(LogLevel::Debug), + Some( + format!( + "filter_context: on_http_call_response called with token_id: {:?}", + 102 + ) + .as_str(), + ), + ) + .expect_metric_increment("active_http_calls", -1) + .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) + .returning(Some(&embedding_response_str)) + .expect_log(Some(LogLevel::Debug), None) + .execute_and_expect(ReturnType::None) + .unwrap(); + + filter_context +} + +fn default_config() -> &'static str { + r#" version: "0.1-beta" listener: @@ -293,24 +415,6 @@ prompt_targets: - Use farenheight for temperature - Use miles per hour for wind speed - - name: insurance_claim_details - type: function_resolver - description: This function resolver provides insurance claim details for a given policy number. - parameters: - - name: policy_number - required: true - description: The policy number for which the insurance claim details are requested. - type: string - - name: include_expired - description: whether to include expired insurance claims in the response. - type: bool - required: true - endpoint: - name: api_server - path: /insurance_claim_details - system_prompt: | - You are a helpful insurance claim details provider. Use insurance claim data that is provided to you. Please following following guidelines when responding to user queries: - - Use policy number to retrieve insurance claim details ratelimits: - model: gpt-4 selector: @@ -319,8 +423,7 @@ ratelimits: limit: tokens: 1 unit: minute -"#; - serde_yaml::from_str(config).unwrap() +"# } #[test] @@ -339,22 +442,7 @@ fn successful_request_to_open_ai_chat_completions() { .unwrap(); // Setup Filter - let filter_context = 1; - let config = serde_json::to_string(&default_config()).unwrap(); - - module - .call_proxy_on_context_create(filter_context, 0) - .expect_metric_creation(MetricType::Gauge, "active_http_calls") - .expect_metric_creation(MetricType::Counter, "ratelimited_rq") - .execute_and_expect(ReturnType::None) - .unwrap(); - - module - .call_proxy_on_configure(filter_context, config.len() as i32) - .expect_get_buffer_bytes(Some(BufferType::PluginConfiguration)) - .returning(Some(&config)) - .execute_and_expect(ReturnType::Bool(true)) - .unwrap(); + let filter_context = setup_filter(&mut module, default_config()); // Setup HTTP Stream let http_context = 2; @@ -415,22 +503,7 @@ fn bad_request_to_open_ai_chat_completions() { .unwrap(); // Setup Filter - let filter_context = 1; - let config = serde_json::to_string(&default_config()).unwrap(); - - module - .call_proxy_on_context_create(filter_context, 0) - .expect_metric_creation(MetricType::Gauge, "active_http_calls") - .expect_metric_creation(MetricType::Counter, "ratelimited_rq") - .execute_and_expect(ReturnType::None) - .unwrap(); - - module - .call_proxy_on_configure(filter_context, config.len() as i32) - .expect_get_buffer_bytes(Some(BufferType::PluginConfiguration)) - .returning(Some(&config)) - .execute_and_expect(ReturnType::Bool(true)) - .unwrap(); + let filter_context = setup_filter(&mut module, default_config()); // Setup HTTP Stream let http_context = 2; @@ -492,21 +565,7 @@ fn request_ratelimited() { .unwrap(); // Setup Filter - let filter_context = 1; - let config = serde_json::to_string(&default_config()).unwrap(); - - module - .call_proxy_on_context_create(filter_context, 0) - .expect_metric_creation(MetricType::Gauge, "active_http_calls") - .expect_metric_creation(MetricType::Counter, "ratelimited_rq") - .execute_and_expect(ReturnType::None) - .unwrap(); - module - .call_proxy_on_configure(filter_context, config.len() as i32) - .expect_get_buffer_bytes(Some(BufferType::PluginConfiguration)) - .returning(Some(&config)) - .execute_and_expect(ReturnType::Bool(true)) - .unwrap(); + let filter_context = setup_filter(&mut module, default_config()); // Setup HTTP Stream let http_context = 2; @@ -615,24 +674,11 @@ fn request_not_ratelimited() { .unwrap(); // Setup Filter - let filter_context = 1; - - let mut config = default_config(); + let mut config: Configuration = serde_yaml::from_str(default_config()).unwrap(); config.ratelimits.as_mut().unwrap()[0].limit.tokens += 1000; let config_str = serde_json::to_string(&config).unwrap(); - module - .call_proxy_on_context_create(filter_context, 0) - .expect_metric_creation(MetricType::Gauge, "active_http_calls") - .expect_metric_creation(MetricType::Counter, "ratelimited_rq") - .execute_and_expect(ReturnType::None) - .unwrap(); - module - .call_proxy_on_configure(filter_context, config_str.len() as i32) - .expect_get_buffer_bytes(Some(BufferType::PluginConfiguration)) - .returning(Some(&config_str)) - .execute_and_expect(ReturnType::Bool(true)) - .unwrap(); + let filter_context = setup_filter(&mut module, &config_str); // Setup HTTP Stream let http_context = 2; diff --git a/public_types/src/common_types.rs b/public_types/src/common_types.rs index 9b3e3968..5b6bd794 100644 --- a/public_types/src/common_types.rs +++ b/public_types/src/common_types.rs @@ -7,7 +7,7 @@ pub struct EmbeddingRequest { pub prompt_target: PromptTarget, } -#[derive(Debug, Clone, Hash, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] pub enum EmbeddingType { Name, Description,