Skip to content

Commit

Permalink
feat: Update gpt4all, support multiple implementations in runtime (#472)
Browse files Browse the repository at this point in the history
Signed-off-by: mudler <[email protected]>
  • Loading branch information
mudler authored Jun 1, 2023
1 parent 42d7538 commit 78ad481
Show file tree
Hide file tree
Showing 12 changed files with 142 additions and 29 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,6 @@ release/
# just in case
.DS_Store
.idea

# Generated during build
backend-assets/
34 changes: 16 additions & 18 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ BINARY_NAME=local-ai

GOLLAMA_VERSION?=10caf37d8b73386708b4373975b8917e6b212c0e
GPT4ALL_REPO?=https://github.com/nomic-ai/gpt4all
GPT4ALL_VERSION?=337c7fecacfa4ae6779046513ab090687a5b0ef6
GPT4ALL_VERSION?=022f1cabe7dd2c911936b37510582f279069ba1e
GOGGMLTRANSFORMERS_VERSION?=13ccc22621bb21afecd38675a2b043498e2e756c
RWKV_REPO?=https://github.com/donomii/go-rwkv.cpp
RWKV_VERSION?=ccb05c3e1c6efd098017d114dcb58ab3262b40b2
Expand Down Expand Up @@ -63,22 +63,13 @@ gpt4all:
git clone --recurse-submodules $(GPT4ALL_REPO) gpt4all
cd gpt4all && git checkout -b build $(GPT4ALL_VERSION) && git submodule update --init --recursive --depth 1
# This is hackish, but needed as both go-llama and go-gpt4allj have their own version of ggml..
@find ./gpt4all -type f -name "*.c" -exec sed -i'' -e 's/ggml_/ggml_gptj_/g' {} +
@find ./gpt4all -type f -name "*.cpp" -exec sed -i'' -e 's/ggml_/ggml_gptj_/g' {} +
@find ./gpt4all -type f -name "*.h" -exec sed -i'' -e 's/ggml_/ggml_gptj_/g' {} +
@find ./gpt4all -type f -name "*.cpp" -exec sed -i'' -e 's/gpt_/gptj_/g' {} +
@find ./gpt4all -type f -name "*.h" -exec sed -i'' -e 's/gpt_/gptj_/g' {} +
@find ./gpt4all -type f -name "*.h" -exec sed -i'' -e 's/set_console_color/set_gptj_console_color/g' {} +
@find ./gpt4all -type f -name "*.cpp" -exec sed -i'' -e 's/set_console_color/set_gptj_console_color/g' {} +
@find ./gpt4all -type f -name "*.cpp" -exec sed -i'' -e 's/llama_/gptjllama_/g' {} +
@find ./gpt4all -type f -name "*.go" -exec sed -i'' -e 's/llama_/gptjllama_/g' {} +
@find ./gpt4all -type f -name "*.h" -exec sed -i'' -e 's/llama_/gptjllama_/g' {} +
@find ./gpt4all -type f -name "*.txt" -exec sed -i'' -e 's/llama_/gptjllama_/g' {} +
@find ./gpt4all -type f -name "*.cpp" -exec sed -i'' -e 's/json_/json_gptj_/g' {} +
@find ./gpt4all -type f -name "*.cpp" -exec sed -i'' -e 's/void replace/void json_gptj_replace/g' {} +
@find ./gpt4all -type f -name "*.cpp" -exec sed -i'' -e 's/::replace/::json_gptj_replace/g' {} +
@find ./gpt4all -type f -name "*.cpp" -exec sed -i'' -e 's/regex_escape/gpt4allregex_escape/g' {} +
mv ./gpt4all/gpt4all-backend/llama.cpp/llama_util.h ./gpt4all/gpt4all-backend/llama.cpp/gptjllama_util.h
@find ./gpt4all -type f -name "*.c" -exec sed -i'' -e 's/ggml_/ggml_gpt4all_/g' {} +
@find ./gpt4all -type f -name "*.cpp" -exec sed -i'' -e 's/ggml_/ggml_gpt4all_/g' {} +
@find ./gpt4all -type f -name "*.h" -exec sed -i'' -e 's/ggml_/ggml_gpt4all_/g' {} +
@find ./gpt4all/gpt4all-bindings/golang -type f -name "*.cpp" -exec sed -i'' -e 's/load_model/load_gpt4all_model/g' {} +
@find ./gpt4all/gpt4all-bindings/golang -type f -name "*.go" -exec sed -i'' -e 's/load_model/load_gpt4all_model/g' {} +
@find ./gpt4all/gpt4all-bindings/golang -type f -name "*.h" -exec sed -i'' -e 's/load_model/load_gpt4all_model/g' {} +


## BERT embeddings
go-bert:
Expand Down Expand Up @@ -124,6 +115,12 @@ bloomz/libbloomz.a: bloomz
go-bert/libgobert.a: go-bert
$(MAKE) -C go-bert libgobert.a

backend-assets/gpt4all: gpt4all/gpt4all-bindings/golang/libgpt4all.a
mkdir -p backend-assets/gpt4all
@cp gpt4all/gpt4all-bindings/golang/buildllm/*.so backend-assets/gpt4all/ || true
@cp gpt4all/gpt4all-bindings/golang/buildllm/*.dylib backend-assets/gpt4all/ || true
@cp gpt4all/gpt4all-bindings/golang/buildllm/*.dll backend-assets/gpt4all/ || true

gpt4all/gpt4all-bindings/golang/libgpt4all.a: gpt4all
$(MAKE) -C gpt4all/gpt4all-bindings/golang/ libgpt4all.a

Expand Down Expand Up @@ -188,14 +185,15 @@ rebuild: ## Rebuilds the project
$(MAKE) -C bloomz clean
$(MAKE) build

prepare: prepare-sources gpt4all/gpt4all-bindings/golang/libgpt4all.a $(OPTIONAL_TARGETS) go-llama/libbinding.a go-bert/libgobert.a go-ggml-transformers/libtransformers.a go-rwkv/librwkv.a whisper.cpp/libwhisper.a bloomz/libbloomz.a ## Prepares for building
prepare: prepare-sources backend-assets/gpt4all $(OPTIONAL_TARGETS) go-llama/libbinding.a go-bert/libgobert.a go-ggml-transformers/libtransformers.a go-rwkv/librwkv.a whisper.cpp/libwhisper.a bloomz/libbloomz.a ## Prepares for building

clean: ## Remove build related file
rm -fr ./go-llama
rm -rf ./gpt4all
rm -rf ./go-gpt2
rm -rf ./go-stable-diffusion
rm -rf ./go-ggml-transformers
rm -rf ./backend-assets
rm -rf ./go-rwkv
rm -rf ./go-bert
rm -rf ./bloomz
Expand Down
7 changes: 7 additions & 0 deletions api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,13 @@ func App(opts ...AppOption) (*fiber.App, error) {
log.Debug().Msgf("Model: %s (config: %+v)", v, cfg)
}
}

if options.assetsDestination != "" {
if err := PrepareBackendAssets(options.backendAssets, options.assetsDestination); err != nil {
log.Warn().Msgf("Failed extracting backend assets files: %s (might be required for some backends to work properly, like gpt4all)", err)
}
}

// Default middleware config
app.Use(recover.New())

Expand Down
2 changes: 1 addition & 1 deletion api/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ var _ = Describe("API test", func() {
It("returns errors", func() {
_, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "foomodel", Prompt: "abcdedfghikl"})
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("error, status code: 500, message: could not load model - all backends returned error: 12 errors occurred:"))
Expect(err.Error()).To(ContainSubstring("error, status code: 500, message: could not load model - all backends returned error: 10 errors occurred:"))
})
It("transcribes audio", func() {
if runtime.GOOS != "linux" {
Expand Down
27 changes: 27 additions & 0 deletions api/backend_assets.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package api

import (
"embed"
"os"
"path/filepath"

"github.com/go-skynet/LocalAI/pkg/assets"
"github.com/rs/zerolog/log"
)

func PrepareBackendAssets(backendAssets embed.FS, dst string) error {

// Extract files from the embedded FS
err := assets.ExtractFiles(backendAssets, dst)
if err != nil {
return err
}

// Set GPT4ALL libs where we extracted the files
// https://github.com/nomic-ai/gpt4all/commit/27e80e1d10985490c9fd4214e4bf458cfcf70896
gpt4alldir := filepath.Join(dst, "backend-assets", "gpt4all")
os.Setenv("GPT4ALL_IMPLEMENTATIONS_PATH", gpt4alldir)
log.Debug().Msgf("GPT4ALL_IMPLEMENTATIONS_PATH: %s", gpt4alldir)

return nil
}
16 changes: 16 additions & 0 deletions api/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package api

import (
"context"
"embed"

model "github.com/go-skynet/LocalAI/pkg/model"
)
Expand All @@ -18,6 +19,9 @@ type Option struct {
preloadJSONModels string
preloadModelsFromPath string
corsAllowOrigins string

backendAssets embed.FS
assetsDestination string
}

type AppOption func(*Option)
Expand Down Expand Up @@ -49,6 +53,18 @@ func WithCorsAllowOrigins(b string) AppOption {
}
}

func WithBackendAssetsOutput(out string) AppOption {
return func(o *Option) {
o.assetsDestination = out
}
}

func WithBackendAssets(f embed.FS) AppOption {
return func(o *Option) {
o.backendAssets = f
}
}

func WithContext(ctx context.Context) AppOption {
return func(o *Option) {
o.context = ctx
Expand Down
6 changes: 6 additions & 0 deletions assets.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package main

import "embed"

//go:embed backend-assets/*
var backendAssets embed.FS
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ require (
github.com/hashicorp/go-multierror v1.1.1
github.com/imdario/mergo v0.3.16
github.com/mudler/go-stable-diffusion v0.0.0-20230516152536-c0748eca3642
github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230531011104-5f940208e4f5
github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230601151908-5175db27813c
github.com/onsi/ginkgo/v2 v2.9.7
github.com/onsi/gomega v1.27.7
github.com/otiai10/openaigo v1.1.0
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,8 @@ github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230528235700-9eb81c
github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230528235700-9eb81cb54922/go.mod h1:4T3CHXyrt+7FQHXaxULZfPjHbD8/99WuDDJa0YVZARI=
github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230531011104-5f940208e4f5 h1:99cF+V5wk7IInDAEM9HAlSHdLf/xoJR529Wr8lAG5KQ=
github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230531011104-5f940208e4f5/go.mod h1:4T3CHXyrt+7FQHXaxULZfPjHbD8/99WuDDJa0YVZARI=
github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230601151908-5175db27813c h1:KXYqUH6bdYbxnF67l8wayctaCZ4BQJQOsUyNke7HC0A=
github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230601151908-5175db27813c/go.mod h1:4T3CHXyrt+7FQHXaxULZfPjHbD8/99WuDDJa0YVZARI=
github.com/onsi/ginkgo/v2 v2.9.5 h1:+6Hr4uxzP4XIUyAkg61dWBw8lb/gc4/X5luuxN/EC+Q=
github.com/onsi/ginkgo/v2 v2.9.5/go.mod h1:tvAoo1QUJwNEU2ITftXTpR7R1RbCzoZUOs3RonqW57k=
github.com/onsi/ginkgo/v2 v2.9.7 h1:06xGQy5www2oN160RtEZoTvnP2sPhEfePYmCDc2szss=
Expand Down
8 changes: 8 additions & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,12 @@ func main() {
EnvVars: []string{"IMAGE_PATH"},
Value: "",
},
&cli.StringFlag{
Name: "backend-assets-path",
DefaultText: "Path used to extract libraries that are required by some of the backends in runtime.",
EnvVars: []string{"BACKEND_ASSETS_PATH"},
Value: "/tmp/localai/backend_data",
},
&cli.IntFlag{
Name: "context-size",
DefaultText: "Default context size of the model",
Expand Down Expand Up @@ -124,6 +130,8 @@ It uses llama.cpp, ggml and gpt4all as backend with golang c bindings.
api.WithCors(ctx.Bool("cors")),
api.WithCorsAllowOrigins(ctx.String("cors-allow-origins")),
api.WithThreads(ctx.Int("threads")),
api.WithBackendAssets(backendAssets),
api.WithBackendAssetsOutput(ctx.String("backend-assets-path")),
api.WithUploadLimitMB(ctx.Int("upload-limit")))
if err != nil {
return err
Expand Down
51 changes: 51 additions & 0 deletions pkg/assets/extract.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package assets

import (
"embed"
"fmt"
"io/fs"
"os"
"path/filepath"
)

func ExtractFiles(content embed.FS, extractDir string) error {
// Create the target directory if it doesn't exist
err := os.MkdirAll(extractDir, 0755)
if err != nil {
return fmt.Errorf("failed to create directory: %v", err)
}

// Walk through the embedded FS and extract files
err = fs.WalkDir(content, ".", func(path string, d fs.DirEntry, err error) error {
if err != nil {
return err
}

// Reconstruct the directory structure in the target directory
targetFile := filepath.Join(extractDir, path)
if d.IsDir() {
// Create the directory in the target directory
err := os.MkdirAll(targetFile, 0755)
if err != nil {
return fmt.Errorf("failed to create directory: %v", err)
}
return nil
}

// Read the file from the embedded FS
fileData, err := content.ReadFile(path)
if err != nil {
return fmt.Errorf("failed to read file: %v", err)
}

// Create the file in the target directory
err = os.WriteFile(targetFile, fileData, 0644)
if err != nil {
return fmt.Errorf("failed to write file: %v", err)
}

return nil
})

return err
}
13 changes: 4 additions & 9 deletions pkg/model/initializers.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ const (
Gpt4AllLlamaBackend = "gpt4all-llama"
Gpt4AllMptBackend = "gpt4all-mpt"
Gpt4AllJBackend = "gpt4all-j"
Gpt4All = "gpt4all"
BertEmbeddingsBackend = "bert-embeddings"
RwkvBackend = "rwkv"
WhisperBackend = "whisper"
Expand All @@ -42,9 +43,7 @@ const (

var backends []string = []string{
LlamaBackend,
Gpt4AllLlamaBackend,
Gpt4AllMptBackend,
Gpt4AllJBackend,
Gpt4All,
RwkvBackend,
GPTNeoXBackend,
WhisperBackend,
Expand Down Expand Up @@ -153,12 +152,8 @@ func (ml *ModelLoader) BackendLoader(backendString string, modelFile string, lla
return ml.LoadModel(modelFile, stableDiffusion)
case StarcoderBackend:
return ml.LoadModel(modelFile, starCoder)
case Gpt4AllLlamaBackend:
return ml.LoadModel(modelFile, gpt4allLM(gpt4all.SetThreads(int(threads)), gpt4all.SetModelType(gpt4all.LLaMAType)))
case Gpt4AllMptBackend:
return ml.LoadModel(modelFile, gpt4allLM(gpt4all.SetThreads(int(threads)), gpt4all.SetModelType(gpt4all.MPTType)))
case Gpt4AllJBackend:
return ml.LoadModel(modelFile, gpt4allLM(gpt4all.SetThreads(int(threads)), gpt4all.SetModelType(gpt4all.GPTJType)))
case Gpt4AllLlamaBackend, Gpt4AllMptBackend, Gpt4AllJBackend, Gpt4All:
return ml.LoadModel(modelFile, gpt4allLM(gpt4all.SetThreads(int(threads))))
case BertEmbeddingsBackend:
return ml.LoadModel(modelFile, bertEmbeddings)
case RwkvBackend:
Expand Down

0 comments on commit 78ad481

Please sign in to comment.