Skip to content

Commit

Permalink
Remove optional PromptGuards from Stream Context (#113)
Browse files Browse the repository at this point in the history
Signed-off-by: José Ulises Niño Rivera <[email protected]>
  • Loading branch information
junr03 authored Oct 3, 2024
1 parent 8ea917a commit af018e5
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 61 deletions.
7 changes: 3 additions & 4 deletions arch/src/filter_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,7 @@ pub struct FilterContext {
callouts: HashMap<u32, CallContext>,
overrides: Rc<Option<Overrides>>,
prompt_targets: Rc<RwLock<HashMap<String, PromptTarget>>>,
// This should be Option<Rc<PromptGuards>>, because StreamContext::new() should get an Rc<PromptGuards> not Option<Rc<PromptGuards>>.
prompt_guards: Rc<Option<PromptGuards>>,
prompt_guards: Rc<PromptGuards>,
llm_providers: Option<Rc<LlmProviders>>,
}

Expand All @@ -67,7 +66,7 @@ impl FilterContext {
metrics: Rc::new(WasmMetrics::new()),
prompt_targets: Rc::new(RwLock::new(HashMap::new())),
overrides: Rc::new(None),
prompt_guards: Rc::new(Some(PromptGuards::default())),
prompt_guards: Rc::new(PromptGuards::default()),
llm_providers: None,
}
}
Expand Down Expand Up @@ -242,7 +241,7 @@ impl RootContext for FilterContext {
ratelimit::ratelimits(config.ratelimits);

if let Some(prompt_guards) = config.prompt_guards {
self.prompt_guards = Rc::new(Some(prompt_guards))
self.prompt_guards = Rc::new(prompt_guards)
}

match config.llm_providers.try_into() {
Expand Down
61 changes: 22 additions & 39 deletions arch/src/stream_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ pub struct CallContext {
prompt_target_name: Option<String>,
request_body: ChatCompletionsRequest,
similarity_scores: Option<Vec<(String, f64)>>,
up_stream_cluster: Option<String>,
up_stream_cluster_path: Option<String>,
upstream_cluster: Option<String>,
upstream_cluster_path: Option<String>,
}

pub struct StreamContext {
Expand All @@ -62,17 +62,17 @@ pub struct StreamContext {
streaming_response: bool,
response_tokens: usize,
chat_completions_request: bool,
prompt_guards: Rc<PromptGuards>,
llm_providers: Rc<LlmProviders>,
llm_provider: Option<Rc<LlmProvider>>,
prompt_guards: Rc<Option<PromptGuards>>,
}

impl StreamContext {
pub fn new(
context_id: u32,
metrics: Rc<WasmMetrics>,
prompt_targets: Rc<RwLock<HashMap<String, PromptTarget>>>,
prompt_guards: Rc<Option<PromptGuards>>,
prompt_guards: Rc<PromptGuards>,
overrides: Rc<Option<Overrides>>,
llm_providers: Rc<LlmProviders>,
) -> Self {
Expand Down Expand Up @@ -615,8 +615,8 @@ impl StreamContext {
}
};

callout_context.up_stream_cluster = Some(endpoint.name);
callout_context.up_stream_cluster_path = Some(path);
callout_context.upstream_cluster = Some(endpoint.name);
callout_context.upstream_cluster_path = Some(path);
callout_context.response_handler_type = ResponseHandlerType::FunctionCall;
if self.callouts.insert(token_id, callout_context).is_some() {
panic!("duplicate token_id")
Expand All @@ -630,8 +630,8 @@ impl StreamContext {
if http_status.1 != StatusCode::OK.as_str() {
let error_msg = format!(
"Error in function call response: cluster: {}, path: {}, status code: {}",
callout_context.up_stream_cluster.unwrap(),
callout_context.up_stream_cluster_path.unwrap(),
callout_context.upstream_cluster.unwrap(),
callout_context.upstream_cluster_path.unwrap(),
http_status.1
);
return self.send_server_error(error_msg, Some(StatusCode::BAD_REQUEST));
Expand Down Expand Up @@ -741,9 +741,9 @@ impl StreamContext {

if prompt_guard_resp.jailbreak_verdict.unwrap_or_default() {
//TODO: handle other scenarios like forward to error target
let msg = (*self.prompt_guards)
.as_ref()
.and_then(|pg| pg.jailbreak_on_exception_message())
let msg = self
.prompt_guards
.jailbreak_on_exception_message()
.unwrap_or("Jailbreak detected. Please refrain from discussing jailbreaking.");
return self.send_server_error(msg.to_string(), Some(StatusCode::BAD_REQUEST));
}
Expand Down Expand Up @@ -801,15 +801,17 @@ impl StreamContext {
prompt_target_name: None,
request_body: callout_context.request_body,
similarity_scores: None,
up_stream_cluster: None,
up_stream_cluster_path: None,
upstream_cluster: None,
upstream_cluster_path: None,
};
if self.callouts.insert(token_id, call_context).is_some() {
panic!(
"duplicate token_id={} in embedding server requests",
token_id
)
}

self.metrics.active_http_calls.increment(1);
}

fn default_target_handler(&self, body: Vec<u8>, callout_context: CallContext) {
Expand Down Expand Up @@ -971,39 +973,20 @@ impl HttpContext for StreamContext {
}
};

let prompt_guards = match self.prompt_guards.as_ref() {
Some(prompt_guards) => {
debug!("prompt guards: {:?}", prompt_guards);
prompt_guards
}
None => {
let callout_context = CallContext {
response_handler_type: ResponseHandlerType::ArchGuard,
user_message: Some(user_message),
prompt_target_name: None,
request_body: deserialized_body,
similarity_scores: None,
up_stream_cluster: None,
up_stream_cluster_path: None,
};
self.get_embeddings(callout_context);
return Action::Pause;
}
};

let prompt_guard_jailbreak_task = prompt_guards
let prompt_guard_jailbreak_task = self
.prompt_guards
.input_guards
.contains_key(&public_types::configuration::GuardType::Jailbreak);
if !prompt_guard_jailbreak_task {
info!("Input guards set but no prompt guards were found");
debug!("Missing input guard. Making inline call to retrieve");
let callout_context = CallContext {
response_handler_type: ResponseHandlerType::ArchGuard,
user_message: Some(user_message),
prompt_target_name: None,
request_body: deserialized_body,
similarity_scores: None,
up_stream_cluster: None,
up_stream_cluster_path: None,
upstream_cluster: None,
upstream_cluster_path: None,
};
self.get_embeddings(callout_context);
return Action::Pause;
Expand Down Expand Up @@ -1056,8 +1039,8 @@ impl HttpContext for StreamContext {
prompt_target_name: None,
request_body: deserialized_body,
similarity_scores: None,
up_stream_cluster: None,
up_stream_cluster_path: None,
upstream_cluster: None,
upstream_cluster_path: None,
};
if self.callouts.insert(token_id, call_context).is_some() {
panic!(
Expand Down
Loading

0 comments on commit af018e5

Please sign in to comment.