diff --git a/Cargo.lock b/Cargo.lock index adb800c..ee01d5f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1041,7 +1041,6 @@ dependencies = [ "dirs", "lazy_static", "matrix-sdk", - "ollama-rs", "rand", "regex", "serde", @@ -1758,17 +1757,6 @@ dependencies = [ "memchr", ] -[[package]] -name = "ollama-rs" -version = "0.1.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ee48e21359d1a897e180d612319de5e348a9247629aefc1f73147d78b4b859c" -dependencies = [ - "reqwest", - "serde", - "serde_json", -] - [[package]] name = "once_cell" version = "1.19.0" diff --git a/Cargo.toml b/Cargo.toml index 815d54a..3ea93e9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,7 +18,6 @@ anyhow = "1" tokio = { version = "1.24.2", features = ["macros", "rt-multi-thread"] } tracing-subscriber = "0.3.15" matrix-sdk = "0.7.1" -ollama-rs = "0.1.0" serde = { version = "1.0", features = ["derive"] } serde_yaml = "0.9" serde_json = "1.0" diff --git a/src/aichat.rs b/src/aichat.rs new file mode 100644 index 0000000..3ff972e --- /dev/null +++ b/src/aichat.rs @@ -0,0 +1,47 @@ +use std::process::Command; + +pub struct AiChat { + binary_location: String, +} + +impl Default for AiChat { + fn default() -> Self { + AiChat::new("aichat".to_string()) + } +} + +impl AiChat { + pub fn new(binary_location: String) -> Self { + AiChat { binary_location } + } + + pub fn list_models(&self) -> Vec { + // Run the binary with the `list` argument + let output = Command::new(&self.binary_location) + .arg("--list-models") + .output() + .expect("Failed to execute command"); + + // split each line of the output into it's own string and return + output + .stdout + .split(|c| *c == b'\n') + .map(|s| String::from_utf8(s.to_vec()).unwrap()) + .filter(|s| !s.is_empty()) + .collect() + } + + pub fn execute(&self, model: Option, prompt: String) -> Result { + let mut command = Command::new(&self.binary_location); + if let Some(model) = model { + command.arg("--model").arg(model); + } + command.arg("--").arg(prompt); + eprintln!("Running command: {:?}", command); + + let output = command.output().expect("Failed to execute command"); + + // return the output as a string + String::from_utf8(output.stdout).map_err(|_| ()) + } +} diff --git a/src/main.rs b/src/main.rs index 3de5983..fa7379f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,3 +1,6 @@ +mod aichat; +use aichat::AiChat; + use clap::Parser; use lazy_static::lazy_static; use matrix_sdk::{ @@ -13,12 +16,10 @@ use matrix_sdk::{ }, Client, Error, LoopCtrl, Room, RoomState, }; -use ollama_rs::{generation::completion::request::GenerationRequest, Ollama}; use rand::{distributions::Alphanumeric, thread_rng, Rng}; use regex::Regex; use serde::{Deserialize, Serialize}; use std::{ - collections::HashMap, fs::File, io::{self, Read, Write}, path::{Path, PathBuf}, @@ -72,31 +73,16 @@ struct HeadJackArgs { struct Config { homeserver_url: String, username: String, + /// Optionally specify the password, if not set it will be asked for on cmd line password: Option, /// Allow list of which accounts we will respond to allow_list: Option, - ollama: Option>, -} - -#[derive(Debug, Deserialize, Clone)] -struct OllamaConfig { - model: String, - endpoint: Option, -} - -#[derive(Debug, Deserialize, Clone)] -struct EndpointConfig { - host: String, - port: u16, } lazy_static! { static ref GLOBAL_CONFIG: Mutex> = Mutex::new(None); } -/// This is the starting point of the app. `main` is called by rust binaries to -/// run the program in this case, we use tokio (a reactor) to allow us to use -/// an `async` function run. #[tokio::main] async fn main() -> anyhow::Result<()> { // set up some simple stderr logging. You can configure it by changing the env @@ -428,8 +414,9 @@ async fn on_room_message(event: OriginalSyncRoomMessageEvent, room: Room) { return; } - // If we start with a single '!', interpret as a command + // If we start with a single '.', interpret as a command let text = text_content.body.trim_start(); + eprintln!("Received message: {}", text); if is_command(text) { let command = text.split_whitespace().next(); if let Some(command) = command { @@ -437,18 +424,21 @@ async fn on_room_message(event: OriginalSyncRoomMessageEvent, room: Room) { match &command[1..] { "party" => { let content = - RoomMessageEventContent::text_plain("!party\nšŸŽ‰šŸŽŠšŸ„³ let's PARTY!! šŸ„³šŸŽŠšŸŽ‰"); + RoomMessageEventContent::text_plain("šŸŽ‰šŸŽŠšŸ„³ let's PARTY!! šŸ„³šŸŽŠšŸŽ‰"); // send our message to the room we found the "!party" command in room.send(content).await.unwrap(); } - "ollama" => { - // Send just this 1 message to the ollama server - let input = text_content.body.trim_start_matches("!ollama").trim(); + "send" => { + // Send just this message with no context + let input = text_content.body.trim_start_matches(".send").trim(); - if let Ok(result) = send_to_ollama_server(input.to_string()).await { - // Add the prefix "!response:\n" to the result + // But we need to read the context to figure out the model to use + let (_, model) = get_context(&room).await.unwrap(); + + if let Ok(result) = AiChat::default().execute(model, input.to_string()) { + // Add the prefix ".response:\n" to the result // That way we can identify our own responses and ignore them for context - let result = format!("!response:\n{}", result); + let result = format!(".response:\n{}", result); let content = RoomMessageEventContent::text_plain(result); room.send(content).await.unwrap(); @@ -456,30 +446,81 @@ async fn on_room_message(event: OriginalSyncRoomMessageEvent, room: Room) { } "help" => { let content = RoomMessageEventContent::text_plain( - "!help\n\nAvailable commands:\n- !party - Start a party!\n- !ollama - Send to the ollama server without context\n- !print - Print the full context of the conversation\n- !help - Print this message", + [ + ".help", + "", + "Available commands:", + "- .party - Start a party!", + "- .send - Send this message without context", + "- .print - Print the full context of the conversation", + "- .help - Print this message", + "- .list - List available models", + "- .model - Select a model to use", + ] + .join("\n"), ); room.send(content).await.unwrap(); } "print" => { // Prints the full context back to the room - let mut context = get_context(&room).await.unwrap(); - context.insert_str(0, "!context\n"); + let (mut context, _) = get_context(&room).await.unwrap(); + context.insert_str(0, ".context\n"); let content = RoomMessageEventContent::text_plain(context); room.send(content).await.unwrap(); } + "model" => { + // Verify the command is fine + // Get the second word in the command + let model = text.split_whitespace().nth(1); + if let Some(model) = model { + // Verify this model is available + let models = AiChat::new("aichat".to_string()).list_models(); + if models.contains(&model.to_string()) { + // Set the model + let response = format!(".model set to {}", model); + room.send(RoomMessageEventContent::text_plain(response)) + .await + .unwrap(); + } else { + let response = format!( + ".model {} not found. Available models:\n\n{}", + model, + models.join("\n") + ); + room.send(RoomMessageEventContent::text_plain(response)) + .await + .unwrap(); + } + } else { + room.send(RoomMessageEventContent::text_plain( + ".error - must choose a model", + )) + .await + .unwrap(); + } + } + "list" => { + let response = format!( + ".models available:\n\n{}", + AiChat::new("aichat".to_string()).list_models().join("\n") + ); + room.send(RoomMessageEventContent::text_plain(response)) + .await + .unwrap(); + } _ => { - eprintln!("Unknown command"); + eprintln!(".error - Unknown command"); } } } } else { eprintln!("Received message: {}", text_content.body); // If it's not a command, we should send the full context without commands to the ollama server - if let Ok(mut context) = get_context(&room).await { + if let Ok((mut context, model)) = get_context(&room).await { let prefix = format!("Here is the full text of our ongoing conversation. Your name is {}, and your messages are prefixed by {}:. My name is {}, and my messages are prefixed by {}:. Send the next response in this conversation. Do not prefix your response with your name or any other text. Do not greet me again if you've already done so. Send only the text of your response.\n", room.client().user_id().unwrap(), room.client().user_id().unwrap(), event.sender, event.sender); context.insert_str(0, &prefix); - if let Ok(result) = send_to_ollama_server(context).await { + if let Ok(result) = AiChat::default().execute(model, context) { let content = RoomMessageEventContent::text_plain(result); room.send(content).await.unwrap(); } @@ -488,15 +529,17 @@ async fn on_room_message(event: OriginalSyncRoomMessageEvent, room: Room) { } fn is_command(text: &str) -> bool { - text.starts_with('!') && !text.starts_with("!!") + text.starts_with('.') && !text.starts_with("..") } /// Gets the context of the current conversation -async fn get_context(room: &Room) -> Result { +/// Returns a model if it was ever entered +async fn get_context(room: &Room) -> Result<(String, Option), ()> { // Read all the messages in the room, place them into a single string, and print them out let mut messages = Vec::new(); let mut options = MessagesOptions::backward(); + let mut model_response = None; while let Ok(batch) = room.messages(options).await { for message in batch.chunk { @@ -512,6 +555,16 @@ async fn get_context(room: &Room) -> Result { continue; }; if is_command(&text_content.body) { + // if the message is a valid model command, set the model + if text_content.body.starts_with(".model") { + let model = text_content.body.split_whitespace().nth(1); + if let Some(model) = model { + let models = AiChat::new("aichat".to_string()).list_models(); + if models.contains(&model.to_string()) { + model_response = Some(model.to_string()); + } + } + } continue; } // Push the sender and message to the front of the string @@ -527,38 +580,8 @@ async fn get_context(room: &Room) -> Result { } } // Append the messages into a string with newlines in between, in reverse order - Ok(messages.into_iter().rev().collect::()) -} - -// Send the current conversation to the configured ollama server -async fn send_to_ollama_server(input: String) -> Result { - let config = GLOBAL_CONFIG.lock().unwrap().clone().unwrap(); - if config.ollama.is_none() { - return Err(()); - } - let ollama = config.ollama.unwrap(); - if ollama.is_empty() { - return Err(()); - } - - let server = ollama.values().next().unwrap(); - - // Just pull the first thing we see - let ollama_server = if let Some(endpoint) = &server.endpoint { - Ollama::new(endpoint.host.clone(), endpoint.port) - } else { - Ollama::default() - }; - - let prompt = input; - - let res = ollama_server - .generate(GenerationRequest::new(server.model.clone(), prompt)) - .await; - - if let Ok(res) = res { - // Strip leading and trailing whitespace from res.response - return Ok(res.response.trim().to_string()); - } - Err(()) + Ok(( + messages.into_iter().rev().collect::(), + model_response, + )) }