Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

265: unifying roles, setting default as user #269

Open
wants to merge 9 commits into
base: develop
Choose a base branch
from
19 changes: 18 additions & 1 deletion docs/docs.go
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,11 @@ const docTemplate = `{
},
"role": {
"description": "The role of the author of this message. One of system, user, or assistant.",
"type": "string"
"allOf": [
{
"$ref": "#/definitions/schemas.Role"
}
]
}
}
},
Expand Down Expand Up @@ -308,6 +312,19 @@ const docTemplate = `{
}
}
},
"schemas.Role": {
"type": "string",
"enum": [
"system",
"user",
"assistant"
],
"x-enum-varnames": [
"RoleSystem",
"RoleUser",
"RoleAssistant"
]
},
"schemas.RouterListSchema": {
"type": "object",
"properties": {
Expand Down
19 changes: 18 additions & 1 deletion docs/swagger.json
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,11 @@
},
"role": {
"description": "The role of the author of this message. One of system, user, or assistant.",
"type": "string"
"allOf": [
{
"$ref": "#/definitions/schemas.Role"
}
]
}
}
},
Expand Down Expand Up @@ -305,6 +309,19 @@
}
}
},
"schemas.Role": {
"type": "string",
"enum": [
"system",
"user",
"assistant"
],
"x-enum-varnames": [
"RoleSystem",
"RoleUser",
"RoleAssistant"
]
},
"schemas.RouterListSchema": {
"type": "object",
"properties": {
Expand Down
13 changes: 12 additions & 1 deletion docs/swagger.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@ definitions:
description: The content of the message.
type: string
role:
allOf:
- $ref: '#/definitions/schemas.Role'
description: The role of the author of this message. One of system, user,
or assistant.
type: string
required:
- content
- role
Expand Down Expand Up @@ -77,6 +78,16 @@ definitions:
token_usage:
$ref: '#/definitions/schemas.TokenUsage'
type: object
schemas.Role:
enum:
- system
- user
- assistant
type: string
x-enum-varnames:
- RoleSystem
- RoleUser
- RoleAssistant
schemas.RouterListSchema:
properties:
routers:
Expand Down
12 changes: 10 additions & 2 deletions pkg/api/schemas/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func (r *ChatRequest) Params(modelID string, modelName string) *ChatParams {
func NewChatFromStr(message string) *ChatRequest {
return &ChatRequest{
Message: ChatMessage{
"user",
RoleUser,
message,
},
}
Expand Down Expand Up @@ -93,10 +93,18 @@ type TokenUsage struct {
TotalTokens int `json:"total_tokens"`
}

type Role string

const (
RoleSystem Role = "system"
RoleUser Role = "user"
RoleAssistant Role = "assistant"
)

// ChatMessage is a message in a chat request.
type ChatMessage struct {
// The role of the author of this message. One of system, user, or assistant.
Role string `json:"role" validate:"required"`
Role Role `json:"role" validate:"required"`
// The content of the message.
Content string `json:"content" validate:"required"`
}
2 changes: 1 addition & 1 deletion pkg/api/schemas/chat_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func NewChatStreamFromStr(message string) *ChatStreamRequest {
return &ChatStreamRequest{
ChatRequest: &ChatRequest{
Message: ChatMessage{
"user",
RoleUser,
message,
},
},
Expand Down
18 changes: 9 additions & 9 deletions pkg/api/schemas/chat_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func TestChatRequest_DefaultParams(t *testing.T) {

chatReq := ChatRequest{
Message: ChatMessage{
Role: "user",
Role: RoleUser,
Content: defaultMessage,
},
MessageHistory: []ChatMessage{
Expand All @@ -42,7 +42,7 @@ func TestChatRequest_DefaultParams(t *testing.T) {
OverrideParams: &map[string]ModelParamsOverride{
modelID: {
Message: ChatMessage{
Role: "user",
Role: RoleUser,
Content: myModelMessage,
},
},
Expand All @@ -66,7 +66,7 @@ func TestChatRequest_ModelIDOverride(t *testing.T) {

chatReq := ChatRequest{
Message: ChatMessage{
Role: "user",
Role: RoleUser,
Content: defaultMessage,
},
MessageHistory: []ChatMessage{
Expand All @@ -78,7 +78,7 @@ func TestChatRequest_ModelIDOverride(t *testing.T) {
OverrideParams: &map[string]ModelParamsOverride{
modelID: {
Message: ChatMessage{
Role: "user",
Role: RoleUser,
Content: myModelMessage,
},
},
Expand All @@ -102,7 +102,7 @@ func TestChatRequest_ModelNameOverride(t *testing.T) {

chatReq := ChatRequest{
Message: ChatMessage{
Role: "user",
Role: RoleUser,
Content: defaultMessage,
},
MessageHistory: []ChatMessage{
Expand All @@ -114,7 +114,7 @@ func TestChatRequest_ModelNameOverride(t *testing.T) {
OverrideParams: &map[string]ModelParamsOverride{
modelName: {
Message: ChatMessage{
Role: "user",
Role: RoleUser,
Content: myModelMessage,
},
},
Expand All @@ -139,7 +139,7 @@ func TestChatRequest_ModelNameIDOverride(t *testing.T) {

chatReq := ChatRequest{
Message: ChatMessage{
Role: "user",
Role: RoleUser,
Content: defaultMessage,
},
MessageHistory: []ChatMessage{
Expand All @@ -151,13 +151,13 @@ func TestChatRequest_ModelNameIDOverride(t *testing.T) {
OverrideParams: &map[string]ModelParamsOverride{
modelName: {
Message: ChatMessage{
Role: "user",
Role: RoleUser,
Content: myModelNameMessage,
},
},
modelID: {
Message: ChatMessage{
Role: "user",
Role: RoleUser,
Content: myModelIDMessage,
},
},
Expand Down
2 changes: 1 addition & 1 deletion pkg/providers/anthropic/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche
ModelResponse: schemas.ModelResponse{
Metadata: map[string]string{},
Message: schemas.ChatMessage{
Role: completion.Type,
Role: schemas.Role(completion.Type),
Content: completion.Text,
},
TokenUsage: schemas.TokenUsage{
Expand Down
4 changes: 2 additions & 2 deletions pkg/providers/azureopenai/chat_stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ func TestAzureOpenAIClient_ChatStreamRequest(t *testing.T) {
require.NoError(t, err)

chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{
Role: "user",
Role: schemas.RoleUser,
Content: "What's the capital of the United Kingdom?",
}}}

Expand Down Expand Up @@ -140,7 +140,7 @@ func TestAzureOpenAIClient_ChatStreamRequestInterrupted(t *testing.T) {
require.NoError(t, err)

chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{
Role: "user",
Role: schemas.RoleUser,
Content: "What's the biggest animal?",
}}}

Expand Down
4 changes: 2 additions & 2 deletions pkg/providers/azureopenai/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func TestAzureOpenAIClient_ChatRequest(t *testing.T) {
require.NoError(t, err)

chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{
Role: "user",
Role: schemas.RoleUser,
Content: "What's the capital of the United Kingdom?",
}}}

Expand Down Expand Up @@ -116,7 +116,7 @@ func TestDoChatRequest_ErrorResponse(t *testing.T) {
require.NoError(t, err)

chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{
Role: "user",
Role: schemas.RoleUser,
Content: "What's the dealio?",
}}}

Expand Down
2 changes: 1 addition & 1 deletion pkg/providers/bedrock/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche
Cached: false,
ModelResponse: schemas.ModelResponse{
Message: schemas.ChatMessage{
Role: "assistant",
Role: schemas.RoleAssistant,
Content: modelResult.OutputText,
},
TokenUsage: schemas.TokenUsage{
Expand Down
2 changes: 1 addition & 1 deletion pkg/providers/bedrock/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func TestBedrockClient_ChatRequest(t *testing.T) {
require.NoError(t, err)

chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{
Role: "user",
Role: schemas.RoleUser,
Content: "What's the biggest animal?",
}}}

Expand Down
2 changes: 1 addition & 1 deletion pkg/providers/bedrock/testdata/chat.req.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"model": "amazon.titan-text-express-v1",
"messages": [
{
"role": "user",
"role": schemas.RoleUser,
"content": "What's the biggest animal?"
}
],
Expand Down
2 changes: 1 addition & 1 deletion pkg/providers/cohere/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche
"responseId": cohereCompletion.ResponseID,
},
Message: schemas.ChatMessage{
Role: "assistant",
Role: payload.Role,
Content: cohereCompletion.Text,
},
TokenUsage: schemas.TokenUsage{
Expand Down
25 changes: 22 additions & 3 deletions pkg/providers/cohere/schemas.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package cohere

import "github.com/EinStack/glide/pkg/api/schemas"

// Cohere Chat Response
// ChatCompletion Cohere Chat Response
type ChatCompletion struct {
Text string `json:"text"`
GenerationID string `json:"generation_id"`
Expand Down Expand Up @@ -92,6 +92,7 @@ type FinalResponse struct {
type ChatRequest struct {
Model string `json:"model"`
Message string `json:"message"`
Role schemas.Role `json:"role"`
ChatHistory []schemas.ChatMessage `json:"chat_history"`
Temperature float64 `json:"temperature,omitempty"`
Preamble string `json:"preamble,omitempty"`
Expand All @@ -112,8 +113,26 @@ func (r *ChatRequest) ApplyParams(params *schemas.ChatParams) {
message := params.Messages[len(params.Messages)-1]
messageHistory := params.Messages[:len(params.Messages)-1]

// TODO: Map chat message roles to Cohere roles: CHATBOT, SYSTEM, USER

mapRole := func(role schemas.Role) string {
switch role {
case schemas.RoleSystem:
return "SYSTEM"
case schemas.RoleUser:
return "USER"
case schemas.RoleAssistant:
return "CHATBOT"
default:
return "USER"
}
}

for i := range messageHistory {
messageHistory[i].Role = schemas.Role(mapRole(messageHistory[i].Role))
}

message.Role = schemas.Role(mapRole(message.Role))

r.Role = message.Role
r.Message = message.Content
r.ChatHistory = messageHistory
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/providers/octoml/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ func TestDoChatRequest_ErrorResponse(t *testing.T) {

// Create a chat request payload
chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{
Role: "user",
Role: schemas.RoleUser,
Content: "What's the dealeo?",
}}}

Expand Down
2 changes: 1 addition & 1 deletion pkg/providers/ollama/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche
Cached: false,
ModelResponse: schemas.ModelResponse{
Message: schemas.ChatMessage{
Role: ollamaCompletion.Message.Role,
Role: schemas.Role(ollamaCompletion.Message.Role),
Content: ollamaCompletion.Message.Content,
},
TokenUsage: schemas.TokenUsage{
Expand Down
8 changes: 4 additions & 4 deletions pkg/providers/ollama/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ func TestOllamaClient_ChatRequest(t *testing.T) {
require.NoError(t, err)

chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{
Role: "user",
Role: schemas.RoleUser,
Content: "What's the biggest animal?",
}}}

Expand Down Expand Up @@ -85,7 +85,7 @@ func TestOllamaClient_ChatRequest_Non200Response(t *testing.T) {
require.NoError(t, err)

chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{
Role: "user",
Role: schemas.RoleUser,
Content: "What's the capital of the United Kingdom?",
}}}

Expand Down Expand Up @@ -122,14 +122,14 @@ func TestOllamaClient_ChatRequest_SuccessfulResponse(t *testing.T) {
require.NoError(t, err)

chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{
Role: "user",
Role: schemas.RoleUser,
Content: "What's the capital of the United Kingdom?",
}}}

response, err := client.Chat(context.Background(), &chatParams)

require.NoError(t, err)
require.NotNil(t, response)
require.Equal(t, "assistant", response.ModelResponse.Message.Role)
require.Equal(t, schemas.RoleAssistant, response.ModelResponse.Message.Role)
require.Equal(t, "London", response.ModelResponse.Message.Content)
}
Loading
Loading