Skip to content

Commit

Permalink
llms: Improve json mode support (#683)
Browse files Browse the repository at this point in the history
* llms: Improve json response format coverage, add example
  • Loading branch information
tmc authored Mar 19, 2024
1 parent 325d534 commit 6045596
Show file tree
Hide file tree
Showing 9 changed files with 137 additions and 7 deletions.
13 changes: 13 additions & 0 deletions examples/json-mode-example/go.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
module github.com/tmc/langchaingo/examples/json-mode-example

go 1.21

toolchain go1.21.4

require github.com/tmc/langchaingo v0.1.6-alpha.0.0.20240318012619-9dbcc88fd002

require (
github.com/dlclark/regexp2 v1.10.0 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/pkoukk/tiktoken-go v0.1.6 // indirect
)
21 changes: 21 additions & 0 deletions examples/json-mode-example/go.sum
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dlclark/regexp2 v1.8.1 h1:6Lcdwya6GjPUNsBct8Lg/yRPwMhABj269AAzdGSiR+0=
github.com/dlclark/regexp2 v1.8.1/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/google/uuid v1.4.0 h1:MtMxsa51/r9yyhkyLsVeVt0B+BGQZzpQiTQ4eHZ8bc4=
github.com/google/uuid v1.4.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/pkoukk/tiktoken-go v0.1.2 h1:u7PCSBiWJ3nJYoTGShyM9iHXz4dNyYkurwwp+GHtyHY=
github.com/pkoukk/tiktoken-go v0.1.2/go.mod h1:boMWvk9pQCOTx11pgu0DrIdrAKgQzzJKUP6vLXaz7Rw=
github.com/pkoukk/tiktoken-go v0.1.6/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/tmc/langchaingo v0.1.5 h1:PNPFu54wn5uVPRt9GS/quRwdFZW4omSab9/dcFAsGmU=
github.com/tmc/langchaingo v0.1.5/go.mod h1:RLtnUED/hH2v765vdjS9Z6gonErZAXURuJHph0BttqM=
github.com/tmc/langchaingo v0.1.6-alpha.0.0.20240318012619-9dbcc88fd002 h1:qM/fnCN2BvGZmDS3gyxeV3m4p6veX/8KCttIMtIYrps=
github.com/tmc/langchaingo v0.1.6-alpha.0.0.20240318012619-9dbcc88fd002/go.mod h1:m+VxH55LmyknIgla6GyUu0U/syv03r4wtIfrJYmWXMY=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
52 changes: 52 additions & 0 deletions examples/json-mode-example/json_mode_example.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package main

import (
"context"
"flag"
"fmt"
"log"
"os"

"github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/llms/anthropic"
"github.com/tmc/langchaingo/llms/googleai"
"github.com/tmc/langchaingo/llms/ollama"
"github.com/tmc/langchaingo/llms/openai"
)

var flagBackend = flag.String("backend", "openai", "backend to use")

func main() {
flag.Parse()
ctx := context.Background()
llm, err := initBackend(ctx)
if err != nil {
log.Fatal(err)
}
completion, err := llms.GenerateFromSinglePrompt(ctx,
llm,
"Who was first man to walk on the moon? Respond in json format, include `first_man` in response keys.",
llms.WithTemperature(0.0),
llms.WithJSONMode(),
)
if err != nil {
log.Fatal(err)
}

fmt.Println(completion)
}

func initBackend(ctx context.Context) (llms.Model, error) {
switch *flagBackend {
case "openai":
return openai.New()
case "ollama":
return ollama.New(ollama.WithModel("mistral"))
case "anthropic":
return anthropic.New(anthropic.WithModel("claude-2.1"))
case "googleai":
return googleai.New(ctx, googleai.WithAPIKey(os.Getenv("GOOGLE_AI_API_KEY")))
default:
return nil, fmt.Errorf("unknown backend: %s", *flagBackend)
}
}
8 changes: 7 additions & 1 deletion llms/anthropic/anthropicllm.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package anthropic
import (
"context"
"errors"
"fmt"
"net/http"
"os"

Expand Down Expand Up @@ -71,9 +72,14 @@ func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageConten
// Assume we get a single text message
msg0 := messages[0]
part := msg0.Parts[0]
partText, ok := part.(llms.TextContent)
if !ok {
return nil, fmt.Errorf("unexpected message type: %T", part)
}
prompt := fmt.Sprintf("\n\nHuman: %s\n\nAssistant:", partText.Text)
result, err := o.client.CreateCompletion(ctx, &anthropicclient.CompletionRequest{
Model: opts.Model,
Prompt: part.(llms.TextContent).Text,
Prompt: prompt,
MaxTokens: opts.MaxTokens,
StopWords: opts.StopWords,
Temperature: opts.Temperature,
Expand Down
7 changes: 6 additions & 1 deletion llms/ollama/ollamallm.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,16 @@ func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageConten
chatMsgs = append(chatMsgs, msg)
}

format := o.options.format
if opts.JSONMode {
format = "json"
}

// Get our ollamaOptions from llms.CallOptions
ollamaOptions := makeOllamaOptionsFromOptions(o.options.ollamaOptions, opts)
req := &ollamaclient.ChatRequest{
Model: model,
Format: o.options.format,
Format: format,
Messages: chatMsgs,
Options: ollamaOptions,
Stream: func(b bool) *bool { return &b }(opts.StreamingFunc != nil),
Expand Down
7 changes: 7 additions & 0 deletions llms/openai/internal/openaiclient/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ type ChatRequest struct {
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`

ResponseFormat ResponseFormat `json:"response_format,omitempty"`

// Function definitions to include in the request.
Functions []FunctionDefinition `json:"functions,omitempty"`
// FunctionCallBehavior is the behavior to use when calling functions.
Expand All @@ -46,6 +48,11 @@ type ChatRequest struct {
StreamingFunc func(ctx context.Context, chunk []byte) error `json:"-"`
}

// ResponseFormat is the format of the response.
type ResponseFormat struct {
Type string `json:"type"`
}

// ChatMessage is a message in a chat request.
type ChatMessage struct { //nolint:musttag
// The role of the author of this message. One of system, user, or assistant.
Expand Down
9 changes: 4 additions & 5 deletions llms/openai/openaillm.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,7 @@ func (o *LLM) Call(ctx context.Context, prompt string, options ...llms.CallOptio
}

// GenerateContent implements the Model interface.
//
//nolint:goerr113
func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageContent, options ...llms.CallOption) (*llms.ContentResponse, error) { //nolint: lll, cyclop
func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageContent, options ...llms.CallOption) (*llms.ContentResponse, error) { //nolint: lll, cyclop, goerr113, funlen
if o.CallbacksHandler != nil {
o.CallbacksHandler.HandleLLMGenerateContentStart(ctx, messages)
}
Expand Down Expand Up @@ -76,7 +74,6 @@ func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageConten

chatMsgs = append(chatMsgs, msg)
}

req := &openaiclient.ChatRequest{
Model: opts.Model,
StopWords: opts.StopWords,
Expand All @@ -89,7 +86,9 @@ func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageConten
PresencePenalty: opts.PresencePenalty,
FunctionCallBehavior: openaiclient.FunctionCallBehavior(opts.FunctionCallBehavior),
}

if opts.JSONMode {
req.ResponseFormat = ResponseFormatJSON
}
for _, fn := range opts.Functions {
req.Functions = append(req.Functions, openaiclient.FunctionDefinition{
Name: fn.Name,
Expand Down
16 changes: 16 additions & 0 deletions llms/openai/openaillm_option.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,24 @@ type options struct {
apiType APIType
httpClient openaiclient.Doer

responseFormat ResponseFormat

// required when APIType is APITypeAzure or APITypeAzureAD
apiVersion string
embeddingModel string

callbackHandler callbacks.Handler
}

// Option is a functional option for the OpenAI client.
type Option func(*options)

// ResponseFormat is the response format for the OpenAI client.
type ResponseFormat = openaiclient.ResponseFormat

// ResponseFormatJSON is the JSON response format.
var ResponseFormatJSON = ResponseFormat{Type: "json_object"} //nolint:gochecknoglobals

// WithToken passes the OpenAI API token to the client. If not set, the token
// is read from the OPENAI_API_KEY environment variable.
func WithToken(token string) Option {
Expand Down Expand Up @@ -112,3 +121,10 @@ func WithCallback(callbackHandler callbacks.Handler) Option {
opts.callbackHandler = callbackHandler
}
}

// WithResponseFormat allows setting a custom response format.
func WithResponseFormat(responseFormat ResponseFormat) Option {
return func(opts *options) {
opts.responseFormat = responseFormat
}
}
11 changes: 11 additions & 0 deletions llms/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ type CallOptions struct {
// PresencePenalty is the presence penalty for sampling.
PresencePenalty float64 `json:"presence_penalty"`

// JSONMode is a flag to enable JSON mode.
JSONMode bool `json:"json"`

// Function defitions to include in the request.
Functions []FunctionDefinition `json:"functions"`
// FunctionCallBehavior is the behavior to use when calling functions.
Expand Down Expand Up @@ -195,3 +198,11 @@ func WithFunctions(functions []FunctionDefinition) CallOption {
o.Functions = functions
}
}

// WithJSONMode will add an option to set the response format to JSON.
// This is useful for models that return structured data.
func WithJSONMode() CallOption {
return func(o *CallOptions) {
o.JSONMode = true
}
}

0 comments on commit 6045596

Please sign in to comment.