Skip to content

Commit

Permalink
feat: use aichat as the chat backend
Browse files Browse the repository at this point in the history
  • Loading branch information
arcuru committed Mar 22, 2024
1 parent 065272d commit 490dc28
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 81 deletions.
12 changes: 0 additions & 12 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ anyhow = "1"
tokio = { version = "1.24.2", features = ["macros", "rt-multi-thread"] }
tracing-subscriber = "0.3.15"
matrix-sdk = "0.7.1"
ollama-rs = "0.1.0"
serde = { version = "1.0", features = ["derive"] }
serde_yaml = "0.9"
serde_json = "1.0"
Expand Down
47 changes: 47 additions & 0 deletions src/aichat.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
use std::process::Command;

pub struct AiChat {
binary_location: String,
}

impl Default for AiChat {
fn default() -> Self {
AiChat::new("aichat".to_string())
}
}

impl AiChat {
pub fn new(binary_location: String) -> Self {
AiChat { binary_location }
}

pub fn list_models(&self) -> Vec<String> {
// Run the binary with the `list` argument
let output = Command::new(&self.binary_location)
.arg("--list-models")
.output()
.expect("Failed to execute command");

// split each line of the output into it's own string and return
output
.stdout
.split(|c| *c == b'\n')
.map(|s| String::from_utf8(s.to_vec()).unwrap())
.filter(|s| !s.is_empty())
.collect()
}

pub fn execute(&self, model: Option<String>, prompt: String) -> Result<String, ()> {
let mut command = Command::new(&self.binary_location);
if let Some(model) = model {
command.arg("--model").arg(model);
}
command.arg("--").arg(prompt);
eprintln!("Running command: {:?}", command);

let output = command.output().expect("Failed to execute command");

// return the output as a string
String::from_utf8(output.stdout).map_err(|_| ())
}
}
159 changes: 91 additions & 68 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
mod aichat;
use aichat::AiChat;

use clap::Parser;
use lazy_static::lazy_static;
use matrix_sdk::{
Expand All @@ -13,12 +16,10 @@ use matrix_sdk::{
},
Client, Error, LoopCtrl, Room, RoomState,
};
use ollama_rs::{generation::completion::request::GenerationRequest, Ollama};
use rand::{distributions::Alphanumeric, thread_rng, Rng};
use regex::Regex;
use serde::{Deserialize, Serialize};
use std::{
collections::HashMap,
fs::File,
io::{self, Read, Write},
path::{Path, PathBuf},
Expand Down Expand Up @@ -72,31 +73,16 @@ struct HeadJackArgs {
struct Config {
homeserver_url: String,
username: String,
/// Optionally specify the password, if not set it will be asked for on cmd line
password: Option<String>,
/// Allow list of which accounts we will respond to
allow_list: Option<String>,
ollama: Option<HashMap<String, OllamaConfig>>,
}

#[derive(Debug, Deserialize, Clone)]
struct OllamaConfig {
model: String,
endpoint: Option<EndpointConfig>,
}

#[derive(Debug, Deserialize, Clone)]
struct EndpointConfig {
host: String,
port: u16,
}

lazy_static! {
static ref GLOBAL_CONFIG: Mutex<Option<Config>> = Mutex::new(None);
}

/// This is the starting point of the app. `main` is called by rust binaries to
/// run the program in this case, we use tokio (a reactor) to allow us to use
/// an `async` function run.
#[tokio::main]
async fn main() -> anyhow::Result<()> {
// set up some simple stderr logging. You can configure it by changing the env
Expand Down Expand Up @@ -428,58 +414,113 @@ async fn on_room_message(event: OriginalSyncRoomMessageEvent, room: Room) {
return;
}

// If we start with a single '!', interpret as a command
// If we start with a single '.', interpret as a command
let text = text_content.body.trim_start();
eprintln!("Received message: {}", text);
if is_command(text) {
let command = text.split_whitespace().next();
if let Some(command) = command {
// Write a match statement to match the first word in the body
match &command[1..] {
"party" => {
let content =
RoomMessageEventContent::text_plain("!party\n🎉🎊🥳 let's PARTY!! 🥳🎊🎉");
RoomMessageEventContent::text_plain("🎉🎊🥳 let's PARTY!! 🥳🎊🎉");
// send our message to the room we found the "!party" command in
room.send(content).await.unwrap();
}
"ollama" => {
// Send just this 1 message to the ollama server
let input = text_content.body.trim_start_matches("!ollama").trim();
"send" => {
// Send just this message with no context
let input = text_content.body.trim_start_matches(".send").trim();

if let Ok(result) = send_to_ollama_server(input.to_string()).await {
// Add the prefix "!response:\n" to the result
// But we need to read the context to figure out the model to use
let (_, model) = get_context(&room).await.unwrap();

if let Ok(result) = AiChat::default().execute(model, input.to_string()) {
// Add the prefix ".response:\n" to the result
// That way we can identify our own responses and ignore them for context
let result = format!("!response:\n{}", result);
let result = format!(".response:\n{}", result);
let content = RoomMessageEventContent::text_plain(result);

room.send(content).await.unwrap();
}
}
"help" => {
let content = RoomMessageEventContent::text_plain(
"!help\n\nAvailable commands:\n- !party - Start a party!\n- !ollama <input> - Send <input> to the ollama server without context\n- !print - Print the full context of the conversation\n- !help - Print this message",
[
".help",
"",
"Available commands:",
"- .party - Start a party!",
"- .send <message> - Send this message without context",
"- .print - Print the full context of the conversation",
"- .help - Print this message",
"- .list - List available models",
"- .model <model> - Select a model to use",
]
.join("\n"),
);
room.send(content).await.unwrap();
}
"print" => {
// Prints the full context back to the room
let mut context = get_context(&room).await.unwrap();
context.insert_str(0, "!context\n");
let (mut context, _) = get_context(&room).await.unwrap();
context.insert_str(0, ".context\n");
let content = RoomMessageEventContent::text_plain(context);
room.send(content).await.unwrap();
}
"model" => {
// Verify the command is fine
// Get the second word in the command
let model = text.split_whitespace().nth(1);
if let Some(model) = model {
// Verify this model is available
let models = AiChat::new("aichat".to_string()).list_models();
if models.contains(&model.to_string()) {
// Set the model
let response = format!(".model set to {}", model);
room.send(RoomMessageEventContent::text_plain(response))
.await
.unwrap();
} else {
let response = format!(
".model {} not found. Available models:\n\n{}",
model,
models.join("\n")
);
room.send(RoomMessageEventContent::text_plain(response))
.await
.unwrap();
}
} else {
room.send(RoomMessageEventContent::text_plain(
".error - must choose a model",
))
.await
.unwrap();
}
}
"list" => {
let response = format!(
".models available:\n\n{}",
AiChat::new("aichat".to_string()).list_models().join("\n")
);
room.send(RoomMessageEventContent::text_plain(response))
.await
.unwrap();
}
_ => {
eprintln!("Unknown command");
eprintln!(".error - Unknown command");
}
}
}
} else {
eprintln!("Received message: {}", text_content.body);
// If it's not a command, we should send the full context without commands to the ollama server
if let Ok(mut context) = get_context(&room).await {
if let Ok((mut context, model)) = get_context(&room).await {
let prefix = format!("Here is the full text of our ongoing conversation. Your name is {}, and your messages are prefixed by {}:. My name is {}, and my messages are prefixed by {}:. Send the next response in this conversation. Do not prefix your response with your name or any other text. Do not greet me again if you've already done so. Send only the text of your response.\n",
room.client().user_id().unwrap(), room.client().user_id().unwrap(), event.sender, event.sender);
context.insert_str(0, &prefix);
if let Ok(result) = send_to_ollama_server(context).await {
if let Ok(result) = AiChat::default().execute(model, context) {
let content = RoomMessageEventContent::text_plain(result);
room.send(content).await.unwrap();
}
Expand All @@ -488,15 +529,17 @@ async fn on_room_message(event: OriginalSyncRoomMessageEvent, room: Room) {
}

fn is_command(text: &str) -> bool {
text.starts_with('!') && !text.starts_with("!!")
text.starts_with('.') && !text.starts_with("..")
}

/// Gets the context of the current conversation
async fn get_context(room: &Room) -> Result<String, ()> {
/// Returns a model if it was ever entered
async fn get_context(room: &Room) -> Result<(String, Option<String>), ()> {
// Read all the messages in the room, place them into a single string, and print them out
let mut messages = Vec::new();

let mut options = MessagesOptions::backward();
let mut model_response = None;

while let Ok(batch) = room.messages(options).await {
for message in batch.chunk {
Expand All @@ -512,6 +555,16 @@ async fn get_context(room: &Room) -> Result<String, ()> {
continue;
};
if is_command(&text_content.body) {
// if the message is a valid model command, set the model
if text_content.body.starts_with(".model") {
let model = text_content.body.split_whitespace().nth(1);
if let Some(model) = model {
let models = AiChat::new("aichat".to_string()).list_models();
if models.contains(&model.to_string()) {
model_response = Some(model.to_string());
}
}
}
continue;
}
// Push the sender and message to the front of the string
Expand All @@ -527,38 +580,8 @@ async fn get_context(room: &Room) -> Result<String, ()> {
}
}
// Append the messages into a string with newlines in between, in reverse order
Ok(messages.into_iter().rev().collect::<String>())
}

// Send the current conversation to the configured ollama server
async fn send_to_ollama_server(input: String) -> Result<String, ()> {
let config = GLOBAL_CONFIG.lock().unwrap().clone().unwrap();
if config.ollama.is_none() {
return Err(());
}
let ollama = config.ollama.unwrap();
if ollama.is_empty() {
return Err(());
}

let server = ollama.values().next().unwrap();

// Just pull the first thing we see
let ollama_server = if let Some(endpoint) = &server.endpoint {
Ollama::new(endpoint.host.clone(), endpoint.port)
} else {
Ollama::default()
};

let prompt = input;

let res = ollama_server
.generate(GenerationRequest::new(server.model.clone(), prompt))
.await;

if let Ok(res) = res {
// Strip leading and trailing whitespace from res.response
return Ok(res.response.trim().to_string());
}
Err(())
Ok((
messages.into_iter().rev().collect::<String>(),
model_response,
))
}

0 comments on commit 490dc28

Please sign in to comment.