From c8a51df1d97627d0b080be603443801888cbd534 Mon Sep 17 00:00:00 2001 From: neo <1100909+neowu@users.noreply.github.com> Date: Thu, 27 Jun 2024 16:28:51 +0800 Subject: [PATCH] support azure tts --- src/{openai.rs => azure.rs} | 1 + src/{openai => azure}/chatgpt.rs | 14 +++--- src/{openai => azure}/chatgpt_api.rs | 0 src/azure/tts.rs | 41 ++++++++++++++++++ src/command/speak.rs | 31 ++++++++------ src/gcloud.rs | 3 +- src/gcloud/synthesize_api.rs | 41 ------------------ src/gcloud/{synthesize.rs => tts.rs} | 64 ++++++++++++++++++---------- src/llm.rs | 2 +- src/llm/config.rs | 9 +--- src/main.rs | 3 +- src/provider.rs | 9 ++++ src/tts.rs | 56 ++++++++++++++++++++++-- src/tts/config.rs | 14 ------ 14 files changed, 176 insertions(+), 112 deletions(-) rename src/{openai.rs => azure.rs} (74%) rename src/{openai => azure}/chatgpt.rs (96%) rename src/{openai => azure}/chatgpt_api.rs (100%) create mode 100644 src/azure/tts.rs delete mode 100644 src/gcloud/synthesize_api.rs rename src/gcloud/{synthesize.rs => tts.rs} (66%) create mode 100644 src/provider.rs delete mode 100644 src/tts/config.rs diff --git a/src/openai.rs b/src/azure.rs similarity index 74% rename from src/openai.rs rename to src/azure.rs index 1696ef7..375ab0e 100644 --- a/src/openai.rs +++ b/src/azure.rs @@ -1,2 +1,3 @@ pub mod chatgpt; pub mod chatgpt_api; +pub mod tts; diff --git a/src/openai/chatgpt.rs b/src/azure/chatgpt.rs similarity index 96% rename from src/openai/chatgpt.rs rename to src/azure/chatgpt.rs index 38b3a22..a739b02 100644 --- a/src/openai/chatgpt.rs +++ b/src/azure/chatgpt.rs @@ -10,15 +10,15 @@ use tokio::sync::mpsc::channel; use tokio::sync::mpsc::Receiver; use tokio::sync::mpsc::Sender; +use crate::azure::chatgpt_api::ChatRequest; +use crate::azure::chatgpt_api::ChatRequestMessage; +use crate::azure::chatgpt_api::ChatResponse; +use crate::azure::chatgpt_api::Role; +use crate::azure::chatgpt_api::Tool; use crate::llm::function::FunctionStore; use crate::llm::ChatEvent; use crate::llm::ChatHandler; use crate::llm::Usage; -use crate::openai::chatgpt_api::ChatRequest; -use crate::openai::chatgpt_api::ChatRequestMessage; -use crate::openai::chatgpt_api::ChatResponse; -use crate::openai::chatgpt_api::Role; -use crate::openai::chatgpt_api::Tool; use crate::util::exception::Exception; use crate::util::http_client; use crate::util::json; @@ -125,8 +125,8 @@ impl ChatGPT { async fn call_api(&mut self) -> Result { let request = ChatRequest { messages: Rc::clone(&self.messages), - temperature: 0.8, - top_p: 0.8, + temperature: 0.7, + top_p: 0.95, stream: true, stop: None, max_tokens: 800, diff --git a/src/openai/chatgpt_api.rs b/src/azure/chatgpt_api.rs similarity index 100% rename from src/openai/chatgpt_api.rs rename to src/azure/chatgpt_api.rs diff --git a/src/azure/tts.rs b/src/azure/tts.rs new file mode 100644 index 0000000..155d1a2 --- /dev/null +++ b/src/azure/tts.rs @@ -0,0 +1,41 @@ +use crate::util::exception::Exception; +use crate::util::http_client; + +pub struct AzureTTS { + pub endpoint: String, + pub resource: String, + pub api_key: String, + pub voice: String, +} + +impl AzureTTS { + pub async fn synthesize(&self, text: &str) -> Result, Exception> { + let body = format!( + r#" + "#, + self.voice + ); + + let response = http_client::http_client() + .post(&self.endpoint) + .header("Ocp-Apim-Subscription-Key", &self.api_key) + .header("User-Agent", &self.resource) + .header("X-Microsoft-OutputFormat", "riff-44100hz-16bit-mono-pcm") + .header("Content-Type", "application/ssml+xml") + .body(body) + .send() + .await?; + + let status = response.status(); + if status != 200 { + let response_text = response.text().await?; + return Err(Exception::ExternalError(format!( + "failed to call azure api, status={status}, response={response_text}" + ))); + } + + Ok(response.bytes().await?.to_vec()) + } +} diff --git a/src/command/speak.rs b/src/command/speak.rs index 8e74f9c..39c81a4 100644 --- a/src/command/speak.rs +++ b/src/command/speak.rs @@ -1,11 +1,15 @@ +use std::env::temp_dir; use std::path::PathBuf; use clap::arg; use clap::Args; +use tokio::fs; use tokio::io::stdin; use tokio::io::AsyncReadExt; +use tokio::process::Command; +use tracing::info; +use uuid::Uuid; -use crate::gcloud::synthesize; use crate::tts; use crate::util::exception::Exception; @@ -30,28 +34,31 @@ impl Speak { return Err(Exception::ValidationError("must specify --stdin or --text".to_string())); } - let config = tts::load(&self.conf).await?; - let model = config - .models - .get(&self.name) - .ok_or_else(|| Exception::ValidationError(format!("can not find model, name={}", self.name)))?; + let speech = tts::load(&self.conf, &self.name).await?; let mut buffer = String::new(); let text = if self.stdin { stdin().read_to_string(&mut buffer).await?; + info!("text={}", buffer); &buffer } else { self.text.as_ref().unwrap() }; - let gcloud = synthesize::GCloud { - endpoint: model.endpoint.to_string(), - project: model.params.get("project").unwrap().to_string(), - voice: model.params.get("voice").unwrap().to_string(), - }; + let audio = speech.synthesize(text).await?; - gcloud.synthesize(text).await?; + play(audio).await?; Ok(()) } } + +async fn play(audio: Vec) -> Result<(), Exception> { + let temp_file = temp_dir().join(format!("{}.wav", Uuid::new_v4())); + fs::write(&temp_file, &audio).await?; + info!("play audio file, file={}", temp_file.to_string_lossy()); + let mut command = Command::new("afplay").args([temp_file.to_string_lossy().to_string()]).spawn()?; + let _ = command.wait().await; + fs::remove_file(temp_file).await?; + Ok(()) +} diff --git a/src/gcloud.rs b/src/gcloud.rs index 7d02bd9..1b284c7 100644 --- a/src/gcloud.rs +++ b/src/gcloud.rs @@ -2,8 +2,7 @@ use std::env; pub mod gemini; mod gemini_api; -pub mod synthesize; -mod synthesize_api; +pub mod tts; pub fn token() -> String { env::var("GCLOUD_AUTH_TOKEN").expect("please set GCLOUD_AUTH_TOKEN env") diff --git a/src/gcloud/synthesize_api.rs b/src/gcloud/synthesize_api.rs deleted file mode 100644 index 941f7f3..0000000 --- a/src/gcloud/synthesize_api.rs +++ /dev/null @@ -1,41 +0,0 @@ -use std::borrow::Cow; - -use serde::Deserialize; -use serde::Serialize; - -#[derive(Debug, Serialize)] -pub struct SynthesizeRequest<'a> { - #[serde(rename = "audioConfig")] - pub audio_config: AudioConfig, - pub input: Input<'a>, - pub voice: Voice<'a>, -} - -#[derive(Debug, Serialize)] -pub struct AudioConfig { - #[serde(rename = "audioEncoding")] - pub audio_encoding: String, - #[serde(rename = "effectsProfileId")] - pub effects_profile_id: Vec, - pub pitch: i64, - #[serde(rename = "speakingRate")] - pub speaking_rate: i64, -} - -#[derive(Debug, Serialize)] -pub struct Input<'a> { - pub text: Cow<'a, str>, -} - -#[derive(Debug, Serialize)] -pub struct Voice<'a> { - #[serde(rename = "languageCode")] - pub language_code: String, - pub name: Cow<'a, str>, -} - -#[derive(Debug, Deserialize)] -pub struct SynthesizeResponse { - #[serde(rename = "audioContent")] - pub audio_content: String, -} diff --git a/src/gcloud/synthesize.rs b/src/gcloud/tts.rs similarity index 66% rename from src/gcloud/synthesize.rs rename to src/gcloud/tts.rs index 0dfed97..71172ca 100644 --- a/src/gcloud/synthesize.rs +++ b/src/gcloud/tts.rs @@ -1,32 +1,25 @@ use std::borrow::Cow; -use std::env::temp_dir; use base64::prelude::BASE64_STANDARD; use base64::DecodeError; use base64::Engine; -use tokio::fs; -use tokio::process::Command; +use serde::Deserialize; +use serde::Serialize; use tracing::info; -use uuid::Uuid; use super::token; -use crate::gcloud::synthesize_api::AudioConfig; -use crate::gcloud::synthesize_api::Input; -use crate::gcloud::synthesize_api::SynthesizeRequest; -use crate::gcloud::synthesize_api::SynthesizeResponse; -use crate::gcloud::synthesize_api::Voice; use crate::util::exception::Exception; use crate::util::http_client; use crate::util::json; -pub struct GCloud { +pub struct GCloudTTS { pub endpoint: String, pub project: String, pub voice: String, } -impl GCloud { - pub async fn synthesize(&self, text: &str) -> Result<(), Exception> { +impl GCloudTTS { + pub async fn synthesize(&self, text: &str) -> Result, Exception> { info!("call gcloud synthesize api, endpoint={}", self.endpoint); let request = SynthesizeRequest { audio_config: AudioConfig { @@ -65,20 +58,45 @@ impl GCloud { let response: SynthesizeResponse = json::from_json(&response_body)?; let content = BASE64_STANDARD.decode(response.audio_content)?; - play(content).await?; - - Ok(()) + Ok(content) } } -async fn play(audio: Vec) -> Result<(), Exception> { - let temp_file = temp_dir().join(format!("{}.wav", Uuid::new_v4())); - fs::write(&temp_file, &audio).await?; - info!("play audio file, file={}", temp_file.to_string_lossy()); - let mut command = Command::new("afplay").args([temp_file.to_string_lossy().to_string()]).spawn()?; - let _ = command.wait().await; - fs::remove_file(temp_file).await?; - Ok(()) +#[derive(Debug, Serialize)] +struct SynthesizeRequest<'a> { + #[serde(rename = "audioConfig")] + audio_config: AudioConfig, + input: Input<'a>, + voice: Voice<'a>, +} + +#[derive(Debug, Serialize)] +struct AudioConfig { + #[serde(rename = "audioEncoding")] + audio_encoding: String, + #[serde(rename = "effectsProfileId")] + effects_profile_id: Vec, + pitch: i64, + #[serde(rename = "speakingRate")] + speaking_rate: i64, +} + +#[derive(Debug, Serialize)] +struct Input<'a> { + text: Cow<'a, str>, +} + +#[derive(Debug, Serialize)] +struct Voice<'a> { + #[serde(rename = "languageCode")] + language_code: String, + name: Cow<'a, str>, +} + +#[derive(Debug, Deserialize)] +struct SynthesizeResponse { + #[serde(rename = "audioContent")] + audio_content: String, } impl From for Exception { diff --git a/src/llm.rs b/src/llm.rs index 5a37088..56a5ad7 100644 --- a/src/llm.rs +++ b/src/llm.rs @@ -4,9 +4,9 @@ use tokio::fs; use tracing::info; use tracing::warn; +use crate::azure::chatgpt::ChatGPT; use crate::gcloud::gemini::Gemini; use crate::llm::config::Config; -use crate::openai::chatgpt::ChatGPT; use crate::util::exception::Exception; use crate::util::json; diff --git a/src/llm/config.rs b/src/llm/config.rs index fd7701d..d958c88 100644 --- a/src/llm/config.rs +++ b/src/llm/config.rs @@ -4,11 +4,12 @@ use rand::Rng; use serde::Deserialize; use serde_json::json; +use crate::azure::chatgpt::ChatGPT; use crate::gcloud::gemini::Gemini; use crate::llm::function::Function; use crate::llm::function::FunctionStore; use crate::llm::Model; -use crate::openai::chatgpt::ChatGPT; +use crate::provider::Provider; use crate::util::exception::Exception; #[derive(Deserialize, Debug)] @@ -25,12 +26,6 @@ pub struct ModelConfig { pub functions: Vec, } -#[derive(Deserialize, Debug)] -pub enum Provider { - Azure, - GCloud, -} - impl Config { pub fn create(&self, name: &str) -> Result { let config = self diff --git a/src/main.rs b/src/main.rs index 63a3f30..d3e44ba 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,10 +5,11 @@ use command::generate_zsh_completion::GenerateZshCompletion; use command::speak::Speak; use util::exception::Exception; +mod azure; mod command; mod gcloud; mod llm; -mod openai; +mod provider; mod tts; mod util; diff --git a/src/provider.rs b/src/provider.rs new file mode 100644 index 0000000..ad1ca33 --- /dev/null +++ b/src/provider.rs @@ -0,0 +1,9 @@ +use serde::Deserialize; + +#[derive(Deserialize, Debug)] +pub enum Provider { + #[serde(rename = "azure")] + Azure, + #[serde(rename = "gcloud")] + GCloud, +} diff --git a/src/tts.rs b/src/tts.rs index e3643ad..f68b5da 100644 --- a/src/tts.rs +++ b/src/tts.rs @@ -1,17 +1,65 @@ +use std::collections::HashMap; use std::path::Path; -use config::Config; +use serde::Deserialize; use tokio::fs; use tracing::info; +use crate::azure::tts::AzureTTS; +use crate::gcloud::tts::GCloudTTS; +use crate::provider::Provider; use crate::util::exception::Exception; use crate::util::json; -pub mod config; +#[derive(Deserialize, Debug)] +struct Config { + models: HashMap, +} + +#[derive(Deserialize, Debug)] +struct ModelConfig { + endpoint: String, + provider: Provider, + params: HashMap, +} + +pub enum Speech { + Azure(AzureTTS), + GCloud(GCloudTTS), +} -pub async fn load(path: &Path) -> Result { +impl Speech { + pub async fn synthesize(&self, text: &str) -> Result, Exception> { + match self { + Speech::Azure(model) => model.synthesize(text).await, + Speech::GCloud(model) => model.synthesize(text).await, + } + } +} + +pub async fn load(path: &Path, name: &str) -> Result { info!("load config, path={}", path.to_string_lossy()); let content = fs::read_to_string(path).await?; let config: Config = json::from_json(&content)?; - Ok(config) + + let config = config + .models + .get(name) + .ok_or_else(|| Exception::ValidationError(format!("can not find model, name={name}")))?; + + let model = match config.provider { + Provider::Azure => Speech::Azure(AzureTTS { + endpoint: config.endpoint.to_string(), + resource: config.params.get("resource").unwrap().to_string(), + api_key: config.params.get("api_key").unwrap().to_string(), + voice: config.params.get("voice").unwrap().to_string(), + }), + Provider::GCloud => Speech::GCloud(GCloudTTS { + endpoint: config.endpoint.to_string(), + project: config.params.get("project").unwrap().to_string(), + voice: config.params.get("voice").unwrap().to_string(), + }), + }; + + Ok(model) } diff --git a/src/tts/config.rs b/src/tts/config.rs deleted file mode 100644 index 0ebe85a..0000000 --- a/src/tts/config.rs +++ /dev/null @@ -1,14 +0,0 @@ -use std::collections::HashMap; - -use serde::Deserialize; - -#[derive(Deserialize, Debug)] -pub struct Config { - pub models: HashMap, -} - -#[derive(Deserialize, Debug)] -pub struct ModelConfig { - pub endpoint: String, - pub params: HashMap, -}