From af018e5fd8f4dbe1f4a37c360c4da6985d90b6e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Ulises=20Ni=C3=B1o=20Rivera?= Date: Thu, 3 Oct 2024 12:21:35 -0700 Subject: [PATCH] Remove optional PromptGuards from Stream Context (#113) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: José Ulises Niño Rivera --- arch/src/filter_context.rs | 7 +- arch/src/stream_context.rs | 61 ++++++---------- arch/tests/integration.rs | 143 ++++++++++++++++++++++++++++++++----- 3 files changed, 150 insertions(+), 61 deletions(-) diff --git a/arch/src/filter_context.rs b/arch/src/filter_context.rs index 853f5cdc..b4a21a7c 100644 --- a/arch/src/filter_context.rs +++ b/arch/src/filter_context.rs @@ -47,8 +47,7 @@ pub struct FilterContext { callouts: HashMap, overrides: Rc>, prompt_targets: Rc>>, - // This should be Option>, because StreamContext::new() should get an Rc not Option>. - prompt_guards: Rc>, + prompt_guards: Rc, llm_providers: Option>, } @@ -67,7 +66,7 @@ impl FilterContext { metrics: Rc::new(WasmMetrics::new()), prompt_targets: Rc::new(RwLock::new(HashMap::new())), overrides: Rc::new(None), - prompt_guards: Rc::new(Some(PromptGuards::default())), + prompt_guards: Rc::new(PromptGuards::default()), llm_providers: None, } } @@ -242,7 +241,7 @@ impl RootContext for FilterContext { ratelimit::ratelimits(config.ratelimits); if let Some(prompt_guards) = config.prompt_guards { - self.prompt_guards = Rc::new(Some(prompt_guards)) + self.prompt_guards = Rc::new(prompt_guards) } match config.llm_providers.try_into() { diff --git a/arch/src/stream_context.rs b/arch/src/stream_context.rs index e339a765..d67498b9 100644 --- a/arch/src/stream_context.rs +++ b/arch/src/stream_context.rs @@ -48,8 +48,8 @@ pub struct CallContext { prompt_target_name: Option, request_body: ChatCompletionsRequest, similarity_scores: Option>, - up_stream_cluster: Option, - up_stream_cluster_path: Option, + upstream_cluster: Option, + upstream_cluster_path: Option, } pub struct StreamContext { @@ -62,9 +62,9 @@ pub struct StreamContext { streaming_response: bool, response_tokens: usize, chat_completions_request: bool, + prompt_guards: Rc, llm_providers: Rc, llm_provider: Option>, - prompt_guards: Rc>, } impl StreamContext { @@ -72,7 +72,7 @@ impl StreamContext { context_id: u32, metrics: Rc, prompt_targets: Rc>>, - prompt_guards: Rc>, + prompt_guards: Rc, overrides: Rc>, llm_providers: Rc, ) -> Self { @@ -615,8 +615,8 @@ impl StreamContext { } }; - callout_context.up_stream_cluster = Some(endpoint.name); - callout_context.up_stream_cluster_path = Some(path); + callout_context.upstream_cluster = Some(endpoint.name); + callout_context.upstream_cluster_path = Some(path); callout_context.response_handler_type = ResponseHandlerType::FunctionCall; if self.callouts.insert(token_id, callout_context).is_some() { panic!("duplicate token_id") @@ -630,8 +630,8 @@ impl StreamContext { if http_status.1 != StatusCode::OK.as_str() { let error_msg = format!( "Error in function call response: cluster: {}, path: {}, status code: {}", - callout_context.up_stream_cluster.unwrap(), - callout_context.up_stream_cluster_path.unwrap(), + callout_context.upstream_cluster.unwrap(), + callout_context.upstream_cluster_path.unwrap(), http_status.1 ); return self.send_server_error(error_msg, Some(StatusCode::BAD_REQUEST)); @@ -741,9 +741,9 @@ impl StreamContext { if prompt_guard_resp.jailbreak_verdict.unwrap_or_default() { //TODO: handle other scenarios like forward to error target - let msg = (*self.prompt_guards) - .as_ref() - .and_then(|pg| pg.jailbreak_on_exception_message()) + let msg = self + .prompt_guards + .jailbreak_on_exception_message() .unwrap_or("Jailbreak detected. Please refrain from discussing jailbreaking."); return self.send_server_error(msg.to_string(), Some(StatusCode::BAD_REQUEST)); } @@ -801,8 +801,8 @@ impl StreamContext { prompt_target_name: None, request_body: callout_context.request_body, similarity_scores: None, - up_stream_cluster: None, - up_stream_cluster_path: None, + upstream_cluster: None, + upstream_cluster_path: None, }; if self.callouts.insert(token_id, call_context).is_some() { panic!( @@ -810,6 +810,8 @@ impl StreamContext { token_id ) } + + self.metrics.active_http_calls.increment(1); } fn default_target_handler(&self, body: Vec, callout_context: CallContext) { @@ -971,39 +973,20 @@ impl HttpContext for StreamContext { } }; - let prompt_guards = match self.prompt_guards.as_ref() { - Some(prompt_guards) => { - debug!("prompt guards: {:?}", prompt_guards); - prompt_guards - } - None => { - let callout_context = CallContext { - response_handler_type: ResponseHandlerType::ArchGuard, - user_message: Some(user_message), - prompt_target_name: None, - request_body: deserialized_body, - similarity_scores: None, - up_stream_cluster: None, - up_stream_cluster_path: None, - }; - self.get_embeddings(callout_context); - return Action::Pause; - } - }; - - let prompt_guard_jailbreak_task = prompt_guards + let prompt_guard_jailbreak_task = self + .prompt_guards .input_guards .contains_key(&public_types::configuration::GuardType::Jailbreak); if !prompt_guard_jailbreak_task { - info!("Input guards set but no prompt guards were found"); + debug!("Missing input guard. Making inline call to retrieve"); let callout_context = CallContext { response_handler_type: ResponseHandlerType::ArchGuard, user_message: Some(user_message), prompt_target_name: None, request_body: deserialized_body, similarity_scores: None, - up_stream_cluster: None, - up_stream_cluster_path: None, + upstream_cluster: None, + upstream_cluster_path: None, }; self.get_embeddings(callout_context); return Action::Pause; @@ -1056,8 +1039,8 @@ impl HttpContext for StreamContext { prompt_target_name: None, request_body: deserialized_body, similarity_scores: None, - up_stream_cluster: None, - up_stream_cluster_path: None, + upstream_cluster: None, + upstream_cluster_path: None, }; if self.callouts.insert(token_id, call_context).is_some() { panic!( diff --git a/arch/tests/integration.rs b/arch/tests/integration.rs index 1c7f6166..ca467734 100644 --- a/arch/tests/integration.rs +++ b/arch/tests/integration.rs @@ -5,6 +5,7 @@ use proxy_wasm_test_framework::types::{ }; use public_types::common_types::open_ai::{ChatCompletionsResponse, Choice, Message, Usage}; use public_types::common_types::open_ai::{FunctionCallDetail, ToolCall, ToolType}; +use public_types::common_types::PromptGuardResponse; use public_types::embeddings::embedding::Object; use public_types::embeddings::{ create_embedding_response, CreateEmbeddingResponse, CreateEmbeddingResponseUsage, Embedding, @@ -91,14 +92,66 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) { .expect_get_buffer_bytes(Some(BufferType::HttpRequestBody)) .returning(Some(chat_completions_request_body)) // The actual call is not important in this test, we just need to grab the token_id - .expect_http_call(Some("model_server"), None, None, None, None) + .expect_http_call( + Some("model_server"), + Some(vec![ + (":method", "POST"), + (":path", "/guard"), + (":authority", "model_server"), + ("content-type", "application/json"), + ("x-envoy-max-retries", "3"), + ("x-envoy-upstream-rq-timeout-ms", "60000"), + ]), + None, + None, + None, + ) .returning(Some(1)) .expect_log(Some(LogLevel::Debug), None) .expect_metric_increment("active_http_calls", 1) - .expect_log(Some(LogLevel::Info), None) .execute_and_expect(ReturnType::Action(Action::Pause)) .unwrap(); + let prompt_guard_response = PromptGuardResponse { + toxic_prob: None, + toxic_verdict: None, + jailbreak_prob: None, + jailbreak_verdict: None, + }; + let prompt_guard_response_buffer = serde_json::to_string(&prompt_guard_response).unwrap(); + module + .call_proxy_on_http_call_response( + http_context, + 1, + 0, + prompt_guard_response_buffer.len() as i32, + 0, + ) + .expect_metric_increment("active_http_calls", -1) + .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) + .returning(Some(&prompt_guard_response_buffer)) + .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Debug), None) + .expect_http_call( + Some("model_server"), + Some(vec![ + (":method", "POST"), + (":path", "/embeddings"), + (":authority", "model_server"), + ("content-type", "application/json"), + ("x-envoy-max-retries", "3"), + ("x-envoy-upstream-rq-timeout-ms", "60000"), + ]), + None, + None, + None, + ) + .returning(Some(2)) + .expect_metric_increment("active_http_calls", 1) + .expect_log(Some(LogLevel::Debug), None) + .execute_and_expect(ReturnType::None) + .unwrap(); + let embedding_response = CreateEmbeddingResponse { data: vec![Embedding { index: 0, @@ -113,7 +166,7 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) { module .call_proxy_on_http_call_response( http_context, - 1, + 2, 0, embeddings_response_buffer.len() as i32, 0, @@ -123,8 +176,21 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) { .returning(Some(&embeddings_response_buffer)) .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) - .expect_http_call(Some("model_server"), None, None, None, None) - .returning(Some(2)) + .expect_http_call( + Some("model_server"), + Some(vec![ + (":method", "POST"), + (":path", "/zeroshot"), + (":authority", "model_server"), + ("content-type", "application/json"), + ("x-envoy-max-retries", "3"), + ("x-envoy-upstream-rq-timeout-ms", "60000"), + ]), + None, + None, + None, + ) + .returning(Some(3)) .expect_metric_increment("active_http_calls", 1) .expect_log(Some(LogLevel::Debug), None) .execute_and_expect(ReturnType::None) @@ -140,7 +206,7 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) { module .call_proxy_on_http_call_response( http_context, - 2, + 3, 0, zeroshot_intent_detection_buffer.len() as i32, 0, @@ -151,8 +217,21 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) { .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Info), None) - .expect_http_call(Some("arch_fc"), None, None, None, None) - .returning(Some(3)) + .expect_http_call( + Some("arch_fc"), + Some(vec![ + (":method", "POST"), + (":path", "/v1/chat/completions"), + (":authority", "arch_fc"), + ("content-type", "application/json"), + ("x-envoy-max-retries", "3"), + ("x-envoy-upstream-rq-timeout-ms", "120000"), + ]), + None, + None, + None, + ) + .returning(Some(4)) .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) .expect_metric_increment("active_http_calls", 1) @@ -189,8 +268,13 @@ overrides: system_prompt: | You are a helpful assistant. -prompt_targets: +prompt_guards: + input_guards: + jailbreak: + on_exception: + message: "Looks like you're curious about my abilities, but I can only provide assistance within my programmed parameters." +prompt_targets: - name: weather_forecast description: This function provides realtime weather forecast information for a given city. parameters: @@ -308,7 +392,6 @@ fn successful_request_to_open_ai_chat_completions() { .expect_get_buffer_bytes(Some(BufferType::HttpRequestBody)) .returning(Some(chat_completions_request_body)) .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Info), None) .expect_http_call(Some("model_server"), None, None, None, None) .returning(Some(4)) .expect_metric_increment("active_http_calls", 1) @@ -459,7 +542,7 @@ fn request_ratelimited() { let arch_fc_resp_str = serde_json::to_string(&arch_fc_resp).unwrap(); module - .call_proxy_on_http_call_response(http_context, 3, 0, arch_fc_resp_str.len() as i32, 0) + .call_proxy_on_http_call_response(http_context, 4, 0, arch_fc_resp_str.len() as i32, 0) .expect_metric_increment("active_http_calls", -1) .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) .returning(Some(&arch_fc_resp_str)) @@ -470,15 +553,27 @@ fn request_ratelimited() { .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) - .expect_http_call(Some("api_server"), None, None, None, None) - .returning(Some(4)) + .expect_http_call( + Some("api_server"), + Some(vec![ + (":method", "POST"), + (":path", "/weather"), + (":authority", "api_server"), + ("content-type", "application/json"), + ("x-envoy-max-retries", "3"), + ]), + None, + None, + None, + ) + .returning(Some(5)) .expect_metric_increment("active_http_calls", 1) .execute_and_expect(ReturnType::None) .unwrap(); let body_text = String::from("test body"); module - .call_proxy_on_http_call_response(http_context, 4, 0, body_text.len() as i32, 0) + .call_proxy_on_http_call_response(http_context, 5, 0, body_text.len() as i32, 0) .expect_metric_increment("active_http_calls", -1) .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) .returning(Some(&body_text)) @@ -573,7 +668,7 @@ fn request_not_ratelimited() { let arch_fc_resp_str = serde_json::to_string(&arch_fc_resp).unwrap(); module - .call_proxy_on_http_call_response(http_context, 3, 0, arch_fc_resp_str.len() as i32, 0) + .call_proxy_on_http_call_response(http_context, 4, 0, arch_fc_resp_str.len() as i32, 0) .expect_metric_increment("active_http_calls", -1) .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) .returning(Some(&arch_fc_resp_str)) @@ -584,15 +679,27 @@ fn request_not_ratelimited() { .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) - .expect_http_call(Some("api_server"), None, None, None, None) - .returning(Some(4)) + .expect_http_call( + Some("api_server"), + Some(vec![ + (":method", "POST"), + (":path", "/weather"), + (":authority", "api_server"), + ("content-type", "application/json"), + ("x-envoy-max-retries", "3"), + ]), + None, + None, + None, + ) + .returning(Some(5)) .expect_metric_increment("active_http_calls", 1) .execute_and_expect(ReturnType::None) .unwrap(); let body_text = String::from("test body"); module - .call_proxy_on_http_call_response(http_context, 4, 0, body_text.len() as i32, 0) + .call_proxy_on_http_call_response(http_context, 5, 0, body_text.len() as i32, 0) .expect_metric_increment("active_http_calls", -1) .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) .returning(Some(&body_text))