Skip to content

Commit

Permalink
support tts
Browse files Browse the repository at this point in the history
  • Loading branch information
neowu committed Jun 26, 2024
1 parent 3883559 commit 8850625
Show file tree
Hide file tree
Showing 22 changed files with 362 additions and 125 deletions.
10 changes: 10 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@ reqwest-eventsource = "0"
futures = "0"
rand = "0"
base64 = "0"
uuid = { version = "1", features = ["v4"] }
1 change: 1 addition & 0 deletions src/command.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
pub mod chat;
pub mod generate_zsh_completion;
pub mod speak;
19 changes: 10 additions & 9 deletions src/command/chat.rs
Original file line number Diff line number Diff line change
@@ -1,24 +1,25 @@
use std::io;
use std::io::Write;
use std::path::Path;
use std::path::PathBuf;

use clap::Args;
use tokio::io::stdin;
use tokio::io::AsyncBufReadExt;
use tokio::io::BufReader;
use tracing::info;

use crate::bot;
use crate::bot::ChatEvent;
use crate::bot::ChatHandler;
use crate::llm;
use crate::llm::ChatEvent;
use crate::llm::ChatHandler;
use crate::util::exception::Exception;

#[derive(Args)]
pub struct Chat {
#[arg(long, help = "conf path")]
conf: String,
conf: PathBuf,

#[arg(long, help = "bot name")]
#[arg(long, help = "model name")]
name: String,
}

Expand Down Expand Up @@ -46,8 +47,8 @@ impl ChatHandler for ConsoleHandler {

impl Chat {
pub async fn execute(&self) -> Result<(), Exception> {
let config = bot::load(Path::new(&self.conf)).await?;
let mut bot = config.create(&self.name)?;
let config = llm::load(&self.conf).await?;
let mut model = config.create(&self.name)?;
let handler = ConsoleHandler {};

let reader = BufReader::new(stdin());
Expand All @@ -60,9 +61,9 @@ impl Chat {
break;
}
if line.starts_with("/file ") {
bot.file(Path::new(line.strip_prefix("/file ").unwrap()))?;
model.file(Path::new(line.strip_prefix("/file ").unwrap()))?;
} else {
bot.chat(line, &handler).await?;
model.chat(line, &handler).await?;
}
}

Expand Down
57 changes: 57 additions & 0 deletions src/command/speak.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
use std::path::PathBuf;

use clap::arg;
use clap::Args;
use tokio::io::stdin;
use tokio::io::AsyncReadExt;

use crate::gcloud::synthesize;
use crate::tts;
use crate::util::exception::Exception;

#[derive(Args)]
pub struct Speak {
#[arg(long, help = "conf path")]
conf: PathBuf,

#[arg(long, help = "model name")]
name: String,

#[arg(long, help = "text")]
text: Option<String>,

#[arg(long, help = "stdin", default_value_t = false)]
stdin: bool,
}

impl Speak {
pub async fn execute(&self) -> Result<(), Exception> {
if !self.stdin && self.text.is_none() {
return Err(Exception::ValidationError("must specify --stdin or --text".to_string()));
}

let config = tts::load(&self.conf).await?;
let model = config
.models
.get(&self.name)
.ok_or_else(|| Exception::ValidationError(format!("can not find model, name={}", self.name)))?;

let mut buffer = String::new();
let text = if self.stdin {
stdin().read_to_string(&mut buffer).await?;
&buffer
} else {
self.text.as_ref().unwrap()
};

let gcloud = synthesize::GCloud {
endpoint: model.endpoint.to_string(),
project: model.params.get("project").unwrap().to_string(),
voice: model.params.get("voice").unwrap().to_string(),
};

gcloud.synthesize(text).await?;

Ok(())
}
}
12 changes: 10 additions & 2 deletions src/gcloud.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,10 @@
mod api;
pub mod vertex;
use std::env;

pub mod gemini;
mod gemini_api;
pub mod synthesize;
mod synthesize_api;

pub fn token() -> String {
env::var("GCLOUD_AUTH_TOKEN").expect("please set GCLOUD_AUTH_TOKEN env")
}
47 changes: 23 additions & 24 deletions src/gcloud/vertex.rs → src/gcloud/gemini.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use std::env;
use std::fs;
use std::mem;
use std::ops::Not;
Expand All @@ -14,23 +13,24 @@ use tokio::sync::mpsc::Receiver;
use tokio::sync::mpsc::Sender;
use tracing::info;

use super::api::Content;
use super::api::FunctionCall;
use super::api::GenerationConfig;
use super::api::InlineData;
use super::api::Role;
use super::api::StreamGenerateContent;
use super::api::Tool;
use crate::bot::function::FunctionStore;
use crate::bot::ChatEvent;
use crate::bot::ChatHandler;
use crate::bot::Usage;
use crate::gcloud::api::GenerateContentResponse;
use super::gemini_api::Content;
use super::gemini_api::FunctionCall;
use super::gemini_api::GenerationConfig;
use super::gemini_api::InlineData;
use super::gemini_api::Role;
use super::gemini_api::StreamGenerateContent;
use super::gemini_api::Tool;
use super::token;
use crate::gcloud::gemini_api::GenerateContentResponse;
use crate::llm::function::FunctionStore;
use crate::llm::ChatEvent;
use crate::llm::ChatHandler;
use crate::llm::Usage;
use crate::util::exception::Exception;
use crate::util::http_client;
use crate::util::json;

pub struct Vertex {
pub struct Gemini {
url: String,
messages: Rc<Vec<Content>>,
system_message: Option<Rc<Content>>,
Expand All @@ -40,7 +40,7 @@ pub struct Vertex {
usage: Usage,
}

impl Vertex {
impl Gemini {
pub fn new(
endpoint: String,
project: String,
Expand All @@ -50,7 +50,7 @@ impl Vertex {
function_store: FunctionStore,
) -> Self {
let url = format!("{endpoint}/v1/projects/{project}/locations/{location}/publishers/google/models/{model}:streamGenerateContent");
Vertex {
Gemini {
url,
messages: Rc::new(vec![]),
system_message: system_message.map(|message| Rc::new(Content::new_text(Role::Model, message))),
Expand Down Expand Up @@ -78,15 +78,18 @@ impl Vertex {
pub fn file(&mut self, path: &Path) -> Result<(), Exception> {
let extension = path
.extension()
.ok_or_else(|| Exception::new(format!("file must have extension, path={}", path.to_string_lossy())))?
.ok_or_else(|| Exception::ValidationError(format!("file must have extension, path={}", path.to_string_lossy())))?
.to_str()
.unwrap();
let content = fs::read(path)?;
let mime_type = match extension {
"jpg" => Ok("image/jpeg".to_string()),
"png" => Ok("image/png".to_string()),
"pdf" => Ok("application/pdf".to_string()),
_ => Err(Exception::new(format!("not supported extension, path={}", path.to_string_lossy()))),
_ => Err(Exception::ValidationError(format!(
"not supported extension, path={}",
path.to_string_lossy()
))),
}?;
info!(
"file added, will submit with next message, mime_type={mime_type}, path={}",
Expand Down Expand Up @@ -182,7 +185,7 @@ impl Vertex {
let status = response.status();
if status != 200 {
let response_text = response.text().await?;
return Err(Exception::new(format!(
return Err(Exception::ExternalError(format!(
"failed to call gcloud api, status={status}, response={response_text}"
)));
}
Expand Down Expand Up @@ -210,7 +213,7 @@ async fn read_response_stream(response: Response, tx: Sender<GenerateContentResp
buffer.clear();
}
Err(err) => {
return Err(Exception::new(err.to_string()));
return Err(Exception::unexpected(err));
}
}
}
Expand All @@ -221,7 +224,3 @@ fn is_valid_json(content: &str) -> bool {
let result: serde_json::Result<serde::de::IgnoredAny> = serde_json::from_str(content);
result.is_ok()
}

fn token() -> String {
env::var("GCLOUD_AUTH_TOKEN").expect("please set GCLOUD_AUTH_TOKEN env")
}
2 changes: 1 addition & 1 deletion src/gcloud/api.rs → src/gcloud/gemini_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::rc::Rc;
use serde::Deserialize;
use serde::Serialize;

use crate::bot::function::Function;
use crate::llm::function::Function;

#[derive(Debug, Serialize)]
pub struct StreamGenerateContent {
Expand Down
88 changes: 88 additions & 0 deletions src/gcloud/synthesize.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
use std::borrow::Cow;
use std::env::temp_dir;

use base64::prelude::BASE64_STANDARD;
use base64::DecodeError;
use base64::Engine;
use tokio::fs;
use tokio::process::Command;
use tracing::info;
use uuid::Uuid;

use super::token;
use crate::gcloud::synthesize_api::AudioConfig;
use crate::gcloud::synthesize_api::Input;
use crate::gcloud::synthesize_api::SynthesizeRequest;
use crate::gcloud::synthesize_api::SynthesizeResponse;
use crate::gcloud::synthesize_api::Voice;
use crate::util::exception::Exception;
use crate::util::http_client;
use crate::util::json;

pub struct GCloud {
pub endpoint: String,
pub project: String,
pub voice: String,
}

impl GCloud {
pub async fn synthesize(&self, text: &str) -> Result<(), Exception> {
info!("call gcloud synthesize api, endpoint={}", self.endpoint);
let request = SynthesizeRequest {
audio_config: AudioConfig {
audio_encoding: "LINEAR16".to_string(),
effects_profile_id: vec!["headphone-class-device".to_string()],
pitch: 0,
speaking_rate: 1,
},
input: Input { text: Cow::from(text) },
voice: Voice {
language_code: "en-US".to_string(),
name: Cow::from(&self.voice),
},
};

let body = json::to_json(&request)?;
let response = http_client::http_client()
.post(&self.endpoint)
.bearer_auth(token())
.header("x-goog-user-project", &self.project)
.header("Content-Type", "application/json")
.header("Accept", "application/json")
.body(body)
.send()
.await?;

let status = response.status();
if status != 200 {
let response_text = response.text().await?;
return Err(Exception::ExternalError(format!(
"failed to call gcloud api, status={status}, response={response_text}"
)));
}

let response_body = response.text_with_charset("utf-8").await?;
let response: SynthesizeResponse = json::from_json(&response_body)?;
let content = BASE64_STANDARD.decode(response.audio_content)?;

play(content).await?;

Ok(())
}
}

async fn play(audio: Vec<u8>) -> Result<(), Exception> {
let temp_file = temp_dir().join(format!("{}.wav", Uuid::new_v4()));
fs::write(&temp_file, &audio).await?;
info!("play audio file, file={}", temp_file.to_string_lossy());
let mut command = Command::new("afplay").args([temp_file.to_string_lossy().to_string()]).spawn()?;
let _ = command.wait().await;
fs::remove_file(temp_file).await?;
Ok(())
}

impl From<DecodeError> for Exception {
fn from(err: DecodeError) -> Self {
Exception::unexpected(err)
}
}
Loading

0 comments on commit 8850625

Please sign in to comment.