Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

v2.0 go #259

Merged
merged 11 commits into from
Nov 23, 2023
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/go-codestyle.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ jobs:
uses: golangci/golangci-lint-action@v3
with:
working-directory: binding/go
# TODO: figure out why the linter complains about this??
args: --exclude="could not import C"

check-go-micdemo-codestyle:
runs-on: ubuntu-latest
Expand Down
8 changes: 7 additions & 1 deletion .github/workflows/go-demos.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,17 @@ jobs:
strategy:
matrix:
os: [ubuntu-latest, windows-latest, macos-latest]
go: [ '1.16', '1.17', '1.18' ]
go: [ '1.16', '1.17', '1.18', '1.19', '1.20', 'stable' ]

steps:
- uses: actions/checkout@v3

- name: Set up Mingw
uses: egor-tensin/setup-mingw@v2
if: ${{ (matrix.os == 'windows-latest') && (matrix.go != 'stable') && (matrix.go < 1.20) }}
with:
version: 11.2.0

- name: Setup go
uses: actions/setup-go@v3
with:
Expand Down
12 changes: 9 additions & 3 deletions .github/workflows/go.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,17 @@ jobs:
strategy:
matrix:
os: [ubuntu-latest, windows-latest, macos-latest]
go: [ '1.16', '1.17', '1.18' ]
go: [ '1.16', '1.17', '1.18', '1.19', '1.20', 'stable' ]

steps:
- uses: actions/checkout@v3

- name: Set up Mingw
uses: egor-tensin/setup-mingw@v2
if: ${{ (matrix.os == 'windows-latest') && (matrix.go != 'stable') && (matrix.go < 1.20) }}
with:
version: 11.2.0

- name: Setup go
uses: actions/setup-go@v3
with:
Expand All @@ -58,7 +64,7 @@ jobs:
run: go build

- name: Test
run: go test -v -access_key ${{secrets.PV_VALID_ACCESS_KEY}}
run: go test -modfile="go_test.mod" -v -access_key ${{secrets.PV_VALID_ACCESS_KEY}}

build-self-hosted:
runs-on: ${{ matrix.machine }}
Expand All @@ -80,4 +86,4 @@ jobs:
run: go build

- name: Test
run: go test -v -access_key ${{secrets.PV_VALID_ACCESS_KEY}}
run: go test -modfile="go_test.mod" -v -access_key ${{secrets.PV_VALID_ACCESS_KEY}}
2 changes: 2 additions & 0 deletions binding/go/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ Cheetah is an on-device streaming speech-to-text engine. Cheetah is:

- go 1.16+
- Runs on Linux (x86_64), macOS (x86_64, arm64), Windows (x86_64), Raspberry Pi (4, 3), and NVIDIA Jetson Nano.
- **Windows**: The Go binding requires `cgo`, which means that you need to install a gcc compiler like [Mingw](http://mingw-w64.org/) to build it properly.
- Go versions less than `1.20` requires `gcc` version `11` or lower.

## Installation

Expand Down
95 changes: 68 additions & 27 deletions binding/go/cheetah.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,23 @@ const (
)

type CheetahError struct {
StatusCode PvStatus
Message string
StatusCode PvStatus
Message string
MessageStack []string
}

func (e *CheetahError) Error() string {
return fmt.Sprintf("%s: %s", pvStatusToString(e.StatusCode), e.Message)
var message strings.Builder
message.WriteString(fmt.Sprintf("%s: %s", pvStatusToString(e.StatusCode), e.Message))

if len(e.MessageStack) > 0 {
message.WriteString(":")
}

for i, value := range e.MessageStack {
message.WriteString(fmt.Sprintf("\n [%d] %s", i, value))
}
return message.String()
}

// Cheetah struct
Expand Down Expand Up @@ -120,8 +131,8 @@ func NewCheetah(accessKey string) Cheetah {
func (cheetah *Cheetah) Init() error {
if cheetah.AccessKey == "" {
return &CheetahError{
INVALID_ARGUMENT,
"No AccessKey provided to Cheetah"}
StatusCode: INVALID_ARGUMENT,
Message: "No AccessKey provided to Cheetah"}
}

if cheetah.ModelPath == "" {
Expand All @@ -134,27 +145,37 @@ func (cheetah *Cheetah) Init() error {

if _, err := os.Stat(cheetah.ModelPath); os.IsNotExist(err) {
return &CheetahError{
INVALID_ARGUMENT,
fmt.Sprintf("Specified model file could not be found at %s", cheetah.ModelPath)}
StatusCode: INVALID_ARGUMENT,
Message: fmt.Sprintf("Specified model file could not be found at %s", cheetah.ModelPath)}
}

if _, err := os.Stat(cheetah.LibraryPath); os.IsNotExist(err) {
return &CheetahError{
INVALID_ARGUMENT,
fmt.Sprintf("Specified library file could not be found at %s", cheetah.LibraryPath)}
StatusCode: INVALID_ARGUMENT,
Message: fmt.Sprintf("Specified library file could not be found at %s", cheetah.LibraryPath)}
}

if cheetah.EndpointDuration < 0 {
return &CheetahError{
INVALID_ARGUMENT,
"Endpoint duration must be non-negative"}
StatusCode: INVALID_ARGUMENT,
Message: "Endpoint duration must be non-negative"}
}

ret := nativeCheetah.nativeInit(cheetah)
if PvStatus(ret) != SUCCESS {
errorStatus, messageStack := nativeCheetah.nativeGetErrorStack()
if errorStatus != SUCCESS {
return &CheetahError{
StatusCode: errorStatus,
Message: "Unable to get Cheetah error state",
}
}

return &CheetahError{
PvStatus(ret),
"Cheetah init failed."}
StatusCode: ret,
Message: "Cheetah init failed",
MessageStack: messageStack,
}
}

FrameLength = nativeCheetah.nativeFrameLength()
Expand All @@ -168,8 +189,8 @@ func (cheetah *Cheetah) Init() error {
func (cheetah *Cheetah) Delete() error {
if cheetah.handle == nil {
return &CheetahError{
INVALID_STATE,
"Cheetah has not been initialized or has already been deleted"}
StatusCode: INVALID_STATE,
Message: "Cheetah has not been initialized or has already been deleted"}
}

nativeCheetah.nativeDelete(cheetah)
Expand All @@ -183,21 +204,31 @@ func (cheetah *Cheetah) Delete() error {
func (cheetah *Cheetah) Process(pcm []int16) (string, bool, error) {
if cheetah.handle == nil {
return "", false, &CheetahError{
INVALID_STATE,
"Cheetah has not been initialized or has already been deleted"}
StatusCode: INVALID_STATE,
Message: "Cheetah has not been initialized or has already been deleted"}
}

if len(pcm) != FrameLength {
return "", false, &CheetahError{
INVALID_ARGUMENT,
fmt.Sprintf("Input data frame size (%d) does not match required size of %d", len(pcm), FrameLength)}
StatusCode: INVALID_ARGUMENT,
Message: fmt.Sprintf("Input data frame size (%d) does not match required size of %d", len(pcm), FrameLength)}
}

ret, transcript, isEndpoint := nativeCheetah.nativeProcess(cheetah, pcm)
if PvStatus(ret) != SUCCESS {
return "", false, &CheetahError{
PvStatus(ret),
"Cheetah process failed."}
errorStatus, messageStack := nativeCheetah.nativeGetErrorStack()
if errorStatus != SUCCESS {
return "", false, &CheetahError{
StatusCode: errorStatus,
Message: "Unable to get Cheetah error state",
}
}

return"", false, &CheetahError{
ksyeo1010 marked this conversation as resolved.
Show resolved Hide resolved
StatusCode: ret,
Message: "Cheetah process failed",
MessageStack: messageStack,
}
}

return transcript, isEndpoint, nil
Expand All @@ -208,15 +239,25 @@ func (cheetah *Cheetah) Process(pcm []int16) (string, bool, error) {
func (cheetah *Cheetah) Flush() (string, error) {
if cheetah.handle == nil {
return "", &CheetahError{
INVALID_STATE,
"Cheetah has not been initialized or has already been deleted"}
StatusCode: INVALID_STATE,
Message: "Cheetah has not been initialized or has already been deleted"}
}

ret, transcript := nativeCheetah.nativeFlush(cheetah)
if PvStatus(ret) != SUCCESS {
return "", &CheetahError{
PvStatus(ret),
"Cheetah flush failed."}
errorStatus, messageStack := nativeCheetah.nativeGetErrorStack()
if errorStatus != SUCCESS {
return "", &CheetahError{
StatusCode: errorStatus,
Message: "Unable to get Cheetah error state",
}
}

return"", &CheetahError{
ksyeo1010 marked this conversation as resolved.
Show resolved Hide resolved
StatusCode: ret,
Message: "Cheetah flush failed",
MessageStack: messageStack,
}
}

return transcript, nil
Expand Down
90 changes: 80 additions & 10 deletions binding/go/cheetah_native.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2022 Picovoice Inc.
// Copyright 2022-2023 Picovoice Inc.
//
// You may not use this file except in compliance with the license. A copy of the license is
// located in the "LICENSE" file accompanying this source.
Expand Down Expand Up @@ -135,6 +135,35 @@ typedef void (*pv_cheetah_delete_func)(void *);
void pv_cheetah_delete_wrapper(void *f, void *object) {
return ((pv_cheetah_delete_func) f)(object);
}

typedef void (*pv_set_sdk_func)(const char *);

void pv_set_sdk_wrapper(void *f, const char *sdk) {
return ((pv_set_sdk_func) f)(sdk);
}

typedef int32_t (*pv_get_error_stack_func)(char ***, int32_t *);

int32_t pv_get_error_stack_wrapper(
void *f,
char ***message_stack,
int32_t *message_stack_depth) {
return ((pv_get_error_stack_func) f)(message_stack, message_stack_depth);
}

typedef void (*pv_free_error_stack_func)(char **);

void pv_free_error_stack_wrapper(
void *f,
char **message_stack) {
return ((pv_free_error_stack_func) f)(message_stack);
}

typedef void (*pv_cheetah_transcript_delete_func)(void *);

void pv_cheetah_transcript_delete_wrapper(void* f, char *transcript) {
((pv_cheetah_transcript_delete_func) f)(transcript);
}
*/
import "C"

Expand All @@ -149,18 +178,23 @@ type nativeCheetahInterface interface {
nativeDelete(*Cheetah)
nativeSampleRate()
nativeVersion()
nativeGetErrorStack()
}

type nativeCheetahType struct {
libraryPath unsafe.Pointer

pv_cheetah_init_ptr unsafe.Pointer
pv_cheetah_process_ptr unsafe.Pointer
pv_cheetah_flush_ptr unsafe.Pointer
pv_cheetah_delete_ptr unsafe.Pointer
pv_cheetah_version_ptr unsafe.Pointer
pv_cheetah_frame_length_ptr unsafe.Pointer
pv_sample_rate_ptr unsafe.Pointer
pv_cheetah_init_ptr unsafe.Pointer
pv_cheetah_process_ptr unsafe.Pointer
pv_cheetah_flush_ptr unsafe.Pointer
pv_cheetah_delete_ptr unsafe.Pointer
pv_cheetah_version_ptr unsafe.Pointer
pv_cheetah_frame_length_ptr unsafe.Pointer
pv_sample_rate_ptr unsafe.Pointer
pv_set_sdk_ptr unsafe.Pointer
pv_get_error_stack_ptr unsafe.Pointer
pv_free_error_stack_ptr unsafe.Pointer
pv_cheetah_transcript_delete_ptr unsafe.Pointer
}

func (nc *nativeCheetahType) nativeInit(cheetah *Cheetah) (status PvStatus) {
Expand All @@ -184,6 +218,14 @@ func (nc *nativeCheetahType) nativeInit(cheetah *Cheetah) (status PvStatus) {
nc.pv_cheetah_version_ptr = C.load_symbol(nc.libraryPath, C.CString("pv_cheetah_version"))
nc.pv_cheetah_frame_length_ptr = C.load_symbol(nc.libraryPath, C.CString("pv_cheetah_frame_length"))
nc.pv_sample_rate_ptr = C.load_symbol(nc.libraryPath, C.CString("pv_sample_rate"))
nc.pv_set_sdk_ptr = C.load_symbol(nc.libraryPath, C.CString("pv_set_sdk"))
nc.pv_get_error_stack_ptr = C.load_symbol(nc.libraryPath, C.CString("pv_get_error_stack"))
nc.pv_free_error_stack_ptr = C.load_symbol(nc.libraryPath, C.CString("pv_free_error_stack"))
nc.pv_cheetah_transcript_delete_ptr = C.load_symbol(nc.libraryPath, C.CString("pv_cheetah_transcript_delete"))

C.pv_set_sdk_wrapper(
nc.pv_set_sdk_ptr,
C.CString("go"))

var ret = C.pv_cheetah_init_wrapper(
nc.pv_cheetah_init_ptr,
Expand Down Expand Up @@ -215,7 +257,7 @@ func (nc *nativeCheetahType) nativeProcess(cheetah *Cheetah, pcm []int16) (statu
}

transcript = C.GoString((*C.char)(transcriptPtr))
C.free(transcriptPtr)
C.pv_cheetah_transcript_delete_wrapper(nc.pv_cheetah_transcript_delete_ptr, (*C.char)(transcriptPtr))

return PvStatus(ret), transcript, isEndpoint
}
Expand All @@ -232,7 +274,7 @@ func (nc *nativeCheetahType) nativeFlush(cheetah *Cheetah) (status PvStatus, tra
}

transcript = C.GoString((*C.char)(transcriptPtr))
C.free(transcriptPtr)
C.pv_cheetah_transcript_delete_wrapper(nc.pv_cheetah_transcript_delete_ptr, (*C.char)(transcriptPtr))

return PvStatus(ret), transcript
}
Expand All @@ -248,3 +290,31 @@ func (nc nativeCheetahType) nativeFrameLength() (frameLength int) {
func (nc nativeCheetahType) nativeVersion() (version string) {
return C.GoString(C.pv_cheetah_version_wrapper(nc.pv_cheetah_version_ptr))
}

func (nc *nativeCheetahType) nativeGetErrorStack() (status PvStatus, messageStack []string) {
var messageStackDepthRef C.int32_t
var messageStackRef **C.char

var ret = C.pv_get_error_stack_wrapper(nc.pv_get_error_stack_ptr,
&messageStackRef,
&messageStackDepthRef)

if PvStatus(ret) != SUCCESS {
return PvStatus(ret), []string{}
}

defer C.pv_free_error_stack_wrapper(
nc.pv_free_error_stack_ptr,
messageStackRef)

messageStackDepth := int(messageStackDepthRef)
messageStackSlice := (*[1 << 28]*C.char)(unsafe.Pointer(messageStackRef))[:messageStackDepth:messageStackDepth]

messageStack = make([]string, messageStackDepth)

for i := 0; i < messageStackDepth; i++ {
messageStack[i] = C.GoString(messageStackSlice[i])
}

return PvStatus(ret), messageStack
}
Loading