Skip to content

Commit

Permalink
Add prompt_tokens and completion_tokens to conv_turns
Browse files Browse the repository at this point in the history
  • Loading branch information
xwjdsh authored and lyricat committed Apr 3, 2023
1 parent 9e46d9a commit a26f871
Show file tree
Hide file tree
Showing 8 changed files with 102 additions and 70 deletions.
2 changes: 1 addition & 1 deletion cmd/httpd/httpd.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ func NewCmdHttpd() *cobra.Command {
ExtraRate: cfg.Sys.ExtraRate,
InitUserCredits: cfg.Sys.InitUserCredits,
}, client, users)
indexService := indexServ.NewService(ctx, gptHandler, indexes, userz)
indexService := indexServ.NewService(ctx, gptHandler, indexes, userz, models)
appz := appServ.New(appServ.Config{
SecretKey: cfg.Sys.SecretKey,
}, apps, indexService)
Expand Down
32 changes: 19 additions & 13 deletions core/conv.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,20 @@ type (
}

ConvTurn struct {
ID uint64 `yaml:"id" json:"id"`
ConversationID string `yaml:"conversation_id" json:"conversation_id"`
BotID uint64 `yaml:"bot_id" json:"bot_id"`
AppID uint64 `yaml:"app_id" json:"app_id"`
UserID uint64 `yaml:"user_id" json:"user_id"`
UserIdentity string `yaml:"user_identity" json:"user_identity"`
Request string `yaml:"request" json:"request"`
Response string `yaml:"response" json:"response"`
TotalTokens int `yaml:"total_tokens" json:"total_tokens"`
Status int `yaml:"status" json:"status"`
CreatedAt *time.Time `yaml:"created_at" json:"created_at"`
UpdatedAt *time.Time `yaml:"updated_at" json:"updated_at"`
ID uint64 `yaml:"id" json:"id"`
ConversationID string `yaml:"conversation_id" json:"conversation_id"`
BotID uint64 `yaml:"bot_id" json:"bot_id"`
AppID uint64 `yaml:"app_id" json:"app_id"`
UserID uint64 `yaml:"user_id" json:"user_id"`
UserIdentity string `yaml:"user_identity" json:"user_identity"`
Request string `yaml:"request" json:"request"`
Response string `yaml:"response" json:"response"`
PromptTokens int `yaml:"prompt_tokens" json:"prompt_tokens"`
CompletionTokens int `yaml:"completion_tokens" json:"completion_tokens"`
TotalTokens int `yaml:"total_tokens" json:"total_tokens"`
Status int `yaml:"status" json:"status"`
CreatedAt *time.Time `yaml:"created_at" json:"created_at"`
UpdatedAt *time.Time `yaml:"updated_at" json:"updated_at"`
}

ConversationStore interface {
Expand All @@ -63,6 +65,7 @@ type (
// "conversation_id", "bot_id", "app_id", "user_id",
// "user_identity",
// "request", "response", "status",
// "prompt_tokens", "completion_tokens", "total_tokens",
// "created_at", "updated_at"
// FROM "conv_turns" WHERE
// "id" IN (@ids)
Expand All @@ -73,6 +76,7 @@ type (
// "conversation_id", "bot_id", "app_id", "user_id",
// "user_identity",
// "request", "response", "status",
// "prompt_tokens", "completion_tokens", "total_tokens",
// "created_at", "updated_at"
// FROM "conv_turns" WHERE
// "id" = @id
Expand All @@ -96,13 +100,15 @@ type (
// UPDATE "conv_turns"
// {{set}}
// "response"=@response,
// "prompt_tokens"=@promptTokens,
// "completion_tokens"=@completionTokens,
// "total_tokens"=@totalTokens,
// "status"=@status,
// "updated_at"=NOW()
// {{end}}
// WHERE
// "id"=@id
UpdateConvTurn(ctx context.Context, id uint64, response string, totalTokens int64, status int) error
UpdateConvTurn(ctx context.Context, id uint64, response string, promptTokens, completionTokens, totalTokens int64, status int) error
}

ConversationService interface {
Expand Down
48 changes: 24 additions & 24 deletions handler/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,27 +27,27 @@ func New(cfg Config, s *session.Session,
models core.ModelStore,
appz core.AppService,
botz core.BotService,
indexService core.IndexService,
indexz core.IndexService,
userz core.UserService,
convz core.ConversationService,
orderz core.OrderService,
hub *chanhub.Hub,
) Server {
return Server{
cfg: cfg,
apps: apps,
indexes: indexs,
users: users,
appz: appz,
models: models,
indexService: indexService,
botz: botz,
convz: convz,
userz: userz,
session: s,
convs: convs,
orderz: orderz,
hub: hub,
cfg: cfg,
apps: apps,
indexes: indexs,
users: users,
appz: appz,
models: models,
indexz: indexz,
botz: botz,
convz: convz,
userz: userz,
session: s,
convs: convs,
orderz: orderz,
hub: hub,
}
}

Expand All @@ -65,12 +65,12 @@ type (
convs core.ConversationStore
models core.ModelStore

botz core.BotService
appz core.AppService
indexService core.IndexService
convz core.ConversationService
userz core.UserService
orderz core.OrderService
botz core.BotService
appz core.AppService
indexz core.IndexService
convz core.ConversationService
userz core.UserService
orderz core.OrderService

hub *chanhub.Hub
}
Expand All @@ -83,9 +83,9 @@ func (s Server) HandleRest() http.Handler {
r.Use(auth.HandleAuthentication(s.session, s.users))

r.Route("/indexes", func(r chi.Router) {
r.With(auth.HandleAppSecretRequired(), auth.UserCreditRequired(s.users)).Post("/", indexHandler.CreateIndex(s.indexService))
r.With(auth.HandleAppSecretRequired()).Post("/reset", indexHandler.ResetIndexes(s.indexService))
r.With(auth.UserCreditRequired(s.users)).Get("/search", indexHandler.Search(s.apps, s.indexService))
r.With(auth.HandleAppSecretRequired(), auth.UserCreditRequired(s.users)).Post("/", indexHandler.CreateIndex(s.indexz))
r.With(auth.HandleAppSecretRequired()).Post("/reset", indexHandler.ResetIndexes(s.indexz))
r.With(auth.UserCreditRequired(s.users)).Get("/search", indexHandler.Search(s.apps, s.indexz))
r.With(auth.HandleAppSecretRequired()).Delete("/{objectID}", indexHandler.Delete(s.apps, s.indexes))
})

Expand Down
4 changes: 1 addition & 3 deletions service/conv/conv.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,7 @@ func (s *service) GetConversation(ctx context.Context, convID string) (*core.Con
turn := conv.History[ix]
existed, ok := turnMap[turn.ID]
if ok && existed.Status != turn.Status {
turn.Status = existed.Status
turn.Response = existed.Response
turn.UpdatedAt = existed.UpdatedAt
*turn = *existed
}
}
}
Expand Down
3 changes: 2 additions & 1 deletion service/index/index.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@ import (
gogpt "github.com/sashabaranov/go-openai"
)

func NewService(ctx context.Context, gptHandler *gpt.Handler, indexes core.IndexStore, userz core.UserService) core.IndexService {
func NewService(ctx context.Context, gptHandler *gpt.Handler, indexes core.IndexStore, userz core.UserService, models core.ModelStore) core.IndexService {
return &serviceImpl{
gptHandler: gptHandler,
indexes: indexes,
userz: userz,
models: models,
createEmbeddingsLimitChan: make(chan struct{}, 20),
}
}
Expand Down
64 changes: 39 additions & 25 deletions store/conv/dao/conv_turns.gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 13 additions & 0 deletions store/migrations/20230402185050_add_prompt_completion_tokens.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
-- +goose Up
-- +goose StatementBegin
SELECT 'up SQL query';
ALTER TABLE conv_turns ADD COLUMN "prompt_tokens" int DEFAULT 0;
ALTER TABLE conv_turns ADD COLUMN "completion_tokens" int DEFAULT 0;
-- +goose StatementEnd

-- +goose Down
-- +goose StatementBegin
SELECT 'down SQL query';
ALTER TABLE conv_turns DROP COLUMN IF EXISTS "prompt_tokens";
ALTER TABLE conv_turns DROP COLUMN IF EXISTS "completion_tokens";
-- +goose StatementEnd
6 changes: 3 additions & 3 deletions worker/rotater/rotater.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ func (w *Worker) run(ctx context.Context) error {
}

if turn.Status == core.ConvTurnStatusInit {
if err := w.convs.UpdateConvTurn(ctx, turn.ID, "", 0, core.ConvTurnStatusPending); err != nil {
if err := w.convs.UpdateConvTurn(ctx, turn.ID, "", 0, 0, 0, core.ConvTurnStatusPending); err != nil {
continue
}
}
Expand All @@ -179,7 +179,7 @@ func (w *Worker) run(ctx context.Context) error {

func (w *Worker) UpdateConvTurnAsError(ctx context.Context, id uint64, errMsg string) error {
fmt.Printf("errMsg: %v, %d\n", errMsg, id)
if err := w.convs.UpdateConvTurn(ctx, id, "Something wrong happened", 0, core.ConvTurnStatusError); err != nil {
if err := w.convs.UpdateConvTurn(ctx, id, "Something wrong happened", 0, 0, 0, core.ConvTurnStatusError); err != nil {
return err
}
return nil
Expand Down Expand Up @@ -214,7 +214,7 @@ func (w *Worker) subworker(ctx context.Context, id int) {
return
}

if err := w.convs.UpdateConvTurn(ctx, turn.ID, rr.respText, rr.totalTokens, core.ConvTurnStatusCompleted); err != nil {
if err := w.convs.UpdateConvTurn(ctx, turn.ID, rr.respText, rr.promptTokenCount, rr.completionTokenCount, rr.totalTokens, core.ConvTurnStatusCompleted); err != nil {
return
}

Expand Down

0 comments on commit a26f871

Please sign in to comment.