Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revive anthropic byok #512

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions bring_your_own_key/anthropic.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
cloud_name: Anthropic

chat_endpoint: "https://api.anthropic.com/v1/messages"
chat_endpoint_style: anthropic
chat_apikey: "$ANTHROPIC_API_KEY"
chat_model: claude-3-5-sonnet-20241022

running_models:
- claude-3-5-sonnet-20241022
6 changes: 3 additions & 3 deletions src/at_commands/execute_at.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ pub async fn run_at_commands(
continue;
}
let mut content = msg.content.content_text_only();
let content_n_tokens = msg.content.count_tokens(tokenizer.clone(), &None).unwrap_or(0) as usize;
let content_n_tokens = msg.content.count_tokens(tokenizer.clone(), "openai").unwrap_or(0) as usize;

let mut context_limit = reserve_for_context / messages_with_at.max(1);
context_limit = context_limit.saturating_sub(content_n_tokens);
Expand Down Expand Up @@ -109,7 +109,7 @@ pub async fn run_at_commands(
plain_text_messages,
tokenizer.clone(),
tokens_limit_plain,
&None,
"openai",
).await;
for m in pp_plain_text {
// OUTPUT: plain text after all custom messages
Expand Down Expand Up @@ -159,7 +159,7 @@ pub async fn run_at_commands(

ccx.lock().await.pp_skeleton = false;

return (rebuilt_messages.clone(), user_msg_starts, any_context_produced)
(rebuilt_messages.clone(), user_msg_starts, any_context_produced)
}

pub async fn correct_at_arg(
Expand Down
11 changes: 5 additions & 6 deletions src/call_validation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,13 @@ pub enum ContextEnum {
ChatMessage(ChatMessage),
}

#[derive(Debug, Serialize, Deserialize, Clone)]
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
pub struct ChatToolFunction {
pub arguments: String,
pub name: String,
}

#[derive(Debug, Serialize, Deserialize, Clone)]
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
pub struct ChatToolCall {
pub id: String,
pub function: ChatToolFunction,
Expand All @@ -126,14 +126,15 @@ impl Default for ChatContent {
}
}

#[derive(Debug, Serialize, Deserialize, Clone, Default)]
#[derive(Debug, Serialize, Deserialize, Clone, Default, PartialEq)]
pub struct ChatUsage {
pub prompt_tokens: usize,
pub completion_tokens: usize,
pub total_tokens: usize, // TODO: remove (can produce self-contradictory data when prompt+completion != total)
}

#[derive(Debug, Serialize, Clone, Default)]
// deserialize_messages_from_post must be used to decode content as ChatContentRaw
#[derive(Debug, Serialize, Deserialize, Clone, Default, PartialEq)]
pub struct ChatMessage {
pub role: String,
pub content: ChatContent,
Expand Down Expand Up @@ -192,8 +193,6 @@ pub struct ChatPost {
pub postprocess_parameters: PostprocessSettings,
#[serde(default)]
pub meta: ChatMeta,
#[serde(default)]
pub style: Option<String>,
}

#[derive(Debug, Deserialize, Clone, Default)]
Expand Down
124 changes: 124 additions & 0 deletions src/forward_to_anthropic_endpoint.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
use reqwest::header::{HeaderMap, CONTENT_TYPE, HeaderValue};

use reqwest_eventsource::EventSource;
use serde_json::{json, Value};
use tracing::info;
use crate::call_validation::SamplingParameters;


fn embed_messages_and_tools_from_prompt(
data: &mut Value, prompt: &str
) {
assert!(prompt.starts_with("PASSTHROUGH "));
let messages_str = &prompt[12..];
let big_json: Value = serde_json::from_str(&messages_str).unwrap();

if let Some(messages) = big_json["messages"].as_array() {
data["messages"] = Value::Array(
messages.iter().filter(|msg| msg["role"] != "system").cloned().collect()
);
let system_string = messages.iter()
.filter(|msg| msg["role"] == "system")
.map(|msg| msg["content"].as_str().unwrap_or(""))
.collect::<Vec<_>>()
.join("\n");

if !system_string.is_empty() {
data["system"] = Value::String(system_string);
}

}

if let Some(tools) = big_json.get("tools") {
data["tools"] = tools.clone();
}
}

fn make_headers(bearer: &str) -> Result<HeaderMap, String> {
let mut headers = HeaderMap::new();
headers.insert(CONTENT_TYPE, HeaderValue::from_str("application/json").unwrap());
// see https://docs.anthropic.com/en/api/versioning
headers.insert("anthropic-version", HeaderValue::from_static("2023-06-01"));

if !bearer.is_empty() {
headers.insert("x-api-key", HeaderValue::from_str(bearer)
.map_err(|e| format!("Failed to insert header: {}", e))?);
}
Ok(headers)
}

pub async fn forward_to_anthropic_endpoint(
save_url: &mut String,
bearer: String,
model_name: &str,
prompt: &str,
client: &reqwest::Client,
endpoint_chat_passthrough: &String,
sampling_parameters: &SamplingParameters,
) -> Result<Value, String> {
*save_url = endpoint_chat_passthrough.clone();
let headers = make_headers(bearer.as_str())?;

let mut data = json!({
"model": model_name,
"stream": false,
"temperature": sampling_parameters.temperature,
"max_tokens": sampling_parameters.max_new_tokens,
});

embed_messages_and_tools_from_prompt(&mut data, prompt);

let req = client.post(save_url.as_str())
.headers(headers)
.body(data.to_string())
.send()
.await;
let resp = req.map_err(|e| format!("{}", e))?;
let status_code = resp.status().as_u16();
let response_txt = resp.text().await.map_err(|e|
format!("reading from socket {}: {}", save_url, e)
)?;

if status_code != 200 && status_code != 400 {
return Err(format!("{} status={} text {}", save_url, status_code, response_txt));
}
if status_code != 200 {
info!("forward_to_openai_style_endpoint: {} {}\n{}", save_url, status_code, response_txt);
}
let parsed_json: Value = match serde_json::from_str(&response_txt) {
Ok(json) => json,
Err(e) => return Err(format!("Failed to parse JSON response: {}\n{}", e, response_txt)),
};
Ok(parsed_json)
}

pub async fn forward_to_anthropic_endpoint_streaming(
save_url: &mut String,
bearer: String,
model_name: &str,
prompt: &str,
client: &reqwest::Client,
endpoint_chat_passthrough: &String,
sampling_parameters: &SamplingParameters,
) -> Result<EventSource, String> {
*save_url = endpoint_chat_passthrough.clone();
let headers = make_headers(bearer.as_str())?;

let mut data = json!({
"model": model_name,
"stream": true,
"temperature": sampling_parameters.temperature,
"max_tokens": sampling_parameters.max_new_tokens,
});

embed_messages_and_tools_from_prompt(&mut data, prompt);

let builder = client.post(save_url.as_str())
.headers(headers)
.body(data.to_string());
let event_source: EventSource = EventSource::new(builder).map_err(|e|
format!("can't stream from {}: {}", save_url, e)
)?;

Ok(event_source)
}
4 changes: 2 additions & 2 deletions src/http/routers/v1/at_tools.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ pub struct ToolsExecutePost {
pub postprocess_parameters: PostprocessSettings,
pub model_name: String,
pub chat_id: String,
pub style: Option<String>,
pub style: String,
}

#[derive(Debug, Serialize, Deserialize)]
Expand Down Expand Up @@ -81,7 +81,7 @@ pub async fn handle_v1_tools(
vec![]
});

let tools_openai_stype = tool_desclist.into_iter().map(|x| x.into_openai_style()).collect::<Vec<_>>();
let tools_openai_stype = tool_desclist.into_iter().map(|x| x.into_openai_style(true)).collect::<Vec<_>>();

let body = serde_json::to_string_pretty(&tools_openai_stype).map_err(|e| ScratchError::new(StatusCode::UNPROCESSABLE_ENTITY, format!("JSON problem: {}", e)))?;
Ok(Response::builder()
Expand Down
17 changes: 9 additions & 8 deletions src/http/routers/v1/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use crate::custom_error::ScratchError;
use crate::at_commands::at_commands::AtCommandsContext;
use crate::global_context::{GlobalContext, SharedGlobalContext};
use crate::integrations::docker::docker_container_manager::docker_container_check_status_or_start;
use crate::scratchpads::multimodality::ChatMessages;


pub fn available_tools_by_chat_mode(current_tools: Vec<Value>, chat_mode: &ChatMode) -> Vec<Value> {
Expand Down Expand Up @@ -96,14 +97,14 @@ pub async fn handle_v1_chat(
}

pub fn deserialize_messages_from_post(messages: &Vec<serde_json::Value>) -> Result<Vec<ChatMessage>, ScratchError> {
let messages: Vec<ChatMessage> = messages.iter()
.map(|x| serde_json::from_value(x.clone()))
.collect::<Result<Vec<_>, _>>()
.map_err(|e| {
tracing::error!("can't deserialize ChatMessage: {}", e);
ScratchError::new(StatusCode::BAD_REQUEST, format!("JSON problem: {}", e))
})?;
Ok(messages)
let messages_value = serde_json::Value::Array(messages.clone());

let chat_messages: ChatMessages = serde_json::from_value(messages_value).map_err(|e| {
tracing::error!("can't deserialize ChatMessages: {}", e);
ScratchError::new(StatusCode::BAD_REQUEST, format!("can't deserialize ChatMessages: {}", e))
})?;

Ok(chat_messages.0)
}

async fn _chat(
Expand Down
5 changes: 3 additions & 2 deletions src/http/routers/v1/subchat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ pub async fn handle_v1_subchat(
).await.map_err(|e| ScratchError::new(StatusCode::INTERNAL_SERVER_ERROR, format!("Error: {}", e)))?;

let new_messages = new_messages.into_iter()
.map(|msgs|msgs.iter().map(|msg|msg.into_value(&None)).collect::<Vec<_>>())
// todo subchat does not support anthropic byok
.map(|msgs|msgs.iter().map(|msg|msg.into_value("openai")).collect::<Vec<_>>())
.collect::<Vec<Vec<_>>>();
let resp_serialised = serde_json::to_string_pretty(&new_messages).unwrap();
Ok(
Expand Down Expand Up @@ -107,7 +108,7 @@ pub async fn handle_v1_subchat_single(
).await.map_err(|e| ScratchError::new(StatusCode::INTERNAL_SERVER_ERROR, format!("Error: {}", e)))?;

let new_messages = new_messages.into_iter()
.map(|msgs|msgs.iter().map(|msg|msg.into_value(&None)).collect::<Vec<_>>())
.map(|msgs|msgs.iter().map(|msg|msg.into_value("openai")).collect::<Vec<_>>())
.collect::<Vec<Vec<_>>>();
let resp_serialised = serde_json::to_string_pretty(&new_messages).unwrap();
Ok(
Expand Down
1 change: 1 addition & 0 deletions src/integrations/integr_chrome.rs
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@ impl Tool for ToolChrome {
name: "commands".to_string(),
param_type: "string".to_string(),
description,
..Default::default()
}],
parameters_required: vec!["commands".to_string()],
}
Expand Down
1 change: 1 addition & 0 deletions src/integrations/integr_cmdline_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,7 @@ impl Tool for ToolService {
name: "action".to_string(),
param_type: "string".to_string(),
description: "Action to perform: start, restart, stop, status".to_string(),
..Default::default()
});

let parameters_required = self.cfg.parameters_required.clone().unwrap_or_else(|| {
Expand Down
1 change: 1 addition & 0 deletions src/integrations/integr_pdb.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ impl Tool for ToolPdb {
name: "command".to_string(),
param_type: "string".to_string(),
description: "Examples: 'python -m pdb script.py', 'break module_name.function_name', 'break 10', 'continue', 'print(variable_name)', 'list', 'quit'".to_string(),
..Default::default()
},
],
parameters_required: vec!["command".to_string()],
Expand Down
2 changes: 2 additions & 0 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ mod scratchpads;

#[cfg(feature="vecdb")]
mod fetch_embedding;
mod forward_to_anthropic_endpoint;
mod forward_to_hf_endpoint;
mod forward_to_openai_endpoint;
mod restream;
Expand All @@ -68,6 +69,7 @@ mod git;
mod agentic;
mod trajectories;


#[tokio::main]
async fn main() {
let cpu_num = std::thread::available_parallelism().unwrap().get();
Expand Down
2 changes: 1 addition & 1 deletion src/postprocessing/pp_plain_text.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ pub async fn postprocess_plain_text(
plain_text_messages: Vec<&ChatMessage>,
tokenizer: Arc<RwLock<Tokenizer>>,
tokens_limit: usize,
style: &Option<String>,
style: &str,
) -> (Vec<ChatMessage>, usize) {
if plain_text_messages.is_empty() {
return (vec![], tokens_limit);
Expand Down
Loading