Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LlamaCpp and Llama Go backends #40

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions client/llamacpp/llamacppclient.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package llamacpp

import (
"fmt"
"github.com/spandigitial/codeassistant/client"
"github.com/spandigitial/codeassistant/model"
"os/exec"
)

type Client struct {
binaryPath string
modelPath string
promptContextSize int
extraArguments []string
}

type Option func(client *Client)

func New(binaryPath string, modelPath string, promptContextSize int, options ...Option) *Client {
c := &Client{
binaryPath: binaryPath,
modelPath: modelPath,
promptContextSize: promptContextSize,
}

for _, option := range options {
option(c)
}

return c
}

func WithExtraArguments(arguments ...string) Option {
return func(client *Client) {
client.extraArguments = arguments
}
}

func (c *Client) Models(models chan<- client.LanguageModel) error {
close(models)
return nil
}

func (c *Client) Completion(ci *model.CommandInstance, messageParts chan<- client.MessagePart) error {
args := append([]string{"-m", c.modelPath, "-n", fmt.Sprintf("%d", c.promptContextSize)}, c.extraArguments...)
out, err := exec.Command(c.binaryPath, args...).Output()
if err != nil {
return err
}
messageParts <- client.MessagePart{Delta: "", Type: "Start"}
messageParts <- client.MessagePart{Delta: string(out), Type: "Part"}
messageParts <- client.MessagePart{Delta: "", Type: "Done"}
}
16 changes: 16 additions & 0 deletions client/llamagoremote/jobresponse.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package llamagoremote

import (
"github.com/google/uuid"
"time"
)

type jobResponse struct {
ID uuid.UUID `json:"id"`
Prompt string `json:"prompt"`
Output string `json:"output"`
Created time.Time `json:"created"`
Started time.Time `json:"started"`
Model string `json:"model"`
status status `json:"status"`
}
5 changes: 5 additions & 0 deletions client/llamagoremote/jobstatusresponse.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package llamagoremote

type jobStatusResponse struct {
Status status `json:"status"`
}
161 changes: 161 additions & 0 deletions client/llamagoremote/llamagoclient.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
package llamagoremote

import (
"bytes"
"encoding/json"
"fmt"
"github.com/google/uuid"
"github.com/spandigitial/codeassistant/client"
"github.com/spandigitial/codeassistant/client/debugger"
"github.com/spandigitial/codeassistant/model"
"github.com/spandigitial/codeassistant/transport"
"io"
"net/http"
"time"
)

type Client struct {
protocol Protocol
host string
port int
debugger *debugger.Debugger
httpClient *http.Client
userAgent *string
pollDuration time.Duration
}

type Option func(client *Client)

func New(protocol Protocol, host string, port int, pollDuration time.Duration, debugger *debugger.Debugger, options ...Option) *Client {
c := &Client{
protocol: protocol,
host: host,
port: port,
debugger: debugger,
pollDuration: pollDuration,
}

for _, option := range options {
option(c)
}

if c.httpClient == nil {
c.httpClient = http.DefaultClient
}

c.httpClient.Transport = transport.New(c.httpClient.Transport, c.debugger)

return c
}

func WithHttpClient(httpClient *http.Client) Option {
return func(client *Client) {
client.httpClient = httpClient
}
}

func (c *Client) Models(models chan<- client.LanguageModel) error {
close(models)
return nil
}

func (c *Client) Completion(ci *model.CommandInstance, messageParts chan<- client.MessagePart) error {

sendURL := fmt.Sprintf("%s://%s:%d/jobs", c.protocol, c.host, c.port)
uuid := uuid.New()
requestTime := time.Now()

c.debugger.Message("request-time", fmt.Sprintf("%v", requestTime))

request := promptRequest{
ID: uuid,
Prompt: ci.JoinedPromptsContent("\n\n"),
}

c.debugger.Message("sent-prompt", request.Prompt)

requestBytes, err := json.Marshal(request)
if err != nil {
return err
}

// Create the HTTP request
req, err := http.NewRequest("POST", sendURL, bytes.NewBuffer(requestBytes))
if err != nil {
return err
}
resp, err := c.httpClient.Do(req)
if err != nil {
return err
}

// Read the response body
responseBytes, err := io.ReadAll(resp.Body)
if err != nil {
return err
}

var promptResponse promptResponse
err = json.Unmarshal(responseBytes, &promptResponse)
if err != nil {
return err
}

if promptResponse.Status == processing {
for {
statusURL := fmt.Sprintf("%s://%s:%d/jobs/status/%s", c.protocol, c.host, c.port, uuid.String())

req, err := http.NewRequest("GET", statusURL, nil)
if err != nil {
return err
}
resp, err := c.httpClient.Do(req)
if err != nil {
return err
}
responseBytes, err := io.ReadAll(resp.Body)
if err != nil {
return err
}

var jobStatusResponse jobStatusResponse
err = json.Unmarshal(responseBytes, &jobStatusResponse)
if err != nil {
return err
}

if jobStatusResponse.Status == finished {
break
}
time.Sleep(c.pollDuration)
}

}

jobUrl := fmt.Sprintf("%s://%s:%d/jobs/%s", c.protocol, c.host, c.port, uuid.String())
req, err = http.NewRequest("GET", jobUrl, nil)

resp, err = c.httpClient.Do(req)
if err != nil {
return err
}

responseBytes, err = io.ReadAll(resp.Body)
if err != nil {
return err
}

var jobResponse jobResponse
err = json.Unmarshal(responseBytes, &jobResponse)
if err != nil {
return err
}

messageParts <- client.MessagePart{Delta: "", Type: "Start"}
messageParts <- client.MessagePart{Delta: jobResponse.Output, Type: "Part"}
messageParts <- client.MessagePart{Delta: "", Type: "Done"}
close(messageParts)

return nil

}
8 changes: 8 additions & 0 deletions client/llamagoremote/promptrequest.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package llamagoremote

import "github.com/google/uuid"

type promptRequest struct {
ID uuid.UUID `json:"id"`
Prompt string `json:"prompt"`
}
13 changes: 13 additions & 0 deletions client/llamagoremote/promptresponse.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package llamagoremote

import (
"github.com/google/uuid"
"time"
)

type promptResponse struct {
ID uuid.UUID `json:"id"`
Prompt string `json:"prompt"`
Created time.Time `json:"created"`
Status status `json:"status"`
}
24 changes: 24 additions & 0 deletions client/llamagoremote/protocol.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package llamagoremote

import "fmt"

type Protocol string

const (
HttpProtocol Protocol = "http"
HttpsProtocol Protocol = "https"
)

var protocolMap = map[string]Protocol{
"http": HttpProtocol,
"https": HttpsProtocol,
}

func ParseProtocol(protocolStr string) (Protocol, error) {
protocol, found := protocolMap[protocolStr]
if found {
return protocol, nil
} else {
return "", fmt.Errorf("protocol: '%s' not found", protocolStr)
}
}
26 changes: 26 additions & 0 deletions client/llamagoremote/status.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package llamagoremote

import (
"fmt"
)

type status string

const (
processing status = "processing"
finished status = "finished"
)

var statusMap = map[string]status{
"processing": processing,
"finished": finished,
}

func parseStatus(statusStr string) (status, error) {
status, found := statusMap[statusStr]
if found {
return status, nil
} else {
return "", fmt.Errorf("status: '%s' not found", statusStr)
}
}
21 changes: 12 additions & 9 deletions client/openai/openaiclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/spandigitial/codeassistant/client"
"github.com/spandigitial/codeassistant/client/debugger"
"github.com/spandigitial/codeassistant/model"
"github.com/spandigitial/codeassistant/transport"
"github.com/spf13/viper"
"golang.org/x/time/rate"
"io"
Expand All @@ -17,7 +18,7 @@ import (
"time"
)

type OpenAiClient struct {
type Client struct {
apiKey string
debugger *debugger.Debugger
rateLimiter *rate.Limiter
Expand All @@ -26,10 +27,10 @@ type OpenAiClient struct {
userAgent *string
}

type Option func(client *OpenAiClient)
type Option func(client *Client)

func New(apiKey string, debugger *debugger.Debugger, options ...Option) *OpenAiClient {
c := &OpenAiClient{
func New(apiKey string, debugger *debugger.Debugger, options ...Option) *Client {
c := &Client{
apiKey: apiKey,
debugger: debugger,
}
Expand All @@ -42,30 +43,32 @@ func New(apiKey string, debugger *debugger.Debugger, options ...Option) *OpenAiC
c.httpClient = http.DefaultClient
}

c.httpClient.Transport = transport.New(c.httpClient.Transport, c.debugger)

return c
}

func WithHttpClient(httpClient *http.Client) Option {
return func(client *OpenAiClient) {
return func(client *Client) {
client.httpClient = httpClient
}
}

func WithUser(user string) Option {
return func(client *OpenAiClient) {
return func(client *Client) {
client.user = &user
}
}

func WithUserAgent(userAgent string) Option {
return func(client *OpenAiClient) {
return func(client *Client) {
client.userAgent = &userAgent
}
}

var dataRegex = regexp.MustCompile("data: (\\{.+\\})\\w?")

func (c *OpenAiClient) Models(models chan<- client.LanguageModel) error {
func (c *Client) Models(models chan<- client.LanguageModel) error {
url := "https://api.openai.com/v1/models"
requestTime := time.Now()

Expand Down Expand Up @@ -112,7 +115,7 @@ func (c *OpenAiClient) Models(models chan<- client.LanguageModel) error {
return nil
}

func (c *OpenAiClient) Completion(commandInstance *model.CommandInstance, messageParts chan<- client.MessagePart) error {
func (c *Client) Completion(commandInstance *model.CommandInstance, messageParts chan<- client.MessagePart) error {
url := "https://api.openai.com/v1/chat/completions"

for _, prompt := range commandInstance.Prompts {
Expand Down
Loading
Loading