From 708ea8eaf325d4888e88960c46b097ff3c38acef Mon Sep 17 00:00:00 2001 From: Sergey Vakhreev Date: Mon, 11 Nov 2024 13:22:33 +1030 Subject: [PATCH 1/3] Add new model configurations (#419) * Add new model configurations * move to the bottom WIP Signed-off-by: V4LER11 WIP not streaming works basic streaming support enable system send tools list for anthropic WIP streaming tools WIP tools seem to work arguments empty -> empty hashmap executing tools in parallel safer code vision support rebase fixes fix: openai does not like empty tools # Conflicts: # src/call_validation.rs # src/http/routers/v1/chat.rs # src/integrations/integr_chrome.rs # src/integrations/integr_cmdline_service.rs # src/known_models.rs # src/restream.rs # src/scratchpads/chat_passthrough.rs # src/scratchpads/mod.rs # src/subchat.rs # src/tools/tools_description.rs --- bring_your_own_key/anthropic.yaml | 9 + src/at_commands/execute_at.rs | 6 +- src/call_validation.rs | 11 +- src/forward_to_anthropic_endpoint.rs | 124 ++++++ src/http/routers/v1/at_tools.rs | 4 +- src/http/routers/v1/chat.rs | 17 +- src/http/routers/v1/subchat.rs | 5 +- src/integrations/integr_pdb.rs | 1 + src/main.rs | 2 + src/postprocessing/pp_plain_text.rs | 2 +- src/restream.rs | 217 ++++++++--- src/scratchpads/chat_passthrough.rs | 82 ++-- src/scratchpads/chat_utils_limit_history.rs | 2 +- src/scratchpads/mod.rs | 10 +- src/scratchpads/multimodality.rs | 362 ++++++++++++++---- .../passthrough_convert_messages.rs | 4 +- src/scratchpads/scratchpad_utils.rs | 9 + src/subchat.rs | 4 +- .../model_based_edit/model_execution.rs | 4 +- src/tools/tools_description.rs | 74 ++-- src/tools/tools_execute.rs | 31 +- tests/test13_vision.py | 4 +- 22 files changed, 763 insertions(+), 221 deletions(-) create mode 100644 bring_your_own_key/anthropic.yaml create mode 100644 src/forward_to_anthropic_endpoint.rs diff --git a/bring_your_own_key/anthropic.yaml b/bring_your_own_key/anthropic.yaml new file mode 100644 index 000000000..f196926d9 --- /dev/null +++ b/bring_your_own_key/anthropic.yaml @@ -0,0 +1,9 @@ +cloud_name: Anthropic + +chat_endpoint: "https://api.anthropic.com/v1/messages" +chat_endpoint_style: anthropic +chat_apikey: "$ANTHROPIC_API_KEY" +chat_model: claude-3-5-sonnet-20241022 + +running_models: + - claude-3-5-sonnet-20241022 diff --git a/src/at_commands/execute_at.rs b/src/at_commands/execute_at.rs index cf77cd70a..60201281b 100644 --- a/src/at_commands/execute_at.rs +++ b/src/at_commands/execute_at.rs @@ -64,7 +64,7 @@ pub async fn run_at_commands( continue; } let mut content = msg.content.content_text_only(); - let content_n_tokens = msg.content.count_tokens(tokenizer.clone(), &None).unwrap_or(0) as usize; + let content_n_tokens = msg.content.count_tokens(tokenizer.clone(), "openai").unwrap_or(0) as usize; let mut context_limit = reserve_for_context / messages_with_at.max(1); context_limit = context_limit.saturating_sub(content_n_tokens); @@ -109,7 +109,7 @@ pub async fn run_at_commands( plain_text_messages, tokenizer.clone(), tokens_limit_plain, - &None, + "openai", ).await; for m in pp_plain_text { // OUTPUT: plain text after all custom messages @@ -159,7 +159,7 @@ pub async fn run_at_commands( ccx.lock().await.pp_skeleton = false; - return (rebuilt_messages.clone(), user_msg_starts, any_context_produced) + (rebuilt_messages.clone(), user_msg_starts, any_context_produced) } pub async fn correct_at_arg( diff --git a/src/call_validation.rs b/src/call_validation.rs index 85f6c5c57..f0994cd2e 100644 --- a/src/call_validation.rs +++ b/src/call_validation.rs @@ -99,13 +99,13 @@ pub enum ContextEnum { ChatMessage(ChatMessage), } -#[derive(Debug, Serialize, Deserialize, Clone)] +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] pub struct ChatToolFunction { pub arguments: String, pub name: String, } -#[derive(Debug, Serialize, Deserialize, Clone)] +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] pub struct ChatToolCall { pub id: String, pub function: ChatToolFunction, @@ -126,14 +126,15 @@ impl Default for ChatContent { } } -#[derive(Debug, Serialize, Deserialize, Clone, Default)] +#[derive(Debug, Serialize, Deserialize, Clone, Default, PartialEq)] pub struct ChatUsage { pub prompt_tokens: usize, pub completion_tokens: usize, pub total_tokens: usize, // TODO: remove (can produce self-contradictory data when prompt+completion != total) } -#[derive(Debug, Serialize, Clone, Default)] +// deserialize_messages_from_post must be used to decode content as ChatContentRaw +#[derive(Debug, Serialize, Deserialize, Clone, Default, PartialEq)] pub struct ChatMessage { pub role: String, pub content: ChatContent, @@ -192,8 +193,6 @@ pub struct ChatPost { pub postprocess_parameters: PostprocessSettings, #[serde(default)] pub meta: ChatMeta, - #[serde(default)] - pub style: Option, } #[derive(Debug, Deserialize, Clone, Default)] diff --git a/src/forward_to_anthropic_endpoint.rs b/src/forward_to_anthropic_endpoint.rs new file mode 100644 index 000000000..cd13f7a6d --- /dev/null +++ b/src/forward_to_anthropic_endpoint.rs @@ -0,0 +1,124 @@ +use reqwest::header::{HeaderMap, CONTENT_TYPE, HeaderValue}; + +use reqwest_eventsource::EventSource; +use serde_json::{json, Value}; +use tracing::info; +use crate::call_validation::SamplingParameters; + + +fn embed_messages_and_tools_from_prompt( + data: &mut Value, prompt: &str +) { + assert!(prompt.starts_with("PASSTHROUGH ")); + let messages_str = &prompt[12..]; + let big_json: Value = serde_json::from_str(&messages_str).unwrap(); + + if let Some(messages) = big_json["messages"].as_array() { + data["messages"] = Value::Array( + messages.iter().filter(|msg| msg["role"] != "system").cloned().collect() + ); + let system_string = messages.iter() + .filter(|msg| msg["role"] == "system") + .map(|msg| msg["content"].as_str().unwrap_or("")) + .collect::>() + .join("\n"); + + if !system_string.is_empty() { + data["system"] = Value::String(system_string); + } + + } + + if let Some(tools) = big_json.get("tools") { + data["tools"] = tools.clone(); + } +} + +fn make_headers(bearer: &str) -> Result { + let mut headers = HeaderMap::new(); + headers.insert(CONTENT_TYPE, HeaderValue::from_str("application/json").unwrap()); + // see https://docs.anthropic.com/en/api/versioning + headers.insert("anthropic-version", HeaderValue::from_static("2023-06-01")); + + if !bearer.is_empty() { + headers.insert("x-api-key", HeaderValue::from_str(bearer) + .map_err(|e| format!("Failed to insert header: {}", e))?); + } + Ok(headers) +} + +pub async fn forward_to_anthropic_endpoint( + save_url: &mut String, + bearer: String, + model_name: &str, + prompt: &str, + client: &reqwest::Client, + endpoint_chat_passthrough: &String, + sampling_parameters: &SamplingParameters, +) -> Result { + *save_url = endpoint_chat_passthrough.clone(); + let headers = make_headers(bearer.as_str())?; + + let mut data = json!({ + "model": model_name, + "stream": false, + "temperature": sampling_parameters.temperature, + "max_tokens": sampling_parameters.max_new_tokens, + }); + + embed_messages_and_tools_from_prompt(&mut data, prompt); + + let req = client.post(save_url.as_str()) + .headers(headers) + .body(data.to_string()) + .send() + .await; + let resp = req.map_err(|e| format!("{}", e))?; + let status_code = resp.status().as_u16(); + let response_txt = resp.text().await.map_err(|e| + format!("reading from socket {}: {}", save_url, e) + )?; + + if status_code != 200 && status_code != 400 { + return Err(format!("{} status={} text {}", save_url, status_code, response_txt)); + } + if status_code != 200 { + info!("forward_to_openai_style_endpoint: {} {}\n{}", save_url, status_code, response_txt); + } + let parsed_json: Value = match serde_json::from_str(&response_txt) { + Ok(json) => json, + Err(e) => return Err(format!("Failed to parse JSON response: {}\n{}", e, response_txt)), + }; + Ok(parsed_json) +} + +pub async fn forward_to_anthropic_endpoint_streaming( + save_url: &mut String, + bearer: String, + model_name: &str, + prompt: &str, + client: &reqwest::Client, + endpoint_chat_passthrough: &String, + sampling_parameters: &SamplingParameters, +) -> Result { + *save_url = endpoint_chat_passthrough.clone(); + let headers = make_headers(bearer.as_str())?; + + let mut data = json!({ + "model": model_name, + "stream": true, + "temperature": sampling_parameters.temperature, + "max_tokens": sampling_parameters.max_new_tokens, + }); + + embed_messages_and_tools_from_prompt(&mut data, prompt); + + let builder = client.post(save_url.as_str()) + .headers(headers) + .body(data.to_string()); + let event_source: EventSource = EventSource::new(builder).map_err(|e| + format!("can't stream from {}: {}", save_url, e) + )?; + + Ok(event_source) +} diff --git a/src/http/routers/v1/at_tools.rs b/src/http/routers/v1/at_tools.rs index 48d934bc1..265af8c67 100644 --- a/src/http/routers/v1/at_tools.rs +++ b/src/http/routers/v1/at_tools.rs @@ -48,7 +48,7 @@ pub struct ToolsExecutePost { pub postprocess_parameters: PostprocessSettings, pub model_name: String, pub chat_id: String, - pub style: Option, + pub style: String, } #[derive(Debug, Serialize, Deserialize)] @@ -81,7 +81,7 @@ pub async fn handle_v1_tools( vec![] }); - let tools_openai_stype = tool_desclist.into_iter().map(|x| x.into_openai_style()).collect::>(); + let tools_openai_stype = tool_desclist.into_iter().map(|x| x.into_openai_style(true)).collect::>(); let body = serde_json::to_string_pretty(&tools_openai_stype).map_err(|e| ScratchError::new(StatusCode::UNPROCESSABLE_ENTITY, format!("JSON problem: {}", e)))?; Ok(Response::builder() diff --git a/src/http/routers/v1/chat.rs b/src/http/routers/v1/chat.rs index caac66740..e831b6ccf 100644 --- a/src/http/routers/v1/chat.rs +++ b/src/http/routers/v1/chat.rs @@ -14,6 +14,7 @@ use crate::custom_error::ScratchError; use crate::at_commands::at_commands::AtCommandsContext; use crate::global_context::{GlobalContext, SharedGlobalContext}; use crate::integrations::docker::docker_container_manager::docker_container_check_status_or_start; +use crate::scratchpads::multimodality::ChatMessages; pub fn available_tools_by_chat_mode(current_tools: Vec, chat_mode: &ChatMode) -> Vec { @@ -96,14 +97,14 @@ pub async fn handle_v1_chat( } pub fn deserialize_messages_from_post(messages: &Vec) -> Result, ScratchError> { - let messages: Vec = messages.iter() - .map(|x| serde_json::from_value(x.clone())) - .collect::, _>>() - .map_err(|e| { - tracing::error!("can't deserialize ChatMessage: {}", e); - ScratchError::new(StatusCode::BAD_REQUEST, format!("JSON problem: {}", e)) - })?; - Ok(messages) + let messages_value = serde_json::Value::Array(messages.clone()); + + let chat_messages: ChatMessages = serde_json::from_value(messages_value).map_err(|e| { + tracing::error!("can't deserialize ChatMessages: {}", e); + ScratchError::new(StatusCode::BAD_REQUEST, format!("can't deserialize ChatMessages: {}", e)) + })?; + + Ok(chat_messages.0) } async fn _chat( diff --git a/src/http/routers/v1/subchat.rs b/src/http/routers/v1/subchat.rs index 44785e4d7..06d707efe 100644 --- a/src/http/routers/v1/subchat.rs +++ b/src/http/routers/v1/subchat.rs @@ -51,7 +51,8 @@ pub async fn handle_v1_subchat( ).await.map_err(|e| ScratchError::new(StatusCode::INTERNAL_SERVER_ERROR, format!("Error: {}", e)))?; let new_messages = new_messages.into_iter() - .map(|msgs|msgs.iter().map(|msg|msg.into_value(&None)).collect::>()) + // todo subchat does not support anthropic byok + .map(|msgs|msgs.iter().map(|msg|msg.into_value("openai")).collect::>()) .collect::>>(); let resp_serialised = serde_json::to_string_pretty(&new_messages).unwrap(); Ok( @@ -107,7 +108,7 @@ pub async fn handle_v1_subchat_single( ).await.map_err(|e| ScratchError::new(StatusCode::INTERNAL_SERVER_ERROR, format!("Error: {}", e)))?; let new_messages = new_messages.into_iter() - .map(|msgs|msgs.iter().map(|msg|msg.into_value(&None)).collect::>()) + .map(|msgs|msgs.iter().map(|msg|msg.into_value("openai")).collect::>()) .collect::>>(); let resp_serialised = serde_json::to_string_pretty(&new_messages).unwrap(); Ok( diff --git a/src/integrations/integr_pdb.rs b/src/integrations/integr_pdb.rs index 9b99e1580..b4e268bd7 100644 --- a/src/integrations/integr_pdb.rs +++ b/src/integrations/integr_pdb.rs @@ -185,6 +185,7 @@ impl Tool for ToolPdb { name: "command".to_string(), param_type: "string".to_string(), description: "Examples: 'python -m pdb script.py', 'break module_name.function_name', 'break 10', 'continue', 'print(variable_name)', 'list', 'quit'".to_string(), + ..Default::default() }, ], parameters_required: vec!["command".to_string()], diff --git a/src/main.rs b/src/main.rs index cf8cd3071..a541bf9d6 100644 --- a/src/main.rs +++ b/src/main.rs @@ -52,6 +52,7 @@ mod scratchpads; #[cfg(feature="vecdb")] mod fetch_embedding; +mod forward_to_anthropic_endpoint; mod forward_to_hf_endpoint; mod forward_to_openai_endpoint; mod restream; @@ -68,6 +69,7 @@ mod git; mod agentic; mod trajectories; + #[tokio::main] async fn main() { let cpu_num = std::thread::available_parallelism().unwrap().get(); diff --git a/src/postprocessing/pp_plain_text.rs b/src/postprocessing/pp_plain_text.rs index e4cbc6b9a..e9f005617 100644 --- a/src/postprocessing/pp_plain_text.rs +++ b/src/postprocessing/pp_plain_text.rs @@ -33,7 +33,7 @@ pub async fn postprocess_plain_text( plain_text_messages: Vec<&ChatMessage>, tokenizer: Arc>, tokens_limit: usize, - style: &Option, + style: &str, ) -> (Vec, usize) { if plain_text_messages.is_empty() { return (vec![], tokens_limit); diff --git a/src/restream.rs b/src/restream.rs index 0df24e866..d0d6d9ad5 100644 --- a/src/restream.rs +++ b/src/restream.rs @@ -5,17 +5,22 @@ use tokio::sync::mpsc; use async_stream::stream; use futures::StreamExt; use hyper::{Body, Response, StatusCode}; +use reqwest::Client; use reqwest_eventsource::Event; use serde_json::json; use tracing::info; -use crate::call_validation::SamplingParameters; +use crate::call_validation::{ChatMessage, SamplingParameters}; use crate::custom_error::ScratchError; use crate::nicer_logs; use crate::scratchpad_abstract::{FinishReason, ScratchpadAbstract}; use crate::telemetry::telemetry_structs; use crate::at_commands::at_commands::AtCommandsContext; use crate::caps::get_api_key; +use crate::forward_to_anthropic_endpoint::{forward_to_anthropic_endpoint, forward_to_anthropic_endpoint_streaming}; +use crate::forward_to_hf_endpoint::{forward_to_hf_style_endpoint, forward_to_hf_style_endpoint_streaming}; +use crate::forward_to_openai_endpoint::{forward_to_openai_style_endpoint, forward_to_openai_style_endpoint_streaming}; +use crate::scratchpads::multimodality::AnthropicInputElement; async fn _get_endpoint_and_stuff_from_model_name( @@ -61,7 +66,7 @@ async fn _get_endpoint_and_stuff_from_model_name( if !custom_endpoint_template.is_empty() { endpoint_template = custom_endpoint_template; } - return ( + ( api_key, endpoint_template, endpoint_style, @@ -69,6 +74,38 @@ async fn _get_endpoint_and_stuff_from_model_name( ) } +async fn get_model_says( + only_deterministic_messages: bool, + endpoint_style: String, + bearer: String, + model_name: &String, + prompt: &str, + client: &Client, + endpoint_template: &String, + endpoint_chat_passthrough: &String, + parameters: &SamplingParameters, + save_url: &mut String, +) -> Result { + if only_deterministic_messages { + *save_url = "only-det-messages".to_string(); + return Ok(serde_json::Value::Object(serde_json::Map::new())); + } + + match endpoint_style.as_str() { + "hf" => forward_to_hf_style_endpoint( + save_url, bearer.clone(), model_name, prompt, client, &endpoint_template, ¶meters + ).await, + "anthropic" => forward_to_anthropic_endpoint( + save_url, bearer.clone(), model_name, prompt, client, endpoint_chat_passthrough, ¶meters + ).await, + _ => { + forward_to_openai_style_endpoint( + save_url, bearer.clone(), model_name, prompt, client, &endpoint_template, &endpoint_chat_passthrough, ¶meters // includes n + ).await + } + } +} + pub async fn scratchpad_interaction_not_stream_json( ccx: Arc>, scratchpad: &mut Box, @@ -100,39 +137,19 @@ pub async fn scratchpad_interaction_not_stream_json( let mut save_url: String = String::new(); let _ = slowdown_arc.acquire().await; - let mut model_says = if only_deterministic_messages { - save_url = "only-det-messages".to_string(); - Ok(serde_json::Value::Object(serde_json::Map::new())) - } else if endpoint_style == "hf" { - crate::forward_to_hf_endpoint::forward_to_hf_style_endpoint( - &mut save_url, - bearer.clone(), - &model_name, - &prompt, - &client, - &endpoint_template, - ¶meters, - ).await - } else { - crate::forward_to_openai_endpoint::forward_to_openai_style_endpoint( - &mut save_url, - bearer.clone(), - &model_name, - &prompt, - &client, - &endpoint_template, - &endpoint_chat_passthrough, - ¶meters, // includes n - ).await - }.map_err(|e| { + + let mut model_says = get_model_says( + only_deterministic_messages, endpoint_style, bearer, &model_name, prompt, &client, &endpoint_template, &endpoint_chat_passthrough, ¶meters, &mut save_url + ).await.map_err(|e|{ tele_storage.write().unwrap().tele_net.push(telemetry_structs::TelemetryNetwork::new( - save_url.clone(), - scope.clone(), - false, - e.to_string(), - )); + save_url.clone(), + scope.clone(), + false, + e.to_string(), + )); ScratchError::new_but_skip_telemetry(StatusCode::INTERNAL_SERVER_ERROR, format!("forward_to_endpoint: {}", e)) })?; + tele_storage.write().unwrap().tele_net.push(telemetry_structs::TelemetryNetwork::new( save_url.clone(), scope.clone(), @@ -281,7 +298,7 @@ pub async fn scratchpad_interaction_not_stream( .header("Content-Type", "application/json") .body(Body::from(txt)) .unwrap(); - return Ok(response); + Ok(response) } pub async fn scratchpad_interaction_stream( @@ -386,9 +403,9 @@ pub async fn scratchpad_interaction_stream( if only_deterministic_messages { break; } - // info!("prompt: {:?}", prompt); - let event_source_maybe = if endpoint_style == "hf" { - crate::forward_to_hf_endpoint::forward_to_hf_style_endpoint_streaming( + + let event_source = match endpoint_style.as_str() { + "hf" => forward_to_hf_style_endpoint_streaming( &mut save_url, bearer.clone(), &model_name, @@ -396,9 +413,17 @@ pub async fn scratchpad_interaction_stream( &client, &endpoint_template, ¶meters, - ).await - } else { - crate::forward_to_openai_endpoint::forward_to_openai_style_endpoint_streaming( + ).await, + "anthropic" => forward_to_anthropic_endpoint_streaming( + &mut save_url, + bearer.clone(), + &model_name, + prompt.as_str(), + &client, + &endpoint_chat_passthrough, + ¶meters, + ).await, + _ => forward_to_openai_style_endpoint_streaming( &mut save_url, bearer.clone(), &model_name, @@ -409,7 +434,8 @@ pub async fn scratchpad_interaction_stream( ¶meters, ).await }; - let mut event_source = match event_source_maybe { + + let mut event_source = match event_source { Ok(event_source) => event_source, Err(e) => { let e_str = format!("forward_to_endpoint: {:?}", e); @@ -427,17 +453,54 @@ pub async fn scratchpad_interaction_stream( }; let mut was_correct_output_even_if_error = false; let mut last_finish_reason = FinishReason::None; + let mut message_template: serde_json::Value = serde_json::Value::Null; + let mut ant_tool_call_index: i32 = -1; + // let mut test_countdown = 250; while let Some(event) = event_source.next().await { match event { Ok(Event::Open) => {}, Ok(Event::Message(message)) => { - // info!("Message: {:#?}", message); if message.data.starts_with("[DONE]") { break; } let json = serde_json::from_str::(&message.data).unwrap(); + + // for anthropic + match message.event.as_str() { + "message_start" => { + message_template = json!({ + "id": json["message"]["id"], + "object": "chat.completion.chunk", + "model": json["message"]["model"], + // "usage": json["message"]["usage"], todo: implement usage (event: message_delta) + "choices": [ + {"index": 0, "delta": {"role": json["message"]["role"], "content": ""}} + ] + }); + }, + "content_block_start" => { + + }, + "message_stop" => { + finished = true; + break; + }, + "ping" | "content_block_stop" => { + continue; + } + _ => {} + } + crate::global_context::look_for_piggyback_fields(gcx.clone(), &json).await; + + let value = _push_streaming_json_into_scratchpad( + my_scratchpad, + &json, + &mut model_name, + &mut was_correct_output_even_if_error, + ); + match _push_streaming_json_into_scratchpad( my_scratchpad, &json, @@ -449,8 +512,6 @@ pub async fn scratchpad_interaction_stream( try_insert_usage(&mut value); value["created"] = json!(t1.duration_since(std::time::UNIX_EPOCH).unwrap().as_millis() as f64 / 1000.0); let value_str = format!("data: {}\n\n", serde_json::to_string(&value).unwrap()); - // let last_60_chars: String = crate::nicer_logs::first_n_chars(&value_str, 60); - // info!("yield: {:?}", last_60_chars); yield Result::<_, String>::Ok(value_str); }, Err(err_str) => { @@ -550,6 +611,76 @@ pub fn try_insert_usage(msg_value: &mut serde_json::Value) -> bool { return false; } +fn _push_streaming_json_anthropic( + json: &serde_json::Value, + message_template: &serde_json::Value, + event_type: &str, + ant_tool_call_index: &mut i32, +) -> Result, String> { + if !message_template.is_object() { + return Ok(None); + } + let mut value = message_template.clone(); + + match event_type { + "error" => { + let error_type = json.get("error") + .and_then(|e| e.get("type")) + .and_then(|t| t.as_str()) + .unwrap_or("Unknown error type"); + + let error_message = json.get("error") + .and_then(|e| e.get("message")) + .and_then(|m| m.as_str()) + .unwrap_or("Unknown error message"); + + Err(format!("{}: {}", error_type, error_message)) + }, + "message_start" => { + Ok(Some(value)) + }, + "content_block_start" => { + if json["content_block"]["type"] == "tool_use" { + let tool_id = json["content_block"]["id"].clone(); + let tool_name = json["content_block"]["name"].clone(); + *ant_tool_call_index += 1; + value["choices"][0]["delta"] = json!({ + "tool_calls": [{ + "index": *ant_tool_call_index, + "id": tool_id, + "type": "function", + "function": { + "name": tool_name, + "arguments": "" + } + }] + }); + Ok(Some(value)) + } else { + Ok(None) + } + }, + _ => { + value["choices"][0]["delta"] = json!({}); + let delta = json["delta"].clone(); + if delta["type"] == "text_delta" { + value["choices"][0]["delta"]["content"] = delta["text"].clone(); + } else if delta["type"] == "input_json_delta" { + let partial_json = delta["partial_json"].clone(); + value["choices"][0]["delta"] = json!({ + "tool_calls": [{ + "index": *ant_tool_call_index, + "function": { + "arguments": partial_json + } + }] + }); + } + Ok(Some(value)) + } + } +} + fn _push_streaming_json_into_scratchpad( scratch: &mut Box, json: &serde_json::Value, diff --git a/src/scratchpads/chat_passthrough.rs b/src/scratchpads/chat_passthrough.rs index ede7fa4a1..55bb395fc 100644 --- a/src/scratchpads/chat_passthrough.rs +++ b/src/scratchpads/chat_passthrough.rs @@ -59,6 +59,7 @@ pub struct ChatPassthrough { pub allow_at: bool, pub supports_tools: bool, pub supports_clicks: bool, + pub endpoint_style: String, } impl ChatPassthrough { @@ -69,6 +70,7 @@ impl ChatPassthrough { allow_at: bool, supports_tools: bool, supports_clicks: bool, + endpoint_style: &str, ) -> Self { ChatPassthrough { t: HasTokenizerAndEot::new(tokenizer), @@ -79,6 +81,7 @@ impl ChatPassthrough { allow_at, supports_tools, supports_clicks, + endpoint_style: endpoint_style.to_string(), } } } @@ -103,7 +106,8 @@ impl ScratchpadAbstract for ChatPassthrough { let ccx_locked = ccx.lock().await; (ccx_locked.global_context.clone(), ccx_locked.n_ctx, ccx_locked.should_execute_remotely) }; - let style = self.post.style.clone(); + let style = self.endpoint_style.clone(); + let allow_experimental = gcx.read().await.cmdline.experimental; let at_tools = tools_merged_and_filtered(gcx.clone(), self.supports_clicks).await?; let messages = prepend_the_right_system_prompt_and_maybe_more_initial_messages(gcx.clone(), self.messages.clone(), &self.post, &mut self.has_rag_results).await; @@ -125,50 +129,49 @@ impl ScratchpadAbstract for ChatPassthrough { }); let converted_messages = convert_messages_to_openai_format(limited_msgs, &style); + let converted_messages = if style.as_str() == "anthropic" { + format_messages_anthropic(converted_messages) + } else { + converted_messages + }; let mut big_json = serde_json::json!({ "messages": converted_messages, }); if self.supports_tools { - let post_tools = self.post.tools.as_ref().and_then(|tools| { + let tools = if let Some(tools) = &self.post.tools { + // if tools.is_empty() || any_context_produced { if tools.is_empty() { None } else { - Some(tools.clone()) + Some(tools) } - }); - - let mut tools = if let Some(t) = post_tools { - // here we only use names from the tools in `post` - let turned_on = t.iter().filter_map(|x| { - if let Value::Object(map) = x { - map.get("function").and_then(|f| f.get("name")).and_then(|name| name.as_str().map(|s| s.to_string())) - } else { - None - } - }).collect::>(); - let allow_experimental = gcx.read().await.cmdline.experimental; - // and take descriptions of tools from the official source - let tool_descriptions = tool_description_list_from_yaml(at_tools, &turned_on, allow_experimental).await?; - Some(tool_descriptions.into_iter().map(|x|x.into_openai_style()).collect::>()) } else { None }; - // remove "agentic" - if let Some(tools) = &mut tools { - for tool in tools { - if let Some(function) = tool.get_mut("function") { - function.as_object_mut().unwrap().remove("agentic"); - } + let tools_enabled = match tools { + Some(tools) => { + tools.iter().map(|t|t["function"]["name"].as_str().unwrap().to_string()).collect::>() + }, + None => vec![] + }; + + let tools_desc_list = tool_description_list_from_yaml(at_tools, &tools_enabled, allow_experimental).await?; + let tools_filtered = tools_desc_list.iter().filter(|t|tools_enabled.contains(&t.name)).cloned().collect::>(); + + if !tools_filtered.is_empty() { + if self.endpoint_style == "anthropic" { + big_json["tools"] = serde_json::json!(tools_filtered.iter().map(|t|t.clone().into_anthropic_style()).collect::>()); + } else { + big_json["tools"] = serde_json::json!(tools_filtered.iter().map(|t|t.clone().into_openai_style(false)).collect::>()); + big_json["tool_choice"] = serde_json::json!(self.post.tool_choice); } } - big_json["tools"] = json!(tools); - big_json["tool_choice"] = json!(self.post.tool_choice); if DEBUG { - info!("PASSTHROUGH TOOLS ENABLED CNT: {:?}", tools.unwrap_or(vec![]).len()); + info!("PASSTHROUGH TOOLS ENABLED CNT: {:?}", tools.unwrap_or(&vec![]).len()); } } else { if DEBUG { @@ -232,3 +235,28 @@ impl ScratchpadAbstract for ChatPassthrough { })) } } + +// for anthropic: +// tool answers must be located in the same message.content (if tools executed in parallel) +fn format_messages_anthropic(messages: Vec) -> Vec { + let mut res: Vec = vec![]; + for m in messages { + match m.get("content") { + Some(Value::Array(cont)) => { + if let Some(prev_el) = res.last_mut() { + if let Some(Value::Array(prev_cont)) = prev_el.get_mut("content") { + if cont.iter().any(|c| c.get("type") == Some(&Value::String("tool_result".to_string()))) + && prev_cont.iter().any(|p| p.get("type") == Some(&Value::String("tool_result".to_string()))) + { + prev_cont.extend(cont.iter().cloned()); + continue; + } + } + } + res.push(m); + } + _ => res.push(m), + } + } + res +} diff --git a/src/scratchpads/chat_utils_limit_history.rs b/src/scratchpads/chat_utils_limit_history.rs index 148f6d59a..853bbf74e 100644 --- a/src/scratchpads/chat_utils_limit_history.rs +++ b/src/scratchpads/chat_utils_limit_history.rs @@ -17,7 +17,7 @@ pub fn limit_messages_history( let mut message_token_count: Vec = vec![0; messages.len()]; let mut message_take: Vec = vec![false; messages.len()]; for (i, msg) in messages.iter().enumerate() { - let tcnt = 3 + msg.content.count_tokens(t.tokenizer.clone(), &None)?; + let tcnt = 3 + msg.content.count_tokens(t.tokenizer.clone(), "openai")?; message_token_count[i] = tcnt; if i==0 && msg.role == "system" { message_take[i] = true; diff --git a/src/scratchpads/mod.rs b/src/scratchpads/mod.rs index d04b5e9de..9e686988d 100644 --- a/src/scratchpads/mod.rs +++ b/src/scratchpads/mod.rs @@ -81,7 +81,7 @@ pub async fn create_chat_scratchpad( supports_clicks: bool, ) -> Result, String> { let mut result: Box; - let tokenizer_arc = cached_tokenizers::cached_tokenizer(caps, global_context.clone(), model_name_for_tokenizer).await?; + let tokenizer_arc = cached_tokenizers::cached_tokenizer(caps.clone(), global_context.clone(), model_name_for_tokenizer).await?; if scratchpad_name == "CHAT-GENERIC" { result = Box::new(chat_generic::GenericChatScratchpad::new( tokenizer_arc.clone(), post, messages, allow_at @@ -91,9 +91,15 @@ pub async fn create_chat_scratchpad( tokenizer_arc.clone(), post, messages, allow_at )); } else if scratchpad_name == "PASSTHROUGH" { + let style = caps.read().unwrap().endpoint_style.clone(); + let style = match style.as_str() { + "hf" => "hf", + "anthropic" => "anthropic", + _ => "openai" + }; post.stream = Some(true); // this should be passed from the request result = Box::new(chat_passthrough::ChatPassthrough::new( - tokenizer_arc.clone(), post, messages, allow_at, supports_tools, supports_clicks + tokenizer_arc.clone(), post, messages, allow_at, supports_tools, supports_clicks, style )); } else { return Err(format!("This rust binary doesn't have chat scratchpad \"{}\" compiled in", scratchpad_name)); diff --git a/src/scratchpads/multimodality.rs b/src/scratchpads/multimodality.rs index 0a2c9107a..ba48a642f 100644 --- a/src/scratchpads/multimodality.rs +++ b/src/scratchpads/multimodality.rs @@ -2,8 +2,8 @@ use serde::{Deserialize, Deserializer, Serialize}; use std::sync::{Arc, RwLock, RwLockReadGuard}; use serde_json::{json, Value}; use tokenizers::Tokenizer; -use crate::call_validation::{ChatContent, ChatMessage, ChatToolCall}; -use crate::scratchpads::scratchpad_utils::{calculate_image_tokens_openai, count_tokens as count_tokens_simple_text, image_reader_from_b64string, parse_image_b64_from_image_url_openai}; +use crate::call_validation::{ChatContent, ChatMessage, ChatToolCall, ChatToolFunction}; +use crate::scratchpads::scratchpad_utils::{calculate_image_tokens_anthropic, calculate_image_tokens_openai, count_tokens as count_tokens_simple_text, image_reader_from_b64string, parse_image_b64_from_image_url_openai}; #[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)] @@ -32,28 +32,70 @@ impl MultimodalElement { self.m_type.starts_with("image/") } - pub fn from_openai_image(openai_image: MultimodalElementImageOpenAI) -> Result { - let (image_type, _, image_content) = parse_image_b64_from_image_url_openai(&openai_image.image_url.url) - .ok_or(format!("Failed to parse image URL: {}", openai_image.image_url.url))?; - MultimodalElement::new(image_type, image_content) + pub fn from_openai_image(image: MultimodalElementImageOpenAI) -> Result { + let (m_type, _, m_content) = parse_image_b64_from_image_url_openai(&image.image_url.url) + .ok_or(format!("Failed to parse image URL: {}", image.image_url.url))?; + MultimodalElement::new(m_type, m_content) } - pub fn from_openai_text(openai_text: MultimodalElementTextOpenAI) -> Result { + pub fn from_text(openai_text: MultimodalElementText) -> Result { MultimodalElement::new("text".to_string(), openai_text.text) } - pub fn to_orig(&self, style: &Option) -> ChatMultimodalElement { - let style = style.clone().unwrap_or("openai".to_string()); - match style.as_str() { + pub fn from_anthropic_image(image: MultimodalElementImageAnthropic) -> Result { + let m_type = format!("image/{}", image.source.media_type); + let m_content = image.source.data; + MultimodalElement::new(m_type, m_content) + } + + pub fn from_anthropic_tool_use(el: MultimodalElementToolUseAnthropic) -> ChatMessage { + ChatMessage { + role: "assistant".to_string(), + content: ChatContent::SimpleText("".to_string()), + tool_calls: Some(vec![ChatToolCall { + id: el.id.clone(), + tool_type: "function".to_string(), + function: ChatToolFunction { + arguments: el.input.to_string(), + name: el.name.clone(), + } + } + ]), + tool_call_id: "".to_string(), + usage: None, + } + } + + pub fn from_anthropic_tool_result(el: MultimodalElementToolResultAnthropic) -> ChatMessage { + ChatMessage { + role: "tool".to_string(), + content: ChatContent::SimpleText("".to_string()), + tool_calls: None, + tool_call_id: el.tool_use_id.clone(), + usage: None, + } + } + + pub fn to_orig(&self, style: &str) -> ChatMultimodalElement { + match style { "openai" => { if self.is_text() { - self.to_openai_text() + self.to_text() } else if self.is_image() { self.to_openai_image() } else { unreachable!() } }, + "anthropic" => { + if self.is_text() { + self.to_text() + } else if self.is_image() { + self.to_anthropic_image() + } else { + unreachable!() + } + } _ => unreachable!() } } @@ -69,27 +111,40 @@ impl MultimodalElement { }) } - fn to_openai_text(&self) -> ChatMultimodalElement { - ChatMultimodalElement::MultimodalElementTextOpenAI(MultimodalElementTextOpenAI { + fn to_anthropic_image(&self) -> ChatMultimodalElement { + ChatMultimodalElement::MultimodalElementImageAnthropic(MultimodalElementImageAnthropic { + content_type: "image".to_string(), + source: MultimodalElementImageAnthropicSource { + content_type: "base64".to_string(), + media_type: self.m_type.clone(), + data: self.m_content.clone(), + }, + }) + } + + fn to_text(&self) -> ChatMultimodalElement { + ChatMultimodalElement::MultimodalElementText(MultimodalElementText { content_type: "text".to_string(), text: self.m_content.clone(), }) } - pub fn count_tokens(&self, tokenizer: Option<&RwLockReadGuard>, style: &Option) -> Result { + pub fn count_tokens(&self, tokenizer: Option<&RwLockReadGuard>, style: &str) -> Result { if self.is_text() { if let Some(tokenizer) = tokenizer { Ok(count_tokens_simple_text(&tokenizer, &self.m_content) as i32) } else { - return Err("count_tokens() received no tokenizer".to_string()); + Err("count_tokens() received no tokenizer".to_string()) } } else if self.is_image() { - let style = style.clone().unwrap_or("openai".to_string()); - match style.as_str() { + match style { "openai" => { calculate_image_tokens_openai(&self.m_content, "high") }, - _ => unreachable!(), + "anthropic" => { + calculate_image_tokens_anthropic(&self.m_content) + }, + _ => unreachable!() } } else { unreachable!() @@ -98,7 +153,7 @@ impl MultimodalElement { } #[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)] -pub struct MultimodalElementTextOpenAI { +pub struct MultimodalElementText { #[serde(rename = "type")] pub content_type: String, pub text: String, @@ -111,6 +166,38 @@ pub struct MultimodalElementImageOpenAI { pub image_url: MultimodalElementImageOpenAIImageURL, } +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)] +pub struct MultimodalElementImageAnthropicSource { + #[serde(rename = "type")] + pub content_type: String, + pub media_type: String, + pub data: String, +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)] +pub struct MultimodalElementImageAnthropic { + #[serde(rename = "type")] + pub content_type: String, + pub source: MultimodalElementImageAnthropicSource, +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)] +pub struct MultimodalElementToolResultAnthropic { + #[serde(rename = "type")] + pub content_type: String, // type="tool_result" + pub tool_use_id: String, + pub content: ChatContentRaw, +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)] +pub struct MultimodalElementToolUseAnthropic { + #[serde(rename = "type")] + pub content_type: String, // type="tool_use" + pub id: String, + pub name: String, + pub input: Value, +} + #[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)] pub struct MultimodalElementImageOpenAIImageURL { pub url: String, @@ -122,38 +209,90 @@ fn default_detail() -> String { "high".to_string() } +#[derive(Clone, Serialize, Deserialize)] +#[serde(untagged)] +pub enum AnthropicInputElement { + MultimodalElementText(MultimodalElementText), + MultimodalElementToolUseAnthropic(MultimodalElementToolUseAnthropic), +} + +pub fn split_anthropic_input_elements(els: Vec) -> (Vec, Vec) { + let mut text_elements = Vec::new(); + let mut tool_use_elements = Vec::new(); + + for el in els { + match el { + AnthropicInputElement::MultimodalElementText(text_el) => text_elements.push(text_el), + AnthropicInputElement::MultimodalElementToolUseAnthropic(tool_use_el) => tool_use_elements.push(tool_use_el), + } + } + + (text_elements, tool_use_elements) +} + #[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] #[serde(untagged)] // tries to deserialize each enum variant in order pub enum ChatMultimodalElement { - MultimodalElementTextOpenAI(MultimodalElementTextOpenAI), + MultimodalElementText(MultimodalElementText), MultimodalElementImageOpenAI(MultimodalElementImageOpenAI), + MultimodalElementToolUseAnthropic(MultimodalElementToolUseAnthropic), + MultimodalElementToolResultAnthropic(MultimodalElementToolResultAnthropic), + MultimodalElementImageAnthropic(MultimodalElementImageAnthropic), + MultimodalElement(MultimodalElement), } -#[derive(Clone, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] #[serde(untagged)] pub enum ChatContentRaw { SimpleText(String), Multimodal(Vec), } +impl Default for ChatContentRaw { + fn default() -> Self { + ChatContentRaw::SimpleText(String::new()) + } +} + impl ChatContentRaw { - pub fn to_internal_format(&self) -> Result { + pub fn to_internal_format(&self) -> Result<(ChatContent, Vec), String> { match self { - ChatContentRaw::SimpleText(text) => Ok(ChatContent::SimpleText(text.clone())), + ChatContentRaw::SimpleText(text) => Ok((ChatContent::SimpleText(text.clone()), vec![])), ChatContentRaw::Multimodal(elements) => { - let internal_elements: Result, String> = elements.iter() - .map(|el| match el { - ChatMultimodalElement::MultimodalElementTextOpenAI(text_el) => { - MultimodalElement::from_openai_text(text_el.clone()) + let mut internal_elements = Vec::new(); + let mut chat_messages: Vec = vec![]; + + for el in elements { + match el { + ChatMultimodalElement::MultimodalElementText(text_el) => { + let element = MultimodalElement::from_text(text_el.clone())?; + internal_elements.push(element); }, ChatMultimodalElement::MultimodalElementImageOpenAI(image_el) => { - MultimodalElement::from_openai_image(image_el.clone()) + let element = MultimodalElement::from_openai_image(image_el.clone())?; + internal_elements.push(element); }, - ChatMultimodalElement::MultimodalElement(el) => Ok(el.clone()), - }) - .collect(); - internal_elements.map(ChatContent::Multimodal) + ChatMultimodalElement::MultimodalElementToolUseAnthropic(el) => { + let message = MultimodalElement::from_anthropic_tool_use(el.clone()); + chat_messages.push(message); + }, + ChatMultimodalElement::MultimodalElementToolResultAnthropic(el) => { + let message = MultimodalElement::from_anthropic_tool_result(el.clone()); + chat_messages.push(message); + }, + ChatMultimodalElement::MultimodalElementImageAnthropic(el) => { + let element = MultimodalElement::from_anthropic_image(el.clone())?; + internal_elements.push(element); + } + + ChatMultimodalElement::MultimodalElement(el) => { + internal_elements.push(el.clone()); + }, + } + } + + Ok((ChatContent::Multimodal(internal_elements), chat_messages)) } } } @@ -171,7 +310,7 @@ impl ChatContent { } } - pub fn size_estimate(&self, tokenizer: Arc>, style: &Option) -> usize { + pub fn size_estimate(&self, tokenizer: Arc>, style: &str) -> usize { match self { ChatContent::SimpleText(text) => text.len(), ChatContent::Multimodal(_elements) => { @@ -181,7 +320,7 @@ impl ChatContent { } } - pub fn count_tokens(&self, tokenizer: Arc>, style: &Option) -> Result { + pub fn count_tokens(&self, tokenizer: Arc>, style: &str) -> Result { let tokenizer_lock = tokenizer.read().unwrap(); match self { ChatContent::SimpleText(text) => Ok(count_tokens_simple_text(&tokenizer_lock, text) as i32), @@ -192,7 +331,7 @@ impl ChatContent { } } - pub fn into_raw(&self, style: &Option) -> ChatContentRaw { + pub fn into_raw(&self, style: &str) -> ChatContentRaw { match self { ChatContent::SimpleText(text) => ChatContentRaw::SimpleText(text.clone()), ChatContent::Multimodal(elements) => { @@ -208,7 +347,7 @@ impl ChatContent { pub fn chat_content_raw_from_value(value: Value) -> Result { fn validate_multimodal_element(element: &ChatMultimodalElement) -> Result<(), String> { match element { - ChatMultimodalElement::MultimodalElementTextOpenAI(el) => { + ChatMultimodalElement::MultimodalElementText(el) => { if el.content_type != "text" { return Err("Invalid multimodal element: type must be `text`".to_string()); } @@ -220,8 +359,12 @@ pub fn chat_content_raw_from_value(value: Value) -> Result {} + }, + ChatMultimodalElement::MultimodalElementToolUseAnthropic(_el) => {}, + ChatMultimodalElement::MultimodalElementToolResultAnthropic(_el) => {}, + ChatMultimodalElement::MultimodalElementImageAnthropic(_el) => {}, + + ChatMultimodalElement::MultimodalElement(_el) => {}, }; Ok(()) } @@ -232,8 +375,8 @@ pub fn chat_content_raw_from_value(value: Value) -> Result { let mut elements = vec![]; for (idx, item) in array.into_iter().enumerate() { - let element: ChatMultimodalElement = serde_json::from_value(item) - .map_err(|e| format!("Error deserializing element at index {}: {}", idx, e))?; + let element: ChatMultimodalElement = serde_json::from_value(item.clone()) + .map_err(|e| format!("Error deserializing element at index {}:\n{:#?}\n\nError: {}", idx, item, e))?; validate_multimodal_element(&element) .map_err(|e| format!("Validation error for element at index {}: {}", idx, e))?; elements.push(element); @@ -254,50 +397,123 @@ impl ChatMessage { } } - pub fn into_value(&self, style: &Option) -> Value { + pub fn into_value(&self, style: &str) -> Value { let mut dict = serde_json::Map::new(); let chat_content_raw = self.content.into_raw(style); dict.insert("role".to_string(), Value::String(self.role.clone())); dict.insert("content".to_string(), json!(chat_content_raw)); - dict.insert("tool_calls".to_string(), json!(self.tool_calls.clone())); - dict.insert("tool_call_id".to_string(), Value::String(self.tool_call_id.clone())); + + match style { + "openai" => { + dict.insert("tool_calls".to_string(), json!(self.tool_calls.clone())); + dict.insert("tool_call_id".to_string(), Value::String(self.tool_call_id.clone())); + }, + "anthropic" => { + if self.role == "tool" { + let content = vec![json!({ + "type": "tool_result", + "tool_use_id": self.tool_call_id.clone(), + "content": self.content.clone().into_raw(style), + })]; + dict.insert("role".to_string(), Value::String("user".to_string())); + dict.insert("content".to_string(), Value::Array(content)); + } + + if self.role == "assistant" && self.tool_calls.is_some() { + let tool_calls = self.tool_calls.clone().unwrap_or_default(); + let content = tool_calls.iter().map(|call| { + let input_map: serde_json::Map = serde_json::from_str(&call.function.arguments) + .unwrap_or_else(|_| serde_json::Map::new()); + json!({ + "type": "tool_use", + "id": call.id.clone(), + "name": call.function.name.clone(), + "input": input_map, + }) + }).collect::>(); + dict.insert("content".to_string(), Value::Array(content)); + } + }, + _ => unreachable!(), + } Value::Object(dict) } + + pub fn from_anthropic_input(els: Vec, role: &str) -> Self { + let (text_elements, tool_use_elements) = split_anthropic_input_elements(els); + let content = text_elements.iter().map(|x|x.text.clone()).collect::>().join("\n\n"); + + if !tool_use_elements.is_empty() { + ChatMessage { + role: role.to_string(), + content: ChatContent::SimpleText(content), + tool_calls: Some(tool_use_elements.iter().map(|m| ChatToolCall { + id: m.id.clone(), + function: ChatToolFunction { + arguments: m.input.to_string(), + name: m.name.clone() + }, + tool_type: "function".to_string(), + }).collect::>()), + tool_call_id: "".to_string(), + usage: None, + } + } else { + ChatMessage::new(role.to_string(), content) + } + } } -impl<'de> Deserialize<'de> for ChatMessage { - fn deserialize(deserializer: D) -> Result where D: Deserializer<'de> { - let value: Value = Deserialize::deserialize(deserializer)?; - let role = value.get("role") - .and_then(|s| s.as_str()) - .ok_or_else(|| serde::de::Error::missing_field("role"))? - .to_string(); - - let content = match value.get("content") { - Some(content_value) => { - let content_raw: ChatContentRaw = chat_content_raw_from_value(content_value.clone()) - .map_err(|e| serde::de::Error::custom(e))?; - content_raw.to_internal_format() - .map_err(|e| serde::de::Error::custom(e))? - }, - None => ChatContent::SimpleText(String::new()), - }; +#[derive(Debug, Serialize, Clone, PartialEq, Default)] +pub struct ChatMessages(pub Vec); - let tool_calls: Option> = value.get("tool_calls") - .and_then(|v| v.as_array()) - .map(|v| v.iter().map(|v| serde_json::from_value(v.clone()).map_err(serde::de::Error::custom)).collect::, _>>()) - .transpose()?; - let tool_call_id: Option = value.get("tool_call_id") - .and_then(|s| s.as_str()).map(|s| s.to_string()); +impl<'de> Deserialize<'de> for ChatMessages { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let value: Value = Deserialize::deserialize(deserializer)?; + let mut messages = Vec::new(); + + if let Value::Array(array) = value { + for item in array { + let role = item.get("role") + .and_then(|s| s.as_str()) + .ok_or_else(|| serde::de::Error::missing_field("role"))? + .to_string(); + + let (content, chat_messages) = match item.get("content") { + Some(content_value) => { + let content_raw: ChatContentRaw = chat_content_raw_from_value(content_value.clone()) + .map_err(|e| serde::de::Error::custom(e))?; + content_raw.to_internal_format() + .map_err(|e| serde::de::Error::custom(e))? + }, + None => (ChatContent::SimpleText(String::new()), vec![]), + }; + + let tool_calls: Option> = item.get("tool_calls") + .and_then(|v| v.as_array()) + .map(|v| v.iter().map(|v| serde_json::from_value(v.clone()).map_err(serde::de::Error::custom)).collect::, _>>()) + .transpose()?; + let tool_call_id: Option = item.get("tool_call_id") + .and_then(|s| s.as_str()).map(|s| s.to_string()); + + messages.push(ChatMessage { + role, + content, + tool_calls, + tool_call_id: tool_call_id.unwrap_or_default(), + ..Default::default() + }); + messages.extend(chat_messages); + } + } else { + return Err(serde::de::Error::custom("Expected an array of chat messages")); + } - Ok(ChatMessage { - role, - content, - tool_calls, - tool_call_id: tool_call_id.unwrap_or_default(), - ..Default::default() - }) + Ok(ChatMessages(messages)) } } diff --git a/src/scratchpads/passthrough_convert_messages.rs b/src/scratchpads/passthrough_convert_messages.rs index 1965763bc..0d6f2366b 100644 --- a/src/scratchpads/passthrough_convert_messages.rs +++ b/src/scratchpads/passthrough_convert_messages.rs @@ -3,7 +3,7 @@ use tracing::{error, warn}; use crate::call_validation::{ChatContent, ChatMessage, ContextFile}; -pub fn convert_messages_to_openai_format(messages: Vec, style: &Option) -> Vec { +pub fn convert_messages_to_openai_format(messages: Vec, style: &str) -> Vec { let mut results = vec![]; let mut delay_images = vec![]; @@ -179,7 +179,7 @@ mod tests { let roles_out_expected = expected_output.iter().map(|x| x.get("role").unwrap().as_str().unwrap().to_string()).collect::>(); - let style = Some("openai".to_string()); + let style = "openai"; let output = convert_messages_to_openai_format(messages, &style); // println!("OUTPUT: {:#?}", output); diff --git a/src/scratchpads/scratchpad_utils.rs b/src/scratchpads/scratchpad_utils.rs index b20420d6c..a6833bd25 100644 --- a/src/scratchpads/scratchpad_utils.rs +++ b/src/scratchpads/scratchpad_utils.rs @@ -99,6 +99,15 @@ pub fn calculate_image_tokens_openai(image_string: &String, detail: &str) -> Res } } +pub fn calculate_image_tokens_anthropic(image_string: &String) -> Result { + let reader = image_reader_from_b64string(&image_string).map_err(|_| "Failed to read image".to_string())?; + let (width, height) = reader.into_dimensions().map_err(|_| "Failed to get dimensions".to_string())?; + width.checked_mul(height) + .and_then(|area| area.checked_div(750)) + .ok_or_else(|| "Overflow or division by zero occurred".to_string()) + .map(|x|x as i32) +} + // cargo test scratchpads::scratchpad_utils #[cfg(test)] mod tests { diff --git a/src/subchat.rs b/src/subchat.rs index 2ca6b739e..79b386d04 100644 --- a/src/subchat.rs +++ b/src/subchat.rs @@ -185,7 +185,7 @@ async fn chat_interaction_non_stream( ) }; - let content = chat_content_raw_from_value(content_value).and_then(|c|c.to_internal_format()) + let (content, _chat_messages) = chat_content_raw_from_value(content_value).and_then(|c|c.to_internal_format()) .map_err(|e| format!("error parsing model's output: {}", e))?; let mut ch_results = vec![]; @@ -267,7 +267,7 @@ pub async fn subchat_single( error!("Error loading compiled_in_tools: {:?}", e); vec![] }); - let tools = tools_desclist.into_iter().map(|x|x.into_openai_style()).collect::>(); + let tools = tools_desclist.into_iter().map(|x|x.into_openai_style(true)).collect::>(); info!("tools_subset {:?}", tools_subset); info!("tools_turned_on_by_cmdline_set {:?}", tools_turned_on_by_cmdline_set); info!("tools_on_intersection {:?}", tools_on_intersection); diff --git a/src/tools/tool_patch_aux/model_based_edit/model_execution.rs b/src/tools/tool_patch_aux/model_based_edit/model_execution.rs index f9007ef4f..77aea43d8 100644 --- a/src/tools/tool_patch_aux/model_based_edit/model_execution.rs +++ b/src/tools/tool_patch_aux/model_based_edit/model_execution.rs @@ -76,7 +76,7 @@ async fn make_chat_history( } let tokens = messages.iter().map(|x| - 3 + x.content.count_tokens(tokenizer_arc.clone(), &None).unwrap_or(0) as usize + 3 + x.content.count_tokens(tokenizer_arc.clone(), "openai").unwrap_or(0) as usize ).sum::(); if tokens > max_tokens { return Err(format!( @@ -118,7 +118,7 @@ async fn make_follow_up_chat_history( } let tokens = messages.iter().map(|x| - 3 + x.content.count_tokens(tokenizer_arc.clone(), &None).unwrap_or(0) as usize + 3 + x.content.count_tokens(tokenizer_arc.clone(), "openai").unwrap_or(0) as usize ).sum::(); if tokens > max_tokens { return Err(format!( diff --git a/src/tools/tools_description.rs b/src/tools/tools_description.rs index a2976d591..592b60dab 100644 --- a/src/tools/tools_description.rs +++ b/src/tools/tools_description.rs @@ -13,7 +13,6 @@ use crate::call_validation::{ChatUsage, ContextEnum}; use crate::global_context::GlobalContext; use crate::integrations::integr_abstract::IntegrationConfirmation; use crate::tools::tools_execute::{command_should_be_confirmed_by_user, command_should_be_denied}; -// use crate::integrations::docker::integr_docker::ToolDocker; #[derive(Clone, Debug)] @@ -459,60 +458,71 @@ pub struct ToolDesc { pub parameters_required: Vec, } -#[derive(Clone, Serialize, Deserialize, Debug)] +#[derive(Clone, Serialize, Deserialize, Debug, Default)] pub struct ToolParam { pub name: String, #[serde(rename = "type", default = "default_param_type")] pub param_type: String, pub description: String, + #[serde(rename = "enum", default)] + pub param_enum: Vec, // anthropic; learn more https://docs.anthropic.com/en/docs/build-with-claude/tool-use#example-simple-tool-definition } fn default_param_type() -> String { "string".to_string() } -pub fn make_openai_tool_value( - name: String, - agentic: bool, - description: String, - parameters_required: Vec, - parameters: Vec, -) -> Value { - let params_properties = parameters.iter().map(|param| { + +fn map_parameters_to_properties(parameters: Vec, style: &str) -> serde_json::Map { + parameters.iter().map(|param| { + let mut param_json = json!({ + "type": param.param_type, + "description": param.description + }); + + if style == "anthropic" && !param.param_enum.is_empty() { + param_json["enum"] = json!(param.param_enum); + } + ( param.name.clone(), - json!({ - "type": param.param_type, - "description": param.description - }) + param_json ) - }).collect::>(); + }).collect::>() +} - let function_json = json!({ +impl ToolDesc { + pub fn into_openai_style(self, internal_style: bool) -> Value { + let mut function_json = json!({ "type": "function", "function": { - "name": name, - "agentic": agentic, // this field is not OpenAI's - "description": description, + "name": self.name, + "description": self.description, "parameters": { "type": "object", - "properties": params_properties, - "required": parameters_required + "properties": map_parameters_to_properties(self.parameters, "openai"), + "required": self.parameters_required } } }); - function_json -} -impl ToolDesc { - pub fn into_openai_style(self) -> Value { - make_openai_tool_value( - self.name, - self.agentic, - self.description, - self.parameters_required, - self.parameters, - ) + if internal_style { + function_json["function"]["agentic"] = json!(self.agentic); + } + + function_json + } + + pub fn into_anthropic_style(self) -> Value { + json!({ + "name": self.name, + "description": self.description, + "input_schema": { + "type": "object", + "properties": map_parameters_to_properties(self.parameters, "anthropic"), + "required": self.parameters_required + } + }) } } diff --git a/src/tools/tools_execute.rs b/src/tools/tools_execute.rs index 643869233..a6946a5ce 100644 --- a/src/tools/tools_execute.rs +++ b/src/tools/tools_execute.rs @@ -54,7 +54,7 @@ pub async fn run_tools_remotely( maxgen: usize, original_messages: &Vec, stream_back_to_user: &mut HasRagResults, - style: &Option, + style: &str, ) -> Result<(Vec, bool), String> { let (n_ctx, subchat_tool_parameters, postprocess_parameters, gcx, chat_id) = { let ccx_locked = ccx.lock().await; @@ -75,7 +75,7 @@ pub async fn run_tools_remotely( postprocess_parameters, model_name: model_name.to_string(), chat_id: chat_id.clone(), - style: style.clone(), + style: style.to_string(), }; let port = docker_container_get_host_lsp_port_to_connect(gcx.clone(), &chat_id).await?; @@ -101,7 +101,7 @@ pub async fn run_tools_locally( maxgen: usize, original_messages: &Vec, stream_back_to_user: &mut HasRagResults, - style: &Option, + style: &str, ) -> Result<(Vec, bool), String> { let (new_messages, tools_runned) = run_tools( // todo: fix typo "runned" ccx, tools, tokenizer, maxgen, original_messages, style @@ -122,7 +122,7 @@ pub async fn run_tools( tokenizer: Arc>, maxgen: usize, original_messages: &Vec, - style: &Option, + style: &str, ) -> Result<(Vec, bool), String> { let n_ctx = ccx.lock().await.n_ctx; let reserve_for_context = max_tokens_for_rag_chat(n_ctx, maxgen); @@ -161,16 +161,21 @@ pub async fn run_tools( } }; - let args = match serde_json::from_str::>(&t_call.function.arguments) { - Ok(args) => args, - Err(e) => { - let tool_failed_message = tool_answer( - format!("Tool use: couldn't parse arguments: {}. Error:\n{}", t_call.function.arguments, e), t_call.id.to_string() - ); - generated_tool.push(tool_failed_message); - continue; + let args = if t_call.function.arguments.trim().is_empty() { + HashMap::new() + } else { + match serde_json::from_str::>(&t_call.function.arguments) { + Ok(args) => args, + Err(e) => { + let tool_failed_message = tool_answer( + format!("Tool use: couldn't parse arguments: {}. Error:\n{}", t_call.function.arguments, e), t_call.id.to_string() + ); + generated_tool.push(tool_failed_message); + continue; + } } }; + info!("tool use {}({:?})", &t_call.function.name, args); { @@ -263,7 +268,7 @@ async fn pp_run_tools( context_files_for_pp: &mut Vec, tokens_for_rag: usize, tokenizer: Arc>, - style: &Option, + style: &str, ) -> (Vec, Vec) { let mut generated_tool = generated_tool.to_vec(); let mut generated_other = generated_other.to_vec(); diff --git a/tests/test13_vision.py b/tests/test13_vision.py index a3054deb2..bf0a24d7a 100644 --- a/tests/test13_vision.py +++ b/tests/test13_vision.py @@ -24,7 +24,7 @@ def encode_image(image_path): def chat_request(msgs, max_tokens: int = 200): url = "http://localhost:8001/v1/chat" payload = { - "model": "gpt-4o", + "model": "claude-3-5-sonnet-20241022", "messages": msgs, "stream": False, "max_tokens": max_tokens, @@ -108,6 +108,6 @@ def test_multiple_images_sending(): if __name__ == "__main__": - test_format() +# test_format() test_image_sending() test_multiple_images_sending() From 716eb37acfc3b48526e6bf95da251e1c2c4e2dc6 Mon Sep 17 00:00:00 2001 From: V4LER11 Date: Thu, 19 Dec 2024 17:49:48 +0000 Subject: [PATCH 2/3] fixes after rebase --- src/integrations/integr_chrome.rs | 1 + src/integrations/integr_cmdline_service.rs | 1 + src/restream.rs | 295 ++++++++++++--------- src/scratchpads/code_completion_replace.rs | 2 +- src/scratchpads/mod.rs | 13 +- 5 files changed, 176 insertions(+), 136 deletions(-) diff --git a/src/integrations/integr_chrome.rs b/src/integrations/integr_chrome.rs index 45f131f4c..c5fea79ed 100644 --- a/src/integrations/integr_chrome.rs +++ b/src/integrations/integr_chrome.rs @@ -293,6 +293,7 @@ impl Tool for ToolChrome { name: "commands".to_string(), param_type: "string".to_string(), description, + ..Default::default() }], parameters_required: vec!["commands".to_string()], } diff --git a/src/integrations/integr_cmdline_service.rs b/src/integrations/integr_cmdline_service.rs index 4336210cb..4cbb687e4 100644 --- a/src/integrations/integr_cmdline_service.rs +++ b/src/integrations/integr_cmdline_service.rs @@ -318,6 +318,7 @@ impl Tool for ToolService { name: "action".to_string(), param_type: "string".to_string(), description: "Action to perform: start, restart, stop, status".to_string(), + ..Default::default() }); let parameters_required = self.cfg.parameters_required.clone().unwrap_or_else(|| { diff --git a/src/restream.rs b/src/restream.rs index d0d6d9ad5..802d3544d 100644 --- a/src/restream.rs +++ b/src/restream.rs @@ -21,6 +21,7 @@ use crate::forward_to_anthropic_endpoint::{forward_to_anthropic_endpoint, forwar use crate::forward_to_hf_endpoint::{forward_to_hf_style_endpoint, forward_to_hf_style_endpoint_streaming}; use crate::forward_to_openai_endpoint::{forward_to_openai_style_endpoint, forward_to_openai_style_endpoint_streaming}; use crate::scratchpads::multimodality::AnthropicInputElement; +use crate::scratchpads::resolve_endpoint_style; async fn _get_endpoint_and_stuff_from_model_name( @@ -69,7 +70,7 @@ async fn _get_endpoint_and_stuff_from_model_name( ( api_key, endpoint_template, - endpoint_style, + resolve_endpoint_style(&endpoint_style).to_string(), endpoint_chat_passthrough, ) } @@ -106,66 +107,17 @@ async fn get_model_says( } } -pub async fn scratchpad_interaction_not_stream_json( - ccx: Arc>, +fn scratchpad_result_not_stream( + model_says: &mut serde_json::Value, scratchpad: &mut Box, - scope: String, - prompt: &str, - model_name: String, - parameters: &SamplingParameters, // includes n only_deterministic_messages: bool, ) -> Result { - let t2 = std::time::SystemTime::now(); - let gcx = ccx.lock().await.global_context.clone(); - let (client, caps, tele_storage, slowdown_arc) = { - let gcx_locked = gcx.write().await; - let caps = gcx_locked.caps.clone() - .ok_or(ScratchError::new(StatusCode::INTERNAL_SERVER_ERROR, "No caps available".to_string()))?; - ( - gcx_locked.http_client.clone(), - caps, - gcx_locked.telemetry.clone(), - gcx_locked.http_client_slowdown.clone() - ) - }; - let ( - bearer, - endpoint_template, - endpoint_style, - endpoint_chat_passthrough, - ) = _get_endpoint_and_stuff_from_model_name(gcx.clone(), caps.clone(), model_name.clone()).await; - - let mut save_url: String = String::new(); - let _ = slowdown_arc.acquire().await; - - let mut model_says = get_model_says( - only_deterministic_messages, endpoint_style, bearer, &model_name, prompt, &client, &endpoint_template, &endpoint_chat_passthrough, ¶meters, &mut save_url - ).await.map_err(|e|{ - tele_storage.write().unwrap().tele_net.push(telemetry_structs::TelemetryNetwork::new( - save_url.clone(), - scope.clone(), - false, - e.to_string(), - )); - ScratchError::new_but_skip_telemetry(StatusCode::INTERNAL_SERVER_ERROR, format!("forward_to_endpoint: {}", e)) - })?; - - tele_storage.write().unwrap().tele_net.push(telemetry_structs::TelemetryNetwork::new( - save_url.clone(), - scope.clone(), - true, - "".to_string(), - )); - info!("forward to endpoint {:.2}ms, url was {}", t2.elapsed().unwrap().as_millis() as f64, save_url); - crate::global_context::look_for_piggyback_fields(gcx.clone(), &model_says).await; - - let scratchpad_result: Result; - if only_deterministic_messages { + let scratchpad_result = if only_deterministic_messages { if let Ok(det_msgs) = scratchpad.response_spontaneous() { model_says["deterministic_messages"] = json!(det_msgs); model_says["choices"] = serde_json::Value::Array(vec![]); } - scratchpad_result = Ok(model_says.clone()); + Ok(model_says.clone()) } else if let Some(hf_arr) = model_says.as_array() { let choices = hf_arr.iter().map(|x| { @@ -178,9 +130,9 @@ pub async fn scratchpad_interaction_not_stream_json( }) }).collect::>(); let finish_reasons = vec![FinishReason::Length; choices.len()]; - scratchpad_result = scratchpad.response_n_choices(choices, finish_reasons); + scratchpad.response_n_choices(choices, finish_reasons) - } else if let Some(oai_choices) = model_says.clone().get("choices") { + } else if let Some(oai_choices) = model_says.get("choices").cloned() { let choice0 = oai_choices.as_array().unwrap().get(0).unwrap(); let finish_reasons = oai_choices.clone().as_array().unwrap().iter().map( |x| FinishReason::from_json_val(x.get("finish_reason").unwrap_or(&json!(""))).unwrap_or_else(|err| { @@ -204,7 +156,7 @@ pub async fn scratchpad_interaction_not_stream_json( } } }).collect::>(); - scratchpad_result = match scratchpad.response_message_n_choices(choices, finish_reasons) { + match scratchpad.response_message_n_choices(choices, finish_reasons) { Ok(res) => Ok(res), Err(err) => { if err == "not implemented" { @@ -214,7 +166,7 @@ pub async fn scratchpad_interaction_not_stream_json( Err(err) } } - }; + } } else { // TODO: restore order using 'index' // for oai_choice in oai_choices.as_array().unwrap() { @@ -229,36 +181,112 @@ pub async fn scratchpad_interaction_not_stream_json( "".to_string() }) }).collect::>(); - scratchpad_result = scratchpad.response_n_choices(choices, finish_reasons); + scratchpad.response_n_choices(choices, finish_reasons) } - } else if let Some(err) = model_says.get("error") { - return Err(ScratchError::new(StatusCode::INTERNAL_SERVER_ERROR, - format!("{}", err) - )); + } else if let Some(content) = model_says.get("content") { // anthropic style + let content_elements = if let Some(content_arr) = content.as_array() { + content_arr.clone() + } else { + vec![content.clone()] + }; + let multimodal_elements: Vec = serde_json::from_value(serde_json::Value::Array(content_elements)).map_err(|e|{ + ScratchError::new(StatusCode::INTERNAL_SERVER_ERROR, format!("Anthropic: failed to deserialize content: {:?}", e)) + })?; + let role = model_says["role"].as_str().unwrap(); + let chat_message = ChatMessage::from_anthropic_input(multimodal_elements, role); + + let mut message = json!({"role": role, "content": chat_message.content}); + if let Some(t_calls) = chat_message.tool_calls { + message["tool_calls"] = json!(t_calls); + } - } else if let Some(msg) = model_says.get("human_readable_message") { - return Err(ScratchError::new(StatusCode::INTERNAL_SERVER_ERROR, - format!("{}", msg) - )); + let mut response = json!({ + "choices": [{ + "index": 0, + "message": message, + // todo: maybe provide real finish reason in else case + "finish_reason": if model_says["stop_reason"] == "end_turn" { "stop" } else { "length" }, + }] + }); + if let Ok(det_msgs) = scratchpad.response_spontaneous() { + response["deterministic_messages"] = json!(det_msgs); + } + Ok(response) + } else if let Some(err) = model_says.get("error") { + return Err(ScratchError::new(StatusCode::INTERNAL_SERVER_ERROR, format!("{}", err))); + } else if let Some(msg) = model_says.get("human_readable_message") { + return Err(ScratchError::new(StatusCode::INTERNAL_SERVER_ERROR, format!("{}", msg))); } else if let Some(msg) = model_says.get("detail") { - return Err(ScratchError::new(StatusCode::INTERNAL_SERVER_ERROR, - format!("{}", msg) - )); - + return Err(ScratchError::new(StatusCode::INTERNAL_SERVER_ERROR, format!("{}", msg))); } else { - return Err(ScratchError::new(StatusCode::INTERNAL_SERVER_ERROR, - format!("unrecognized response (1): {:?}", model_says)) - ); + return Err(ScratchError::new(StatusCode::INTERNAL_SERVER_ERROR, format!("unrecognized response (1): {:?}", model_says))); + }; + match scratchpad_result { + Ok(x) => Ok(x), + Err(e) => Err(ScratchError::new(StatusCode::INTERNAL_SERVER_ERROR, format!("scratchpad: {}", e))) } +} - if let Err(problem) = scratchpad_result { - return Err(ScratchError::new(StatusCode::INTERNAL_SERVER_ERROR, - format!("scratchpad: {}", problem)) - ); - } - return Ok(scratchpad_result.unwrap()); +pub async fn scratchpad_interaction_not_stream_json( + ccx: Arc>, + scratchpad: &mut Box, + scope: String, + prompt: &str, + model_name: String, + parameters: &SamplingParameters, // includes n + only_deterministic_messages: bool, +) -> Result { + let t2 = std::time::SystemTime::now(); + let gcx = ccx.lock().await.global_context.clone(); + let (client, caps, tele_storage, slowdown_arc) = { + let gcx_locked = gcx.write().await; + let caps = gcx_locked.caps.clone() + .ok_or(ScratchError::new(StatusCode::INTERNAL_SERVER_ERROR, "No caps available".to_string()))?; + ( + gcx_locked.http_client.clone(), + caps, + gcx_locked.telemetry.clone(), + gcx_locked.http_client_slowdown.clone() + ) + }; + let ( + bearer, + endpoint_template, + endpoint_style, + endpoint_chat_passthrough, + ) = _get_endpoint_and_stuff_from_model_name(gcx.clone(), caps.clone(), model_name.clone()).await; + + let mut save_url: String = String::new(); + let _ = slowdown_arc.acquire().await; + + let mut model_says = get_model_says( + only_deterministic_messages, endpoint_style, bearer, &model_name, prompt, &client, &endpoint_template, &endpoint_chat_passthrough, ¶meters, &mut save_url + ).await.map_err(|e|{ + tele_storage.write().unwrap().tele_net.push(telemetry_structs::TelemetryNetwork::new( + save_url.clone(), + scope.clone(), + false, + e.to_string(), + )); + ScratchError::new_but_skip_telemetry(StatusCode::INTERNAL_SERVER_ERROR, format!("forward_to_endpoint: {}", e)) + })?; + + tele_storage.write().unwrap().tele_net.push(telemetry_structs::TelemetryNetwork::new( + save_url.clone(), + scope.clone(), + true, + "".to_string(), + )); + info!("forward to endpoint {:.2}ms, url was {}", t2.elapsed().unwrap().as_millis() as f64, save_url); + crate::global_context::look_for_piggyback_fields(gcx.clone(), &model_says).await; + + scratchpad_result_not_stream( + &mut model_says, + scratchpad, + only_deterministic_messages, + ) } pub async fn scratchpad_interaction_not_stream( @@ -466,48 +494,54 @@ pub async fn scratchpad_interaction_stream( } let json = serde_json::from_str::(&message.data).unwrap(); - // for anthropic - match message.event.as_str() { - "message_start" => { - message_template = json!({ - "id": json["message"]["id"], - "object": "chat.completion.chunk", - "model": json["message"]["model"], - // "usage": json["message"]["usage"], todo: implement usage (event: message_delta) - "choices": [ - {"index": 0, "delta": {"role": json["message"]["role"], "content": ""}} - ] - }); - }, - "content_block_start" => { - - }, - "message_stop" => { - finished = true; - break; - }, - "ping" | "content_block_stop" => { - continue; - } - _ => {} - } - crate::global_context::look_for_piggyback_fields(gcx.clone(), &json).await; - - let value = _push_streaming_json_into_scratchpad( - my_scratchpad, - &json, - &mut model_name, - &mut was_correct_output_even_if_error, - ); - - match _push_streaming_json_into_scratchpad( - my_scratchpad, - &json, - &mut model_name, - &mut was_correct_output_even_if_error, - ) { - Ok((mut value, finish_reason)) => { + + let value = match endpoint_style.as_str() { + "anthropic" => { + match message.event.as_str() { + "message_start" => { + message_template = json!({ + "id": json["message"]["id"], + "object": "chat.completion.chunk", + "model": json["message"]["model"], + // "usage": json["message"]["usage"], todo: implement usage (event: message_delta) + "choices": [ + {"index": 0, "delta": {"role": json["message"]["role"], "content": ""}} + ] + }); + }, + "content_block_start" => {}, + "message_stop" => { + last_finish_reason = FinishReason::Stop; + break; + }, + "ping" | "content_block_stop" => { + continue; + }, + _ => {} + } + _push_streaming_json_anthropic( + &json, + &message_template, + message.event.as_str(), + &mut ant_tool_call_index, + ) + }, + // openai, hf + _ => _push_streaming_json_into_scratchpad( + my_scratchpad, + &json, + &mut model_name, + &mut was_correct_output_even_if_error, + ) + }; + + match value { + Ok((value_mb, finish_reason)) => { + let mut value = match value_mb { + Some(v) => v, + None => continue + }; last_finish_reason = finish_reason; try_insert_usage(&mut value); value["created"] = json!(t1.duration_since(std::time::UNIX_EPOCH).unwrap().as_millis() as f64 / 1000.0); @@ -524,6 +558,7 @@ pub async fn scratchpad_interaction_stream( } }, + Err(err) => { if was_correct_output_even_if_error { // "restream error: Stream ended" @@ -616,9 +651,9 @@ fn _push_streaming_json_anthropic( message_template: &serde_json::Value, event_type: &str, ant_tool_call_index: &mut i32, -) -> Result, String> { +) -> Result<(Option, FinishReason), String> { if !message_template.is_object() { - return Ok(None); + return Ok((None, FinishReason::None)); } let mut value = message_template.clone(); @@ -637,7 +672,7 @@ fn _push_streaming_json_anthropic( Err(format!("{}: {}", error_type, error_message)) }, "message_start" => { - Ok(Some(value)) + Ok((Some(value), FinishReason::None)) }, "content_block_start" => { if json["content_block"]["type"] == "tool_use" { @@ -655,9 +690,9 @@ fn _push_streaming_json_anthropic( } }] }); - Ok(Some(value)) + Ok((Some(value), FinishReason::None)) } else { - Ok(None) + Ok((None, FinishReason::None)) } }, _ => { @@ -676,7 +711,7 @@ fn _push_streaming_json_anthropic( }] }); } - Ok(Some(value)) + Ok((Some(value), FinishReason::None)) } } } @@ -686,14 +721,14 @@ fn _push_streaming_json_into_scratchpad( json: &serde_json::Value, model_name: &mut String, was_correct_output_even_if_error: &mut bool, -) -> Result<(serde_json::Value, FinishReason), String> { +) -> Result<(Option, FinishReason), String> { if let Some(token) = json.get("token") { // hf style produces this let text = token.get("text").unwrap_or(&json!("")).as_str().unwrap_or("").to_string(); // TODO: probably we must retrieve the correct `finish_reason` from the json somehow let (mut value, finish_reason) = scratch.response_streaming(text, FinishReason::None)?; value["model"] = json!(model_name.clone()); *was_correct_output_even_if_error |= json.get("generated_text").is_some(); - Ok((value, finish_reason)) + Ok((Some(value), finish_reason)) } else if let Some(choices) = json.get("choices") { // openai style let choice0 = &choices[0]; let mut value: serde_json::Value; @@ -723,7 +758,7 @@ fn _push_streaming_json_into_scratchpad( model_name.clone_from(&model_value.as_str().unwrap_or("").to_string()); } value["model"] = json!(model_name.clone()); - Ok((value, finish_reason)) + Ok((Some(value), finish_reason)) } else if let Some(err) = json.get("error") { Err(format!("{}", err)) } else if let Some(msg) = json.get("human_readable_message") { diff --git a/src/scratchpads/code_completion_replace.rs b/src/scratchpads/code_completion_replace.rs index a6f3068a6..f11b5d825 100644 --- a/src/scratchpads/code_completion_replace.rs +++ b/src/scratchpads/code_completion_replace.rs @@ -1009,7 +1009,7 @@ impl ScratchpadAbstract for CodeCompletionReplacePassthroughScratchpad { }); let json_messages = &serde_json::to_string(&json!({ - "messages": messages.iter().map(|x| { x.into_value(&None) }).collect::>(), + "messages": messages.iter().map(|x| { x.into_value("openai") }).collect::>(), })) .unwrap(); let prompt = format!("PASSTHROUGH {json_messages}").to_string(); diff --git a/src/scratchpads/mod.rs b/src/scratchpads/mod.rs index 9e686988d..b8a262b3f 100644 --- a/src/scratchpads/mod.rs +++ b/src/scratchpads/mod.rs @@ -30,6 +30,13 @@ use crate::cached_tokenizers; fn verify_has_send(_x: &T) {} +pub fn resolve_endpoint_style(endpoint_style: &str) -> &str { + match endpoint_style { + "hf" => "hf", + "anthropic" => "anthropic", + _ => "openai" + } +} pub async fn create_code_completion_scratchpad( global_context: Arc>, @@ -92,11 +99,7 @@ pub async fn create_chat_scratchpad( )); } else if scratchpad_name == "PASSTHROUGH" { let style = caps.read().unwrap().endpoint_style.clone(); - let style = match style.as_str() { - "hf" => "hf", - "anthropic" => "anthropic", - _ => "openai" - }; + let style = resolve_endpoint_style(&style); post.stream = Some(true); // this should be passed from the request result = Box::new(chat_passthrough::ChatPassthrough::new( tokenizer_arc.clone(), post, messages, allow_at, supports_tools, supports_clicks, style From 7ede5b17ae834911dbee50dd6cb843c6bfb7a0fd Mon Sep 17 00:00:00 2001 From: V4LER11 Date: Thu, 19 Dec 2024 18:05:16 +0000 Subject: [PATCH 3/3] limit_messages_history uses style --- src/scratchpads/chat_generic.rs | 2 +- src/scratchpads/chat_llama2.rs | 2 +- src/scratchpads/chat_passthrough.rs | 29 ++----------------- src/scratchpads/chat_utils_limit_history.rs | 3 +- .../passthrough_convert_messages.rs | 24 +++++++++++++++ 5 files changed, 30 insertions(+), 30 deletions(-) diff --git a/src/scratchpads/chat_generic.rs b/src/scratchpads/chat_generic.rs index ab8f14ee3..4e4e503d6 100644 --- a/src/scratchpads/chat_generic.rs +++ b/src/scratchpads/chat_generic.rs @@ -102,7 +102,7 @@ impl ScratchpadAbstract for GenericChatScratchpad { } else { (self.messages.clone(), self.messages.len(), false) }; - let limited_msgs: Vec = limit_messages_history(&self.t, &messages, undroppable_msg_n, self.post.parameters.max_new_tokens, n_ctx)?; + let limited_msgs: Vec = limit_messages_history(&self.t, &messages, undroppable_msg_n, self.post.parameters.max_new_tokens, n_ctx, "openai")?; // if self.supports_tools { // }; sampling_parameters_to_patch.stop = self.dd.stop_list.clone(); diff --git a/src/scratchpads/chat_llama2.rs b/src/scratchpads/chat_llama2.rs index b962c7626..0f6f04232 100644 --- a/src/scratchpads/chat_llama2.rs +++ b/src/scratchpads/chat_llama2.rs @@ -83,7 +83,7 @@ impl ScratchpadAbstract for ChatLlama2 { } else { (self.messages.clone(), self.messages.len(), false) }; - let limited_msgs: Vec = limit_messages_history(&self.t, &messages, undroppable_msg_n, sampling_parameters_to_patch.max_new_tokens, n_ctx)?; + let limited_msgs: Vec = limit_messages_history(&self.t, &messages, undroppable_msg_n, sampling_parameters_to_patch.max_new_tokens, n_ctx, "openai")?; sampling_parameters_to_patch.stop = self.dd.stop_list.clone(); // loosely adapted from https://huggingface.co/spaces/huggingface-projects/llama-2-13b-chat/blob/main/model.py#L24 let mut prompt = "".to_string(); diff --git a/src/scratchpads/chat_passthrough.rs b/src/scratchpads/chat_passthrough.rs index 55bb395fc..d0f97d411 100644 --- a/src/scratchpads/chat_passthrough.rs +++ b/src/scratchpads/chat_passthrough.rs @@ -13,7 +13,7 @@ use crate::scratchpad_abstract::{FinishReason, HasTokenizerAndEot, ScratchpadAbs use crate::scratchpads::chat_utils_limit_history::limit_messages_history; use crate::scratchpads::scratchpad_utils::HasRagResults; use crate::scratchpads::chat_utils_prompts::prepend_the_right_system_prompt_and_maybe_more_initial_messages; -use crate::scratchpads::passthrough_convert_messages::convert_messages_to_openai_format; +use crate::scratchpads::passthrough_convert_messages::{convert_messages_to_openai_format, format_messages_anthropic}; use crate::tools::tools_description::{tool_description_list_from_yaml, tools_merged_and_filtered}; use crate::tools::tools_execute::{run_tools_locally, run_tools_remotely}; @@ -123,7 +123,7 @@ impl ScratchpadAbstract for ChatPassthrough { run_tools_locally(ccx.clone(), at_tools.clone(), self.t.tokenizer.clone(), sampling_parameters_to_patch.max_new_tokens, &messages, &mut self.has_rag_results, &style).await? } }; - let limited_msgs = limit_messages_history(&self.t, &messages, undroppable_msg_n, sampling_parameters_to_patch.max_new_tokens, n_ctx).unwrap_or_else(|e| { + let limited_msgs = limit_messages_history(&self.t, &messages, undroppable_msg_n, sampling_parameters_to_patch.max_new_tokens, n_ctx, &style).unwrap_or_else(|e| { error!("error limiting messages: {}", e); vec![] }); @@ -235,28 +235,3 @@ impl ScratchpadAbstract for ChatPassthrough { })) } } - -// for anthropic: -// tool answers must be located in the same message.content (if tools executed in parallel) -fn format_messages_anthropic(messages: Vec) -> Vec { - let mut res: Vec = vec![]; - for m in messages { - match m.get("content") { - Some(Value::Array(cont)) => { - if let Some(prev_el) = res.last_mut() { - if let Some(Value::Array(prev_cont)) = prev_el.get_mut("content") { - if cont.iter().any(|c| c.get("type") == Some(&Value::String("tool_result".to_string()))) - && prev_cont.iter().any(|p| p.get("type") == Some(&Value::String("tool_result".to_string()))) - { - prev_cont.extend(cont.iter().cloned()); - continue; - } - } - } - res.push(m); - } - _ => res.push(m), - } - } - res -} diff --git a/src/scratchpads/chat_utils_limit_history.rs b/src/scratchpads/chat_utils_limit_history.rs index 853bbf74e..2e8574aab 100644 --- a/src/scratchpads/chat_utils_limit_history.rs +++ b/src/scratchpads/chat_utils_limit_history.rs @@ -9,6 +9,7 @@ pub fn limit_messages_history( last_user_msg_starts: usize, max_new_tokens: usize, context_size: usize, + style: &str, ) -> Result, String> { let tokens_limit: i32 = context_size as i32 - max_new_tokens as i32; @@ -17,7 +18,7 @@ pub fn limit_messages_history( let mut message_token_count: Vec = vec![0; messages.len()]; let mut message_take: Vec = vec![false; messages.len()]; for (i, msg) in messages.iter().enumerate() { - let tcnt = 3 + msg.content.count_tokens(t.tokenizer.clone(), "openai")?; + let tcnt = 3 + msg.content.count_tokens(t.tokenizer.clone(), style)?; message_token_count[i] = tcnt; if i==0 && msg.role == "system" { message_take[i] = true; diff --git a/src/scratchpads/passthrough_convert_messages.rs b/src/scratchpads/passthrough_convert_messages.rs index 0d6f2366b..f9124e80a 100644 --- a/src/scratchpads/passthrough_convert_messages.rs +++ b/src/scratchpads/passthrough_convert_messages.rs @@ -91,6 +91,30 @@ pub fn convert_messages_to_openai_format(messages: Vec, style: &str results } +// for anthropic: +// tool answers must be located in the same message.content (if tools executed in parallel) +pub fn format_messages_anthropic(messages: Vec) -> Vec { + let mut res: Vec = vec![]; + for m in messages { + match m.get("content") { + Some(Value::Array(cont)) => { + if let Some(prev_el) = res.last_mut() { + if let Some(Value::Array(prev_cont)) = prev_el.get_mut("content") { + if cont.iter().any(|c| c.get("type") == Some(&Value::String("tool_result".to_string()))) + && prev_cont.iter().any(|p| p.get("type") == Some(&Value::String("tool_result".to_string()))) + { + prev_cont.extend(cont.iter().cloned()); + continue; + } + } + } + res.push(m); + } + _ => res.push(m), + } + } + res +} #[cfg(test)] mod tests {