diff --git a/src/agent_db/db_cthread.rs b/src/agent_db/db_cthread.rs index 2e9b71ebe..0cc868acd 100644 --- a/src/agent_db/db_cthread.rs +++ b/src/agent_db/db_cthread.rs @@ -40,6 +40,8 @@ pub fn cthreads_from_rows( cthread_toolset: row.get("cthread_toolset").unwrap(), cthread_model: row.get("cthread_model").unwrap(), cthread_temperature: row.get("cthread_temperature").unwrap(), + cthread_max_new_tokens: row.get("cthread_max_new_tokens").unwrap(), + cthread_n: row.get("cthread_n").unwrap(), cthread_error: row.get("cthread_error").unwrap(), cthread_anything_new: row.get("cthread_anything_new").unwrap(), cthread_created_ts: row.get("cthread_created_ts").unwrap(), @@ -47,6 +49,7 @@ pub fn cthreads_from_rows( cthread_archived_ts: row.get("cthread_archived_ts").unwrap(), cthread_locked_by: row.get("cthread_locked_by").unwrap(), cthread_locked_ts: row.get("cthread_locked_ts").unwrap(), + ..Default::default() }); } cthreads @@ -63,13 +66,15 @@ pub fn cthread_set_lowlevel( cthread_toolset = ?4, cthread_model = ?5, cthread_temperature = ?6, - cthread_error = ?7, - cthread_anything_new = ?8, - cthread_created_ts = ?9, - cthread_updated_ts = ?10, - cthread_archived_ts = ?11, - cthread_locked_by = ?12, - cthread_locked_ts = ?13 + cthread_max_new_tokens = ?7, + cthread_n = ?8, + cthread_error = ?9, + cthread_anything_new = ?10, + cthread_created_ts = ?11, + cthread_updated_ts = ?12, + cthread_archived_ts = ?13, + cthread_locked_by = ?14, + cthread_locked_ts = ?15 WHERE cthread_id = ?1", rusqlite::params![ cthread.cthread_id, @@ -78,6 +83,8 @@ pub fn cthread_set_lowlevel( cthread.cthread_toolset, cthread.cthread_model, cthread.cthread_temperature, + cthread.cthread_max_new_tokens, + cthread.cthread_n, cthread.cthread_error, cthread.cthread_anything_new, cthread.cthread_created_ts, @@ -96,6 +103,8 @@ pub fn cthread_set_lowlevel( cthread_toolset, cthread_model, cthread_temperature, + cthread_max_new_tokens, + cthread_n, cthread_error, cthread_anything_new, cthread_created_ts, @@ -103,7 +112,7 @@ pub fn cthread_set_lowlevel( cthread_archived_ts, cthread_locked_by, cthread_locked_ts - ) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13)", + ) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14, ?15)", rusqlite::params![ cthread.cthread_id, cthread.cthread_belongs_to_chore_event_id, @@ -111,6 +120,8 @@ pub fn cthread_set_lowlevel( cthread.cthread_toolset, cthread.cthread_model, cthread.cthread_temperature, + cthread.cthread_max_new_tokens, + cthread.cthread_n, cthread.cthread_error, cthread.cthread_anything_new, cthread.cthread_created_ts, diff --git a/src/agent_db/db_init.rs b/src/agent_db/db_init.rs index 4f3a73aec..c231c047c 100644 --- a/src/agent_db/db_init.rs +++ b/src/agent_db/db_init.rs @@ -1,3 +1,4 @@ +use std::path::PathBuf; use std::sync::Arc; use tokio::sync::Notify as ANotify; use parking_lot::Mutex as ParkMutex; @@ -7,10 +8,11 @@ use crate::agent_db::db_structs::ChoreDB; fn _make_connection( - _chore_db_fn: String, + config_dir: &PathBuf, ) -> Result>, String> { + let db_path = config_dir.join("chore_db.sqlite"); let db = Connection::open_with_flags( - "experimental_db.sqlite", + db_path, rusqlite::OpenFlags::SQLITE_OPEN_READ_WRITE | rusqlite::OpenFlags::SQLITE_OPEN_CREATE | rusqlite::OpenFlags::SQLITE_OPEN_FULL_MUTEX @@ -30,10 +32,10 @@ fn _make_connection( } pub async fn chore_db_init( - chore_db_fn: String, + config_dir: &PathBuf, reset_memory: bool, ) -> Arc> { - let db = match _make_connection(chore_db_fn) { + let db = match _make_connection(config_dir) { Ok(db) => db, Err(err) => panic!("Failed to initialize chore database: {}", err), }; diff --git a/src/agent_db/db_schema_20241102.rs b/src/agent_db/db_schema_20241102.rs index d79f35604..47bfc0b5b 100644 --- a/src/agent_db/db_schema_20241102.rs +++ b/src/agent_db/db_schema_20241102.rs @@ -59,6 +59,8 @@ pub fn create_tables_20241102(conn: &Connection, reset_memory: bool) -> Result<( cthread_toolset TEXT NOT NULL, cthread_model TEXT NOT NULL, cthread_temperature REAL NOT NULL, + cthread_max_new_tokens INT NOT NULL DEFAULT 2048, + cthread_n INT NOT NULL DEFAULT 1, cthread_error TEXT NOT NULL, cthread_anything_new BOOLEAN NOT NULL, cthread_created_ts REAL NOT NULL, diff --git a/src/agent_db/db_structs.rs b/src/agent_db/db_structs.rs index ff93b8f8e..a78406924 100644 --- a/src/agent_db/db_structs.rs +++ b/src/agent_db/db_structs.rs @@ -23,7 +23,7 @@ pub struct ChoreEvent { pub chore_event_cthread_id: Option, } -#[derive(Serialize, Deserialize, Default)] +#[derive(Serialize, Deserialize)] pub struct CThread { pub cthread_id: String, pub cthread_belongs_to_chore_event_id: Option, @@ -31,6 +31,9 @@ pub struct CThread { pub cthread_toolset: String, // quick/explore/agent pub cthread_model: String, pub cthread_temperature: f64, + pub cthread_n_ctx: usize, + pub cthread_max_new_tokens: usize, + pub cthread_n: usize, pub cthread_error: String, // assign to special value "pause" to avoid auto repost to the model pub cthread_anything_new: bool, // the ⚪ pub cthread_created_ts: f64, @@ -40,6 +43,29 @@ pub struct CThread { pub cthread_locked_ts: f64, } +impl Default for CThread { + fn default() -> Self { + CThread { + cthread_id: String::new(), + cthread_belongs_to_chore_event_id: None, + cthread_title: String::new(), + cthread_toolset: String::new(), + cthread_model: String::new(), + cthread_temperature: f64::default(), + cthread_n_ctx: 65536, + cthread_max_new_tokens: 2048, + cthread_n: 1, + cthread_error: String::new(), + cthread_anything_new: false, + cthread_created_ts: f64::default(), + cthread_updated_ts: f64::default(), + cthread_archived_ts: f64::default(), + cthread_locked_by: String::new(), + cthread_locked_ts: f64::default() + } + } +} + #[derive(Serialize, Deserialize, Default)] pub struct CMessage { // primary key starts here diff --git a/src/agentic/generate_commit_message.rs b/src/agentic/generate_commit_message.rs index 74a041561..3b37ea564 100644 --- a/src/agentic/generate_commit_message.rs +++ b/src/agentic/generate_commit_message.rs @@ -227,7 +227,7 @@ pub async fn generate_commit_message_by_diff( ccx.clone(), model_name.as_str(), messages, - vec![], + Some(vec![]), None, false, Some(TEMPERATURE), diff --git a/src/agentic/generate_follow_up_message.rs b/src/agentic/generate_follow_up_message.rs index a5ff51bdc..bf39723ec 100644 --- a/src/agentic/generate_follow_up_message.rs +++ b/src/agentic/generate_follow_up_message.rs @@ -60,7 +60,7 @@ pub async fn generate_follow_up_message( ccx.clone(), model_name, messages.clone(), - vec![], + Some(vec![]), None, false, Some(0.5), diff --git a/src/autonomy.rs b/src/autonomy.rs index 0da44bb00..5609db96d 100644 --- a/src/autonomy.rs +++ b/src/autonomy.rs @@ -7,8 +7,9 @@ use crate::global_context::GlobalContext; use crate::agent_db::db_structs::{CThread, CMessage}; use crate::agent_db::chore_pubsub_sleeping_procedure; use crate::agent_db::db_cthread::CThreadSubscription; -use crate::call_validation::{ChatContent, ChatMessage}; +use crate::call_validation::{ChatContent, ChatMessage, ChatUsage}; use crate::at_commands::at_commands::AtCommandsContext; +use crate::subchat::subchat_single; const SLEEP_IF_NO_WORK_SEC: u64 = 10; const LOCK_TOO_OLD_SEC: f64 = 600.0; @@ -184,61 +185,38 @@ async fn do_the_job( let message_info_str = message_info.join(", "); tracing::info!("{} started work on {}\n[{}]", worker_name, cthread_rec.cthread_id, message_info_str); + // TODO: make something similar to the `subchat` with chat `wrapping` logic // wrap_up_depth: usize, // wrap_up_tokens_cnt: usize, // wrap_up_prompt: &str, // wrap_up_n: usize, - let tools_turned_on_by_cmdline = crate::tools::tools_description::tools_merged_and_filtered(gcx.clone()).await?; - let allow_experimental = gcx.read().await.cmdline.experimental; - let tools_desclist = crate::tools::tools_description::tool_description_list_from_yaml( - tools_turned_on_by_cmdline, - None, - allow_experimental - ).await?; - let tools = tools_desclist.into_iter().filter_map(|tool_desc| { - let good = - (cthread_rec.cthread_toolset == "explore" && !tool_desc.agentic) || - (cthread_rec.cthread_toolset == "agent"); - if good { - Some(tool_desc.into_openai_style()) - } else { - None - } - }).collect::>(); - - let max_new_tokens = 2048; - let n = 1; - let only_deterministic_messages = false; - let (mut chat_post, spad) = crate::subchat::create_chat_post_and_scratchpad( - gcx.clone(), - &cthread_rec.cthread_model, - messages.iter().collect::>(), - Some(cthread_rec.cthread_temperature as f32), - max_new_tokens, - n, - Some(tools), - None, - only_deterministic_messages, - ).await?; - let n_ctx = chat_post.max_tokens; // create_chat_post_and_scratchpad saves n_ctx here :/ - + let mut usage = ChatUsage { ..Default::default() }; let ccx: Arc> = Arc::new(AMutex::new(AtCommandsContext::new( gcx.clone(), - n_ctx, - 7, + cthread_rec.cthread_n_ctx, + 10, false, messages.clone(), cthread_rec.cthread_id.clone(), + false, ).await)); - - // XXX at commands - tracing::info!("{} start chat_interaction()", worker_name); - let chat_response_msgs = crate::subchat::chat_interaction(ccx.clone(), spad, &mut chat_post).await?; - if chat_response_msgs.len() == 0 { - return Err("Oops strange, chat_interaction() returned no choices".to_string()); - } + let log_prefix = chrono::Local::now().format("%Y%m%d-%H%M%S").to_string(); + let chat_response_msgs = subchat_single( + ccx.clone(), + cthread_rec.cthread_model.as_str(), + messages, + None, + None, + false, + Some(cthread_rec.cthread_temperature as f32), + Some(cthread_rec.cthread_max_new_tokens), + cthread_rec.cthread_n, + Some(&mut usage), + Some(cthread_rec.cthread_id.clone()), + Some(format!("{log_prefix}-chore-job")), + ).await.map_err(|e| format!("Error: {}", e))?; + let choice0: Vec = chat_response_msgs[0].clone(); - { let mut lite_locked = lite.lock(); let tx = lite_locked.transaction().map_err(|e| e.to_string())?; @@ -266,94 +244,6 @@ async fn do_the_job( tx.commit().map_err(|e| e.to_string())?; } chore_sleeping_point.notify_waiters(); - - - // let old_messages = messages.clone(); - // let results = chat_response_msgs.iter().map(|new_msgs| { - // let mut extended_msgs = old_messages.clone(); - // extended_msgs.extend(new_msgs.clone()); - // extended_msgs - // }).collect::>>(); - - // if let Some(usage_collector) = usage_collector_mb { - // crate::subchat::update_usage_from_messages(usage_collector, &results); - // } - - // { - // // keep session - // let mut step_n = 0; - // loop { - // let last_message = messages.last().unwrap(); - // // if last_message.role == "assistant" && last_message.tool_calls.is_none() { - // // don't have tool calls, exit the loop unconditionally, model thinks it has finished the work - // break; - // } - // if last_message.role == "assistant" && last_message.tool_calls.is_some() { - // // have tool calls, let's see if we need to wrap up or not - // if step_n >= wrap_up_depth { - // break; - // } - // if let Some(usage) = &last_message.usage { - // if usage.prompt_tokens + usage.completion_tokens > wrap_up_tokens_cnt { - // break; - // } - // } - // } - // messages = subchat_single( - // ccx.clone(), - // model_name, - // messages.clone(), - // tools_subset.clone(), - // Some("auto".to_string()), - // false, - // temperature, - // None, - // 1, - // Some(&mut usage_collector), - // tx_toolid_mb.clone(), - // tx_chatid_mb.clone(), - // ).await?[0].clone(); - // step_n += 1; - // } - // // result => session - // } - // let last_message = messages.last().unwrap(); - // if let Some(tool_calls) = &last_message.tool_calls { - // if !tool_calls.is_empty() { - // messages = subchat_single( - // ccx.clone(), - // model_name, - // messages, - // vec![], - // Some("none".to_string()), - // true, // <-- only runs tool calls - // temperature, - // None, - // 1, - // Some(&mut usage_collector), - // tx_toolid_mb.clone(), - // tx_chatid_mb.clone(), - // ).await?[0].clone(); - // } - // } - // messages.push(ChatMessage::new("user".to_string(), wrap_up_prompt.to_string())); - // let choices = subchat_single( - // ccx.clone(), - // model_name, - // messages, - // vec![], - // Some("none".to_string()), - // false, - // temperature, - // None, - // wrap_up_n, - // Some(&mut usage_collector), - // tx_toolid_mb.clone(), - // tx_chatid_mb.clone(), - // ).await?; - // if let Some(last_message) = messages.last_mut() { - // last_message.usage = Some(usage_collector); - // } Ok(serde_json::json!({})) } diff --git a/src/global_context.rs b/src/global_context.rs index 0629c553f..aff2429f0 100644 --- a/src/global_context.rs +++ b/src/global_context.rs @@ -351,7 +351,7 @@ pub async fn create_global_context( http_client, http_client_slowdown: Arc::new(Semaphore::new(2)), cache_dir, - config_dir, + config_dir: config_dir.clone(), caps: None, caps_reading_lock: Arc::new(AMutex::::new(false)), caps_last_error: String::new(), @@ -373,7 +373,7 @@ pub async fn create_global_context( integration_sessions: HashMap::new(), codelens_cache: Arc::new(AMutex::new(crate::http::routers::v1::code_lens::CodeLensCache::default())), docker_ssh_tunnel: Arc::new(AMutex::new(None)), - chore_db: crate::agent_db::db_init::chore_db_init("/tmp/chore.db".to_string(), cmdline.reset_memory).await, + chore_db: crate::agent_db::db_init::chore_db_init(&config_dir, cmdline.reset_memory).await, }; let gcx = Arc::new(ARwLock::new(cx)); { diff --git a/src/http/routers/v1/subchat.rs b/src/http/routers/v1/subchat.rs index 44785e4d7..e746dede5 100644 --- a/src/http/routers/v1/subchat.rs +++ b/src/http/routers/v1/subchat.rs @@ -95,7 +95,7 @@ pub async fn handle_v1_subchat_single( ccx.clone(), post.model_name.as_str(), messages, - post.tools_turn_on, + Some(post.tools_turn_on), post.tool_choice, post.only_deterministic_messages, post.temperature, diff --git a/src/restream.rs b/src/restream.rs index 543a00829..75ea6f5b6 100644 --- a/src/restream.rs +++ b/src/restream.rs @@ -175,6 +175,7 @@ pub async fn scratchpad_interaction_not_stream_json( if let Ok(det_msgs) = scratchpad.response_spontaneous() { model_says["deterministic_messages"] = json!(det_msgs); } + info!("{:?}", oai_choices); let choices = oai_choices.clone().as_array().unwrap().iter().map(|x| { match (x.get("message"), x.get("message").and_then(|msg| msg.get("content")), x.get("message").and_then(|msg| msg.get("content")).and_then(|content| content.as_str())) { (Some(_), Some(_), Some(content)) => content.to_string(), diff --git a/src/scratchpads/chat_passthrough.rs b/src/scratchpads/chat_passthrough.rs index ede7fa4a1..2ba5d82fc 100644 --- a/src/scratchpads/chat_passthrough.rs +++ b/src/scratchpads/chat_passthrough.rs @@ -150,7 +150,7 @@ impl ScratchpadAbstract for ChatPassthrough { }).collect::>(); let allow_experimental = gcx.read().await.cmdline.experimental; // and take descriptions of tools from the official source - let tool_descriptions = tool_description_list_from_yaml(at_tools, &turned_on, allow_experimental).await?; + let tool_descriptions = tool_description_list_from_yaml(at_tools, Some(&turned_on), allow_experimental).await?; Some(tool_descriptions.into_iter().map(|x|x.into_openai_style()).collect::>()) } else { None diff --git a/src/subchat.rs b/src/subchat.rs index 819a38440..a72bab195 100644 --- a/src/subchat.rs +++ b/src/subchat.rs @@ -243,7 +243,7 @@ pub async fn subchat_single( ccx: Arc>, model_name: &str, messages: Vec, - tools_subset: Vec, + tools_subset: Option>, tool_choice: Option, only_deterministic_messages: bool, temperature: Option, @@ -258,9 +258,13 @@ pub async fn subchat_single( (ccx_locked.global_context.clone(), ccx_locked.should_execute_remotely) }; let tools_turned_on_by_cmdline = tools_merged_and_filtered(gcx.clone(), false).await?; - let tools_turn_on_set: HashSet = tools_subset.iter().cloned().collect(); let tools_turned_on_by_cmdline_set: HashSet = tools_turned_on_by_cmdline.keys().cloned().collect(); - let tools_on_intersection: Vec = tools_turn_on_set.intersection(&tools_turned_on_by_cmdline_set).cloned().collect(); + let tools_on_intersection: Vec = if let Some(tools_s) = &tools_subset { + let tools_turn_on_set: HashSet = tools_s.iter().cloned().collect(); + tools_turn_on_set.intersection(&tools_turned_on_by_cmdline_set).cloned().collect() + } else { + tools_turned_on_by_cmdline_set.iter().cloned().collect() + }; let allow_experimental = gcx.read().await.cmdline.experimental; let tools_desclist = tool_description_list_from_yaml(tools_turned_on_by_cmdline, Some(&tools_on_intersection), allow_experimental).await.unwrap_or_else(|e|{ error!("Error loading compiled_in_tools: {:?}", e); @@ -363,7 +367,7 @@ pub async fn subchat( ccx.clone(), model_name, messages.clone(), - tools_subset.clone(), + Some(tools_subset.clone()), Some("auto".to_string()), false, temperature, @@ -384,7 +388,7 @@ pub async fn subchat( ccx.clone(), model_name, messages, - vec![], + Some(vec![]), Some("none".to_string()), true, // <-- only runs tool calls temperature, @@ -401,7 +405,7 @@ pub async fn subchat( ccx.clone(), model_name, messages, - vec![], + Some(vec![]), Some("none".to_string()), false, temperature, diff --git a/src/tools/tool_deep_thinking.rs b/src/tools/tool_deep_thinking.rs index af7c28915..670fe0cbb 100644 --- a/src/tools/tool_deep_thinking.rs +++ b/src/tools/tool_deep_thinking.rs @@ -93,7 +93,7 @@ impl Tool for ToolDeepThinking { ccx_subchat.clone(), subchat_params.subchat_model.as_str(), msgs, - vec![], + Some(vec![]), None, false, None, diff --git a/src/tools/tool_patch_aux/model_based_edit/model_execution.rs b/src/tools/tool_patch_aux/model_based_edit/model_execution.rs index f9007ef4f..7739a75a1 100644 --- a/src/tools/tool_patch_aux/model_based_edit/model_execution.rs +++ b/src/tools/tool_patch_aux/model_based_edit/model_execution.rs @@ -192,7 +192,7 @@ pub async fn execute_blocks_of_code_patch( ccx.clone(), model, messages.clone(), - vec![], + Some(vec![]), None, false, temperature, @@ -244,7 +244,7 @@ pub async fn execute_blocks_of_code_patch( ccx.clone(), model, messages, - vec![], + Some(vec![]), None, false, Some(0.2), @@ -309,7 +309,7 @@ pub async fn execute_whole_file_patch( ccx.clone(), model, messages.clone(), - vec![], + Some(vec![]), None, false, Some(0.1), diff --git a/tests/test13_choredb.py b/tests/test13_choredb.py index d38f3d6e9..a14784074 100644 --- a/tests/test13_choredb.py +++ b/tests/test13_choredb.py @@ -59,6 +59,8 @@ async def various_updates_generator(session, n, cthread_id): "cthread_created_ts": time.time(), "cthread_model": "gpt-4o-mini", "cthread_temperature": 0.8, + "cthread_max_new_tokens": 2048, + "cthread_n": 1, "cthread_error": ("pause" if n != 2 else ""), }) assert r.status == 200, f"oops:\n{r}"