diff --git a/.github/workflows/go-demos.yml b/.github/workflows/go-demos.yml index 4c09fe7e..8e1ecf89 100644 --- a/.github/workflows/go-demos.yml +++ b/.github/workflows/go-demos.yml @@ -74,7 +74,7 @@ jobs: - name: Test filedemo run: ./leopard_file_demo -access_key ${{secrets.PV_VALID_ACCESS_KEY}} -input_audio_path ../../resources/audio_samples/test.wav - + build-grpc-demo: runs-on: ${{ matrix.os }} defaults: diff --git a/binding/go/embedded/lib/common/leopard_params.pv b/binding/go/embedded/lib/common/leopard_params.pv index d87cb150..2a490dfd 100644 Binary files a/binding/go/embedded/lib/common/leopard_params.pv and b/binding/go/embedded/lib/common/leopard_params.pv differ diff --git a/binding/go/embedded/lib/jetson/cortex-a57-aarch64/libpv_leopard.so b/binding/go/embedded/lib/jetson/cortex-a57-aarch64/libpv_leopard.so index eedbbab6..157800d4 100755 Binary files a/binding/go/embedded/lib/jetson/cortex-a57-aarch64/libpv_leopard.so and b/binding/go/embedded/lib/jetson/cortex-a57-aarch64/libpv_leopard.so differ diff --git a/binding/go/embedded/lib/linux/x86_64/libpv_leopard.so b/binding/go/embedded/lib/linux/x86_64/libpv_leopard.so index fd6cce7d..d3f6cd6e 100755 Binary files a/binding/go/embedded/lib/linux/x86_64/libpv_leopard.so and b/binding/go/embedded/lib/linux/x86_64/libpv_leopard.so differ diff --git a/binding/go/embedded/lib/mac/arm64/libpv_leopard.dylib b/binding/go/embedded/lib/mac/arm64/libpv_leopard.dylib index 02ab786b..e5def7ed 100755 Binary files a/binding/go/embedded/lib/mac/arm64/libpv_leopard.dylib and b/binding/go/embedded/lib/mac/arm64/libpv_leopard.dylib differ diff --git a/binding/go/embedded/lib/mac/x86_64/libpv_leopard.dylib b/binding/go/embedded/lib/mac/x86_64/libpv_leopard.dylib index 70307891..4f6b398a 100755 Binary files a/binding/go/embedded/lib/mac/x86_64/libpv_leopard.dylib and b/binding/go/embedded/lib/mac/x86_64/libpv_leopard.dylib differ diff --git a/binding/go/embedded/lib/raspberry-pi/cortex-a53-aarch64/libpv_leopard.so b/binding/go/embedded/lib/raspberry-pi/cortex-a53-aarch64/libpv_leopard.so index b2e2e082..7e14cfb4 100755 Binary files a/binding/go/embedded/lib/raspberry-pi/cortex-a53-aarch64/libpv_leopard.so and b/binding/go/embedded/lib/raspberry-pi/cortex-a53-aarch64/libpv_leopard.so differ diff --git a/binding/go/embedded/lib/raspberry-pi/cortex-a53/libpv_leopard.so b/binding/go/embedded/lib/raspberry-pi/cortex-a53/libpv_leopard.so index 60feba42..0f2d84cf 100755 Binary files a/binding/go/embedded/lib/raspberry-pi/cortex-a53/libpv_leopard.so and b/binding/go/embedded/lib/raspberry-pi/cortex-a53/libpv_leopard.so differ diff --git a/binding/go/embedded/lib/raspberry-pi/cortex-a72-aarch64/libpv_leopard.so b/binding/go/embedded/lib/raspberry-pi/cortex-a72-aarch64/libpv_leopard.so index 80d9804c..546e1c07 100755 Binary files a/binding/go/embedded/lib/raspberry-pi/cortex-a72-aarch64/libpv_leopard.so and b/binding/go/embedded/lib/raspberry-pi/cortex-a72-aarch64/libpv_leopard.so differ diff --git a/binding/go/embedded/lib/raspberry-pi/cortex-a72/libpv_leopard.so b/binding/go/embedded/lib/raspberry-pi/cortex-a72/libpv_leopard.so index 7432d4df..bb3946b6 100755 Binary files a/binding/go/embedded/lib/raspberry-pi/cortex-a72/libpv_leopard.so and b/binding/go/embedded/lib/raspberry-pi/cortex-a72/libpv_leopard.so differ diff --git a/binding/go/embedded/lib/windows/amd64/libpv_leopard.dll b/binding/go/embedded/lib/windows/amd64/libpv_leopard.dll index be8a3dd1..b175080e 100644 Binary files a/binding/go/embedded/lib/windows/amd64/libpv_leopard.dll and b/binding/go/embedded/lib/windows/amd64/libpv_leopard.dll differ diff --git a/binding/go/go.mod b/binding/go/go.mod index d2e317e1..50b76338 100644 --- a/binding/go/go.mod +++ b/binding/go/go.mod @@ -1,3 +1,3 @@ -module github.com/Picovoice/leopard/binding/go +module github.com/Picovoice/leopard/binding/go/v2 go 1.16 diff --git a/binding/go/go_test.mod b/binding/go/go_test.mod index 281658f2..78e6be63 100644 --- a/binding/go/go_test.mod +++ b/binding/go/go_test.mod @@ -1,4 +1,4 @@ -module github.com/Picovoice/leopard/binding/go +module github.com/Picovoice/leopard/binding/go/v2 go 1.16 diff --git a/binding/go/leopard.go b/binding/go/leopard.go index 65b85d8f..601d37b6 100644 --- a/binding/go/leopard.go +++ b/binding/go/leopard.go @@ -1,4 +1,4 @@ -// Copyright 2022 Picovoice Inc. +// Copyright 2022-2023 Picovoice Inc. // // You may not use this file except in compliance with the license. A copy of the license is // located in the "LICENSE" file accompanying this source. @@ -52,8 +52,9 @@ const ( ) type LeopardError struct { - StatusCode PvStatus - Message string + StatusCode PvStatus + Message string + MessageStack []string } type leopardExts struct { @@ -61,7 +62,17 @@ type leopardExts struct { } func (e *LeopardError) Error() string { - return fmt.Sprintf("%s: %s", pvStatusToString(e.StatusCode), e.Message) + var message strings.Builder + message.WriteString(fmt.Sprintf("%s: %s", pvStatusToString(e.StatusCode), e.Message)) + + if len(e.MessageStack) > 0 { + message.WriteString(":") + } + + for i, value := range e.MessageStack { + message.WriteString(fmt.Sprintf("\n [%d] %s", i, value)) + } + return message.String() } // Leopard struct @@ -80,10 +91,14 @@ type Leopard struct { // Flag to enable automatic punctuation insertion. EnableAutomaticPunctuation bool + + // Flag to enable speaker diarization, which allows Leopard to differentiate speakers as part of the transcription process. + // Word metadata will include a `SpeakerTag` to identify unique speakers. + EnableDiarization bool } type LeopardWord struct { - // Transcribed word + // Transcribed word. Word string // Start of word in seconds. @@ -94,6 +109,10 @@ type LeopardWord struct { // Transcription confidence. It is a number within [0, 1]. Confidence float32 + + // Unique speaker identifier. It is `-1` if diarization is not enabled during initialization; otherwise, + // it's a non-negative integer identifying unique speakers, with `0` reserved for unknown speakers. + SpeakerTag int32 } // private vars @@ -123,6 +142,7 @@ func NewLeopard(accessKey string) Leopard { ModelPath: defaultModelFile, LibraryPath: defaultLibPath, EnableAutomaticPunctuation: false, + EnableDiarization: false, } } @@ -130,8 +150,8 @@ func NewLeopard(accessKey string) Leopard { func (leopard *Leopard) Init() error { if leopard.AccessKey == "" { return &LeopardError{ - INVALID_ARGUMENT, - "No AccessKey provided to Leopard"} + StatusCode: INVALID_ARGUMENT, + Message: "No AccessKey provided to Leopard"} } if leopard.ModelPath == "" { @@ -144,21 +164,31 @@ func (leopard *Leopard) Init() error { if _, err := os.Stat(leopard.LibraryPath); os.IsNotExist(err) { return &LeopardError{ - INVALID_ARGUMENT, - fmt.Sprintf("Specified library file could not be found at %s", leopard.LibraryPath)} + StatusCode: INVALID_ARGUMENT, + Message: fmt.Sprintf("Specified library file could not be found at %s", leopard.LibraryPath)} } if _, err := os.Stat(leopard.ModelPath); os.IsNotExist(err) { return &LeopardError{ - INVALID_ARGUMENT, - fmt.Sprintf("Specified model file could not be found at %s", leopard.ModelPath)} + StatusCode: INVALID_ARGUMENT, + Message: fmt.Sprintf("Specified model file could not be found at %s", leopard.ModelPath)} } ret := nativeLeopard.nativeInit(leopard) - if PvStatus(ret) != SUCCESS { + if ret != SUCCESS { + errorStatus, messageStack := nativeLeopard.nativeGetErrorStack() + if errorStatus != SUCCESS { + return &LeopardError{ + StatusCode: errorStatus, + Message: "Unable to get Leopard error state", + } + } + return &LeopardError{ - PvStatus(ret), - "Leopard init failed."} + StatusCode: ret, + Message: "Leopard init failed", + MessageStack: messageStack, + } } SampleRate = nativeLeopard.nativeSampleRate() @@ -171,8 +201,8 @@ func (leopard *Leopard) Init() error { func (leopard *Leopard) Delete() error { if leopard.handle == nil { return &LeopardError{ - INVALID_STATE, - "Leopard has not been initialized or has already been deleted"} + StatusCode: INVALID_STATE, + Message: "Leopard has not been initialized or has already been deleted"} } nativeLeopard.nativeDelete(leopard) @@ -187,21 +217,31 @@ func (leopard *Leopard) Delete() error { func (leopard *Leopard) Process(pcm []int16) (string, []LeopardWord, error) { if leopard.handle == nil { return "", nil, &LeopardError{ - INVALID_STATE, - "Leopard has not been initialized or has already been deleted"} + StatusCode: INVALID_STATE, + Message: "Leopard has not been initialized or has already been deleted"} } if len(pcm) == 0 { return "", nil, &LeopardError{ - INVALID_ARGUMENT, - "Audio data must not be empty"} + StatusCode: INVALID_ARGUMENT, + Message: "Audio data must not be empty"} } ret, transcript, words := nativeLeopard.nativeProcess(leopard, pcm) - if PvStatus(ret) != SUCCESS { + if ret != SUCCESS { + errorStatus, messageStack := nativeLeopard.nativeGetErrorStack() + if errorStatus != SUCCESS { + return "", nil, &LeopardError{ + StatusCode: errorStatus, + Message: "Unable to get Leopard error state", + } + } + return "", nil, &LeopardError{ - PvStatus(ret), - "Leopard process failed."} + StatusCode: ret, + Message: "Leopard process failed", + MessageStack: messageStack, + } } return transcript, words, nil @@ -213,29 +253,31 @@ func (leopard *Leopard) Process(pcm []int16) (string, []LeopardWord, error) { func (leopard *Leopard) ProcessFile(audioPath string) (string, []LeopardWord, error) { if leopard.handle == nil { return "", nil, &LeopardError{ - INVALID_STATE, - "Leopard has not been initialized or has already been deleted"} + StatusCode: INVALID_STATE, + Message: "Leopard has not been initialized or has already been deleted"} } if _, err := os.Stat(audioPath); os.IsNotExist(err) { return "", nil, &LeopardError{ - INVALID_ARGUMENT, - fmt.Sprintf("Specified file could not be found at '%s'", audioPath)} + StatusCode: INVALID_ARGUMENT, + Message: fmt.Sprintf("Specified file could not be found at '%s'", audioPath)} } ret, transcript, words := nativeLeopard.nativeProcessFile(leopard, audioPath) if ret != SUCCESS { - if ret == INVALID_ARGUMENT { - fileExtension := filepath.Ext(audioPath) - if !validExtensions.includes(fileExtension) { - return "", nil, &LeopardError{ - INVALID_ARGUMENT, - fmt.Sprintf("Specified file with extension '%s' is not supported", fileExtension)} + errorStatus, messageStack := nativeLeopard.nativeGetErrorStack() + if errorStatus != SUCCESS { + return "", nil, &LeopardError{ + StatusCode: errorStatus, + Message: "Unable to get Leopard error state", } } + return "", nil, &LeopardError{ - PvStatus(ret), - "Leopard process failed."} + StatusCode: ret, + Message: "Leopard process failed", + MessageStack: messageStack, + } } return transcript, words, nil diff --git a/binding/go/leopard_native.go b/binding/go/leopard_native.go index 61c9c1a8..a8b9fcfd 100644 --- a/binding/go/leopard_native.go +++ b/binding/go/leopard_native.go @@ -22,11 +22,11 @@ package leopard #if defined(_WIN32) || defined(_WIN64) - #include + #include #else - #include + #include #endif @@ -34,11 +34,11 @@ static void *open_dl(const char *dl_path) { #if defined(_WIN32) || defined(_WIN64) - return LoadLibrary((LPCSTR) dl_path); + return LoadLibrary((LPCSTR) dl_path); #else - return dlopen(dl_path, RTLD_NOW); + return dlopen(dl_path, RTLD_NOW); #endif @@ -48,105 +48,109 @@ static void *load_symbol(void *handle, const char *symbol) { #if defined(_WIN32) || defined(_WIN64) - return GetProcAddress((HMODULE) handle, symbol); + return GetProcAddress((HMODULE) handle, symbol); #else - return dlsym(handle, symbol); + return dlsym(handle, symbol); #endif } typedef struct { - const char *word; - float start_sec; - float end_sec; - float confidence; + const char *word; + float start_sec; + float end_sec; + float confidence; + int32_t speaker_tag; } pv_word_t; typedef int32_t (*pv_leopard_sample_rate_func)(); int32_t pv_leopard_sample_rate_wrapper(void *f) { - return ((pv_leopard_sample_rate_func) f)(); + return ((pv_leopard_sample_rate_func) f)(); } typedef char* (*pv_leopard_version_func)(); char* pv_leopard_version_wrapper(void* f) { - return ((pv_leopard_version_func) f)(); + return ((pv_leopard_version_func) f)(); } typedef int32_t (*pv_leopard_init_func)( - const char *access_key, - const char *model_path, - bool enable_punctuation_detection, - void **object); + const char *access_key, + const char *model_path, + bool enable_punctuation_detection, + bool enable_diarization, + void **object); int32_t pv_leopard_init_wrapper( - void *f, - const char *access_key, - const char *model_path, - bool enable_punctuation_detection, - void **object) { - return ((pv_leopard_init_func) f)( - access_key, - model_path, - enable_punctuation_detection, - object); + void *f, + const char *access_key, + const char *model_path, + bool enable_punctuation_detection, + bool enable_diarization, + void **object) { + return ((pv_leopard_init_func) f)( + access_key, + model_path, + enable_punctuation_detection, + enable_diarization, + object); } typedef int32_t (*pv_leopard_process_func)( - void *object, - const int16_t *pcm, - int32_t num_samples, - char **transcript, - int32_t *num_words, - pv_word_t **words); + void *object, + const int16_t *pcm, + int32_t num_samples, + char **transcript, + int32_t *num_words, + pv_word_t **words); int32_t pv_leopard_process_wrapper( - void *f, - void *object, - const int16_t *pcm, - int32_t num_samples, - char **transcript, - int32_t *num_words, - pv_word_t **words) { - return ((pv_leopard_process_func) f)( - object, - pcm, - num_samples, - transcript, - num_words, - words); + void *f, + void *object, + const int16_t *pcm, + int32_t num_samples, + char **transcript, + int32_t *num_words, + pv_word_t **words) { + return ((pv_leopard_process_func) f)( + object, + pcm, + num_samples, + transcript, + num_words, + words); } typedef int32_t (*pv_leopard_process_file_func)( - void *object, - const char *audio_path, - char **transcript, - int32_t *num_words, - pv_word_t **words); + void *object, + const char *audio_path, + char **transcript, + int32_t *num_words, + pv_word_t **words); int32_t pv_leopard_process_file_wrapper( - void *f, - void *object, - const char *audio_path, - char **transcript, - int32_t *num_words, - pv_word_t **words) { - return ((pv_leopard_process_file_func) f)( - object, - audio_path, - transcript, - num_words, - words); + void *f, + void *object, + const char *audio_path, + char **transcript, + int32_t *num_words, + pv_word_t **words) { + return ((pv_leopard_process_file_func) f)( + object, + audio_path, + transcript, + num_words, + words); } typedef void (*pv_leopard_delete_func)(void *); void pv_leopard_delete_wrapper(void *f, void *object) { - return ((pv_leopard_delete_func) f)(object); + return ((pv_leopard_delete_func) f)(object); } typedef void (*pv_leopard_transcript_delete_func)(char *); @@ -161,6 +165,29 @@ void pv_leopard_words_delete_wrapper(void *f, pv_word_t *words) { return ((pv_leopard_words_delete_func) f)(words); } +typedef void (*pv_set_sdk_func)(const char *); + +void pv_set_sdk_wrapper(void *f, const char *sdk) { + return ((pv_set_sdk_func) f)(sdk); +} + +typedef int32_t (*pv_get_error_stack_func)(char ***, int32_t *); + +int32_t pv_get_error_stack_wrapper( + void *f, + char ***message_stack, + int32_t *message_stack_depth) { + return ((pv_get_error_stack_func) f)(message_stack, message_stack_depth); +} + +typedef void (*pv_free_error_stack_func)(char **); + +void pv_free_error_stack_wrapper( + void *f, + char **message_stack) { + return ((pv_free_error_stack_func) f)(message_stack); +} + */ import "C" @@ -175,6 +202,7 @@ type nativeLeopardInterface interface { nativeDelete(*Leopard) nativeSampleRate() nativeVersion() + nativeGetErrorStack() } type nativeLeopardType struct { libraryHandle unsafe.Pointer @@ -186,6 +214,9 @@ type nativeLeopardType struct { pv_leopard_words_delete_ptr unsafe.Pointer pv_leopard_version_ptr unsafe.Pointer pv_sample_rate_ptr unsafe.Pointer + pv_set_sdk_ptr unsafe.Pointer + pv_get_error_stack_ptr unsafe.Pointer + pv_free_error_stack_ptr unsafe.Pointer } func (nl *nativeLeopardType) nativeInit(leopard *Leopard) (status PvStatus) { @@ -194,6 +225,7 @@ func (nl *nativeLeopardType) nativeInit(leopard *Leopard) (status PvStatus) { modelPathC = C.CString(leopard.ModelPath) libraryPathC = C.CString(leopard.LibraryPath) enableAutomaticPunctuationC = C.bool(leopard.EnableAutomaticPunctuation) + enableDiarizationC = C.bool(leopard.EnableDiarization) ) defer C.free(unsafe.Pointer(accessKeyC)) defer C.free(unsafe.Pointer(modelPathC)) @@ -208,12 +240,20 @@ func (nl *nativeLeopardType) nativeInit(leopard *Leopard) (status PvStatus) { nl.pv_leopard_words_delete_ptr = C.load_symbol(nl.libraryHandle, C.CString("pv_leopard_words_delete")) nl.pv_leopard_version_ptr = C.load_symbol(nl.libraryHandle, C.CString("pv_leopard_version")) nl.pv_sample_rate_ptr = C.load_symbol(nl.libraryHandle, C.CString("pv_sample_rate")) + nl.pv_set_sdk_ptr = C.load_symbol(nl.libraryHandle, C.CString("pv_set_sdk")) + nl.pv_get_error_stack_ptr = C.load_symbol(nl.libraryHandle, C.CString("pv_get_error_stack")) + nl.pv_free_error_stack_ptr = C.load_symbol(nl.libraryHandle, C.CString("pv_free_error_stack")) + + C.pv_set_sdk_wrapper( + nl.pv_set_sdk_ptr, + C.CString("go")) var ret = C.pv_leopard_init_wrapper( nl.pv_leopard_init_ptr, accessKeyC, modelPathC, enableAutomaticPunctuationC, + enableDiarizationC, &leopard.handle) return PvStatus(ret) @@ -251,6 +291,7 @@ func (nl *nativeLeopardType) nativeProcess(leopard *Leopard, pcm []int16) (statu StartSec: float32(cWords[i].start_sec), EndSec: float32(cWords[i].end_sec), Confidence: float32(cWords[i].confidence), + SpeakerTag: int32(cWords[i].speaker_tag), } words = append(words, n) } @@ -290,6 +331,7 @@ func (nl *nativeLeopardType) nativeProcessFile(leopard *Leopard, audioPath strin StartSec: float32(cWords[i].start_sec), EndSec: float32(cWords[i].end_sec), Confidence: float32(cWords[i].confidence), + SpeakerTag: int32(cWords[i].speaker_tag), } words = append(words, n) } @@ -308,3 +350,32 @@ func (nl *nativeLeopardType) nativeSampleRate() (sampleRate int) { func (nl *nativeLeopardType) nativeVersion() (version string) { return C.GoString(C.pv_leopard_version_wrapper(nl.pv_leopard_version_ptr)) } + +func (nl *nativeLeopardType) nativeGetErrorStack() (status PvStatus, messageStack []string) { + var messageStackDepthRef C.int32_t + var messageStackRef **C.char + + var ret = C.pv_get_error_stack_wrapper( + nl.pv_get_error_stack_ptr, + &messageStackRef, + &messageStackDepthRef) + + if PvStatus(ret) != SUCCESS { + return PvStatus(ret), []string{} + } + + defer C.pv_free_error_stack_wrapper( + nl.pv_free_error_stack_ptr, + messageStackRef) + + messageStackDepth := int(messageStackDepthRef) + messageStackSlice := (*[1 << 28]*C.char)(unsafe.Pointer(messageStackRef))[:messageStackDepth:messageStackDepth] + + messageStack = make([]string, messageStackDepth) + + for i := 0; i < messageStackDepth; i++ { + messageStack[i] = C.GoString(messageStackSlice[i]) + } + + return PvStatus(ret), messageStack +} diff --git a/binding/go/leopard_test.go b/binding/go/leopard_test.go index 8335d342..f5834ffc 100644 --- a/binding/go/leopard_test.go +++ b/binding/go/leopard_test.go @@ -9,8 +9,6 @@ // limitations under the License. // -// Go binding for Leopard Speech-to-Text engine. - package leopard import ( @@ -19,6 +17,7 @@ import ( "flag" "io/ioutil" "log" + "math" "os" "path/filepath" "reflect" @@ -28,18 +27,26 @@ import ( "github.com/agnivade/levenshtein" ) -type TestParameters struct { - language string - testAudioFile string - transcript string - errorRate float32 - enableAutomaticPunctuation bool +type LanguageTestParameters struct { + language string + testAudioFile string + transcript string + transcriptWithPunctuation string + errorRate float32 + words []LeopardWord +} + +type DiarizationTestParameters struct { + language string + testAudioFile string + words []LeopardWord } var ( - testAccessKey string - leopard Leopard - processTestParameters []TestParameters + testAccessKey string + leopard Leopard + languageTests []LanguageTestParameters + diarizationTests []DiarizationTestParameters ) func TestMain(m *testing.M) { @@ -47,10 +54,14 @@ func TestMain(m *testing.M) { flag.StringVar(&testAccessKey, "access_key", "", "AccessKey for testing") flag.Parse() - processTestParameters = loadTestData() + languageTests, diarizationTests = loadTestData() os.Exit(m.Run()) } +func isClose(value, expected, tolerance float32) bool { + return math.Abs(float64(value-expected)) <= float64(tolerance) +} + func appendLanguage(s string, language string) string { if language == "en" { return s @@ -58,8 +69,8 @@ func appendLanguage(s string, language string) string { return s + "_" + language } } -func loadTestData() []TestParameters { +func loadTestData() (languageTests []LanguageTestParameters, diarizationTests []DiarizationTestParameters) { content, err := ioutil.ReadFile("../../resources/.test/test_data.json") if err != nil { log.Fatalf("Could not read test data json: %v", err) @@ -67,13 +78,28 @@ func loadTestData() []TestParameters { var testData struct { Tests struct { - Parameters []struct { - Language string `json:"language"` - AudioFile string `json:"audio_file"` - Transcript string `json:"transcript"` - Punctuations []string `json:"punctuations"` - ErrorRate float32 `json:"error_rate"` - } `json:"parameters"` + LanguageTests []struct { + Language string `json:"language"` + AudioFile string `json:"audio_file"` + Transcript string `json:"transcript"` + TranscriptWithPunctuation string `json:"transcript_with_punctuation"` + ErrorRate float32 `json:"error_rate"` + Words []struct { + Word string `json:"word"` + StartSec float32 `json:"start_sec"` + EndSec float32 `json:"end_sec"` + Confidence float32 `json:"confidence"` + SpeakerTag int32 `json:"speaker_tag"` + } `json:"words"` + } `json:"language_tests"` + DiarizationTests []struct { + Language string `json:"language"` + AudioFile string `json:"audio_file"` + Words []struct { + Word string `json:"word"` + SpeakerTag int32 `json:"speaker_tag"` + } `json:"words"` + } `json:"diarization_tests"` } `json:"tests"` } err = json.Unmarshal(content, &testData) @@ -81,59 +107,78 @@ func loadTestData() []TestParameters { log.Fatalf("Could not decode test data json: %v", err) } - for _, x := range testData.Tests.Parameters { - testCaseWithPunctuation := TestParameters{ - language: x.Language, - testAudioFile: x.AudioFile, - transcript: x.Transcript, - enableAutomaticPunctuation: true, - errorRate: x.ErrorRate, + for _, x := range testData.Tests.LanguageTests { + languageTestParameters := LanguageTestParameters{ + language: x.Language, + testAudioFile: x.AudioFile, + transcript: x.Transcript, + transcriptWithPunctuation: x.TranscriptWithPunctuation, + errorRate: x.ErrorRate, } - processTestParameters = append(processTestParameters, testCaseWithPunctuation) - transcriptWithoutPunctuation := x.Transcript - for _, p := range x.Punctuations { - transcriptWithoutPunctuation = strings.ReplaceAll(transcriptWithoutPunctuation, p, "") + for _, y := range x.Words { + word := LeopardWord{ + Word: y.Word, + StartSec: y.StartSec, + EndSec: y.EndSec, + Confidence: y.Confidence, + SpeakerTag: y.SpeakerTag, + } + languageTestParameters.words = append(languageTestParameters.words, word) } - testCaseWithoutPunctuation := TestParameters{ - language: x.Language, - testAudioFile: x.AudioFile, - transcript: transcriptWithoutPunctuation, - enableAutomaticPunctuation: false, - errorRate: x.ErrorRate, + + languageTests = append(languageTests, languageTestParameters) + } + + for _, x := range testData.Tests.DiarizationTests { + diarizationTestParameters := DiarizationTestParameters{ + language: x.Language, + testAudioFile: x.AudioFile, } - processTestParameters = append(processTestParameters, testCaseWithoutPunctuation) + + for _, y := range x.Words { + word := LeopardWord{ + Word: y.Word, + SpeakerTag: y.SpeakerTag, + } + diarizationTestParameters.words = append(diarizationTestParameters.words, word) + } + + diarizationTests = append(diarizationTests, diarizationTestParameters) } - return processTestParameters + return languageTests, diarizationTests } -func validateMetadata(t *testing.T, transcript string, words []LeopardWord, audioLength float32) { - transcriptUpperCase := strings.ToUpper(transcript) +func validateMetadata(t *testing.T, referenceWords []LeopardWord, words []LeopardWord, enableDiarization bool) { + if len(words) != len(referenceWords) { + t.Fatalf("Word count `%d` did not match expected word count `%d`", len(words), len(referenceWords)) + } + for i := range words { - wordUpperCase := strings.ToUpper(words[i].Word) - if !strings.Contains(transcriptUpperCase, wordUpperCase) { - t.Fatalf("Word `%s` was not in transcript `%s`", wordUpperCase, transcriptUpperCase) + word := strings.ToUpper(words[i].Word) + referenceWord := strings.ToUpper(referenceWords[i].Word) + if word != referenceWord { + t.Fatalf("Word `%s` did not match expected word `%s`", word, referenceWord) } - if words[i].StartSec <= 0 { - t.Fatalf("Word %d started at %f", i, words[i].StartSec) + if !isClose(words[i].StartSec, referenceWords[i].StartSec, 0.1) { + t.Fatalf("Word %d started at %f, expected %f", i, words[i].StartSec, referenceWords[i].StartSec) } - if words[i].StartSec > words[i].EndSec { - t.Fatalf("Word %d had a start time of %f, but and end time of %f", i, words[i].StartSec, words[i].EndSec) + if !isClose(words[i].EndSec, referenceWords[i].EndSec, 0.1) { + t.Fatalf("Word %d ended at %f, expected %f", i, words[i].EndSec, referenceWords[i].EndSec) } - if i < len(words)-1 { - if words[i].EndSec > words[i+1].StartSec { - t.Fatalf("Word %d had an end time of %f, next word had a start time of %f", i, words[i].EndSec, words[i+1].StartSec) + if !isClose(words[i].Confidence, referenceWords[i].Confidence, 0.1) { + t.Fatalf("Word %d had a confidence of %f, expected %f", i, words[i].Confidence, referenceWords[i].Confidence) + } + if enableDiarization { + if words[i].SpeakerTag != referenceWords[i].SpeakerTag { + t.Fatalf("Word %d had speaker_tag of %d, expected %d", i, words[i].SpeakerTag, referenceWords[i].SpeakerTag) } } else { - if words[i].EndSec > audioLength { - t.Fatalf("Word %d had an end time of %f, audio length is %f", i, words[i].EndSec, audioLength) + if words[i].SpeakerTag != -1 { + t.Fatalf("Word %d had speaker_tag of %d, expected -1", i, words[i].SpeakerTag) } } - - if words[i].Confidence < 0 || words[i].Confidence > 1 { - t.Fatalf("Word %d had an invalid confidence value of %f", i, words[i].Confidence) - } } } @@ -160,10 +205,13 @@ func runProcessTestCase( testAudioFile string, referenceTranscript string, targetErrorRate float32, - enableAutomaticPunctuation bool) { + enableAutomaticPunctuation bool, + enableDiarization bool, + referenceWords []LeopardWord) { leopard = NewLeopard(testAccessKey) leopard.EnableAutomaticPunctuation = enableAutomaticPunctuation + leopard.EnableDiarization = enableDiarization modelPath, _ := filepath.Abs(filepath.Join("../../lib/common", appendLanguage("leopard_params", language)+".pv")) leopard.ModelPath = modelPath @@ -193,15 +241,12 @@ func runProcessTestCase( t.Fatalf("Failed to process pcm buffer: %v", err) } - t.Logf("%s", transcript) - t.Logf("%s", referenceTranscript) - errorRate := float32(levenshtein.ComputeDistance(transcript, referenceTranscript)) / float32(len(referenceTranscript)) if errorRate >= targetErrorRate { t.Fatalf("Expected '%f' got '%f'", targetErrorRate, errorRate) } - validateMetadata(t, transcript, words, float32(len(pcm))/float32(SampleRate)) + validateMetadata(t, referenceWords, words, enableDiarization) } func runProcessFileTestCase( @@ -210,10 +255,13 @@ func runProcessFileTestCase( testAudioFile string, referenceTranscript string, targetErrorRate float32, - enableAutomaticPunctuation bool) { + enableAutomaticPunctuation bool, + enableDiarization bool, + referenceWords []LeopardWord) { leopard = NewLeopard(testAccessKey) leopard.EnableAutomaticPunctuation = enableAutomaticPunctuation + leopard.EnableDiarization = enableDiarization modelPath, _ := filepath.Abs(filepath.Join("../../lib/common", appendLanguage("leopard_params", language)+".pv")) leopard.ModelPath = modelPath @@ -234,26 +282,116 @@ func runProcessFileTestCase( t.Fatalf("Expected '%f' got '%f'", targetErrorRate, errorRate) } - data, err := ioutil.ReadFile(testAudioPath) + validateMetadata(t, referenceWords, words, enableDiarization) +} + +func runDiarizationTestCase( + t *testing.T, + language string, + testAudioFile string, + referenceWords []LeopardWord) { + + leopard = NewLeopard(testAccessKey) + leopard.EnableDiarization = true + + modelPath, _ := filepath.Abs(filepath.Join("../../lib/common", appendLanguage("leopard_params", language)+".pv")) + leopard.ModelPath = modelPath + + err := leopard.Init() if err != nil { - t.Fatalf("Could not read test file: %v", err) + log.Fatalf("Failed to init leopard with: %v", err) } - data = data[44:] // skip header + defer leopard.Delete() - validateMetadata(t, transcript, words, (float32(len(data))/float32(2))/float32(SampleRate)) + testAudioPath, _ := filepath.Abs(filepath.Join("../../resources/audio_samples", testAudioFile)) + _, words, err := leopard.ProcessFile(testAudioPath) + if err != nil { + t.Fatalf("Failed to process pcm buffer: %v", err) + } + + if len(words) != len(referenceWords) { + t.Fatalf("Word count `%d` did not match expected word count `%d`", len(words), len(referenceWords)) + } + for i := range words { + word := strings.ToUpper(words[i].Word) + referenceWord := strings.ToUpper(referenceWords[i].Word) + if word != referenceWord { + t.Fatalf("Word `%s` did not match expected word `%s`", word, referenceWord) + } + if words[i].SpeakerTag != referenceWords[i].SpeakerTag { + t.Fatalf("Word %d had speaker_tag of %d, expected %d", i, words[i].SpeakerTag, referenceWords[i].SpeakerTag) + } + } } func TestProcess(t *testing.T) { - for _, test := range processTestParameters { + for _, test := range languageTests { t.Logf("Running process data test for `%s`", test.language) - runProcessTestCase(t, test.language, test.testAudioFile, test.transcript, test.errorRate, test.enableAutomaticPunctuation) + runProcessTestCase( + t, + test.language, + test.testAudioFile, + test.transcript, + test.errorRate, + false, + false, + test.words) } } func TestProcessFile(t *testing.T) { - for _, test := range processTestParameters { + for _, test := range languageTests { t.Logf("Running process file test for `%s`", test.language) - runProcessTestCase(t, test.language, test.testAudioFile, test.transcript, test.errorRate, test.enableAutomaticPunctuation) + runProcessFileTestCase( + t, + test.language, + test.testAudioFile, + test.transcript, + test.errorRate, + false, + false, + test.words) + } +} + +func TestProcessFileWithPunctuation(t *testing.T) { + for _, test := range languageTests { + t.Logf("Running process file with punctuation test for `%s`", test.language) + runProcessFileTestCase( + t, + test.language, + test.testAudioFile, + test.transcriptWithPunctuation, + test.errorRate, + true, + false, + test.words) + } +} + +func TestProcessFileWithDiarization(t *testing.T) { + for _, test := range languageTests { + t.Logf("Running process file with diarization test for `%s`", test.language) + runProcessFileTestCase( + t, + test.language, + test.testAudioFile, + test.transcript, + test.errorRate, + false, + true, + test.words) + } +} + +func TestDiarization(t *testing.T) { + for _, test := range diarizationTests { + t.Logf("Running diarization test for `%s`", test.language) + runDiarizationTestCase( + t, + test.language, + test.testAudioFile, + test.words) } } @@ -278,3 +416,41 @@ func TestProcessEmptyFile(t *testing.T) { t.Fatalf("Leopard returned %d words on empty file", len(words)) } } + +func TestMessageStack(t *testing.T) { + leopard = NewLeopard("invalid access key") + err := leopard.Init() + err2 := leopard.Init() + + if len(err.Error()) > 1024 { + t.Fatalf("length of error is full: '%d'", len(err.Error())) + } + + if len(err2.Error()) != len(err.Error()) { + t.Fatalf("length of 1st init '%d' does not match 2nd init '%d'", len(err.Error()), len(err2.Error())) + } +} + +func TestProcessMessageStack(t *testing.T) { + leopard = NewLeopard(testAccessKey) + err := leopard.Init() + if err != nil { + log.Fatalf("Failed to init leopard with: %v", err) + } + + address := leopard.handle + leopard.handle = nil + + testPcm := make([]int16, 1014) + + _, _, err = leopard.Process(testPcm) + leopard.handle = address + if err == nil { + t.Fatalf("Expected leopard process to fail") + } + + delErr := leopard.Delete() + if delErr != nil { + t.Fatalf("%v", delErr) + } +} diff --git a/demo/go-grpc/go.mod b/demo/go-grpc/go.mod index cca6ac35..0a2fc1f9 100644 --- a/demo/go-grpc/go.mod +++ b/demo/go-grpc/go.mod @@ -10,6 +10,7 @@ require ( ) require ( + github.com/Picovoice/leopard/binding/go/v2 v2.0.0-20231121002919-52ce3a2d8268 // indirect github.com/agnivade/levenshtein v1.1.1 // indirect golang.org/x/net v0.7.0 // indirect golang.org/x/sys v0.5.0 // indirect diff --git a/demo/go-grpc/go.sum b/demo/go-grpc/go.sum index 893bdf4f..7a14ee12 100644 --- a/demo/go-grpc/go.sum +++ b/demo/go-grpc/go.sum @@ -1,5 +1,7 @@ github.com/Picovoice/leopard/binding/go v1.2.0 h1:NbUW+Fni5UydvcFlMx8RZtk2pccFaRJFG14kGaUa4CA= github.com/Picovoice/leopard/binding/go v1.2.0/go.mod h1:5kaEg9ZcH2dLkrX/H1xMVF6QFM7l3vd9GKxeXSanA8s= +github.com/Picovoice/leopard/binding/go/v2 v2.0.0-20231121002919-52ce3a2d8268 h1:X0V6i8wC6j4oJD7JeRE3/sqQsS9NqsB0684t7N1ZYs0= +github.com/Picovoice/leopard/binding/go/v2 v2.0.0-20231121002919-52ce3a2d8268/go.mod h1:/rYUeRDH4xBgtwBe9D8BwHIauPJ+M7czqLfyeJQJu7c= github.com/agnivade/levenshtein v1.1.1 h1:QY8M92nrzkmr798gCo3kmMyqXFzdQVpxLlGPRBij0P8= github.com/agnivade/levenshtein v1.1.1/go.mod h1:veldBMzWxcCG2ZvUTKD2kJNRdCk5hVbJomOvKkmgYbo= github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0/go.mod h1:t2tdKJDJF9BV14lnkjHmOQgcvEKgtqs5a1N3LNdJhGE= diff --git a/demo/go/filedemo/leopard_file_demo.go b/demo/go/filedemo/leopard_file_demo.go index 7e3fef2e..0e1cd392 100644 --- a/demo/go/filedemo/leopard_file_demo.go +++ b/demo/go/filedemo/leopard_file_demo.go @@ -17,7 +17,7 @@ import ( "os" "path/filepath" - leopard "github.com/Picovoice/leopard/binding/go" + leopard "github.com/Picovoice/leopard/binding/go/v2" ) func main() { @@ -25,8 +25,9 @@ func main() { modelPathArg := flag.String("model_path", "", "Path to Leopard model file") libraryPathArg := flag.String("library_path", "", "Path to Leopard's dynamic library file") disableAutomaticPunctuationArg := flag.Bool("disable_automatic_punctuation", false, "Disable automatic punctuation") + disableSpeakerDiarizationArg := flag.Bool("disable_speaker_diarization", false, "Disable speaker diarization") verboseArg := flag.Bool("verbose", false, "Enable verbose logging") - inputAudioPathArg := flag.String("input_audio_path", "", "Path to input audio file (mono, valid: `3gp (AMR)`, `FLAC`, `MP3`, `MP4/m4a (AAC)`, `Ogg`, `WAV`, `WebM`, 16-bit)") + inputAudioPathArg := flag.String("input_audio_path", "", "Path to input audio file (mono, 16-bit, valid formats: 3gp (AMR), FLAC, MP3, MP4/m4a (AAC), Ogg, WAV, WebM)") flag.Parse() @@ -43,6 +44,8 @@ func main() { l := leopard.NewLeopard(*accessKeyArg) l.EnableAutomaticPunctuation = !*disableAutomaticPunctuationArg + l.EnableDiarization = !*disableSpeakerDiarizationArg + defer func() { err := l.Delete() if err != nil { @@ -84,9 +87,9 @@ func main() { fmt.Println(transcript) if *verboseArg { - fmt.Printf("|%10s | %15s | %15s | %10s|\n", "word", "Start in Sec", "End in Sec", "Confidence") + fmt.Printf("|%10s | %10s | %10s | %10s | %10s|\n", "Word", "Start (s)", "End (s)", "Confidence", "Speaker Tag") for _, word := range words { - fmt.Printf("|%10s | %15.2f | %15.2f | %10.2f|\n", word.Word, word.StartSec, word.EndSec, word.Confidence) + fmt.Printf("|%10s | %10.2f | %10.2f | %10.2f | %11d|\n", word.Word, word.StartSec, word.EndSec, word.Confidence, word.SpeakerTag) } } } diff --git a/demo/go/go.mod b/demo/go/go.mod index 972a87a7..5d97d33f 100644 --- a/demo/go/go.mod +++ b/demo/go/go.mod @@ -3,7 +3,7 @@ module leoparddemo go 1.16 require ( - github.com/Picovoice/leopard/binding/go v1.2.0 + github.com/Picovoice/leopard/binding/go/v2 v2.0.0-20231121002919-52ce3a2d8268 github.com/Picovoice/pvrecorder/binding/go v1.2.1 github.com/agnivade/levenshtein v1.1.1 // indirect ) diff --git a/demo/go/go.sum b/demo/go/go.sum index 54fffd8b..abec509d 100644 --- a/demo/go/go.sum +++ b/demo/go/go.sum @@ -1,5 +1,5 @@ -github.com/Picovoice/leopard/binding/go v1.2.0 h1:NbUW+Fni5UydvcFlMx8RZtk2pccFaRJFG14kGaUa4CA= -github.com/Picovoice/leopard/binding/go v1.2.0/go.mod h1:5kaEg9ZcH2dLkrX/H1xMVF6QFM7l3vd9GKxeXSanA8s= +github.com/Picovoice/leopard/binding/go/v2 v2.0.0-20231121002919-52ce3a2d8268 h1:X0V6i8wC6j4oJD7JeRE3/sqQsS9NqsB0684t7N1ZYs0= +github.com/Picovoice/leopard/binding/go/v2 v2.0.0-20231121002919-52ce3a2d8268/go.mod h1:/rYUeRDH4xBgtwBe9D8BwHIauPJ+M7czqLfyeJQJu7c= github.com/Picovoice/pvrecorder/binding/go v1.2.1 h1:p99fkYMFbTS4g4WwbhSPkT9PHvlEoVYGaNoqxCITiEo= github.com/Picovoice/pvrecorder/binding/go v1.2.1/go.mod h1:gQdvBAjoKmRxMFh8W9cVKWcqHsWvu+d13sCPVFm7dhg= github.com/agnivade/levenshtein v1.1.1 h1:QY8M92nrzkmr798gCo3kmMyqXFzdQVpxLlGPRBij0P8= diff --git a/demo/go/micdemo/leopard_mic_demo.go b/demo/go/micdemo/leopard_mic_demo.go index b12069a0..e3fd6d56 100644 --- a/demo/go/micdemo/leopard_mic_demo.go +++ b/demo/go/micdemo/leopard_mic_demo.go @@ -19,7 +19,7 @@ import ( "os/signal" "path/filepath" - leopard "github.com/Picovoice/leopard/binding/go" + leopard "github.com/Picovoice/leopard/binding/go/v2" pvrecorder "github.com/Picovoice/pvrecorder/binding/go" ) @@ -44,6 +44,7 @@ func main() { modelPathArg := flag.String("model_path", "", "Path to Leopard model file") libraryPathArg := flag.String("library_path", "", "Path to Leopard's dynamic library file") disableAutomaticPunctuationArg := flag.Bool("disable_automatic_punctuation", false, "Disable automatic punctuation") + disableSpeakerDiarizationArg := flag.Bool("disable_speaker_diarization", false, "Disable speaker diarization") verboseArg := flag.Bool("verbose", false, "Enable verbose logging") audioDeviceIndex := flag.Int("audio_device_index", -1, "Index of capture device to use.") showAudioDevices := flag.Bool("show_audio_devices", false, "Display all available capture devices") @@ -56,6 +57,7 @@ func main() { l := leopard.NewLeopard(*accessKeyArg) l.EnableAutomaticPunctuation = !*disableAutomaticPunctuationArg + l.EnableDiarization = !*disableSpeakerDiarizationArg // validate library path if *libraryPathArg != "" { @@ -139,9 +141,9 @@ func main() { fmt.Printf("%s\n\n", transcript) if *verboseArg { - fmt.Printf("|%10s | %15s | %15s | %10s|\n", "word", "Start in Sec", "End in Sec", "Confidence") + fmt.Printf("|%10s | %10s | %10s | %10s | %10s|\n", "Word", "Start (s)", "End (s)", "Confidence", "Speaker Tag") for _, word := range words { - fmt.Printf("|%10s | %15.2f | %15.2f | %10.2f|\n", word.Word, word.StartSec, word.EndSec, word.Confidence) + fmt.Printf("|%10s | %10.2f | %10.2f | %10.2f | %11d|\n", word.Word, word.StartSec, word.EndSec, word.Confidence, word.SpeakerTag) } } } else {