diff --git a/arch/src/stream_context.rs b/arch/src/stream_context.rs index 8853b6a4..d67498b9 100644 --- a/arch/src/stream_context.rs +++ b/arch/src/stream_context.rs @@ -48,8 +48,8 @@ pub struct CallContext { prompt_target_name: Option, request_body: ChatCompletionsRequest, similarity_scores: Option>, - up_stream_cluster: Option, - up_stream_cluster_path: Option, + upstream_cluster: Option, + upstream_cluster_path: Option, } pub struct StreamContext { @@ -615,8 +615,8 @@ impl StreamContext { } }; - callout_context.up_stream_cluster = Some(endpoint.name); - callout_context.up_stream_cluster_path = Some(path); + callout_context.upstream_cluster = Some(endpoint.name); + callout_context.upstream_cluster_path = Some(path); callout_context.response_handler_type = ResponseHandlerType::FunctionCall; if self.callouts.insert(token_id, callout_context).is_some() { panic!("duplicate token_id") @@ -630,8 +630,8 @@ impl StreamContext { if http_status.1 != StatusCode::OK.as_str() { let error_msg = format!( "Error in function call response: cluster: {}, path: {}, status code: {}", - callout_context.up_stream_cluster.unwrap(), - callout_context.up_stream_cluster_path.unwrap(), + callout_context.upstream_cluster.unwrap(), + callout_context.upstream_cluster_path.unwrap(), http_status.1 ); return self.send_server_error(error_msg, Some(StatusCode::BAD_REQUEST)); @@ -801,8 +801,8 @@ impl StreamContext { prompt_target_name: None, request_body: callout_context.request_body, similarity_scores: None, - up_stream_cluster: None, - up_stream_cluster_path: None, + upstream_cluster: None, + upstream_cluster_path: None, }; if self.callouts.insert(token_id, call_context).is_some() { panic!( @@ -810,6 +810,8 @@ impl StreamContext { token_id ) } + + self.metrics.active_http_calls.increment(1); } fn default_target_handler(&self, body: Vec, callout_context: CallContext) { @@ -976,15 +978,15 @@ impl HttpContext for StreamContext { .input_guards .contains_key(&public_types::configuration::GuardType::Jailbreak); if !prompt_guard_jailbreak_task { - info!("Input guards set but no prompt guards were found"); + debug!("Missing input guard. Making inline call to retrieve"); let callout_context = CallContext { response_handler_type: ResponseHandlerType::ArchGuard, user_message: Some(user_message), prompt_target_name: None, request_body: deserialized_body, similarity_scores: None, - up_stream_cluster: None, - up_stream_cluster_path: None, + upstream_cluster: None, + upstream_cluster_path: None, }; self.get_embeddings(callout_context); return Action::Pause; @@ -1037,8 +1039,8 @@ impl HttpContext for StreamContext { prompt_target_name: None, request_body: deserialized_body, similarity_scores: None, - up_stream_cluster: None, - up_stream_cluster_path: None, + upstream_cluster: None, + upstream_cluster_path: None, }; if self.callouts.insert(token_id, call_context).is_some() { panic!( diff --git a/arch/tests/integration.rs b/arch/tests/integration.rs index 1c7f6166..ca467734 100644 --- a/arch/tests/integration.rs +++ b/arch/tests/integration.rs @@ -5,6 +5,7 @@ use proxy_wasm_test_framework::types::{ }; use public_types::common_types::open_ai::{ChatCompletionsResponse, Choice, Message, Usage}; use public_types::common_types::open_ai::{FunctionCallDetail, ToolCall, ToolType}; +use public_types::common_types::PromptGuardResponse; use public_types::embeddings::embedding::Object; use public_types::embeddings::{ create_embedding_response, CreateEmbeddingResponse, CreateEmbeddingResponseUsage, Embedding, @@ -91,14 +92,66 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) { .expect_get_buffer_bytes(Some(BufferType::HttpRequestBody)) .returning(Some(chat_completions_request_body)) // The actual call is not important in this test, we just need to grab the token_id - .expect_http_call(Some("model_server"), None, None, None, None) + .expect_http_call( + Some("model_server"), + Some(vec![ + (":method", "POST"), + (":path", "/guard"), + (":authority", "model_server"), + ("content-type", "application/json"), + ("x-envoy-max-retries", "3"), + ("x-envoy-upstream-rq-timeout-ms", "60000"), + ]), + None, + None, + None, + ) .returning(Some(1)) .expect_log(Some(LogLevel::Debug), None) .expect_metric_increment("active_http_calls", 1) - .expect_log(Some(LogLevel::Info), None) .execute_and_expect(ReturnType::Action(Action::Pause)) .unwrap(); + let prompt_guard_response = PromptGuardResponse { + toxic_prob: None, + toxic_verdict: None, + jailbreak_prob: None, + jailbreak_verdict: None, + }; + let prompt_guard_response_buffer = serde_json::to_string(&prompt_guard_response).unwrap(); + module + .call_proxy_on_http_call_response( + http_context, + 1, + 0, + prompt_guard_response_buffer.len() as i32, + 0, + ) + .expect_metric_increment("active_http_calls", -1) + .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) + .returning(Some(&prompt_guard_response_buffer)) + .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Debug), None) + .expect_http_call( + Some("model_server"), + Some(vec![ + (":method", "POST"), + (":path", "/embeddings"), + (":authority", "model_server"), + ("content-type", "application/json"), + ("x-envoy-max-retries", "3"), + ("x-envoy-upstream-rq-timeout-ms", "60000"), + ]), + None, + None, + None, + ) + .returning(Some(2)) + .expect_metric_increment("active_http_calls", 1) + .expect_log(Some(LogLevel::Debug), None) + .execute_and_expect(ReturnType::None) + .unwrap(); + let embedding_response = CreateEmbeddingResponse { data: vec![Embedding { index: 0, @@ -113,7 +166,7 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) { module .call_proxy_on_http_call_response( http_context, - 1, + 2, 0, embeddings_response_buffer.len() as i32, 0, @@ -123,8 +176,21 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) { .returning(Some(&embeddings_response_buffer)) .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) - .expect_http_call(Some("model_server"), None, None, None, None) - .returning(Some(2)) + .expect_http_call( + Some("model_server"), + Some(vec![ + (":method", "POST"), + (":path", "/zeroshot"), + (":authority", "model_server"), + ("content-type", "application/json"), + ("x-envoy-max-retries", "3"), + ("x-envoy-upstream-rq-timeout-ms", "60000"), + ]), + None, + None, + None, + ) + .returning(Some(3)) .expect_metric_increment("active_http_calls", 1) .expect_log(Some(LogLevel::Debug), None) .execute_and_expect(ReturnType::None) @@ -140,7 +206,7 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) { module .call_proxy_on_http_call_response( http_context, - 2, + 3, 0, zeroshot_intent_detection_buffer.len() as i32, 0, @@ -151,8 +217,21 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) { .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Info), None) - .expect_http_call(Some("arch_fc"), None, None, None, None) - .returning(Some(3)) + .expect_http_call( + Some("arch_fc"), + Some(vec![ + (":method", "POST"), + (":path", "/v1/chat/completions"), + (":authority", "arch_fc"), + ("content-type", "application/json"), + ("x-envoy-max-retries", "3"), + ("x-envoy-upstream-rq-timeout-ms", "120000"), + ]), + None, + None, + None, + ) + .returning(Some(4)) .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) .expect_metric_increment("active_http_calls", 1) @@ -189,8 +268,13 @@ overrides: system_prompt: | You are a helpful assistant. -prompt_targets: +prompt_guards: + input_guards: + jailbreak: + on_exception: + message: "Looks like you're curious about my abilities, but I can only provide assistance within my programmed parameters." +prompt_targets: - name: weather_forecast description: This function provides realtime weather forecast information for a given city. parameters: @@ -308,7 +392,6 @@ fn successful_request_to_open_ai_chat_completions() { .expect_get_buffer_bytes(Some(BufferType::HttpRequestBody)) .returning(Some(chat_completions_request_body)) .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Info), None) .expect_http_call(Some("model_server"), None, None, None, None) .returning(Some(4)) .expect_metric_increment("active_http_calls", 1) @@ -459,7 +542,7 @@ fn request_ratelimited() { let arch_fc_resp_str = serde_json::to_string(&arch_fc_resp).unwrap(); module - .call_proxy_on_http_call_response(http_context, 3, 0, arch_fc_resp_str.len() as i32, 0) + .call_proxy_on_http_call_response(http_context, 4, 0, arch_fc_resp_str.len() as i32, 0) .expect_metric_increment("active_http_calls", -1) .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) .returning(Some(&arch_fc_resp_str)) @@ -470,15 +553,27 @@ fn request_ratelimited() { .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) - .expect_http_call(Some("api_server"), None, None, None, None) - .returning(Some(4)) + .expect_http_call( + Some("api_server"), + Some(vec![ + (":method", "POST"), + (":path", "/weather"), + (":authority", "api_server"), + ("content-type", "application/json"), + ("x-envoy-max-retries", "3"), + ]), + None, + None, + None, + ) + .returning(Some(5)) .expect_metric_increment("active_http_calls", 1) .execute_and_expect(ReturnType::None) .unwrap(); let body_text = String::from("test body"); module - .call_proxy_on_http_call_response(http_context, 4, 0, body_text.len() as i32, 0) + .call_proxy_on_http_call_response(http_context, 5, 0, body_text.len() as i32, 0) .expect_metric_increment("active_http_calls", -1) .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) .returning(Some(&body_text)) @@ -573,7 +668,7 @@ fn request_not_ratelimited() { let arch_fc_resp_str = serde_json::to_string(&arch_fc_resp).unwrap(); module - .call_proxy_on_http_call_response(http_context, 3, 0, arch_fc_resp_str.len() as i32, 0) + .call_proxy_on_http_call_response(http_context, 4, 0, arch_fc_resp_str.len() as i32, 0) .expect_metric_increment("active_http_calls", -1) .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) .returning(Some(&arch_fc_resp_str)) @@ -584,15 +679,27 @@ fn request_not_ratelimited() { .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) - .expect_http_call(Some("api_server"), None, None, None, None) - .returning(Some(4)) + .expect_http_call( + Some("api_server"), + Some(vec![ + (":method", "POST"), + (":path", "/weather"), + (":authority", "api_server"), + ("content-type", "application/json"), + ("x-envoy-max-retries", "3"), + ]), + None, + None, + None, + ) + .returning(Some(5)) .expect_metric_increment("active_http_calls", 1) .execute_and_expect(ReturnType::None) .unwrap(); let body_text = String::from("test body"); module - .call_proxy_on_http_call_response(http_context, 4, 0, body_text.len() as i32, 0) + .call_proxy_on_http_call_response(http_context, 5, 0, body_text.len() as i32, 0) .expect_metric_increment("active_http_calls", -1) .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) .returning(Some(&body_text))