Skip to content

Commit

Permalink
tests
Browse files Browse the repository at this point in the history
Signed-off-by: José Ulises Niño Rivera <[email protected]>
  • Loading branch information
junr03 committed Oct 3, 2024
1 parent a015fad commit 0f641d5
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 31 deletions.
28 changes: 15 additions & 13 deletions arch/src/stream_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ pub struct CallContext {
prompt_target_name: Option<String>,
request_body: ChatCompletionsRequest,
similarity_scores: Option<Vec<(String, f64)>>,
up_stream_cluster: Option<String>,
up_stream_cluster_path: Option<String>,
upstream_cluster: Option<String>,
upstream_cluster_path: Option<String>,
}

pub struct StreamContext {
Expand Down Expand Up @@ -615,8 +615,8 @@ impl StreamContext {
}
};

callout_context.up_stream_cluster = Some(endpoint.name);
callout_context.up_stream_cluster_path = Some(path);
callout_context.upstream_cluster = Some(endpoint.name);
callout_context.upstream_cluster_path = Some(path);
callout_context.response_handler_type = ResponseHandlerType::FunctionCall;
if self.callouts.insert(token_id, callout_context).is_some() {
panic!("duplicate token_id")
Expand All @@ -630,8 +630,8 @@ impl StreamContext {
if http_status.1 != StatusCode::OK.as_str() {
let error_msg = format!(
"Error in function call response: cluster: {}, path: {}, status code: {}",
callout_context.up_stream_cluster.unwrap(),
callout_context.up_stream_cluster_path.unwrap(),
callout_context.upstream_cluster.unwrap(),
callout_context.upstream_cluster_path.unwrap(),
http_status.1
);
return self.send_server_error(error_msg, Some(StatusCode::BAD_REQUEST));
Expand Down Expand Up @@ -801,15 +801,17 @@ impl StreamContext {
prompt_target_name: None,
request_body: callout_context.request_body,
similarity_scores: None,
up_stream_cluster: None,
up_stream_cluster_path: None,
upstream_cluster: None,
upstream_cluster_path: None,
};
if self.callouts.insert(token_id, call_context).is_some() {
panic!(
"duplicate token_id={} in embedding server requests",
token_id
)
}

self.metrics.active_http_calls.increment(1);
}

fn default_target_handler(&self, body: Vec<u8>, callout_context: CallContext) {
Expand Down Expand Up @@ -976,15 +978,15 @@ impl HttpContext for StreamContext {
.input_guards
.contains_key(&public_types::configuration::GuardType::Jailbreak);
if !prompt_guard_jailbreak_task {
info!("Input guards set but no prompt guards were found");
debug!("Missing input guard. Making inline call to retrieve");
let callout_context = CallContext {
response_handler_type: ResponseHandlerType::ArchGuard,
user_message: Some(user_message),
prompt_target_name: None,
request_body: deserialized_body,
similarity_scores: None,
up_stream_cluster: None,
up_stream_cluster_path: None,
upstream_cluster: None,
upstream_cluster_path: None,
};
self.get_embeddings(callout_context);
return Action::Pause;
Expand Down Expand Up @@ -1037,8 +1039,8 @@ impl HttpContext for StreamContext {
prompt_target_name: None,
request_body: deserialized_body,
similarity_scores: None,
up_stream_cluster: None,
up_stream_cluster_path: None,
upstream_cluster: None,
upstream_cluster_path: None,
};
if self.callouts.insert(token_id, call_context).is_some() {
panic!(
Expand Down
143 changes: 125 additions & 18 deletions arch/tests/integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use proxy_wasm_test_framework::types::{
};
use public_types::common_types::open_ai::{ChatCompletionsResponse, Choice, Message, Usage};
use public_types::common_types::open_ai::{FunctionCallDetail, ToolCall, ToolType};
use public_types::common_types::PromptGuardResponse;
use public_types::embeddings::embedding::Object;
use public_types::embeddings::{
create_embedding_response, CreateEmbeddingResponse, CreateEmbeddingResponseUsage, Embedding,
Expand Down Expand Up @@ -91,14 +92,66 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
.returning(Some(chat_completions_request_body))
// The actual call is not important in this test, we just need to grab the token_id
.expect_http_call(Some("model_server"), None, None, None, None)
.expect_http_call(
Some("model_server"),
Some(vec![
(":method", "POST"),
(":path", "/guard"),
(":authority", "model_server"),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
("x-envoy-upstream-rq-timeout-ms", "60000"),
]),
None,
None,
None,
)
.returning(Some(1))
.expect_log(Some(LogLevel::Debug), None)
.expect_metric_increment("active_http_calls", 1)
.expect_log(Some(LogLevel::Info), None)
.execute_and_expect(ReturnType::Action(Action::Pause))
.unwrap();

let prompt_guard_response = PromptGuardResponse {
toxic_prob: None,
toxic_verdict: None,
jailbreak_prob: None,
jailbreak_verdict: None,
};
let prompt_guard_response_buffer = serde_json::to_string(&prompt_guard_response).unwrap();
module
.call_proxy_on_http_call_response(
http_context,
1,
0,
prompt_guard_response_buffer.len() as i32,
0,
)
.expect_metric_increment("active_http_calls", -1)
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&prompt_guard_response_buffer))
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_http_call(
Some("model_server"),
Some(vec![
(":method", "POST"),
(":path", "/embeddings"),
(":authority", "model_server"),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
("x-envoy-upstream-rq-timeout-ms", "60000"),
]),
None,
None,
None,
)
.returning(Some(2))
.expect_metric_increment("active_http_calls", 1)
.expect_log(Some(LogLevel::Debug), None)
.execute_and_expect(ReturnType::None)
.unwrap();

let embedding_response = CreateEmbeddingResponse {
data: vec![Embedding {
index: 0,
Expand All @@ -113,7 +166,7 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
module
.call_proxy_on_http_call_response(
http_context,
1,
2,
0,
embeddings_response_buffer.len() as i32,
0,
Expand All @@ -123,8 +176,21 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
.returning(Some(&embeddings_response_buffer))
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_http_call(Some("model_server"), None, None, None, None)
.returning(Some(2))
.expect_http_call(
Some("model_server"),
Some(vec![
(":method", "POST"),
(":path", "/zeroshot"),
(":authority", "model_server"),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
("x-envoy-upstream-rq-timeout-ms", "60000"),
]),
None,
None,
None,
)
.returning(Some(3))
.expect_metric_increment("active_http_calls", 1)
.expect_log(Some(LogLevel::Debug), None)
.execute_and_expect(ReturnType::None)
Expand All @@ -140,7 +206,7 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
module
.call_proxy_on_http_call_response(
http_context,
2,
3,
0,
zeroshot_intent_detection_buffer.len() as i32,
0,
Expand All @@ -151,8 +217,21 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Info), None)
.expect_http_call(Some("arch_fc"), None, None, None, None)
.returning(Some(3))
.expect_http_call(
Some("arch_fc"),
Some(vec![
(":method", "POST"),
(":path", "/v1/chat/completions"),
(":authority", "arch_fc"),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
("x-envoy-upstream-rq-timeout-ms", "120000"),
]),
None,
None,
None,
)
.returning(Some(4))
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_metric_increment("active_http_calls", 1)
Expand Down Expand Up @@ -189,8 +268,13 @@ overrides:
system_prompt: |
You are a helpful assistant.
prompt_targets:
prompt_guards:
input_guards:
jailbreak:
on_exception:
message: "Looks like you're curious about my abilities, but I can only provide assistance within my programmed parameters."
prompt_targets:
- name: weather_forecast
description: This function provides realtime weather forecast information for a given city.
parameters:
Expand Down Expand Up @@ -308,7 +392,6 @@ fn successful_request_to_open_ai_chat_completions() {
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
.returning(Some(chat_completions_request_body))
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Info), None)
.expect_http_call(Some("model_server"), None, None, None, None)
.returning(Some(4))
.expect_metric_increment("active_http_calls", 1)
Expand Down Expand Up @@ -459,7 +542,7 @@ fn request_ratelimited() {

let arch_fc_resp_str = serde_json::to_string(&arch_fc_resp).unwrap();
module
.call_proxy_on_http_call_response(http_context, 3, 0, arch_fc_resp_str.len() as i32, 0)
.call_proxy_on_http_call_response(http_context, 4, 0, arch_fc_resp_str.len() as i32, 0)
.expect_metric_increment("active_http_calls", -1)
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&arch_fc_resp_str))
Expand All @@ -470,15 +553,27 @@ fn request_ratelimited() {
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_http_call(Some("api_server"), None, None, None, None)
.returning(Some(4))
.expect_http_call(
Some("api_server"),
Some(vec![
(":method", "POST"),
(":path", "/weather"),
(":authority", "api_server"),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
]),
None,
None,
None,
)
.returning(Some(5))
.expect_metric_increment("active_http_calls", 1)
.execute_and_expect(ReturnType::None)
.unwrap();

let body_text = String::from("test body");
module
.call_proxy_on_http_call_response(http_context, 4, 0, body_text.len() as i32, 0)
.call_proxy_on_http_call_response(http_context, 5, 0, body_text.len() as i32, 0)
.expect_metric_increment("active_http_calls", -1)
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&body_text))
Expand Down Expand Up @@ -573,7 +668,7 @@ fn request_not_ratelimited() {

let arch_fc_resp_str = serde_json::to_string(&arch_fc_resp).unwrap();
module
.call_proxy_on_http_call_response(http_context, 3, 0, arch_fc_resp_str.len() as i32, 0)
.call_proxy_on_http_call_response(http_context, 4, 0, arch_fc_resp_str.len() as i32, 0)
.expect_metric_increment("active_http_calls", -1)
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&arch_fc_resp_str))
Expand All @@ -584,15 +679,27 @@ fn request_not_ratelimited() {
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_http_call(Some("api_server"), None, None, None, None)
.returning(Some(4))
.expect_http_call(
Some("api_server"),
Some(vec![
(":method", "POST"),
(":path", "/weather"),
(":authority", "api_server"),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
]),
None,
None,
None,
)
.returning(Some(5))
.expect_metric_increment("active_http_calls", 1)
.execute_and_expect(ReturnType::None)
.unwrap();

let body_text = String::from("test body");
module
.call_proxy_on_http_call_response(http_context, 4, 0, body_text.len() as i32, 0)
.call_proxy_on_http_call_response(http_context, 5, 0, body_text.len() as i32, 0)
.expect_metric_increment("active_http_calls", -1)
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&body_text))
Expand Down

0 comments on commit 0f641d5

Please sign in to comment.