Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(silero): add Silero-vad backend #4204

Merged
merged 8 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ STABLEDIFFUSION_VERSION?=4a3cd6aeae6f66ee57eae9a0075f8c58c3a6a38f
TINYDREAM_REPO?=https://github.com/M0Rf30/go-tiny-dream
TINYDREAM_VERSION?=c04fa463ace9d9a6464313aa5f9cd0f953b6c057

ONNX_VERSION?=1.20.0
ONNX_ARCH?=x64
ONNX_OS?=linux

export BUILD_TYPE?=
export STABLE_BUILD_TYPE?=$(BUILD_TYPE)
export CMAKE_ARGS?=
Expand Down Expand Up @@ -89,7 +93,20 @@ ifeq ($(NATIVE),false)
CMAKE_ARGS+=-DGGML_NATIVE=OFF
endif

# Detect if we are running on arm64
ifneq (,$(findstring aarch64,$(shell uname -m)))
ONNX_ARCH=aarch64
endif

ifeq ($(OS),Darwin)
ONNX_OS=osx
ifneq (,$(findstring aarch64,$(shell uname -m)))
ONNX_ARCH=arm64
else ifneq (,$(findstring arm64,$(shell uname -m)))
ONNX_ARCH=arm64
else
ONNX_ARCH=x86_64
endif

ifeq ($(OSX_SIGNING_IDENTITY),)
OSX_SIGNING_IDENTITY := $(shell security find-identity -v -p codesigning | grep '"' | head -n 1 | sed -E 's/.*"(.*)"/\1/')
Expand Down Expand Up @@ -195,6 +212,7 @@ ALL_GRPC_BACKENDS+=backend-assets/util/llama-cpp-rpc-server
ALL_GRPC_BACKENDS+=backend-assets/grpc/rwkv
ALL_GRPC_BACKENDS+=backend-assets/grpc/whisper
ALL_GRPC_BACKENDS+=backend-assets/grpc/local-store
ALL_GRPC_BACKENDS+=backend-assets/grpc/silero-vad
ALL_GRPC_BACKENDS+=$(OPTIONAL_GRPC)
# Use filter-out to remove the specified backends
ALL_GRPC_BACKENDS := $(filter-out $(SKIP_GRPC_BACKEND),$(ALL_GRPC_BACKENDS))
Expand Down Expand Up @@ -281,6 +299,20 @@ sources/go-stable-diffusion:
sources/go-stable-diffusion/libstablediffusion.a: sources/go-stable-diffusion
CPATH="$(CPATH):/usr/include/opencv4" $(MAKE) -C sources/go-stable-diffusion libstablediffusion.a

sources/onnxruntime:
mkdir -p sources/onnxruntime
curl -L https://github.com/microsoft/onnxruntime/releases/download/v$(ONNX_VERSION)/onnxruntime-$(ONNX_OS)-$(ONNX_ARCH)-$(ONNX_VERSION).tgz -o sources/onnxruntime/onnxruntime-$(ONNX_OS)-$(ONNX_ARCH)-$(ONNX_VERSION).tgz
cd sources/onnxruntime && tar -xvf onnxruntime-$(ONNX_OS)-$(ONNX_ARCH)-$(ONNX_VERSION).tgz && rm onnxruntime-$(ONNX_OS)-$(ONNX_ARCH)-$(ONNX_VERSION).tgz
cd sources/onnxruntime && mv onnxruntime-$(ONNX_OS)-$(ONNX_ARCH)-$(ONNX_VERSION)/* ./

backend-assets/lib/libonnxruntime.so.1: backend-assets/lib sources/onnxruntime
cp -rfv sources/onnxruntime/lib/* backend-assets/lib/
ifeq ($(OS),Darwin)
mv backend-assets/lib/libonnxruntime.$(ONNX_VERSION).dylib backend-assets/lib/libonnxruntime.dylib
else
mv backend-assets/lib/libonnxruntime.so.$(ONNX_VERSION) backend-assets/lib/libonnxruntime.so.1
endif

## tiny-dream
sources/go-tiny-dream:
mkdir -p sources/go-tiny-dream
Expand Down Expand Up @@ -837,6 +869,13 @@ ifneq ($(UPX),)
$(UPX) backend-assets/grpc/stablediffusion
endif

backend-assets/grpc/silero-vad: backend-assets/grpc backend-assets/lib/libonnxruntime.so.1
CGO_LDFLAGS="$(CGO_LDFLAGS)" CPATH="$(CPATH):$(CURDIR)/sources/onnxruntime/include/" LIBRARY_PATH=$(CURDIR)/backend-assets/lib \
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/silero-vad ./backend/go/vad/silero
ifneq ($(UPX),)
$(UPX) backend-assets/grpc/silero-vad
endif

backend-assets/grpc/tinydream: sources/go-tiny-dream sources/go-tiny-dream/libtinydream.a backend-assets/grpc
CGO_LDFLAGS="$(CGO_LDFLAGS)" LIBRARY_PATH=$(CURDIR)/go-tiny-dream \
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/tinydream ./backend/go/image/tinydream
Expand Down
15 changes: 15 additions & 0 deletions backend/backend.proto
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ service Backend {
rpc Rerank(RerankRequest) returns (RerankResult) {}

rpc GetMetrics(MetricsRequest) returns (MetricsResponse);

rpc VAD(VADRequest) returns (VADResponse) {}
}

// Define the empty request
Expand Down Expand Up @@ -293,6 +295,19 @@ message TTSRequest {
optional string language = 5;
}

message VADRequest {
repeated float audio = 1;
}

message VADSegment {
float start = 1;
float end = 2;
}

message VADResponse {
repeated VADSegment segments = 1;
}

message SoundGenerationRequest {
string text = 1;
string model = 2;
Expand Down
21 changes: 21 additions & 0 deletions backend/go/vad/silero/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package main

// Note: this is started internally by LocalAI and a server is allocated for each model

import (
"flag"

grpc "github.com/mudler/LocalAI/pkg/grpc"
)

var (
addr = flag.String("addr", "localhost:50051", "the address to connect to")
)

func main() {
flag.Parse()

if err := grpc.StartServer(*addr, &VAD{}); err != nil {
panic(err)
}
}
54 changes: 54 additions & 0 deletions backend/go/vad/silero/vad.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package main

// This is a wrapper to statisfy the GRPC service interface
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
import (
"fmt"

"github.com/mudler/LocalAI/pkg/grpc/base"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
"github.com/streamer45/silero-vad-go/speech"
)

type VAD struct {
base.SingleThread
detector *speech.Detector
}

func (vad *VAD) Load(opts *pb.ModelOptions) error {
v, err := speech.NewDetector(speech.DetectorConfig{
ModelPath: opts.ModelFile,
SampleRate: 16000,
//WindowSize: 1024,
Threshold: 0.5,
MinSilenceDurationMs: 0,
SpeechPadMs: 0,
})
if err != nil {
return fmt.Errorf("create silero detector: %w", err)
}

vad.detector = v
return err
}

func (vad *VAD) VAD(req *pb.VADRequest) (pb.VADResponse, error) {
audio := req.Audio

segments, err := vad.detector.Detect(audio)
if err != nil {
return pb.VADResponse{}, fmt.Errorf("detect: %w", err)
}

vadSegments := []*pb.VADSegment{}
for _, s := range segments {
vadSegments = append(vadSegments, &pb.VADSegment{
Start: float32(s.SpeechStartAt),
End: float32(s.SpeechEndAt),
})
}

return pb.VADResponse{
Segments: vadSegments,
}, nil
}
68 changes: 68 additions & 0 deletions core/http/endpoints/localai/vad.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package localai

import (
"github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
fiberContext "github.com/mudler/LocalAI/core/http/ctx"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/grpc/proto"
"github.com/mudler/LocalAI/pkg/model"
"github.com/rs/zerolog/log"
)

// VADEndpoint is Voice-Activation-Detection endpoint
// @Summary Detect voice fragments in an audio stream
// @Accept json
// @Param request body schema.VADRequest true "query params"
// @Success 200 {object} proto.VADResponse "Response"
// @Router /vad [post]
func VADEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {

input := new(schema.VADRequest)

// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
}

modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.Model, false)
if err != nil {
modelFile = input.Model
log.Warn().Msgf("Model not found in context: %s", input.Model)
}

cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath,
config.LoadOptionDebug(appConfig.Debug),
config.LoadOptionThreads(appConfig.Threads),
config.LoadOptionContextSize(appConfig.ContextSize),
config.LoadOptionF16(appConfig.F16),
)

if err != nil {
log.Err(err)
modelFile = input.Model
log.Warn().Msgf("Model not found in context: %s", input.Model)
} else {
modelFile = cfg.Model
}
log.Debug().Msgf("Request for model: %s", modelFile)

opts := backend.ModelOptions(*cfg, appConfig, model.WithBackendString(cfg.Backend), model.WithModel(modelFile))

vadModel, err := ml.Load(opts...)
if err != nil {
return err
}
req := proto.VADRequest{
Audio: input.Audio,
}
resp, err := vadModel.VAD(c.Context(), &req)
if err != nil {
return err
}

return c.JSON(resp)
}
}
1 change: 1 addition & 0 deletions core/http/routes/localai.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ func RegisterLocalAIRoutes(app *fiber.App,
}

app.Post("/tts", localai.TTSEndpoint(cl, ml, appConfig))
app.Post("/vad", localai.VADEndpoint(cl, ml, appConfig))

// Stores
sl := model.NewModelLoader("")
Expand Down
8 changes: 7 additions & 1 deletion core/schema/localai.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,16 @@ type TTSRequest struct {
Input string `json:"input" yaml:"input"` // text input
Voice string `json:"voice" yaml:"voice"` // voice audio file or speaker id
Backend string `json:"backend" yaml:"backend"`
Language string `json:"language,omitempty" yaml:"language,omitempty"` // (optional) language to use with TTS model
Language string `json:"language,omitempty" yaml:"language,omitempty"` // (optional) language to use with TTS model
Format string `json:"response_format,omitempty" yaml:"response_format,omitempty"` // (optional) output format
}

// @Description VAD request body
type VADRequest struct {
Model string `json:"model" yaml:"model"` // model name or full path
Audio []float32 `json:"audio" yaml:"audio"` // model name or full path
}

type StoresSet struct {
Store string `json:"store,omitempty" yaml:"store,omitempty"`

Expand Down
3 changes: 3 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ require (
github.com/pion/turn/v2 v2.1.6 // indirect
github.com/pion/webrtc/v3 v3.3.0 // indirect
github.com/shirou/gopsutil/v4 v4.24.7 // indirect
github.com/streamer45/silero-vad-go v0.2.1 // indirect
github.com/urfave/cli/v2 v2.27.4 // indirect
github.com/valyala/fasttemplate v1.2.2 // indirect
github.com/wlynxg/anet v0.0.4 // indirect
go.uber.org/mock v0.4.0 // indirect
)
Expand Down
5 changes: 5 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -674,6 +674,11 @@ github.com/sourcegraph/annotate v0.0.0-20160123013949-f4cad6c6324d/go.mod h1:Udh
github.com/sourcegraph/syntaxhighlight v0.0.0-20170531221838-bd320f5d308e/go.mod h1:HuIsMU8RRBOtsCgI77wP899iHVBQpCmg4ErYMZB+2IA=
github.com/spaolacci/murmur3 v1.1.0 h1:7c1g84S4BPRrfL5Xrdp6fOJ206sU9y293DDHaoy0bLI=
github.com/spaolacci/murmur3 v1.1.0/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA=
github.com/spf13/cast v1.3.1/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE=
github.com/spf13/cast v1.5.0 h1:rj3WzYc11XZaIZMPKmwP96zkFEnnAmV8s6XbB2aY32w=
github.com/spf13/cast v1.5.0/go.mod h1:SpXXQ5YoyJw6s3/6cMTQuxvgRl3PCJiyaX9p6b155UU=
github.com/streamer45/silero-vad-go v0.2.1 h1:Li1/tTC4H/3cyw6q4weX+U8GWwEL3lTekK/nYa1Cvuk=
github.com/streamer45/silero-vad-go v0.2.1/go.mod h1:B+2FXs/5fZ6pzl6unUZYhZqkYdOB+3saBVzjOzdZnUs=
github.com/spf13/cast v1.7.0 h1:ntdiHjuueXFgm5nzDRdOS4yfT43P5Fnud6DH50rz/7w=
github.com/spf13/cast v1.7.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
Expand Down
2 changes: 2 additions & 0 deletions pkg/grpc/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,6 @@ type Backend interface {
Rerank(ctx context.Context, in *pb.RerankRequest, opts ...grpc.CallOption) (*pb.RerankResult, error)

GetTokenMetrics(ctx context.Context, in *pb.MetricsRequest, opts ...grpc.CallOption) (*pb.MetricsResponse, error)

VAD(ctx context.Context, in *pb.VADRequest, opts ...grpc.CallOption) (*pb.VADResponse, error)
}
4 changes: 4 additions & 0 deletions pkg/grpc/base/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ func (llm *Base) StoresFind(*pb.StoresFindOptions) (pb.StoresFindResult, error)
return pb.StoresFindResult{}, fmt.Errorf("unimplemented")
}

func (llm *Base) VAD(*pb.VADRequest) (pb.VADResponse, error) {
return pb.VADResponse{}, fmt.Errorf("unimplemented")
}

func memoryUsage() *pb.MemoryUsageData {
mud := pb.MemoryUsageData{
Breakdown: make(map[string]uint64),
Expand Down
18 changes: 18 additions & 0 deletions pkg/grpc/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -392,3 +392,21 @@ func (c *Client) GetTokenMetrics(ctx context.Context, in *pb.MetricsRequest, opt
client := pb.NewBackendClient(conn)
return client.GetMetrics(ctx, in, opts...)
}

func (c *Client) VAD(ctx context.Context, in *pb.VADRequest, opts ...grpc.CallOption) (*pb.VADResponse, error) {
if !c.parallel {
c.opMutex.Lock()
defer c.opMutex.Unlock()
}
c.setBusy(true)
defer c.setBusy(false)
c.wdMark()
defer c.wdUnMark()
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, err
}
defer conn.Close()
client := pb.NewBackendClient(conn)
return client.VAD(ctx, in, opts...)
}
4 changes: 4 additions & 0 deletions pkg/grpc/embed.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@ func (e *embedBackend) Rerank(ctx context.Context, in *pb.RerankRequest, opts ..
return e.s.Rerank(ctx, in)
}

func (e *embedBackend) VAD(ctx context.Context, in *pb.VADRequest, opts ...grpc.CallOption) (*pb.VADResponse, error) {
return e.s.VAD(ctx, in)
}

func (e *embedBackend) GetTokenMetrics(ctx context.Context, in *pb.MetricsRequest, opts ...grpc.CallOption) (*pb.MetricsResponse, error) {
return e.s.GetMetrics(ctx, in)
}
Expand Down
2 changes: 2 additions & 0 deletions pkg/grpc/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ type LLM interface {
StoresDelete(*pb.StoresDeleteOptions) error
StoresGet(*pb.StoresGetOptions) (pb.StoresGetResult, error)
StoresFind(*pb.StoresFindOptions) (pb.StoresFindResult, error)

VAD(*pb.VADRequest) (pb.VADResponse, error)
}

func newReply(s string) *pb.Reply {
Expand Down
12 changes: 12 additions & 0 deletions pkg/grpc/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,18 @@ func (s *server) StoresFind(ctx context.Context, in *pb.StoresFindOptions) (*pb.
return &res, nil
}

func (s *server) VAD(ctx context.Context, in *pb.VADRequest) (*pb.VADResponse, error) {
if s.llm.Locking() {
s.llm.Lock()
defer s.llm.Unlock()
}
res, err := s.llm.VAD(in)
if err != nil {
return nil, err
}
return &res, nil
}

func StartServer(address string, model LLM) error {
lis, err := net.Listen("tcp", address)
if err != nil {
Expand Down
Loading