Skip to content

Commit

Permalink
feat: add common version mange
Browse files Browse the repository at this point in the history
  • Loading branch information
cubxxw committed Dec 25, 2024
1 parent e026d09 commit 9f0ee43
Showing 1 changed file with 52 additions and 42 deletions.
94 changes: 52 additions & 42 deletions internal/stt/assemblyai/assemblyai.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,50 +37,29 @@ func (s *STT) Recognize(audioData []byte, audioURL string) (string, error) {
// 使用提供的 audioURL 调用 AssemblyAI 的转录服务
return s.transcribeFromURL(audioURL)
}

// 原有的处理流程,直接使用音频数据
return s.transcribeFromData(audioData)
}

func (s *STT) transcribeFromURL(audioURL string) (string, error) {
ctx := context.Background()

// 第一次尝试:启用语言检测
params := &aai.TranscriptOptionalParams{
LanguageDetection: aai.Bool(true),
LanguageConfidenceThreshold: aai.Float64(0.1), // 设置较低的初始阈值
Punctuate: aai.Bool(s.cfg.AssemblyAI.Punctuate),
FormatText: aai.Bool(s.cfg.AssemblyAI.FormatText),
SpeechThreshold: aai.Float64(s.cfg.AssemblyAI.SpeechThreshold),
Multichannel: aai.Bool(s.cfg.AssemblyAI.Multichannel),
}

// 第一次尝试:使用语言检测
params := s.buildParams()
transcript, err := s.client.Transcripts.TranscribeFromURL(ctx, audioURL, params)
if err != nil {
// 检查是否是语言置信度错误
if s.isLanguageConfidenceError(err) && s.cfg.AssemblyAI.DefaultLanguageCode != "" {
logger.Infof("第一次尝试失败(语言置信度低), 使用默认语言 %s 重试",
// 使用默认语言重试
logger.Infof("语言置信度低于阈值 %.2f,使用默认语言 %s 重试",
s.cfg.AssemblyAI.LanguageConfidenceThreshold,
s.cfg.AssemblyAI.DefaultLanguageCode)

// 第二次尝试:禁用语言检测,使用固定语言
retryParams := &aai.TranscriptOptionalParams{
LanguageDetection: aai.Bool(false), // 明确禁用语言检测
LanguageCode: aai.TranscriptLanguageCode(s.cfg.AssemblyAI.DefaultLanguageCode),
// 基础参数
Punctuate: aai.Bool(s.cfg.AssemblyAI.Punctuate),
FormatText: aai.Bool(s.cfg.AssemblyAI.FormatText),
SpeechThreshold: aai.Float64(s.cfg.AssemblyAI.SpeechThreshold),
Multichannel: aai.Bool(s.cfg.AssemblyAI.Multichannel),
// 不再设置 LanguageConfidenceThreshold
}

// 记录重试请求参数
logger.Debugf("重试请求参数: %+v", retryParams)

transcript, err = s.client.Transcripts.TranscribeFromURL(ctx, audioURL, retryParams)
// 构建新的参数,使用默认语言(禁用自动检测、去掉 threshold)
params = s.buildParamsWithDefaultLanguage()
transcript, err = s.client.Transcripts.TranscribeFromURL(ctx, audioURL, params)
if err != nil {
return "", fmt.Errorf("使用默认语言 %s 重试失败: %v",
s.cfg.AssemblyAI.DefaultLanguageCode, err)
return "", fmt.Errorf("使用默认语言重试失败: %v", err)
}
} else {
return "", fmt.Errorf("转录请求失败: %v", err)
Expand Down Expand Up @@ -145,25 +124,39 @@ func (s *STT) StreamRecognize(ctx context.Context, audioDataChan <-chan []byte,
return fmt.Errorf("AssemblyAI 不支持流式处理")
}

// buildParams 将 config.yaml 中的字段映射到 AssemblyAI 的 TranscriptOptionalParams
// buildParams 将 config.yaml 中的字段映射到 AssemblyAI 的 TranscriptOptionalParams(第一次请求用)
func (s *STT) buildParams() *aai.TranscriptOptionalParams {
aaiCfg := s.cfg.AssemblyAI

params := &aai.TranscriptOptionalParams{
SpeechModel: aai.SpeechModel(aaiCfg.Model),
Punctuate: aai.Bool(aaiCfg.Punctuate),
FormatText: aai.Bool(aaiCfg.FormatText),
SpeechThreshold: aai.Float64(aaiCfg.SpeechThreshold),
Multichannel: aai.Bool(aaiCfg.Multichannel),
// 将 string 转换为 SpeechModel 类型
SpeechModel: aai.SpeechModel(aaiCfg.Model),
LanguageDetection: aai.Bool(aaiCfg.LanguageDetection),
LanguageConfidenceThreshold: aai.Float64(aaiCfg.LanguageConfidenceThreshold),
Punctuate: aai.Bool(aaiCfg.Punctuate),
FormatText: aai.Bool(aaiCfg.FormatText),
Disfluencies: aai.Bool(aaiCfg.Disfluencies),
FilterProfanity: aai.Bool(aaiCfg.FilterProfanity),
AudioStartFrom: aai.Int64(aaiCfg.AudioStartFrom),
AudioEndAt: aai.Int64(aaiCfg.AudioEndAt),
SpeechThreshold: aai.Float64(aaiCfg.SpeechThreshold),
Multichannel: aai.Bool(aaiCfg.Multichannel),
}

// 如果设置了固定的 language_code,则禁用语言检测并指定语言
if aaiCfg.LanguageCode != "" {
params.LanguageDetection = aai.Bool(false)
params.LanguageCode = aai.TranscriptLanguageCode(aaiCfg.LanguageCode)
}

// 词汇增强设置
// 如果配置了词汇增强
if len(aaiCfg.WordBoost) > 0 {
params.WordBoost = aaiCfg.WordBoost
// 将 string 转换为 TranscriptBoostParam 类型
params.BoostParam = aai.TranscriptBoostParam(aaiCfg.BoostParam)
}

// 自定义拼写设置
// 如果配置了自定义拼写
if len(aaiCfg.CustomSpelling) > 0 {
var customSpellings []aai.TranscriptCustomSpelling
for _, cs := range aaiCfg.CustomSpelling {
Expand All @@ -178,9 +171,26 @@ func (s *STT) buildParams() *aai.TranscriptOptionalParams {
return params
}

// isLanguageConfidenceError 优化错误检测逻辑
// 新增:检查是否是语言置信度错误
func (s *STT) isLanguageConfidenceError(err error) bool {
errMsg := err.Error()
return strings.Contains(errMsg, "below the requested confidence threshold") ||
strings.Contains(errMsg, "confidence threshold value")
return strings.Contains(err.Error(), "below the requested confidence threshold value")
}

// **优化后的关键点**:使用默认语言构建参数(禁用自动检测,不再带 threshold)
func (s *STT) buildParamsWithDefaultLanguage() *aai.TranscriptOptionalParams {
// 直接手动指定,不再从 buildParams() 继承
return &aai.TranscriptOptionalParams{
LanguageDetection: aai.Bool(false),
// 在这里写死你要使用的语言
LanguageCode: aai.TranscriptLanguageCode(s.cfg.AssemblyAI.DefaultLanguageCode),

// 以下可按需打开/关闭
Punctuate: aai.Bool(s.cfg.AssemblyAI.Punctuate),
FormatText: aai.Bool(s.cfg.AssemblyAI.FormatText),
Disfluencies: aai.Bool(s.cfg.AssemblyAI.Disfluencies),
FilterProfanity: aai.Bool(s.cfg.AssemblyAI.FilterProfanity),

// 如果想让二次请求也支持别的功能(词汇增强、自定义拼写等),
// 也可以自行在这里加上。但注意不要再设 LanguageConfidenceThreshold。
}
}

0 comments on commit 9f0ee43

Please sign in to comment.