From 885062510da04a0e9f49a1f51f9102c964db0635 Mon Sep 17 00:00:00 2001 From: neo <1100909+neowu@users.noreply.github.com> Date: Wed, 26 Jun 2024 15:40:47 +0800 Subject: [PATCH] support tts --- Cargo.lock | 10 +++ Cargo.toml | 1 + src/command.rs | 1 + src/command/chat.rs | 19 +++--- src/command/speak.rs | 57 +++++++++++++++++ src/gcloud.rs | 12 +++- src/gcloud/{vertex.rs => gemini.rs} | 47 +++++++------- src/gcloud/{api.rs => gemini_api.rs} | 2 +- src/gcloud/synthesize.rs | 88 +++++++++++++++++++++++++++ src/gcloud/synthesize_api.rs | 41 +++++++++++++ src/{bot.rs => llm.rs} | 18 +++--- src/{bot => llm}/config.rs | 56 ++++++++--------- src/{bot => llm}/function.rs | 2 +- src/main.rs | 13 ++-- src/openai.rs | 2 +- src/openai/chatgpt.rs | 22 +++---- src/openai/{api.rs => chatgpt_api.rs} | 4 +- src/tts.rs | 17 ++++++ src/tts/config.rs | 14 +++++ src/util/exception.rs | 52 +++++++--------- src/util/http_client.rs | 2 +- src/util/json.rs | 7 ++- 22 files changed, 362 insertions(+), 125 deletions(-) create mode 100644 src/command/speak.rs rename src/gcloud/{vertex.rs => gemini.rs} (88%) rename src/gcloud/{api.rs => gemini_api.rs} (99%) create mode 100644 src/gcloud/synthesize.rs create mode 100644 src/gcloud/synthesize_api.rs rename src/{bot.rs => llm.rs} (75%) rename src/{bot => llm}/config.rs (84%) rename src/{bot => llm}/function.rs (95%) rename src/openai/{api.rs => chatgpt_api.rs} (97%) create mode 100644 src/tts.rs create mode 100644 src/tts/config.rs diff --git a/Cargo.lock b/Cargo.lock index 43d23dd..ae547da 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -989,6 +989,7 @@ dependencies = [ "tokio", "tracing", "tracing-subscriber", + "uuid", ] [[package]] @@ -1555,6 +1556,15 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" +[[package]] +name = "uuid" +version = "1.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5de17fd2f7da591098415cff336e12965a28061ddace43b59cb3c430179c9439" +dependencies = [ + "getrandom", +] + [[package]] name = "valuable" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index 9782222..eb41969 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,3 +18,4 @@ reqwest-eventsource = "0" futures = "0" rand = "0" base64 = "0" +uuid = { version = "1", features = ["v4"] } diff --git a/src/command.rs b/src/command.rs index 598efc3..1fdbb17 100644 --- a/src/command.rs +++ b/src/command.rs @@ -1,2 +1,3 @@ pub mod chat; pub mod generate_zsh_completion; +pub mod speak; diff --git a/src/command/chat.rs b/src/command/chat.rs index 505f555..93d0b77 100644 --- a/src/command/chat.rs +++ b/src/command/chat.rs @@ -1,6 +1,7 @@ use std::io; use std::io::Write; use std::path::Path; +use std::path::PathBuf; use clap::Args; use tokio::io::stdin; @@ -8,17 +9,17 @@ use tokio::io::AsyncBufReadExt; use tokio::io::BufReader; use tracing::info; -use crate::bot; -use crate::bot::ChatEvent; -use crate::bot::ChatHandler; +use crate::llm; +use crate::llm::ChatEvent; +use crate::llm::ChatHandler; use crate::util::exception::Exception; #[derive(Args)] pub struct Chat { #[arg(long, help = "conf path")] - conf: String, + conf: PathBuf, - #[arg(long, help = "bot name")] + #[arg(long, help = "model name")] name: String, } @@ -46,8 +47,8 @@ impl ChatHandler for ConsoleHandler { impl Chat { pub async fn execute(&self) -> Result<(), Exception> { - let config = bot::load(Path::new(&self.conf)).await?; - let mut bot = config.create(&self.name)?; + let config = llm::load(&self.conf).await?; + let mut model = config.create(&self.name)?; let handler = ConsoleHandler {}; let reader = BufReader::new(stdin()); @@ -60,9 +61,9 @@ impl Chat { break; } if line.starts_with("/file ") { - bot.file(Path::new(line.strip_prefix("/file ").unwrap()))?; + model.file(Path::new(line.strip_prefix("/file ").unwrap()))?; } else { - bot.chat(line, &handler).await?; + model.chat(line, &handler).await?; } } diff --git a/src/command/speak.rs b/src/command/speak.rs new file mode 100644 index 0000000..8e74f9c --- /dev/null +++ b/src/command/speak.rs @@ -0,0 +1,57 @@ +use std::path::PathBuf; + +use clap::arg; +use clap::Args; +use tokio::io::stdin; +use tokio::io::AsyncReadExt; + +use crate::gcloud::synthesize; +use crate::tts; +use crate::util::exception::Exception; + +#[derive(Args)] +pub struct Speak { + #[arg(long, help = "conf path")] + conf: PathBuf, + + #[arg(long, help = "model name")] + name: String, + + #[arg(long, help = "text")] + text: Option, + + #[arg(long, help = "stdin", default_value_t = false)] + stdin: bool, +} + +impl Speak { + pub async fn execute(&self) -> Result<(), Exception> { + if !self.stdin && self.text.is_none() { + return Err(Exception::ValidationError("must specify --stdin or --text".to_string())); + } + + let config = tts::load(&self.conf).await?; + let model = config + .models + .get(&self.name) + .ok_or_else(|| Exception::ValidationError(format!("can not find model, name={}", self.name)))?; + + let mut buffer = String::new(); + let text = if self.stdin { + stdin().read_to_string(&mut buffer).await?; + &buffer + } else { + self.text.as_ref().unwrap() + }; + + let gcloud = synthesize::GCloud { + endpoint: model.endpoint.to_string(), + project: model.params.get("project").unwrap().to_string(), + voice: model.params.get("voice").unwrap().to_string(), + }; + + gcloud.synthesize(text).await?; + + Ok(()) + } +} diff --git a/src/gcloud.rs b/src/gcloud.rs index f9a4625..7d02bd9 100644 --- a/src/gcloud.rs +++ b/src/gcloud.rs @@ -1,2 +1,10 @@ -mod api; -pub mod vertex; +use std::env; + +pub mod gemini; +mod gemini_api; +pub mod synthesize; +mod synthesize_api; + +pub fn token() -> String { + env::var("GCLOUD_AUTH_TOKEN").expect("please set GCLOUD_AUTH_TOKEN env") +} diff --git a/src/gcloud/vertex.rs b/src/gcloud/gemini.rs similarity index 88% rename from src/gcloud/vertex.rs rename to src/gcloud/gemini.rs index 499019f..82a00aa 100644 --- a/src/gcloud/vertex.rs +++ b/src/gcloud/gemini.rs @@ -1,4 +1,3 @@ -use std::env; use std::fs; use std::mem; use std::ops::Not; @@ -14,23 +13,24 @@ use tokio::sync::mpsc::Receiver; use tokio::sync::mpsc::Sender; use tracing::info; -use super::api::Content; -use super::api::FunctionCall; -use super::api::GenerationConfig; -use super::api::InlineData; -use super::api::Role; -use super::api::StreamGenerateContent; -use super::api::Tool; -use crate::bot::function::FunctionStore; -use crate::bot::ChatEvent; -use crate::bot::ChatHandler; -use crate::bot::Usage; -use crate::gcloud::api::GenerateContentResponse; +use super::gemini_api::Content; +use super::gemini_api::FunctionCall; +use super::gemini_api::GenerationConfig; +use super::gemini_api::InlineData; +use super::gemini_api::Role; +use super::gemini_api::StreamGenerateContent; +use super::gemini_api::Tool; +use super::token; +use crate::gcloud::gemini_api::GenerateContentResponse; +use crate::llm::function::FunctionStore; +use crate::llm::ChatEvent; +use crate::llm::ChatHandler; +use crate::llm::Usage; use crate::util::exception::Exception; use crate::util::http_client; use crate::util::json; -pub struct Vertex { +pub struct Gemini { url: String, messages: Rc>, system_message: Option>, @@ -40,7 +40,7 @@ pub struct Vertex { usage: Usage, } -impl Vertex { +impl Gemini { pub fn new( endpoint: String, project: String, @@ -50,7 +50,7 @@ impl Vertex { function_store: FunctionStore, ) -> Self { let url = format!("{endpoint}/v1/projects/{project}/locations/{location}/publishers/google/models/{model}:streamGenerateContent"); - Vertex { + Gemini { url, messages: Rc::new(vec![]), system_message: system_message.map(|message| Rc::new(Content::new_text(Role::Model, message))), @@ -78,7 +78,7 @@ impl Vertex { pub fn file(&mut self, path: &Path) -> Result<(), Exception> { let extension = path .extension() - .ok_or_else(|| Exception::new(format!("file must have extension, path={}", path.to_string_lossy())))? + .ok_or_else(|| Exception::ValidationError(format!("file must have extension, path={}", path.to_string_lossy())))? .to_str() .unwrap(); let content = fs::read(path)?; @@ -86,7 +86,10 @@ impl Vertex { "jpg" => Ok("image/jpeg".to_string()), "png" => Ok("image/png".to_string()), "pdf" => Ok("application/pdf".to_string()), - _ => Err(Exception::new(format!("not supported extension, path={}", path.to_string_lossy()))), + _ => Err(Exception::ValidationError(format!( + "not supported extension, path={}", + path.to_string_lossy() + ))), }?; info!( "file added, will submit with next message, mime_type={mime_type}, path={}", @@ -182,7 +185,7 @@ impl Vertex { let status = response.status(); if status != 200 { let response_text = response.text().await?; - return Err(Exception::new(format!( + return Err(Exception::ExternalError(format!( "failed to call gcloud api, status={status}, response={response_text}" ))); } @@ -210,7 +213,7 @@ async fn read_response_stream(response: Response, tx: Sender { - return Err(Exception::new(err.to_string())); + return Err(Exception::unexpected(err)); } } } @@ -221,7 +224,3 @@ fn is_valid_json(content: &str) -> bool { let result: serde_json::Result = serde_json::from_str(content); result.is_ok() } - -fn token() -> String { - env::var("GCLOUD_AUTH_TOKEN").expect("please set GCLOUD_AUTH_TOKEN env") -} diff --git a/src/gcloud/api.rs b/src/gcloud/gemini_api.rs similarity index 99% rename from src/gcloud/api.rs rename to src/gcloud/gemini_api.rs index 7809866..c233325 100644 --- a/src/gcloud/api.rs +++ b/src/gcloud/gemini_api.rs @@ -3,7 +3,7 @@ use std::rc::Rc; use serde::Deserialize; use serde::Serialize; -use crate::bot::function::Function; +use crate::llm::function::Function; #[derive(Debug, Serialize)] pub struct StreamGenerateContent { diff --git a/src/gcloud/synthesize.rs b/src/gcloud/synthesize.rs new file mode 100644 index 0000000..0dfed97 --- /dev/null +++ b/src/gcloud/synthesize.rs @@ -0,0 +1,88 @@ +use std::borrow::Cow; +use std::env::temp_dir; + +use base64::prelude::BASE64_STANDARD; +use base64::DecodeError; +use base64::Engine; +use tokio::fs; +use tokio::process::Command; +use tracing::info; +use uuid::Uuid; + +use super::token; +use crate::gcloud::synthesize_api::AudioConfig; +use crate::gcloud::synthesize_api::Input; +use crate::gcloud::synthesize_api::SynthesizeRequest; +use crate::gcloud::synthesize_api::SynthesizeResponse; +use crate::gcloud::synthesize_api::Voice; +use crate::util::exception::Exception; +use crate::util::http_client; +use crate::util::json; + +pub struct GCloud { + pub endpoint: String, + pub project: String, + pub voice: String, +} + +impl GCloud { + pub async fn synthesize(&self, text: &str) -> Result<(), Exception> { + info!("call gcloud synthesize api, endpoint={}", self.endpoint); + let request = SynthesizeRequest { + audio_config: AudioConfig { + audio_encoding: "LINEAR16".to_string(), + effects_profile_id: vec!["headphone-class-device".to_string()], + pitch: 0, + speaking_rate: 1, + }, + input: Input { text: Cow::from(text) }, + voice: Voice { + language_code: "en-US".to_string(), + name: Cow::from(&self.voice), + }, + }; + + let body = json::to_json(&request)?; + let response = http_client::http_client() + .post(&self.endpoint) + .bearer_auth(token()) + .header("x-goog-user-project", &self.project) + .header("Content-Type", "application/json") + .header("Accept", "application/json") + .body(body) + .send() + .await?; + + let status = response.status(); + if status != 200 { + let response_text = response.text().await?; + return Err(Exception::ExternalError(format!( + "failed to call gcloud api, status={status}, response={response_text}" + ))); + } + + let response_body = response.text_with_charset("utf-8").await?; + let response: SynthesizeResponse = json::from_json(&response_body)?; + let content = BASE64_STANDARD.decode(response.audio_content)?; + + play(content).await?; + + Ok(()) + } +} + +async fn play(audio: Vec) -> Result<(), Exception> { + let temp_file = temp_dir().join(format!("{}.wav", Uuid::new_v4())); + fs::write(&temp_file, &audio).await?; + info!("play audio file, file={}", temp_file.to_string_lossy()); + let mut command = Command::new("afplay").args([temp_file.to_string_lossy().to_string()]).spawn()?; + let _ = command.wait().await; + fs::remove_file(temp_file).await?; + Ok(()) +} + +impl From for Exception { + fn from(err: DecodeError) -> Self { + Exception::unexpected(err) + } +} diff --git a/src/gcloud/synthesize_api.rs b/src/gcloud/synthesize_api.rs new file mode 100644 index 0000000..941f7f3 --- /dev/null +++ b/src/gcloud/synthesize_api.rs @@ -0,0 +1,41 @@ +use std::borrow::Cow; + +use serde::Deserialize; +use serde::Serialize; + +#[derive(Debug, Serialize)] +pub struct SynthesizeRequest<'a> { + #[serde(rename = "audioConfig")] + pub audio_config: AudioConfig, + pub input: Input<'a>, + pub voice: Voice<'a>, +} + +#[derive(Debug, Serialize)] +pub struct AudioConfig { + #[serde(rename = "audioEncoding")] + pub audio_encoding: String, + #[serde(rename = "effectsProfileId")] + pub effects_profile_id: Vec, + pub pitch: i64, + #[serde(rename = "speakingRate")] + pub speaking_rate: i64, +} + +#[derive(Debug, Serialize)] +pub struct Input<'a> { + pub text: Cow<'a, str>, +} + +#[derive(Debug, Serialize)] +pub struct Voice<'a> { + #[serde(rename = "languageCode")] + pub language_code: String, + pub name: Cow<'a, str>, +} + +#[derive(Debug, Deserialize)] +pub struct SynthesizeResponse { + #[serde(rename = "audioContent")] + pub audio_content: String, +} diff --git a/src/bot.rs b/src/llm.rs similarity index 75% rename from src/bot.rs rename to src/llm.rs index b3b18ad..5a37088 100644 --- a/src/bot.rs +++ b/src/llm.rs @@ -4,8 +4,8 @@ use tokio::fs; use tracing::info; use tracing::warn; -use crate::bot::config::Config; -use crate::gcloud::vertex::Vertex; +use crate::gcloud::gemini::Gemini; +use crate::llm::config::Config; use crate::openai::chatgpt::ChatGPT; use crate::util::exception::Exception; use crate::util::json; @@ -29,26 +29,26 @@ pub struct Usage { pub response_tokens: i32, } -pub enum Bot { +pub enum Model { ChatGPT(ChatGPT), - Vertex(Vertex), + Gemini(Gemini), } -impl Bot { +impl Model { pub async fn chat(&mut self, message: String, handler: &impl ChatHandler) -> Result<(), Exception> { match self { - Bot::ChatGPT(bot) => bot.chat(message, handler).await, - Bot::Vertex(bot) => bot.chat(message, handler).await, + Model::ChatGPT(model) => model.chat(message, handler).await, + Model::Gemini(model) => model.chat(message, handler).await, } } pub fn file(&mut self, path: &Path) -> Result<(), Exception> { match self { - Bot::ChatGPT(_bot) => { + Model::ChatGPT(_model) => { warn!("ChatGPT does not support uploading file"); Ok(()) } - Bot::Vertex(bot) => bot.file(path), + Model::Gemini(model) => model.file(path), } } } diff --git a/src/bot/config.rs b/src/llm/config.rs similarity index 84% rename from src/bot/config.rs rename to src/llm/config.rs index 7f48dcf..fd7701d 100644 --- a/src/bot/config.rs +++ b/src/llm/config.rs @@ -4,36 +4,51 @@ use rand::Rng; use serde::Deserialize; use serde_json::json; -use crate::bot::function::Function; -use crate::bot::function::FunctionStore; -use crate::bot::Bot; -use crate::gcloud::vertex::Vertex; +use crate::gcloud::gemini::Gemini; +use crate::llm::function::Function; +use crate::llm::function::FunctionStore; +use crate::llm::Model; use crate::openai::chatgpt::ChatGPT; use crate::util::exception::Exception; #[derive(Deserialize, Debug)] pub struct Config { - pub bots: HashMap, + pub models: HashMap, +} + +#[derive(Deserialize, Debug)] +pub struct ModelConfig { + pub endpoint: String, + pub provider: Provider, + pub system_message: Option, + pub params: HashMap, + pub functions: Vec, +} + +#[derive(Deserialize, Debug)] +pub enum Provider { + Azure, + GCloud, } impl Config { - pub fn create(&self, name: &str) -> Result { + pub fn create(&self, name: &str) -> Result { let config = self - .bots + .models .get(name) - .ok_or_else(|| Exception::new(format!("can not find bot, name={name}")))?; + .ok_or_else(|| Exception::ValidationError(format!("can not find model, name={name}")))?; let function_store = load_function_store(config); - let bot = match config.r#type { - BotType::Azure => Bot::ChatGPT(ChatGPT::new( + let model = match config.provider { + Provider::Azure => Model::ChatGPT(ChatGPT::new( config.endpoint.to_string(), config.params.get("model").unwrap().to_string(), config.params.get("api_key").unwrap().to_string(), config.system_message.clone(), function_store, )), - BotType::GCloud => Bot::Vertex(Vertex::new( + Provider::GCloud => Model::Gemini(Gemini::new( config.endpoint.to_string(), config.params.get("project").unwrap().to_string(), config.params.get("location").unwrap().to_string(), @@ -43,26 +58,11 @@ impl Config { )), }; - Ok(bot) + Ok(model) } } -#[derive(Deserialize, Debug)] -pub struct BotConfig { - pub endpoint: String, - pub r#type: BotType, - pub system_message: Option, - pub params: HashMap, - pub functions: Vec, -} - -#[derive(Deserialize, Debug)] -pub enum BotType { - Azure, - GCloud, -} - -fn load_function_store(config: &BotConfig) -> FunctionStore { +fn load_function_store(config: &ModelConfig) -> FunctionStore { let mut function_store = FunctionStore::new(); for function in &config.functions { if let "get_random_number" = function.as_str() { diff --git a/src/bot/function.rs b/src/llm/function.rs similarity index 95% rename from src/bot/function.rs rename to src/llm/function.rs index c6f4acb..f0f7292 100644 --- a/src/bot/function.rs +++ b/src/llm/function.rs @@ -64,7 +64,7 @@ impl FunctionStore { let function = self .implementations .get(name) - .ok_or_else(|| Exception::new(format!("function not found, name={name}")))?; + .ok_or_else(|| Exception::ValidationError(format!("function not found, name={name}")))?; Ok(Arc::clone(function)) } } diff --git a/src/main.rs b/src/main.rs index 80b4632..63a3f30 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,17 +2,19 @@ use clap::Parser; use clap::Subcommand; use command::chat::Chat; use command::generate_zsh_completion::GenerateZshCompletion; +use command::speak::Speak; use util::exception::Exception; -mod bot; mod command; mod gcloud; +mod llm; mod openai; +mod tts; mod util; #[derive(Parser)] #[command(author, version)] -#[command(about = "Puppet AI")] +#[command(about = "puppet ai")] pub struct Cli { #[command(subcommand)] command: Option, @@ -21,9 +23,11 @@ pub struct Cli { #[derive(Subcommand)] #[command(arg_required_else_help(true))] pub enum Command { - #[command(about = "Chat")] + #[command(about = "chat")] Chat(Chat), - #[command(about = "Generate zsh completion")] + #[command(about = "speak")] + Speech(Speak), + #[command(about = "generate zsh completion")] GenerateZshCompletion(GenerateZshCompletion), } @@ -33,6 +37,7 @@ async fn main() -> Result<(), Exception> { let cli = Cli::parse(); match cli.command { Some(Command::Chat(command)) => command.execute().await, + Some(Command::Speech(command)) => command.execute().await, Some(Command::GenerateZshCompletion(command)) => command.execute(), None => panic!("not implemented"), } diff --git a/src/openai.rs b/src/openai.rs index e4ed5ac..1696ef7 100644 --- a/src/openai.rs +++ b/src/openai.rs @@ -1,2 +1,2 @@ -pub mod api; pub mod chatgpt; +pub mod chatgpt_api; diff --git a/src/openai/chatgpt.rs b/src/openai/chatgpt.rs index 018e0e4..38b3a22 100644 --- a/src/openai/chatgpt.rs +++ b/src/openai/chatgpt.rs @@ -10,15 +10,15 @@ use tokio::sync::mpsc::channel; use tokio::sync::mpsc::Receiver; use tokio::sync::mpsc::Sender; -use crate::bot::function::FunctionStore; -use crate::bot::ChatEvent; -use crate::bot::ChatHandler; -use crate::bot::Usage; -use crate::openai::api::ChatRequest; -use crate::openai::api::ChatRequestMessage; -use crate::openai::api::ChatResponse; -use crate::openai::api::Role; -use crate::openai::api::Tool; +use crate::llm::function::FunctionStore; +use crate::llm::ChatEvent; +use crate::llm::ChatHandler; +use crate::llm::Usage; +use crate::openai::chatgpt_api::ChatRequest; +use crate::openai::chatgpt_api::ChatRequestMessage; +use crate::openai::chatgpt_api::ChatResponse; +use crate::openai::chatgpt_api::Role; +use crate::openai::chatgpt_api::Tool; use crate::util::exception::Exception; use crate::util::http_client; use crate::util::json; @@ -149,7 +149,7 @@ impl ChatGPT { impl From for Exception { fn from(err: CannotCloneRequestError) -> Self { - Exception::new(err.to_string()) + Exception::unexpected(err) } } @@ -184,7 +184,7 @@ async fn read_event_source(mut source: EventSource, tx: Sender) - } Err(err) => { source.close(); - return Err(Exception::new(err.to_string())); + return Err(Exception::unexpected(err)); } } } diff --git a/src/openai/api.rs b/src/openai/chatgpt_api.rs similarity index 97% rename from src/openai/api.rs rename to src/openai/chatgpt_api.rs index 4074928..80fa388 100644 --- a/src/openai/api.rs +++ b/src/openai/chatgpt_api.rs @@ -4,7 +4,7 @@ use std::rc::Rc; use serde::Deserialize; use serde::Serialize; -use crate::bot::function::Function; +use crate::llm::function::Function; #[derive(Debug, Serialize)] pub struct ChatRequest { @@ -93,6 +93,7 @@ pub enum Role { Tool, } +#[allow(dead_code)] #[derive(Debug, Deserialize)] pub struct ChatResponse { pub id: String, @@ -102,6 +103,7 @@ pub struct ChatResponse { pub choices: Vec, } +#[allow(dead_code)] #[derive(Debug, Deserialize)] pub struct ChatCompletionChoice { pub index: i64, diff --git a/src/tts.rs b/src/tts.rs new file mode 100644 index 0000000..e3643ad --- /dev/null +++ b/src/tts.rs @@ -0,0 +1,17 @@ +use std::path::Path; + +use config::Config; +use tokio::fs; +use tracing::info; + +use crate::util::exception::Exception; +use crate::util::json; + +pub mod config; + +pub async fn load(path: &Path) -> Result { + info!("load config, path={}", path.to_string_lossy()); + let content = fs::read_to_string(path).await?; + let config: Config = json::from_json(&content)?; + Ok(config) +} diff --git a/src/tts/config.rs b/src/tts/config.rs new file mode 100644 index 0000000..0ebe85a --- /dev/null +++ b/src/tts/config.rs @@ -0,0 +1,14 @@ +use std::collections::HashMap; + +use serde::Deserialize; + +#[derive(Deserialize, Debug)] +pub struct Config { + pub models: HashMap, +} + +#[derive(Deserialize, Debug)] +pub struct ModelConfig { + pub endpoint: String, + pub params: HashMap, +} diff --git a/src/util/exception.rs b/src/util/exception.rs index 882c8ce..b48db21 100644 --- a/src/util/exception.rs +++ b/src/util/exception.rs @@ -6,35 +6,29 @@ use std::io; use tokio::sync::mpsc::error::SendError; use tokio::task::JoinError; -pub struct Exception { - message: String, - context: Option, - trace: String, +pub enum Exception { + ValidationError(String), + ExternalError(String), + Unexpected { message: String, trace: String }, } impl Exception { - pub fn new(message: String) -> Self { - Exception::create(message, None) - } - - pub fn from(error: T) -> Self + pub fn unexpected(error: T) -> Self where - T: Error + 'static, + T: std::error::Error, { - Exception::create(error.to_string(), None) + Self::Unexpected { + message: error.to_string(), + trace: Backtrace::force_capture().to_string(), + } } - pub fn from_with_context(error: T, context: String) -> Self + pub fn unexpected_with_context(error: T, context: &str) -> Self where - T: Error + 'static, + T: Error, { - Exception::create(error.to_string(), Some(context)) - } - - fn create(message: String, context: Option) -> Self { - Self { - message, - context, + Self::Unexpected { + message: format!("error={}, context={}", error, context), trace: Backtrace::force_capture().to_string(), } } @@ -42,13 +36,11 @@ impl Exception { impl fmt::Debug for Exception { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!( - f, - "Exception: {}\nContext: {}\nTrace:\n{}", - self.message, - self.context.as_ref().unwrap_or(&"".to_string()), - self.trace - ) + match self { + Exception::ValidationError(message) => write!(f, "{}", message), + Exception::ExternalError(message) => write!(f, "{}", message), + Exception::Unexpected { message, trace } => write!(f, "{}\ntrace:\n{}", message, trace), + } } } @@ -62,18 +54,18 @@ impl Error for Exception {} impl From for Exception { fn from(err: io::Error) -> Self { - Exception::new(err.to_string()) + Exception::unexpected(err) } } impl From for Exception { fn from(err: JoinError) -> Self { - Exception::new(err.to_string()) + Exception::unexpected(err) } } impl From> for Exception { fn from(err: SendError) -> Self { - Exception::new(err.to_string()) + Exception::unexpected(err) } } diff --git a/src/util/http_client.rs b/src/util/http_client.rs index 4edbac4..d17d109 100644 --- a/src/util/http_client.rs +++ b/src/util/http_client.rs @@ -10,6 +10,6 @@ pub fn http_client() -> &'static reqwest::Client { impl From for Exception { fn from(err: reqwest::Error) -> Self { let url = err.url().map_or("", |url| url.as_str()).to_string(); - Exception::from_with_context(err, format!("url={url}")) + Exception::unexpected_with_context(err, &format!("url={url}")) } } diff --git a/src/util/json.rs b/src/util/json.rs index 04a6323..a82dfc6 100644 --- a/src/util/json.rs +++ b/src/util/json.rs @@ -1,19 +1,20 @@ use std::fmt; -use crate::util::exception::Exception; use serde::de; use serde::Serialize; +use crate::util::exception::Exception; + pub fn from_json<'a, T>(json: &'a str) -> Result where T: de::Deserialize<'a>, { - serde_json::from_str(json).map_err(|err| Exception::from_with_context(err, format!("json={json}"))) + serde_json::from_str(json).map_err(|err| Exception::unexpected_with_context(err, &format!("json={json}"))) } pub fn to_json(object: &T) -> Result where T: Serialize + fmt::Debug, { - serde_json::to_string(object).map_err(|err| Exception::from_with_context(err, format!("object={object:?}"))) + serde_json::to_string(object).map_err(|err| Exception::unexpected_with_context(err, &format!("object={object:?}"))) }