Skip to content

Commit

Permalink
support azure tts
Browse files Browse the repository at this point in the history
  • Loading branch information
neowu committed Jun 27, 2024
1 parent 8850625 commit c8a51df
Show file tree
Hide file tree
Showing 14 changed files with 176 additions and 112 deletions.
1 change: 1 addition & 0 deletions src/openai.rs → src/azure.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
pub mod chatgpt;
pub mod chatgpt_api;
pub mod tts;
14 changes: 7 additions & 7 deletions src/openai/chatgpt.rs → src/azure/chatgpt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@ use tokio::sync::mpsc::channel;
use tokio::sync::mpsc::Receiver;
use tokio::sync::mpsc::Sender;

use crate::azure::chatgpt_api::ChatRequest;
use crate::azure::chatgpt_api::ChatRequestMessage;
use crate::azure::chatgpt_api::ChatResponse;
use crate::azure::chatgpt_api::Role;
use crate::azure::chatgpt_api::Tool;
use crate::llm::function::FunctionStore;
use crate::llm::ChatEvent;
use crate::llm::ChatHandler;
use crate::llm::Usage;
use crate::openai::chatgpt_api::ChatRequest;
use crate::openai::chatgpt_api::ChatRequestMessage;
use crate::openai::chatgpt_api::ChatResponse;
use crate::openai::chatgpt_api::Role;
use crate::openai::chatgpt_api::Tool;
use crate::util::exception::Exception;
use crate::util::http_client;
use crate::util::json;
Expand Down Expand Up @@ -125,8 +125,8 @@ impl ChatGPT {
async fn call_api(&mut self) -> Result<EventSource, Exception> {
let request = ChatRequest {
messages: Rc::clone(&self.messages),
temperature: 0.8,
top_p: 0.8,
temperature: 0.7,
top_p: 0.95,
stream: true,
stop: None,
max_tokens: 800,
Expand Down
File renamed without changes.
41 changes: 41 additions & 0 deletions src/azure/tts.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
use crate::util::exception::Exception;
use crate::util::http_client;

pub struct AzureTTS {
pub endpoint: String,
pub resource: String,
pub api_key: String,
pub voice: String,
}

impl AzureTTS {
pub async fn synthesize(&self, text: &str) -> Result<Vec<u8>, Exception> {
let body = format!(
r#"<speak version="1.0" xmlns="http://www.w3.org/2001/10/synthesis" xmlns:mstts="https://www.w3.org/2001/mstts" xml:lang="en-US">
<voice name="{}"><mstts:express-as style="narration-relaxed"><![CDATA[
{text}
]]></mstts:express-as></voice></speak>"#,
self.voice
);

let response = http_client::http_client()
.post(&self.endpoint)
.header("Ocp-Apim-Subscription-Key", &self.api_key)
.header("User-Agent", &self.resource)
.header("X-Microsoft-OutputFormat", "riff-44100hz-16bit-mono-pcm")
.header("Content-Type", "application/ssml+xml")
.body(body)
.send()
.await?;

let status = response.status();
if status != 200 {
let response_text = response.text().await?;
return Err(Exception::ExternalError(format!(
"failed to call azure api, status={status}, response={response_text}"
)));
}

Ok(response.bytes().await?.to_vec())
}
}
31 changes: 19 additions & 12 deletions src/command/speak.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
use std::env::temp_dir;
use std::path::PathBuf;

use clap::arg;
use clap::Args;
use tokio::fs;
use tokio::io::stdin;
use tokio::io::AsyncReadExt;
use tokio::process::Command;
use tracing::info;
use uuid::Uuid;

use crate::gcloud::synthesize;
use crate::tts;
use crate::util::exception::Exception;

Expand All @@ -30,28 +34,31 @@ impl Speak {
return Err(Exception::ValidationError("must specify --stdin or --text".to_string()));
}

let config = tts::load(&self.conf).await?;
let model = config
.models
.get(&self.name)
.ok_or_else(|| Exception::ValidationError(format!("can not find model, name={}", self.name)))?;
let speech = tts::load(&self.conf, &self.name).await?;

let mut buffer = String::new();
let text = if self.stdin {
stdin().read_to_string(&mut buffer).await?;
info!("text={}", buffer);
&buffer
} else {
self.text.as_ref().unwrap()
};

let gcloud = synthesize::GCloud {
endpoint: model.endpoint.to_string(),
project: model.params.get("project").unwrap().to_string(),
voice: model.params.get("voice").unwrap().to_string(),
};
let audio = speech.synthesize(text).await?;

gcloud.synthesize(text).await?;
play(audio).await?;

Ok(())
}
}

async fn play(audio: Vec<u8>) -> Result<(), Exception> {
let temp_file = temp_dir().join(format!("{}.wav", Uuid::new_v4()));
fs::write(&temp_file, &audio).await?;
info!("play audio file, file={}", temp_file.to_string_lossy());
let mut command = Command::new("afplay").args([temp_file.to_string_lossy().to_string()]).spawn()?;
let _ = command.wait().await;
fs::remove_file(temp_file).await?;
Ok(())
}
3 changes: 1 addition & 2 deletions src/gcloud.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@ use std::env;

pub mod gemini;
mod gemini_api;
pub mod synthesize;
mod synthesize_api;
pub mod tts;

pub fn token() -> String {
env::var("GCLOUD_AUTH_TOKEN").expect("please set GCLOUD_AUTH_TOKEN env")
Expand Down
41 changes: 0 additions & 41 deletions src/gcloud/synthesize_api.rs

This file was deleted.

64 changes: 41 additions & 23 deletions src/gcloud/synthesize.rs → src/gcloud/tts.rs
Original file line number Diff line number Diff line change
@@ -1,32 +1,25 @@
use std::borrow::Cow;
use std::env::temp_dir;

use base64::prelude::BASE64_STANDARD;
use base64::DecodeError;
use base64::Engine;
use tokio::fs;
use tokio::process::Command;
use serde::Deserialize;
use serde::Serialize;
use tracing::info;
use uuid::Uuid;

use super::token;
use crate::gcloud::synthesize_api::AudioConfig;
use crate::gcloud::synthesize_api::Input;
use crate::gcloud::synthesize_api::SynthesizeRequest;
use crate::gcloud::synthesize_api::SynthesizeResponse;
use crate::gcloud::synthesize_api::Voice;
use crate::util::exception::Exception;
use crate::util::http_client;
use crate::util::json;

pub struct GCloud {
pub struct GCloudTTS {
pub endpoint: String,
pub project: String,
pub voice: String,
}

impl GCloud {
pub async fn synthesize(&self, text: &str) -> Result<(), Exception> {
impl GCloudTTS {
pub async fn synthesize(&self, text: &str) -> Result<Vec<u8>, Exception> {
info!("call gcloud synthesize api, endpoint={}", self.endpoint);
let request = SynthesizeRequest {
audio_config: AudioConfig {
Expand Down Expand Up @@ -65,20 +58,45 @@ impl GCloud {
let response: SynthesizeResponse = json::from_json(&response_body)?;
let content = BASE64_STANDARD.decode(response.audio_content)?;

play(content).await?;

Ok(())
Ok(content)
}
}

async fn play(audio: Vec<u8>) -> Result<(), Exception> {
let temp_file = temp_dir().join(format!("{}.wav", Uuid::new_v4()));
fs::write(&temp_file, &audio).await?;
info!("play audio file, file={}", temp_file.to_string_lossy());
let mut command = Command::new("afplay").args([temp_file.to_string_lossy().to_string()]).spawn()?;
let _ = command.wait().await;
fs::remove_file(temp_file).await?;
Ok(())
#[derive(Debug, Serialize)]
struct SynthesizeRequest<'a> {
#[serde(rename = "audioConfig")]
audio_config: AudioConfig,
input: Input<'a>,
voice: Voice<'a>,
}

#[derive(Debug, Serialize)]
struct AudioConfig {
#[serde(rename = "audioEncoding")]
audio_encoding: String,
#[serde(rename = "effectsProfileId")]
effects_profile_id: Vec<String>,
pitch: i64,
#[serde(rename = "speakingRate")]
speaking_rate: i64,
}

#[derive(Debug, Serialize)]
struct Input<'a> {
text: Cow<'a, str>,
}

#[derive(Debug, Serialize)]
struct Voice<'a> {
#[serde(rename = "languageCode")]
language_code: String,
name: Cow<'a, str>,
}

#[derive(Debug, Deserialize)]
struct SynthesizeResponse {
#[serde(rename = "audioContent")]
audio_content: String,
}

impl From<DecodeError> for Exception {
Expand Down
2 changes: 1 addition & 1 deletion src/llm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ use tokio::fs;
use tracing::info;
use tracing::warn;

use crate::azure::chatgpt::ChatGPT;
use crate::gcloud::gemini::Gemini;
use crate::llm::config::Config;
use crate::openai::chatgpt::ChatGPT;
use crate::util::exception::Exception;
use crate::util::json;

Expand Down
9 changes: 2 additions & 7 deletions src/llm/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@ use rand::Rng;
use serde::Deserialize;
use serde_json::json;

use crate::azure::chatgpt::ChatGPT;
use crate::gcloud::gemini::Gemini;
use crate::llm::function::Function;
use crate::llm::function::FunctionStore;
use crate::llm::Model;
use crate::openai::chatgpt::ChatGPT;
use crate::provider::Provider;
use crate::util::exception::Exception;

#[derive(Deserialize, Debug)]
Expand All @@ -25,12 +26,6 @@ pub struct ModelConfig {
pub functions: Vec<String>,
}

#[derive(Deserialize, Debug)]
pub enum Provider {
Azure,
GCloud,
}

impl Config {
pub fn create(&self, name: &str) -> Result<Model, Exception> {
let config = self
Expand Down
3 changes: 2 additions & 1 deletion src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@ use command::generate_zsh_completion::GenerateZshCompletion;
use command::speak::Speak;
use util::exception::Exception;

mod azure;
mod command;
mod gcloud;
mod llm;
mod openai;
mod provider;
mod tts;
mod util;

Expand Down
9 changes: 9 additions & 0 deletions src/provider.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
use serde::Deserialize;

#[derive(Deserialize, Debug)]
pub enum Provider {
#[serde(rename = "azure")]
Azure,
#[serde(rename = "gcloud")]
GCloud,
}
Loading

0 comments on commit c8a51df

Please sign in to comment.