Skip to content

Commit

Permalink
feat: add tts websocket demo
Browse files Browse the repository at this point in the history
Signed-off-by: Xinwei Xiong <[email protected]>
  • Loading branch information
cubxxw committed Dec 17, 2024
1 parent b0e6d13 commit 32e5050
Show file tree
Hide file tree
Showing 12 changed files with 650 additions and 58 deletions.
27 changes: 22 additions & 5 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ VOICEFLOW_MINIO_SECRET_KEY='' # MinIO 密钥
# Azure 配置
VOICEFLOW_AZURE_STT_KEY='' # Azure STT 密钥
VOICEFLOW_AZURE_TTS_KEY='' # Azure TTS 密钥
SPEECH_KEY='' # Azure 语音密钥
VOICEFLOW_AZURE_SPEECH_KEY='' # Azure 语音密钥
VOICEFLOW_AZURE_REGION='japaneast' # Azure 区域

# AWS 配置
Expand All @@ -24,10 +24,27 @@ VOICEFLOW_OPENAI_BASE_URL='' # OpenAI 基础 URL
# AssemblyAI 配置
VOICEFLOW_ASSEMBLYAI_API_KEY='' # AssemblyAI API 密钥

# VOLCENGINE 配置
VOICEFLOW_VOLCENGINE_ACCESS_KEY='' # VOLCENGINE 访问密钥
VOICEFLOW_VOLCENGINE_APP_KEY='' # VOLCENGINE 应用密钥
VOICEFLOW_VOLCENGINE_WS_URL='wss://openspeech.bytedance.com/api/v3/sauc/bigmodel' # VOLCENGINE WebSocket URL
# VOLCENGINE STT 配置
VOICEFLOW_VOLCENGINE_STT_WS_URL='wss://openspeech.bytedance.com/api/v3/sauc/bigmodel' # STT WebSocket URL
VOICEFLOW_VOLCENGINE_STT_UID='test' # STT 用户ID
VOICEFLOW_VOLCENGINE_STT_RATE='16000' # STT 采样率
VOICEFLOW_VOLCENGINE_STT_FORMAT='pcm' # STT 音频格式
VOICEFLOW_VOLCENGINE_STT_BITS='16' # STT 位深度
VOICEFLOW_VOLCENGINE_STT_CHANNEL='1' # STT 声道数
VOICEFLOW_VOLCENGINE_STT_CODEC='pcm' # STT 编码格式
VOICEFLOW_VOLCENGINE_STT_ACCESS_KEY='' # STT 访问密钥
VOICEFLOW_VOLCENGINE_STT_APP_KEY='' # STT 应用密钥
VOICEFLOW_VOLCENGINE_STT_RESOURCE_ID='volc.bigasr.sauc.duration' # STT 资源ID

# VOLCENGINE TTS 配置
VOICEFLOW_VOLCENGINE_TTS_WS_URL='wss://openspeech.bytedance.com/api/v1/tts/ws_binary' # TTS WebSocket URL
VOICEFLOW_VOLCENGINE_TTS_APP_ID='' # TTS 应用ID
VOICEFLOW_VOLCENGINE_TTS_TOKEN='' # TTS 令牌
VOICEFLOW_VOLCENGINE_TTS_VOICE_TYPE='zh_female_1' # TTS 音色类型
VOICEFLOW_VOLCENGINE_TTS_ENCODING='mp3' # TTS 音频编码
VOICEFLOW_VOLCENGINE_TTS_SPEED_RATIO='1.0' # TTS 语速比例
VOICEFLOW_VOLCENGINE_TTS_VOLUME_RATIO='1.0' # TTS 音量比例
VOICEFLOW_VOLCENGINE_TTS_PITCH_RATIO='1.0' # TTS 音调比例

# 语音服务端口配置
VOICEFLOW_SERVER_PORT=18080 # 语音服务端口
23 changes: 13 additions & 10 deletions cmd/voiceflow/web/script.js
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@ ws.onopen = () => {

ws.onmessage = (event) => {
const data = JSON.parse(event.data);
if (data.partial_text) {
// 显示部分转录文本
updatePartialMessage('你', data.partial_text);
} else if (data.text) {
// 显示最终转录文本
if (data.text) {
// 显示文本消息
appendMessage('助手', data.text);
} else if (data.audio_url) {
appendAudioMessage('助手', data.audio_url);

// 如果有音频 URL,播放语音
if (data.audio_url) {
appendAudioMessage('助手', data.audio_url);
}
}
};

Expand Down Expand Up @@ -63,7 +63,7 @@ function startRecording() {

mediaRecorder = new MediaRecorder(stream);

// 设置 timeslice 控制音频数据可用的频率(例如每250毫秒)
// 设置 timeslice 控制音频数据可��的频率(例如每250毫秒)
const timeslice = 250; // 时间,单位为毫秒

mediaRecorder.start(timeslice);
Expand Down Expand Up @@ -147,8 +147,11 @@ uploadAudioInput.addEventListener('change', () => {
});

function sendTextMessage(text) {
// 通过 WebSocket 发送文字消息
ws.send(JSON.stringify({ text: text }));
// 通过 WebSocket 发送文字消息,并指明需要 TTS
ws.send(JSON.stringify({
text: text,
require_tts: true // 添加标志表明需要 TTS
}));
}

function sendAudioMessage(audioBlob) {
Expand Down
47 changes: 30 additions & 17 deletions configs/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,16 @@ stt:
provider: volcengine

tts:
# 可选值:azure、google、local
provider: google
# 可选值:azure、google、local、volcengine
provider: volcengine

llm:
# 可选值:openai、local
provider: openai

azure:
stt_key: "your_azure_stt_key"
tts_key: "your_azure_tts_key"
stt_key: ""
tts_key: ""
region: "eastus"

google:
Expand All @@ -45,22 +45,35 @@ openai:
base_url: ""

volcengine:
access_key: ''
app_key: ''
ws_url: ''
uid: "test"
rate: 16000
format: "pcm"
bits: 16
channel: 1
codec: "pcm"
# 小时版:volc.bigasr.sauc.duration
# 并发版:volc.bigasr.sauc.concurrent
resource_id: 'volc.bigasr.sauc.duration'
# 语音识别(STT)配置
stt:
ws_url: ''
uid: "test"
rate: 16000
format: "pcm"
bits: 16
channel: 1
codec: "pcm"
access_key: ''
app_key: ''
# 小时版:volc.bigasr.sauc.duration
# 并发版:volc.bigasr.sauc.concurrent
resource_id: 'volc.bigasr.sauc.duration'

# 语音合成(TTS)配置
tts:
ws_url: "wss://openspeech.bytedance.com/api/v1/tts/ws_binary"
app_id: "your_app_id"
token: "your_token"
voice_type: "zh_female_sajiaonvyou_moon_bigtts"
encoding: "mp3"
speed_ratio: 1.0
volume_ratio: 1.0
pitch_ratio: 1.0

# 日志配置
logging:
# 日志级别(选项:debug调试)、info(信息)、warn(警告)、error(错误)、fatal(致命错误))
# 日志级别(选项:debug���调试)、info(信息)、warn(警告)、error(错误)、fatal(致命错误))
level: "info"
# 日志格式(选项:json(JSON 格式)、text(文本格式))
format: "text"
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ require (
github.com/rs/xid v1.6.0 // indirect
github.com/sagikazarmark/locafero v0.4.0 // indirect
github.com/sagikazarmark/slog-shim v0.1.0 // indirect
github.com/satori/go.uuid v1.2.0 // indirect
github.com/sourcegraph/conc v0.3.0 // indirect
github.com/spf13/afero v1.11.0 // indirect
github.com/spf13/cast v1.6.0 // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ github.com/sagikazarmark/locafero v0.4.0 h1:HApY1R9zGo4DBgr7dqsTH/JJxLTTsOt7u6ke
github.com/sagikazarmark/locafero v0.4.0/go.mod h1:Pe1W6UlPYUk/+wc/6KFhbORCfqzgYEpgQ3O5fPuL3H4=
github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6gto+ugjYE=
github.com/sagikazarmark/slog-shim v0.1.0/go.mod h1:SrcSrq8aKtyuqEI1uvTDTK1arOWRIczQRv+GVI1AkeQ=
github.com/satori/go.uuid v1.2.0 h1:0uYX9dsZ2yD7q2RtLRtPSdGDWzjeM3TbMJP9utgA0ww=
github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0=
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo=
Expand Down
28 changes: 18 additions & 10 deletions internal/server/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ func InitServices() {
storageService = storage.NewService()
}

// 修改消息结构
type TextMessage struct {
Text string `json:"text"`
RequireTTS bool `json:"require_tts"`
}

func (s *Server) handleConnections(w http.ResponseWriter, r *http.Request) {

// 升级 WebSocket 连接
Expand Down Expand Up @@ -63,30 +69,32 @@ func (s *Server) handleConnections(w http.ResponseWriter, r *http.Request) {

if mt == websocket.TextMessage {
logger.Debug("Received text message")
// 处理文字消息
var msg map[string]string
var msg TextMessage
if err := json.Unmarshal(data, &msg); err != nil {
logger.Error("JSON parse error: %v", err)
continue
}
text := msg["text"]

// 调用 TTS 服务,将文字转换为语音
audioData, err := currentTTSService.Synthesize(text)

// 调用 TTS 服务
audioData, err := currentTTSService.Synthesize(msg.Text)
if err != nil {
logger.Error("TTS error: %v", err)
continue
}

// 存储音频并获取 URL
audioURL, err := currentStorageService.StoreAudio(audioData)
if err != nil {
logger.Error("Storage error: %v", err)
continue
}

// 返回音频 URL 给前端
ws.WriteJSON(map[string]string{"audio_url": audioURL})

// 返回文本和音频 URL
response := map[string]string{
"text": msg.Text,
"audio_url": audioURL,
}
ws.WriteJSON(response)
} else if mt == websocket.BinaryMessage {
logger.Debug("Received binary message")
// 处理音频消息
Expand Down
27 changes: 15 additions & 12 deletions internal/stt/volcengine/volcengine.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@ import (
"encoding/binary"
"encoding/json"
"fmt"
"net/http"
"time"

"github.com/google/uuid"
"github.com/gorilla/websocket"
"github.com/telepace/voiceflow/pkg/config"
"github.com/telepace/voiceflow/pkg/logger"
"net/http"
"time"
)

type STT struct {
Expand All @@ -31,17 +32,19 @@ func NewVolcengineSTT() *STT {
if err != nil {
logger.Fatalf("配置初始化失败: %v", err)
}

sttCfg := cfg.Volcengine.STT
return &STT{
wsURL: cfg.Volcengine.WsURL,
uid: cfg.Volcengine.UID,
rate: cfg.Volcengine.Rate,
format: cfg.Volcengine.Format,
bits: cfg.Volcengine.Bits,
channel: cfg.Volcengine.Channel,
codec: cfg.Volcengine.Codec,
accessKey: cfg.Volcengine.AccessKey,
appKey: cfg.Volcengine.AppKey,
resourceID: cfg.Volcengine.ResourceID,
wsURL: sttCfg.WsURL,
uid: sttCfg.UID,
rate: sttCfg.Rate,
format: sttCfg.Format,
bits: sttCfg.Bits,
channel: sttCfg.Channel,
codec: sttCfg.Codec,
accessKey: sttCfg.AccessKey,
appKey: sttCfg.AppKey,
resourceID: sttCfg.ResourceID,
}
}

Expand Down
3 changes: 3 additions & 0 deletions internal/tts/tts.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"github.com/telepace/voiceflow/internal/tts/azure"
"github.com/telepace/voiceflow/internal/tts/google"
"github.com/telepace/voiceflow/internal/tts/local"
"github.com/telepace/voiceflow/internal/tts/volcengine"
"github.com/telepace/voiceflow/pkg/logger"
)

Expand All @@ -22,6 +23,8 @@ func NewService(provider string) Service {
return azure.NewAzureTTS() // 调用 Azure TTS 实现
case "google":
return google.NewGoogleTTS() // 调用 Google TTS 实现
case "volcengine":
return volcengine.NewVolcengineTTS()
case "local":
return local.NewLocalTTS() // 调用本地 TTS 实现
default:
Expand Down
Loading

0 comments on commit 32e5050

Please sign in to comment.