Skip to content

Commit

Permalink
add system prompt (#138)
Browse files Browse the repository at this point in the history
  • Loading branch information
adilhafeez authored Oct 8, 2024
1 parent c1cfbcd commit 422efd3
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 23 deletions.
4 changes: 4 additions & 0 deletions arch/src/filter_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ pub struct FilterContext {
// callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request.
callouts: RefCell<HashMap<u32, FilterCallContext>>,
overrides: Rc<Option<Overrides>>,
system_prompt: Rc<Option<String>>,
prompt_targets: Rc<HashMap<String, PromptTarget>>,
prompt_guards: Rc<PromptGuards>,
llm_providers: Option<Rc<LlmProviders>>,
Expand All @@ -60,6 +61,7 @@ impl FilterContext {
FilterContext {
callouts: RefCell::new(HashMap::new()),
metrics: Rc::new(WasmMetrics::new()),
system_prompt: Rc::new(None),
prompt_targets: Rc::new(HashMap::new()),
overrides: Rc::new(None),
prompt_guards: Rc::new(PromptGuards::default()),
Expand Down Expand Up @@ -245,6 +247,7 @@ impl RootContext for FilterContext {
for pt in config.prompt_targets {
prompt_targets.insert(pt.name.clone(), pt.clone());
}
self.system_prompt = Rc::new(config.system_prompt);
self.prompt_targets = Rc::new(prompt_targets);

ratelimit::ratelimits(config.ratelimits);
Expand Down Expand Up @@ -273,6 +276,7 @@ impl RootContext for FilterContext {
Some(Box::new(StreamContext::new(
context_id,
Rc::clone(&self.metrics),
Rc::clone(&self.system_prompt),
Rc::clone(&self.prompt_targets),
Rc::clone(&self.prompt_guards),
Rc::clone(&self.overrides),
Expand Down
66 changes: 43 additions & 23 deletions arch/src/stream_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,14 @@ pub enum ServerError {
Jailbreak(String),
#[error("{why}")]
BadRequest { why: String },
#[error("{why}")]
NoMessagesFound { why: String },
}

pub struct StreamContext {
context_id: u32,
metrics: Rc<WasmMetrics>,
system_prompt: Rc<Option<String>>,
prompt_targets: Rc<HashMap<String, PromptTarget>>,
embeddings_store: Rc<EmbeddingsStore>,
overrides: Rc<Option<Overrides>>,
Expand All @@ -108,6 +111,7 @@ impl StreamContext {
pub fn new(
context_id: u32,
metrics: Rc<WasmMetrics>,
system_prompt: Rc<Option<String>>,
prompt_targets: Rc<HashMap<String, PromptTarget>>,
prompt_guards: Rc<PromptGuards>,
overrides: Rc<Option<Overrides>>,
Expand All @@ -117,6 +121,7 @@ impl StreamContext {
StreamContext {
context_id,
metrics,
system_prompt,
prompt_targets,
embeddings_store,
callouts: RefCell::new(HashMap::new()),
Expand Down Expand Up @@ -633,9 +638,12 @@ impl StreamContext {
} else {
warn!("http status code not found in api response");
}
let body_str: String = String::from_utf8(body).unwrap();
self.tool_call_response = Some(body_str.clone());
debug!("arch <= app response body: {}", body_str);
let app_function_call_response_str: String = String::from_utf8(body).unwrap();
self.tool_call_response = Some(app_function_call_response_str.clone());
debug!(
"arch <= app response body: {}",
app_function_call_response_str
);
let prompt_target_name = callout_context.prompt_target_name.unwrap();
let prompt_target = self
.prompt_targets
Expand All @@ -644,36 +652,48 @@ impl StreamContext {
.clone();

let mut messages: Vec<Message> = callout_context.request_body.messages.clone();
let user_message = match messages.pop() {
Some(user_message) => user_message,
None => {
return self.send_server_error(
ServerError::NoMessagesFound {
why: "no user messages found".to_string(),
},
None,
);
}
};

// add system prompt
match prompt_target.system_prompt.as_ref() {
None => {}
Some(system_prompt) => {
let system_prompt_message = Message {
role: SYSTEM_ROLE.to_string(),
content: Some(system_prompt.clone()),
model: None,
tool_calls: None,
};
messages.push(system_prompt_message);
}
}
let system_prompt = match prompt_target.system_prompt.as_ref() {
None => match self.system_prompt.as_ref() {
None => None,
Some(system_prompt) => Some(system_prompt.clone()),
},
Some(system_prompt) => Some(system_prompt.clone()),
};

// add data from function call response
messages.push({
Message {
role: USER_ROLE.to_string(),
content: Some(body_str),
if system_prompt.is_some() {
let system_prompt_message = Message {
role: SYSTEM_ROLE.to_string(),
content: system_prompt,
model: None,
tool_calls: None,
}
});
};
messages.push(system_prompt_message);
}

let final_prompt = format!(
"{}\nhere is context: {}",
user_message.content.unwrap(),
app_function_call_response_str
);

// add original user prompt
messages.push({
Message {
role: USER_ROLE.to_string(),
content: Some(callout_context.user_message.unwrap()),
content: Some(final_prompt),
model: None,
tool_calls: None,
}
Expand Down

0 comments on commit 422efd3

Please sign in to comment.