Skip to content

Commit

Permalink
Refactor subchat tool handling to use Option<Vec<String>>
Browse files Browse the repository at this point in the history
- Update tool subset handling in subchat functions to use `Option<Vec<String>>`.
- Change vector initializations from `vec![]` to `Some(vec![])` where applicable.
- Modify database schema to include new fields `cthread_max_new_tokens` and `cthread_n` with default values.
- Use `PathBuf` for database path configuration in `_make_connection`.
- Initialize `config_dir` with a clone to ensure proper usage context.
- Enhance database thread structure with default implementations for new fields.
- Simplify subchat interaction in `autonomy.rs`, removing unused logic.
  • Loading branch information
JegernOUTT committed Dec 20, 2024
1 parent 99a7721 commit f7c4aa3
Show file tree
Hide file tree
Showing 15 changed files with 100 additions and 162 deletions.
27 changes: 19 additions & 8 deletions src/agent_db/db_cthread.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,16 @@ pub fn cthreads_from_rows(
cthread_toolset: row.get("cthread_toolset").unwrap(),
cthread_model: row.get("cthread_model").unwrap(),
cthread_temperature: row.get("cthread_temperature").unwrap(),
cthread_max_new_tokens: row.get("cthread_max_new_tokens").unwrap(),
cthread_n: row.get("cthread_n").unwrap(),
cthread_error: row.get("cthread_error").unwrap(),
cthread_anything_new: row.get("cthread_anything_new").unwrap(),
cthread_created_ts: row.get("cthread_created_ts").unwrap(),
cthread_updated_ts: row.get("cthread_updated_ts").unwrap(),
cthread_archived_ts: row.get("cthread_archived_ts").unwrap(),
cthread_locked_by: row.get("cthread_locked_by").unwrap(),
cthread_locked_ts: row.get("cthread_locked_ts").unwrap(),
..Default::default()
});
}
cthreads
Expand All @@ -63,13 +66,15 @@ pub fn cthread_set_lowlevel(
cthread_toolset = ?4,
cthread_model = ?5,
cthread_temperature = ?6,
cthread_error = ?7,
cthread_anything_new = ?8,
cthread_created_ts = ?9,
cthread_updated_ts = ?10,
cthread_archived_ts = ?11,
cthread_locked_by = ?12,
cthread_locked_ts = ?13
cthread_max_new_tokens = ?7,
cthread_n = ?8,
cthread_error = ?9,
cthread_anything_new = ?10,
cthread_created_ts = ?11,
cthread_updated_ts = ?12,
cthread_archived_ts = ?13,
cthread_locked_by = ?14,
cthread_locked_ts = ?15
WHERE cthread_id = ?1",
rusqlite::params![
cthread.cthread_id,
Expand All @@ -78,6 +83,8 @@ pub fn cthread_set_lowlevel(
cthread.cthread_toolset,
cthread.cthread_model,
cthread.cthread_temperature,
cthread.cthread_max_new_tokens,
cthread.cthread_n,
cthread.cthread_error,
cthread.cthread_anything_new,
cthread.cthread_created_ts,
Expand All @@ -96,21 +103,25 @@ pub fn cthread_set_lowlevel(
cthread_toolset,
cthread_model,
cthread_temperature,
cthread_max_new_tokens,
cthread_n,
cthread_error,
cthread_anything_new,
cthread_created_ts,
cthread_updated_ts,
cthread_archived_ts,
cthread_locked_by,
cthread_locked_ts
) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13)",
) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14, ?15)",
rusqlite::params![
cthread.cthread_id,
cthread.cthread_belongs_to_chore_event_id,
cthread.cthread_title,
cthread.cthread_toolset,
cthread.cthread_model,
cthread.cthread_temperature,
cthread.cthread_max_new_tokens,
cthread.cthread_n,
cthread.cthread_error,
cthread.cthread_anything_new,
cthread.cthread_created_ts,
Expand Down
10 changes: 6 additions & 4 deletions src/agent_db/db_init.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::path::PathBuf;
use std::sync::Arc;
use tokio::sync::Notify as ANotify;
use parking_lot::Mutex as ParkMutex;
Expand All @@ -7,10 +8,11 @@ use crate::agent_db::db_structs::ChoreDB;


fn _make_connection(
_chore_db_fn: String,
config_dir: &PathBuf,
) -> Result<Arc<ParkMutex<ChoreDB>>, String> {
let db_path = config_dir.join("chore_db.sqlite");
let db = Connection::open_with_flags(
"experimental_db.sqlite",
db_path,
rusqlite::OpenFlags::SQLITE_OPEN_READ_WRITE
| rusqlite::OpenFlags::SQLITE_OPEN_CREATE
| rusqlite::OpenFlags::SQLITE_OPEN_FULL_MUTEX
Expand All @@ -30,10 +32,10 @@ fn _make_connection(
}

pub async fn chore_db_init(
chore_db_fn: String,
config_dir: &PathBuf,
reset_memory: bool,
) -> Arc<ParkMutex<ChoreDB>> {
let db = match _make_connection(chore_db_fn) {
let db = match _make_connection(config_dir) {
Ok(db) => db,
Err(err) => panic!("Failed to initialize chore database: {}", err),
};
Expand Down
2 changes: 2 additions & 0 deletions src/agent_db/db_schema_20241102.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ pub fn create_tables_20241102(conn: &Connection, reset_memory: bool) -> Result<(
cthread_toolset TEXT NOT NULL,
cthread_model TEXT NOT NULL,
cthread_temperature REAL NOT NULL,
cthread_max_new_tokens INT NOT NULL DEFAULT 2048,
cthread_n INT NOT NULL DEFAULT 1,
cthread_error TEXT NOT NULL,
cthread_anything_new BOOLEAN NOT NULL,
cthread_created_ts REAL NOT NULL,
Expand Down
28 changes: 27 additions & 1 deletion src/agent_db/db_structs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,17 @@ pub struct ChoreEvent {
pub chore_event_cthread_id: Option<String>,
}

#[derive(Serialize, Deserialize, Default)]
#[derive(Serialize, Deserialize)]
pub struct CThread {
pub cthread_id: String,
pub cthread_belongs_to_chore_event_id: Option<String>,
pub cthread_title: String,
pub cthread_toolset: String, // quick/explore/agent
pub cthread_model: String,
pub cthread_temperature: f64,
pub cthread_n_ctx: usize,
pub cthread_max_new_tokens: usize,
pub cthread_n: usize,
pub cthread_error: String, // assign to special value "pause" to avoid auto repost to the model
pub cthread_anything_new: bool, // the ⚪
pub cthread_created_ts: f64,
Expand All @@ -40,6 +43,29 @@ pub struct CThread {
pub cthread_locked_ts: f64,
}

impl Default for CThread {
fn default() -> Self {
CThread {
cthread_id: String::new(),
cthread_belongs_to_chore_event_id: None,
cthread_title: String::new(),
cthread_toolset: String::new(),
cthread_model: String::new(),
cthread_temperature: f64::default(),
cthread_n_ctx: 65536,
cthread_max_new_tokens: 2048,
cthread_n: 1,
cthread_error: String::new(),
cthread_anything_new: false,
cthread_created_ts: f64::default(),
cthread_updated_ts: f64::default(),
cthread_archived_ts: f64::default(),
cthread_locked_by: String::new(),
cthread_locked_ts: f64::default()
}
}
}

#[derive(Serialize, Deserialize, Default)]
pub struct CMessage {
// primary key starts here
Expand Down
2 changes: 1 addition & 1 deletion src/agentic/generate_commit_message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ pub async fn generate_commit_message_by_diff(
ccx.clone(),
model_name.as_str(),
messages,
vec![],
Some(vec![]),
None,
false,
Some(TEMPERATURE),
Expand Down
2 changes: 1 addition & 1 deletion src/agentic/generate_follow_up_message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ pub async fn generate_follow_up_message(
ccx.clone(),
model_name,
messages.clone(),
vec![],
Some(vec![]),
None,
false,
Some(0.5),
Expand Down
156 changes: 23 additions & 133 deletions src/autonomy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@ use crate::global_context::GlobalContext;
use crate::agent_db::db_structs::{CThread, CMessage};
use crate::agent_db::chore_pubsub_sleeping_procedure;
use crate::agent_db::db_cthread::CThreadSubscription;
use crate::call_validation::{ChatContent, ChatMessage};
use crate::call_validation::{ChatContent, ChatMessage, ChatUsage};
use crate::at_commands::at_commands::AtCommandsContext;
use crate::subchat::subchat_single;

const SLEEP_IF_NO_WORK_SEC: u64 = 10;
const LOCK_TOO_OLD_SEC: f64 = 600.0;
Expand Down Expand Up @@ -184,61 +185,38 @@ async fn do_the_job(
let message_info_str = message_info.join(", ");
tracing::info!("{} started work on {}\n[{}]", worker_name, cthread_rec.cthread_id, message_info_str);

// TODO: make something similar to the `subchat` with chat `wrapping` logic
// wrap_up_depth: usize,
// wrap_up_tokens_cnt: usize,
// wrap_up_prompt: &str,
// wrap_up_n: usize,
let tools_turned_on_by_cmdline = crate::tools::tools_description::tools_merged_and_filtered(gcx.clone()).await?;
let allow_experimental = gcx.read().await.cmdline.experimental;
let tools_desclist = crate::tools::tools_description::tool_description_list_from_yaml(
tools_turned_on_by_cmdline,
None,
allow_experimental
).await?;
let tools = tools_desclist.into_iter().filter_map(|tool_desc| {
let good =
(cthread_rec.cthread_toolset == "explore" && !tool_desc.agentic) ||
(cthread_rec.cthread_toolset == "agent");
if good {
Some(tool_desc.into_openai_style())
} else {
None
}
}).collect::<Vec<_>>();

let max_new_tokens = 2048;
let n = 1;
let only_deterministic_messages = false;
let (mut chat_post, spad) = crate::subchat::create_chat_post_and_scratchpad(
gcx.clone(),
&cthread_rec.cthread_model,
messages.iter().collect::<Vec<_>>(),
Some(cthread_rec.cthread_temperature as f32),
max_new_tokens,
n,
Some(tools),
None,
only_deterministic_messages,
).await?;
let n_ctx = chat_post.max_tokens; // create_chat_post_and_scratchpad saves n_ctx here :/

let mut usage = ChatUsage { ..Default::default() };
let ccx: Arc<AMutex<AtCommandsContext>> = Arc::new(AMutex::new(AtCommandsContext::new(
gcx.clone(),
n_ctx,
7,
cthread_rec.cthread_n_ctx,
10,
false,
messages.clone(),
cthread_rec.cthread_id.clone(),
false,
).await));

// XXX at commands
tracing::info!("{} start chat_interaction()", worker_name);
let chat_response_msgs = crate::subchat::chat_interaction(ccx.clone(), spad, &mut chat_post).await?;
if chat_response_msgs.len() == 0 {
return Err("Oops strange, chat_interaction() returned no choices".to_string());
}
let log_prefix = chrono::Local::now().format("%Y%m%d-%H%M%S").to_string();
let chat_response_msgs = subchat_single(
ccx.clone(),
cthread_rec.cthread_model.as_str(),
messages,
None,
None,
false,
Some(cthread_rec.cthread_temperature as f32),
Some(cthread_rec.cthread_max_new_tokens),
cthread_rec.cthread_n,
Some(&mut usage),
Some(cthread_rec.cthread_id.clone()),
Some(format!("{log_prefix}-chore-job")),
).await.map_err(|e| format!("Error: {}", e))?;

let choice0: Vec<ChatMessage> = chat_response_msgs[0].clone();

{
let mut lite_locked = lite.lock();
let tx = lite_locked.transaction().map_err(|e| e.to_string())?;
Expand Down Expand Up @@ -266,94 +244,6 @@ async fn do_the_job(
tx.commit().map_err(|e| e.to_string())?;
}
chore_sleeping_point.notify_waiters();


// let old_messages = messages.clone();
// let results = chat_response_msgs.iter().map(|new_msgs| {
// let mut extended_msgs = old_messages.clone();
// extended_msgs.extend(new_msgs.clone());
// extended_msgs
// }).collect::<Vec<Vec<ChatMessage>>>();

// if let Some(usage_collector) = usage_collector_mb {
// crate::subchat::update_usage_from_messages(usage_collector, &results);
// }

// {
// // keep session
// let mut step_n = 0;
// loop {
// let last_message = messages.last().unwrap();
// // if last_message.role == "assistant" && last_message.tool_calls.is_none() {
// // don't have tool calls, exit the loop unconditionally, model thinks it has finished the work
// break;
// }
// if last_message.role == "assistant" && last_message.tool_calls.is_some() {
// // have tool calls, let's see if we need to wrap up or not
// if step_n >= wrap_up_depth {
// break;
// }
// if let Some(usage) = &last_message.usage {
// if usage.prompt_tokens + usage.completion_tokens > wrap_up_tokens_cnt {
// break;
// }
// }
// }
// messages = subchat_single(
// ccx.clone(),
// model_name,
// messages.clone(),
// tools_subset.clone(),
// Some("auto".to_string()),
// false,
// temperature,
// None,
// 1,
// Some(&mut usage_collector),
// tx_toolid_mb.clone(),
// tx_chatid_mb.clone(),
// ).await?[0].clone();
// step_n += 1;
// }
// // result => session
// }
// let last_message = messages.last().unwrap();
// if let Some(tool_calls) = &last_message.tool_calls {
// if !tool_calls.is_empty() {
// messages = subchat_single(
// ccx.clone(),
// model_name,
// messages,
// vec![],
// Some("none".to_string()),
// true, // <-- only runs tool calls
// temperature,
// None,
// 1,
// Some(&mut usage_collector),
// tx_toolid_mb.clone(),
// tx_chatid_mb.clone(),
// ).await?[0].clone();
// }
// }
// messages.push(ChatMessage::new("user".to_string(), wrap_up_prompt.to_string()));
// let choices = subchat_single(
// ccx.clone(),
// model_name,
// messages,
// vec![],
// Some("none".to_string()),
// false,
// temperature,
// None,
// wrap_up_n,
// Some(&mut usage_collector),
// tx_toolid_mb.clone(),
// tx_chatid_mb.clone(),
// ).await?;
// if let Some(last_message) = messages.last_mut() {
// last_message.usage = Some(usage_collector);
// }
Ok(serde_json::json!({}))
}

Expand Down
4 changes: 2 additions & 2 deletions src/global_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ pub async fn create_global_context(
http_client,
http_client_slowdown: Arc::new(Semaphore::new(2)),
cache_dir,
config_dir,
config_dir: config_dir.clone(),
caps: None,
caps_reading_lock: Arc::new(AMutex::<bool>::new(false)),
caps_last_error: String::new(),
Expand All @@ -373,7 +373,7 @@ pub async fn create_global_context(
integration_sessions: HashMap::new(),
codelens_cache: Arc::new(AMutex::new(crate::http::routers::v1::code_lens::CodeLensCache::default())),
docker_ssh_tunnel: Arc::new(AMutex::new(None)),
chore_db: crate::agent_db::db_init::chore_db_init("/tmp/chore.db".to_string(), cmdline.reset_memory).await,
chore_db: crate::agent_db::db_init::chore_db_init(&config_dir, cmdline.reset_memory).await,
};
let gcx = Arc::new(ARwLock::new(cx));
{
Expand Down
2 changes: 1 addition & 1 deletion src/http/routers/v1/subchat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ pub async fn handle_v1_subchat_single(
ccx.clone(),
post.model_name.as_str(),
messages,
post.tools_turn_on,
Some(post.tools_turn_on),
post.tool_choice,
post.only_deterministic_messages,
post.temperature,
Expand Down
Loading

0 comments on commit f7c4aa3

Please sign in to comment.