Skip to content

Commit

Permalink
enhance: sdk: list full model objects, instead of just names (gptscri…
Browse files Browse the repository at this point in the history
…pt-ai#919)

Signed-off-by: Grant Linville <[email protected]>
g-linville authored Dec 18, 2024
1 parent eb03680 commit badb126
Showing 6 changed files with 33 additions and 17 deletions.
5 changes: 4 additions & 1 deletion pkg/cli/gptscript.go
Original file line number Diff line number Diff line change
@@ -276,7 +276,10 @@ func (r *GPTScript) listModels(ctx context.Context, gptScript *gptscript.GPTScri
if err != nil {
return err
}
fmt.Println(strings.Join(models, "\n"))

for _, model := range models {
fmt.Println(model.ID)
}
return nil
}

3 changes: 2 additions & 1 deletion pkg/gptscript/gptscript.go
Original file line number Diff line number Diff line change
@@ -10,6 +10,7 @@ import (
"slices"
"strings"

openai2 "github.com/gptscript-ai/chat-completion-client"
"github.com/gptscript-ai/gptscript/pkg/builtin"
"github.com/gptscript-ai/gptscript/pkg/cache"
"github.com/gptscript-ai/gptscript/pkg/config"
@@ -275,7 +276,7 @@ func (g *GPTScript) ListTools(_ context.Context, prg types.Program) []types.Tool
return prg.TopLevelTools()
}

func (g *GPTScript) ListModels(ctx context.Context, providers ...string) ([]string, error) {
func (g *GPTScript) ListModels(ctx context.Context, providers ...string) ([]openai2.Model, error) {
return g.Registry.ListModels(ctx, providers...)
}

9 changes: 6 additions & 3 deletions pkg/llm/registry.go
Original file line number Diff line number Diff line change
@@ -8,6 +8,7 @@ import (
"sync"

"github.com/google/uuid"
openai2 "github.com/gptscript-ai/chat-completion-client"
"github.com/gptscript-ai/gptscript/pkg/env"
"github.com/gptscript-ai/gptscript/pkg/openai"
"github.com/gptscript-ai/gptscript/pkg/remote"
@@ -16,7 +17,7 @@ import (

type Client interface {
Call(ctx context.Context, messageRequest types.CompletionRequest, env []string, status chan<- types.CompletionStatus) (*types.CompletionMessage, error)
ListModels(ctx context.Context, providers ...string) (result []string, _ error)
ListModels(ctx context.Context, providers ...string) (result []openai2.Model, _ error)
Supports(ctx context.Context, modelName string) (bool, error)
}

@@ -38,15 +39,17 @@ func (r *Registry) AddClient(client Client) error {
return nil
}

func (r *Registry) ListModels(ctx context.Context, providers ...string) (result []string, _ error) {
func (r *Registry) ListModels(ctx context.Context, providers ...string) (result []openai2.Model, _ error) {
for _, v := range r.clients {
models, err := v.ListModels(ctx, providers...)
if err != nil {
return nil, err
}
result = append(result, models...)
}
sort.Strings(result)
sort.Slice(result, func(i, j int) bool {
return result[i].ID < result[j].ID
})
return result, nil
}

18 changes: 11 additions & 7 deletions pkg/openai/client.go
Original file line number Diff line number Diff line change
@@ -157,10 +157,15 @@ func (c *Client) Supports(ctx context.Context, modelName string) (bool, error) {
return false, InvalidAuthError{}
}

return slices.Contains(models, modelName), nil
for _, model := range models {
if model.ID == modelName {
return true, nil
}
}
return false, nil
}

func (c *Client) ListModels(ctx context.Context, providers ...string) (result []string, _ error) {
func (c *Client) ListModels(ctx context.Context, providers ...string) ([]openai.Model, error) {
// Only serve if providers is empty or "" is in the list
if len(providers) != 0 && !slices.Contains(providers, "") {
return nil, nil
@@ -179,11 +184,10 @@ func (c *Client) ListModels(ctx context.Context, providers ...string) (result []
if err != nil {
return nil, err
}
for _, model := range models.Models {
result = append(result, model.ID)
}
sort.Strings(result)
return result, nil
sort.Slice(models.Models, func(i, j int) bool {
return models.Models[i].ID < models.Models[j].ID
})
return models.Models, nil
}

func (c *Client) cacheKey(request openai.ChatCompletionRequest) any {
13 changes: 9 additions & 4 deletions pkg/remote/remote.go
Original file line number Diff line number Diff line change
@@ -9,6 +9,7 @@ import (
"strings"
"sync"

openai2 "github.com/gptscript-ai/chat-completion-client"
"github.com/gptscript-ai/gptscript/pkg/cache"
"github.com/gptscript-ai/gptscript/pkg/credentials"
"github.com/gptscript-ai/gptscript/pkg/engine"
@@ -62,7 +63,7 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques
return client.Call(ctx, messageRequest, env, status)
}

func (c *Client) ListModels(ctx context.Context, providers ...string) (result []string, _ error) {
func (c *Client) ListModels(ctx context.Context, providers ...string) (result []openai2.Model, _ error) {
for _, provider := range providers {
client, err := c.load(ctx, provider)
if err != nil {
@@ -72,12 +73,16 @@ func (c *Client) ListModels(ctx context.Context, providers ...string) (result []
if err != nil {
return nil, err
}
for _, model := range models {
result = append(result, model+" from "+provider)
for i := range models {
models[i].ID = fmt.Sprintf("%s from %s", models[i].ID, provider)
}

result = append(result, models...)
}

sort.Strings(result)
sort.Slice(result, func(i, j int) bool {
return result[i].ID < result[j].ID
})
return
}

2 changes: 1 addition & 1 deletion pkg/sdkserver/routes.go
Original file line number Diff line number Diff line change
@@ -145,7 +145,7 @@ func (s *server) listModels(w http.ResponseWriter, r *http.Request) {
return
}

writeResponse(logger, w, map[string]any{"stdout": strings.Join(out, "\n")})
writeResponse(logger, w, map[string]any{"stdout": out})
}

// execHandler is a general handler for executing tools with gptscript. This is mainly responsible for parsing the request body.

0 comments on commit badb126

Please sign in to comment.