From c924140ec6b1a4169104fd1396e695ded2370cc2 Mon Sep 17 00:00:00 2001 From: Roman Glushko Date: Fri, 5 Jul 2024 15:16:43 +0300 Subject: [PATCH] #67: Fixed broken code after the first wave of refactoring/restructuring --- config.dev.yaml | 5 +- config.sample.yaml | 9 + pkg/models/config.go | 17 +- pkg/providers/config.go | 21 +- pkg/providers/openai/embed.go | 2 +- pkg/providers/provider.go | 3 + pkg/providers/testing/models.go | 6 +- pkg/routers/config.go | 11 + pkg/routers/embed/config.go | 3 +- pkg/routers/embed/router.go | 24 +- pkg/routers/lang/config.go | 37 ++- pkg/routers/lang/config_test.go | 216 ++++++++---------- pkg/routers/lang/router_test.go | 87 ++++--- pkg/routers/routing/least_latency.go | 14 +- pkg/routers/routing/least_latency_test.go | 16 +- pkg/routers/routing/priority.go | 16 +- pkg/routers/routing/priority_test.go | 14 +- pkg/routers/routing/round_robin.go | 8 +- pkg/routers/routing/round_robin_test.go | 14 +- pkg/routers/routing/strategies.go | 4 +- pkg/routers/routing/weighted_round_robin.go | 8 +- .../routing/weighted_round_robin_test.go | 14 +- 22 files changed, 286 insertions(+), 263 deletions(-) diff --git a/config.dev.yaml b/config.dev.yaml index 80c77a5b..8bd2af4a 100644 --- a/config.dev.yaml +++ b/config.dev.yaml @@ -8,5 +8,6 @@ routers: - id: default models: - id: openai - openai: - api_key: "${env:OPENAI_API_KEY}" + provider: + openai: + api_key: "${env:OPENAI_API_KEY}" diff --git a/config.sample.yaml b/config.sample.yaml index 3ce72055..7118a6f5 100644 --- a/config.sample.yaml +++ b/config.sample.yaml @@ -6,3 +6,12 @@ telemetry: #api: # http: # ... + +routers: + language: + - id: default + models: + - id: openai + provider: + openai: + api_key: "${env:OPENAI_API_KEY}" diff --git a/pkg/models/config.go b/pkg/models/config.go index 5289ea03..97c93443 100644 --- a/pkg/models/config.go +++ b/pkg/models/config.go @@ -3,13 +3,16 @@ package models import ( "fmt" + "github.com/EinStack/glide/pkg/providers" + "github.com/EinStack/glide/pkg/clients" "github.com/EinStack/glide/pkg/resiliency/health" "github.com/EinStack/glide/pkg/routers/latency" "github.com/EinStack/glide/pkg/telemetry" ) -type Config[P any] struct { +// Config defines an extra configuration for a model wrapper around a provider +type Config[P providers.ProviderFactory] struct { ID string `yaml:"id" json:"id" validate:"required"` // Model instance ID (unique in scope of the router) Enabled bool `yaml:"enabled" json:"enabled" validate:"required"` // Is the model enabled? ErrorBudget *health.ErrorBudget `yaml:"error_budget" json:"error_budget" swaggertype:"primitive,string"` @@ -20,7 +23,15 @@ type Config[P any] struct { Provider P `yaml:"provider" json:"provider"` } -func DefaultConfig[P any]() Config[P] { +func NewConfig[P providers.ProviderFactory](ID string) *Config[P] { + config := DefaultConfig[P]() + + config.ID = ID + + return &config +} + +func DefaultConfig[P providers.ProviderFactory]() Config[P] { return Config[P]{ Enabled: true, Client: clients.DefaultClientConfig(), @@ -30,7 +41,7 @@ func DefaultConfig[P any]() Config[P] { } } -func (c *Config) ToModel(tel *telemetry.Telemetry) (*LanguageModel, error) { +func (c *Config[P]) ToModel(tel *telemetry.Telemetry) (*LanguageModel, error) { client, err := c.Provider.ToClient(tel, c.Client) if err != nil { return nil, fmt.Errorf("error initializing client: %w", err) diff --git a/pkg/providers/config.go b/pkg/providers/config.go index 8e1c66c2..d4145611 100644 --- a/pkg/providers/config.go +++ b/pkg/providers/config.go @@ -16,7 +16,16 @@ import ( "github.com/EinStack/glide/pkg/telemetry" ) -var ErrProviderNotFound = errors.New("provider not found") +// TODO: ProviderFactory should be more generic, not tied to LangProviders + +var ErrNoProviderConfigured = errors.New("exactly one provider must be configured, none is configured") + +type ProviderFactory interface { + ToClient(tel *telemetry.Telemetry, clientConfig *clients.ClientConfig) (LangProvider, error) +} + +// TODO: LangProviders should be decoupled and +// represented as a registry where providers can add their factories dynamically type LangProviders struct { // Add other providers like @@ -29,9 +38,11 @@ type LangProviders struct { Ollama *ollama.Config `yaml:"ollama,omitempty" json:"ollama,omitempty"` } +var _ ProviderFactory = (*LangProviders)(nil) + // ToClient initializes the language model client based on the provided configuration. // It takes a telemetry object as input and returns a LangModelProvider and an error. -func (c *LangProviders) ToClient(tel *telemetry.Telemetry, clientConfig *clients.ClientConfig) (LangProvider, error) { +func (c LangProviders) ToClient(tel *telemetry.Telemetry, clientConfig *clients.ClientConfig) (LangProvider, error) { switch { case c.OpenAI != nil: return openai.NewClient(c.OpenAI, clientConfig, tel) @@ -83,7 +94,7 @@ func (c *LangProviders) validateOneProvider() error { // check other providers here if providersConfigured == 0 { - return fmt.Errorf("exactly one provider must be configured, none is configured") + return ErrNoProviderConfigured } if providersConfigured > 1 { @@ -97,9 +108,7 @@ func (c *LangProviders) validateOneProvider() error { } func (c *LangProviders) UnmarshalYAML(unmarshal func(interface{}) error) error { - *c = DefaultConfig() - - type plain LangModelConfig // to avoid recursion + type plain LangProviders // to avoid recursion if err := unmarshal((*plain)(c)); err != nil { return err diff --git a/pkg/providers/openai/embed.go b/pkg/providers/openai/embed.go index 48e69328..69f9aa27 100644 --- a/pkg/providers/openai/embed.go +++ b/pkg/providers/openai/embed.go @@ -7,7 +7,7 @@ import ( ) // Embed sends an embedding request to the specified OpenAI model. -func (c *Client) Embed(ctx context.Context, params *schemas.ChatParams) (*schemas.ChatResponse, error) { +func (c *Client) Embed(_ context.Context, _ *schemas.ChatParams) (*schemas.ChatResponse, error) { // TODO: implement return nil, nil } diff --git a/pkg/providers/provider.go b/pkg/providers/provider.go index 2341ddc9..7fdefc25 100644 --- a/pkg/providers/provider.go +++ b/pkg/providers/provider.go @@ -2,11 +2,14 @@ package providers import ( "context" + "errors" "github.com/EinStack/glide/pkg/api/schemas" "github.com/EinStack/glide/pkg/clients" ) +var ErrProviderNotFound = errors.New("provider not found") + // ModelProvider exposes provider context type ModelProvider interface { Provider() string diff --git a/pkg/providers/testing/models.go b/pkg/providers/testing/models.go index d4ac3840..57500d21 100644 --- a/pkg/providers/testing/models.go +++ b/pkg/providers/testing/models.go @@ -4,10 +4,8 @@ import ( "time" "github.com/EinStack/glide/pkg/config/fields" - + "github.com/EinStack/glide/pkg/models" "github.com/EinStack/glide/pkg/routers/latency" - - "github.com/EinStack/glide/pkg/providers" ) // LangModelMock @@ -55,6 +53,6 @@ func (m LangModelMock) Weight() int { return m.weight } -func ChatMockLatency(model providers.Model) *latency.MovingAverage { +func ChatMockLatency(model models.Model) *latency.MovingAverage { return model.(LangModelMock).chatLatency } diff --git a/pkg/routers/config.go b/pkg/routers/config.go index 99cef09f..a3c8f69a 100644 --- a/pkg/routers/config.go +++ b/pkg/routers/config.go @@ -5,9 +5,20 @@ import ( "github.com/EinStack/glide/pkg/routers/routing" ) +// TODO: how to specify other backoff strategies? +// TODO: Had to keep RoutingStrategy because of https://github.com/swaggo/swag/issues/1738 + type RouterConfig struct { ID string `yaml:"id" json:"routers" validate:"required"` // Unique router ID Enabled bool `yaml:"enabled" json:"enabled" validate:"required"` // Is router enabled? Retry *retry.ExpRetryConfig `yaml:"retry" json:"retry" validate:"required"` // retry when no healthy model is available to router RoutingStrategy routing.Strategy `yaml:"strategy" json:"strategy" swaggertype:"primitive,string" validate:"required"` // strategy on picking the next model to serve the request } + +func DefaultConfig() RouterConfig { + return RouterConfig{ + Enabled: true, + RoutingStrategy: routing.Priority, + Retry: retry.DefaultExpRetryConfig(), + } +} diff --git a/pkg/routers/embed/config.go b/pkg/routers/embed/config.go index 63894e82..49f4821b 100644 --- a/pkg/routers/embed/config.go +++ b/pkg/routers/embed/config.go @@ -1,11 +1,10 @@ package embed import ( - "github.com/EinStack/glide/pkg/providers" "github.com/EinStack/glide/pkg/routers" ) type EmbeddingRouterConfig struct { routers.RouterConfig - Models []providers.LangModelConfig `yaml:"models" json:"models" validate:"required,min=1,dive"` // the list of models that could handle requests + // Models []providers.LangModelConfig `yaml:"models" json:"models" validate:"required,min=1,dive"` // the list of models that could handle requests } diff --git a/pkg/routers/embed/router.go b/pkg/routers/embed/router.go index 501f63e6..9068537b 100644 --- a/pkg/routers/embed/router.go +++ b/pkg/routers/embed/router.go @@ -1,22 +1,12 @@ package embed -import ( - "context" - - "github.com/EinStack/glide/pkg/api/schemas" - "github.com/EinStack/glide/pkg/resiliency/retry" - "github.com/EinStack/glide/pkg/routers/lang" - "github.com/EinStack/glide/pkg/telemetry" - "go.uber.org/zap" -) - type EmbeddingRouter struct { - routerID lang.RouterID - Config *LangRouterConfig - retry *retry.ExpRetry - tel *telemetry.Telemetry - logger *zap.Logger + // routerID lang.RouterID + // Config *LangRouterConfig + // retry *retry.ExpRetry + // tel *telemetry.Telemetry + // logger *zap.Logger } -func (r *EmbeddingRouter) Embed(ctx context.Context, req *schemas.EmbedRequest) (*schemas.EmbedResponse, error) { -} +//func (r *EmbeddingRouter) Embed(ctx context.Context, req *schemas.EmbedRequest) (*schemas.EmbedResponse, error) { +//} diff --git a/pkg/routers/lang/config.go b/pkg/routers/lang/config.go index f68d1c8b..0be9ef06 100644 --- a/pkg/routers/lang/config.go +++ b/pkg/routers/lang/config.go @@ -15,12 +15,37 @@ import ( "go.uber.org/zap" ) -// TODO: how to specify other backoff strategies? -// TODO: Had to keep RoutingStrategy because of https://github.com/swaggo/swag/issues/1738 +type ( + ModelConfig = models.Config[providers.LangProviders] + ModelPoolConfig = []ModelConfig +) + // RouterConfig type RouterConfig struct { routers.RouterConfig - Models []providers.LangModelConfig `yaml:"models" json:"models" validate:"required,min=1,dive"` // the list of models that could handle requests + Models ModelPoolConfig `yaml:"models" json:"models" validate:"required,min=1,dive"` // the list of models that could handle requests +} + +type RouterConfigOption = func(*RouterConfig) + +func WithModels(models ModelPoolConfig) RouterConfigOption { + return func(c *RouterConfig) { + c.Models = models + } +} + +func NewRouterConfig(RouterID string, opt ...RouterConfigOption) *RouterConfig { + config := &RouterConfig{ + RouterConfig: routers.DefaultConfig(), + } + + config.ID = RouterID + + for _, o := range opt { + o(config) + } + + return config } // BuildModels creates LanguageModel slice out of the given config @@ -165,11 +190,7 @@ func (c *RouterConfig) BuildRouting( func DefaultRouterConfig() *RouterConfig { return &RouterConfig{ - RouterConfig: routers.RouterConfig{ - Enabled: true, - RoutingStrategy: routing.Priority, - Retry: retry.DefaultExpRetryConfig(), - }, + RouterConfig: routers.DefaultConfig(), } } diff --git a/pkg/routers/lang/config_test.go b/pkg/routers/lang/config_test.go index 79cbb210..975cdcb6 100644 --- a/pkg/routers/lang/config_test.go +++ b/pkg/routers/lang/config_test.go @@ -4,70 +4,59 @@ import ( "testing" "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/providers" "github.com/EinStack/glide/pkg/providers/cohere" + "github.com/EinStack/glide/pkg/providers/openai" "github.com/EinStack/glide/pkg/resiliency/health" - "github.com/EinStack/glide/pkg/resiliency/retry" - routers2 "github.com/EinStack/glide/pkg/routers" - "github.com/EinStack/glide/pkg/telemetry" - - "github.com/EinStack/glide/pkg/routers/routing" - "github.com/EinStack/glide/pkg/routers/latency" - - "github.com/EinStack/glide/pkg/providers/openai" - - "github.com/EinStack/glide/pkg/providers" - + "github.com/EinStack/glide/pkg/routers/routing" + "github.com/EinStack/glide/pkg/telemetry" "github.com/stretchr/testify/require" ) func TestRouterConfig_BuildModels(t *testing.T) { defaultParams := openai.DefaultParams() - cfg := routers2.Config{ - LanguageRouters: []RouterConfig{ - { - ID: "first_router", - Enabled: true, - RoutingStrategy: routing.Priority, - Retry: retry.DefaultExpRetryConfig(), - Models: []providers.LangModelConfig{ - { - ID: "first_model", - Enabled: true, - Client: clients.DefaultClientConfig(), - ErrorBudget: health.DefaultErrorBudget(), - Latency: latency.DefaultConfig(), + cfg := RoutersConfig{ + *NewRouterConfig( + "first_router", + WithModels(ModelPoolConfig{ + { + ID: "first_model", + Enabled: true, + Client: clients.DefaultClientConfig(), + ErrorBudget: health.DefaultErrorBudget(), + Latency: latency.DefaultConfig(), + Provider: providers.LangProviders{ OpenAI: &openai.Config{ APIKey: "ABC", DefaultParams: &defaultParams, }, }, }, - }, - { - ID: "second_router", - Enabled: true, - RoutingStrategy: routing.LeastLatency, - Retry: retry.DefaultExpRetryConfig(), - Models: []providers.LangModelConfig{ - { - ID: "first_model", - Enabled: true, - Client: clients.DefaultClientConfig(), - ErrorBudget: health.DefaultErrorBudget(), - Latency: latency.DefaultConfig(), + }), + ), + *NewRouterConfig( + "second_router", + WithModels(ModelPoolConfig{ + { + ID: "first_model", + Enabled: true, + Client: clients.DefaultClientConfig(), + ErrorBudget: health.DefaultErrorBudget(), + Latency: latency.DefaultConfig(), + Provider: providers.LangProviders{ OpenAI: &openai.Config{ APIKey: "ABC", DefaultParams: &defaultParams, }, }, }, - }, - }, + }), + ), } - routers, err := cfg.BuildLangRouters(telemetry.NewTelemetryMock()) + routers, err := cfg.Build(telemetry.NewTelemetryMock()) require.NoError(t, err) require.Len(t, routers, 2) @@ -82,21 +71,20 @@ func TestRouterConfig_BuildModelsPerType(t *testing.T) { openAIParams := openai.DefaultParams() cohereParams := cohere.DefaultParams() - cfg := LangRouterConfig{ - ID: "first_router", - Enabled: true, - RoutingStrategy: routing.Priority, - Retry: retry.DefaultExpRetryConfig(), - Models: []providers.LangModelConfig{ + cfg := NewRouterConfig( + "first_router", + WithModels(ModelPoolConfig{ { ID: "first_model", Enabled: true, Client: clients.DefaultClientConfig(), ErrorBudget: health.DefaultErrorBudget(), Latency: latency.DefaultConfig(), - OpenAI: &openai.Config{ - APIKey: "ABC", - DefaultParams: &openAIParams, + Provider: providers.LangProviders{ + OpenAI: &openai.Config{ + APIKey: "ABC", + DefaultParams: &openAIParams, + }, }, }, { @@ -105,13 +93,15 @@ func TestRouterConfig_BuildModelsPerType(t *testing.T) { Client: clients.DefaultClientConfig(), ErrorBudget: health.DefaultErrorBudget(), Latency: latency.DefaultConfig(), - Cohere: &cohere.Config{ - APIKey: "ABC", - DefaultParams: &cohereParams, + Provider: providers.LangProviders{ + Cohere: &cohere.Config{ + APIKey: "ABC", + DefaultParams: &cohereParams, + }, }, }, - }, - } + }), + ) chatModels, streamChatModels, err := cfg.BuildModels(tel) @@ -125,108 +115,98 @@ func TestRouterConfig_InvalidSetups(t *testing.T) { tests := []struct { name string - config routers2.Config + config RoutersConfig }{ { "duplicated router IDs", - routers2.Config{ - LanguageRouters: []LangRouterConfig{ - { - ID: "first_router", - Enabled: true, - RoutingStrategy: routing.Priority, - Retry: retry.DefaultExpRetryConfig(), - Models: []providers.LangModelConfig{ - { - ID: "first_model", - Enabled: true, - Client: clients.DefaultClientConfig(), - ErrorBudget: health.DefaultErrorBudget(), - Latency: latency.DefaultConfig(), + RoutersConfig{ + *NewRouterConfig( + "first_router", + WithModels(ModelPoolConfig{ + { + ID: "first_model", + Enabled: true, + Client: clients.DefaultClientConfig(), + ErrorBudget: health.DefaultErrorBudget(), + Latency: latency.DefaultConfig(), + Provider: providers.LangProviders{ OpenAI: &openai.Config{ APIKey: "ABC", DefaultParams: &defaultParams, }, }, }, - }, - { - ID: "first_router", - Enabled: true, - RoutingStrategy: routing.LeastLatency, - Retry: retry.DefaultExpRetryConfig(), - Models: []providers.LangModelConfig{ - { - ID: "first_model", - Enabled: true, - Client: clients.DefaultClientConfig(), - ErrorBudget: health.DefaultErrorBudget(), - Latency: latency.DefaultConfig(), + }), + ), + *NewRouterConfig( + "first_router", + WithModels(ModelPoolConfig{ + { + ID: "first_model", + Enabled: true, + Client: clients.DefaultClientConfig(), + ErrorBudget: health.DefaultErrorBudget(), + Latency: latency.DefaultConfig(), + Provider: providers.LangProviders{ OpenAI: &openai.Config{ APIKey: "ABC", DefaultParams: &defaultParams, }, }, }, - }, - }, + }), + ), }, }, { "duplicated model IDs", - routers2.Config{ - LanguageRouters: []LangRouterConfig{ - { - ID: "first_router", - Enabled: true, - RoutingStrategy: routing.Priority, - Retry: retry.DefaultExpRetryConfig(), - Models: []providers.LangModelConfig{ - { - ID: "first_model", - Enabled: true, - Client: clients.DefaultClientConfig(), - ErrorBudget: health.DefaultErrorBudget(), - Latency: latency.DefaultConfig(), + RoutersConfig{ + *NewRouterConfig( + "first_router", + WithModels(ModelPoolConfig{ + { + ID: "first_model", + Enabled: true, + Client: clients.DefaultClientConfig(), + ErrorBudget: health.DefaultErrorBudget(), + Latency: latency.DefaultConfig(), + Provider: providers.LangProviders{ OpenAI: &openai.Config{ APIKey: "ABC", DefaultParams: &defaultParams, }, }, - { - ID: "first_model", - Enabled: true, - Client: clients.DefaultClientConfig(), - ErrorBudget: health.DefaultErrorBudget(), - Latency: latency.DefaultConfig(), + }, + { + ID: "first_model", + Enabled: true, + Client: clients.DefaultClientConfig(), + ErrorBudget: health.DefaultErrorBudget(), + Latency: latency.DefaultConfig(), + Provider: providers.LangProviders{ OpenAI: &openai.Config{ APIKey: "ABC", DefaultParams: &defaultParams, }, }, }, - }, - }, + }), + ), }, }, { "no models", - routers2.Config{ - LanguageRouters: []LangRouterConfig{ - { - ID: "first_router", - Enabled: true, - RoutingStrategy: routing.Priority, - Retry: retry.DefaultExpRetryConfig(), - Models: []providers.LangModelConfig{}, - }, - }, + RoutersConfig{ + *NewRouterConfig( + "first_router", + WithModels(ModelPoolConfig{}), + ), }, }, } for _, test := range tests { - _, err := test.config.BuildLangRouters(telemetry.NewTelemetryMock()) + _, err := test.config.Build(telemetry.NewTelemetryMock()) require.Error(t, err) } diff --git a/pkg/routers/lang/router_test.go b/pkg/routers/lang/router_test.go index 41515958..087e2a71 100644 --- a/pkg/routers/lang/router_test.go +++ b/pkg/routers/lang/router_test.go @@ -12,7 +12,6 @@ import ( "github.com/EinStack/glide/pkg/resiliency/retry" "github.com/EinStack/glide/pkg/api/schemas" - "github.com/EinStack/glide/pkg/providers" ptesting "github.com/EinStack/glide/pkg/providers/testing" "github.com/EinStack/glide/pkg/routers/latency" "github.com/EinStack/glide/pkg/routers/routing" @@ -41,16 +40,15 @@ func TestLangRouter_Chat_PickFistHealthy(t *testing.T) { ), } - models := make([]providers.Model, 0, len(langModels)) + modelPool := make([]models.Model, 0, len(langModels)) for _, model := range langModels { - models = append(models, model) + modelPool = append(modelPool, model) } - router := LangRouter{ + router := Router{ routerID: "test_router", - Config: &LangRouterConfig{}, retry: retry.NewExpRetry(3, 2, 1*time.Second, nil), - chatRouting: routing.NewPriority(models), + chatRouting: routing.NewPriority(modelPool), chatModels: langModels, chatStreamModels: langModels, tel: telemetry.NewTelemetryMock(), @@ -95,19 +93,18 @@ func TestLangRouter_Chat_PickThirdHealthy(t *testing.T) { ), } - models := make([]providers.Model, 0, len(langModels)) + modelPool := make([]models.Model, 0, len(langModels)) for _, model := range langModels { - models = append(models, model) + modelPool = append(modelPool, model) } expectedModels := []string{"third", "third"} - router := LangRouter{ + router := Router{ routerID: "test_router", - Config: &LangRouterConfig{}, retry: retry.NewExpRetry(3, 2, 1*time.Second, nil), - chatRouting: routing.NewPriority(models), - chatStreamRouting: routing.NewPriority(models), + chatRouting: routing.NewPriority(modelPool), + chatStreamRouting: routing.NewPriority(modelPool), chatModels: langModels, chatStreamModels: langModels, tel: telemetry.NewTelemetryMock(), @@ -146,17 +143,16 @@ func TestLangRouter_Chat_SuccessOnRetry(t *testing.T) { ), } - models := make([]providers.Model, 0, len(langModels)) + modelPool := make([]models.Model, 0, len(langModels)) for _, model := range langModels { - models = append(models, model) + modelPool = append(modelPool, model) } - router := LangRouter{ + router := Router{ routerID: "test_router", - Config: &LangRouterConfig{}, retry: retry.NewExpRetry(3, 2, 1*time.Millisecond, nil), - chatRouting: routing.NewPriority(models), - chatStreamRouting: routing.NewPriority(models), + chatRouting: routing.NewPriority(modelPool), + chatStreamRouting: routing.NewPriority(modelPool), chatModels: langModels, chatStreamModels: langModels, tel: telemetry.NewTelemetryMock(), @@ -190,19 +186,18 @@ func TestLangRouter_Chat_UnhealthyModelInThePool(t *testing.T) { ), } - models := make([]providers.Model, 0, len(langModels)) + modelPool := make([]models.Model, 0, len(langModels)) for _, model := range langModels { - models = append(models, model) + modelPool = append(modelPool, model) } - router := LangRouter{ + router := Router{ routerID: "test_router", - Config: &LangRouterConfig{}, retry: retry.NewExpRetry(3, 2, 1*time.Millisecond, nil), - chatRouting: routing.NewPriority(models), + chatRouting: routing.NewPriority(modelPool), chatModels: langModels, chatStreamModels: langModels, - chatStreamRouting: routing.NewPriority(models), + chatStreamRouting: routing.NewPriority(modelPool), tel: telemetry.NewTelemetryMock(), logger: telemetry.NewLoggerMock(), } @@ -236,19 +231,18 @@ func TestLangRouter_Chat_AllModelsUnavailable(t *testing.T) { ), } - models := make([]providers.Model, 0, len(langModels)) + modelPool := make([]models.Model, 0, len(langModels)) for _, model := range langModels { - models = append(models, model) + modelPool = append(modelPool, model) } - router := LangRouter{ + router := Router{ routerID: "test_router", - Config: &LangRouterConfig{}, retry: retry.NewExpRetry(1, 2, 1*time.Millisecond, nil), - chatRouting: routing.NewPriority(models), + chatRouting: routing.NewPriority(modelPool), chatModels: langModels, chatStreamModels: langModels, - chatStreamRouting: routing.NewPriority(models), + chatStreamRouting: routing.NewPriority(modelPool), tel: telemetry.NewTelemetryMock(), logger: telemetry.NewLoggerMock(), } @@ -293,18 +287,17 @@ func TestLangRouter_ChatStream(t *testing.T) { ), } - models := make([]providers.Model, 0, len(langModels)) + modelPool := make([]models.Model, 0, len(langModels)) for _, model := range langModels { - models = append(models, model) + modelPool = append(modelPool, model) } - router := LangRouter{ + router := Router{ routerID: "test_stream_router", - Config: &LangRouterConfig{}, retry: retry.NewExpRetry(3, 2, 1*time.Second, nil), - chatRouting: routing.NewPriority(models), + chatRouting: routing.NewPriority(modelPool), chatModels: langModels, - chatStreamRouting: routing.NewPriority(models), + chatStreamRouting: routing.NewPriority(modelPool), chatStreamModels: langModels, tel: telemetry.NewTelemetryMock(), logger: telemetry.NewLoggerMock(), @@ -363,18 +356,17 @@ func TestLangRouter_ChatStream_FailOnFirst(t *testing.T) { ), } - models := make([]providers.Model, 0, len(langModels)) + modelPool := make([]models.Model, 0, len(langModels)) for _, model := range langModels { - models = append(models, model) + modelPool = append(modelPool, model) } - router := LangRouter{ + router := Router{ routerID: "test_stream_router", - Config: &LangRouterConfig{}, retry: retry.NewExpRetry(3, 2, 1*time.Second, nil), - chatRouting: routing.NewPriority(models), + chatRouting: routing.NewPriority(modelPool), chatModels: langModels, - chatStreamRouting: routing.NewPriority(models), + chatStreamRouting: routing.NewPriority(modelPool), chatStreamModels: langModels, tel: telemetry.NewTelemetryMock(), logger: telemetry.NewLoggerMock(), @@ -433,19 +425,18 @@ func TestLangRouter_ChatStream_AllModelsUnavailable(t *testing.T) { ), } - models := make([]providers.Model, 0, len(langModels)) + modelPool := make([]models.Model, 0, len(langModels)) for _, model := range langModels { - models = append(models, model) + modelPool = append(modelPool, model) } - router := LangRouter{ + router := Router{ routerID: "test_router", - Config: &LangRouterConfig{}, retry: retry.NewExpRetry(1, 2, 1*time.Millisecond, nil), - chatRouting: routing.NewPriority(models), + chatRouting: routing.NewPriority(modelPool), chatModels: langModels, chatStreamModels: langModels, - chatStreamRouting: routing.NewPriority(models), + chatStreamRouting: routing.NewPriority(modelPool), tel: telemetry.NewTelemetryMock(), logger: telemetry.NewLoggerMock(), } diff --git a/pkg/routers/routing/least_latency.go b/pkg/routers/routing/least_latency.go index 015c044e..e6c56a6f 100644 --- a/pkg/routers/routing/least_latency.go +++ b/pkg/routers/routing/least_latency.go @@ -5,9 +5,9 @@ import ( "sync/atomic" "time" - "github.com/EinStack/glide/pkg/routers/latency" + "github.com/EinStack/glide/pkg/models" - "github.com/EinStack/glide/pkg/providers" + "github.com/EinStack/glide/pkg/routers/latency" ) const ( @@ -15,16 +15,16 @@ const ( ) // LatencyGetter defines where to find latency for the specific model action -type LatencyGetter = func(model providers.Model) *latency.MovingAverage +type LatencyGetter = func(model models.Model) *latency.MovingAverage // ModelSchedule defines latency update schedule for models type ModelSchedule struct { mu sync.RWMutex - model providers.Model + model models.Model expireAt time.Time } -func NewSchedule(model providers.Model) *ModelSchedule { +func NewSchedule(model models.Model) *ModelSchedule { schedule := &ModelSchedule{ model: model, } @@ -67,7 +67,7 @@ type LeastLatencyRouting struct { schedules []*ModelSchedule } -func NewLeastLatencyRouting(latencyGetter LatencyGetter, models []providers.Model) *LeastLatencyRouting { +func NewLeastLatencyRouting(latencyGetter LatencyGetter, models []models.Model) *LeastLatencyRouting { schedules := make([]*ModelSchedule, 0, len(models)) for _, model := range models { @@ -95,7 +95,7 @@ func (r *LeastLatencyRouting) Iterator() LangModelIterator { // other model latencies that might have improved over time). // For that, we introduced expiration time after which the model receives a request // even if it was not the fastest to respond -func (r *LeastLatencyRouting) Next() (providers.Model, error) { //nolint:cyclop +func (r *LeastLatencyRouting) Next() (models.Model, error) { //nolint:cyclop coldSchedules := r.getColdModelSchedules() if len(coldSchedules) > 0 { diff --git a/pkg/routers/routing/least_latency_test.go b/pkg/routers/routing/least_latency_test.go index d65ee0d2..523b0790 100644 --- a/pkg/routers/routing/least_latency_test.go +++ b/pkg/routers/routing/least_latency_test.go @@ -5,9 +5,9 @@ import ( "testing" "time" - ptesting "github.com/EinStack/glide/pkg/providers/testing" + "github.com/EinStack/glide/pkg/models" - "github.com/EinStack/glide/pkg/providers" + ptesting "github.com/EinStack/glide/pkg/providers/testing" "github.com/stretchr/testify/require" ) @@ -33,13 +33,13 @@ func TestLeastLatencyRouting_Warmup(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { - models := make([]providers.Model, 0, len(tc.models)) + modelPool := make([]models.Model, 0, len(tc.models)) for _, model := range tc.models { - models = append(models, ptesting.NewLangModelMock(model.modelID, model.healthy, model.latency, 1)) + modelPool = append(modelPool, ptesting.NewLangModelMock(model.modelID, model.healthy, model.latency, 1)) } - routing := NewLeastLatencyRouting(ptesting.ChatMockLatency, models) + routing := NewLeastLatencyRouting(ptesting.ChatMockLatency, modelPool) iterator := routing.Iterator() // loop three times over the whole pool to check if we return back to the begging of the list @@ -144,13 +144,13 @@ func TestLeastLatencyRouting_NoHealthyModels(t *testing.T) { for name, latencies := range tests { t.Run(name, func(t *testing.T) { - models := make([]providers.Model, 0, len(latencies)) + modelPool := make([]models.Model, 0, len(latencies)) for idx, latency := range latencies { - models = append(models, ptesting.NewLangModelMock(strconv.Itoa(idx), false, latency, 1)) + modelPool = append(modelPool, ptesting.NewLangModelMock(strconv.Itoa(idx), false, latency, 1)) } - routing := NewLeastLatencyRouting(models.ChatLatency, models) + routing := NewLeastLatencyRouting(models.ChatLatency, modelPool) iterator := routing.Iterator() _, err := iterator.Next() diff --git a/pkg/routers/routing/priority.go b/pkg/routers/routing/priority.go index f895458c..04d4d94e 100644 --- a/pkg/routers/routing/priority.go +++ b/pkg/routers/routing/priority.go @@ -3,7 +3,7 @@ package routing import ( "sync/atomic" - "github.com/EinStack/glide/pkg/providers" + "github.com/EinStack/glide/pkg/models" ) const ( @@ -15,10 +15,10 @@ const ( // Priority of models are defined as position of the model on the list // (e.g. the first model definition has the highest priority, then the second model definition and so on) type PriorityRouting struct { - models []providers.Model + models []models.Model } -func NewPriority(models []providers.Model) *PriorityRouting { +func NewPriority(models []models.Model) *PriorityRouting { return &PriorityRouting{ models: models, } @@ -35,14 +35,14 @@ func (r *PriorityRouting) Iterator() LangModelIterator { type PriorityIterator struct { idx *atomic.Uint64 - models []providers.Model + models []models.Model } -func (r PriorityIterator) Next() (providers.Model, error) { - models := r.models +func (r PriorityIterator) Next() (models.Model, error) { + modelPool := r.models - for idx := int(r.idx.Load()); idx < len(models); idx = int(r.idx.Add(1)) { - model := models[idx] + for idx := int(r.idx.Load()); idx < len(modelPool); idx = int(r.idx.Add(1)) { + model := modelPool[idx] if !model.Healthy() { continue diff --git a/pkg/routers/routing/priority_test.go b/pkg/routers/routing/priority_test.go index cee98c60..eb090c76 100644 --- a/pkg/routers/routing/priority_test.go +++ b/pkg/routers/routing/priority_test.go @@ -3,9 +3,9 @@ package routing import ( "testing" - ptesting "github.com/EinStack/glide/pkg/providers/testing" + "github.com/EinStack/glide/pkg/models" - "github.com/EinStack/glide/pkg/providers" + ptesting "github.com/EinStack/glide/pkg/providers/testing" "github.com/stretchr/testify/require" ) @@ -29,13 +29,13 @@ func TestPriorityRouting_PickModelsInOrder(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { - models := make([]providers.Model, 0, len(tc.models)) + modelPool := make([]models.Model, 0, len(tc.models)) for _, model := range tc.models { - models = append(models, ptesting.NewLangModelMock(model.modelID, model.healthy, 100, 1)) + modelPool = append(modelPool, ptesting.NewLangModelMock(model.modelID, model.healthy, 100, 1)) } - routing := NewPriority(models) + routing := NewPriority(modelPool) iterator := routing.Iterator() // loop three times over the whole pool to check if we return back to the begging of the list @@ -49,13 +49,13 @@ func TestPriorityRouting_PickModelsInOrder(t *testing.T) { } func TestPriorityRouting_NoHealthyModels(t *testing.T) { - models := []providers.Model{ + modelPool := []models.Model{ ptesting.NewLangModelMock("first", false, 0, 1), ptesting.NewLangModelMock("second", false, 0, 1), ptesting.NewLangModelMock("third", false, 0, 1), } - routing := NewPriority(models) + routing := NewPriority(modelPool) iterator := routing.Iterator() _, err := iterator.Next() diff --git a/pkg/routers/routing/round_robin.go b/pkg/routers/routing/round_robin.go index e5a0f927..abd2ff96 100644 --- a/pkg/routers/routing/round_robin.go +++ b/pkg/routers/routing/round_robin.go @@ -3,7 +3,7 @@ package routing import ( "sync/atomic" - "github.com/EinStack/glide/pkg/providers" + "github.com/EinStack/glide/pkg/models" ) const ( @@ -13,10 +13,10 @@ const ( // RoundRobinRouting routes request to the next model in the list in cycle type RoundRobinRouting struct { idx atomic.Uint64 - models []providers.Model + models []models.Model } -func NewRoundRobinRouting(models []providers.Model) *RoundRobinRouting { +func NewRoundRobinRouting(models []models.Model) *RoundRobinRouting { return &RoundRobinRouting{ models: models, } @@ -26,7 +26,7 @@ func (r *RoundRobinRouting) Iterator() LangModelIterator { return r } -func (r *RoundRobinRouting) Next() (providers.Model, error) { +func (r *RoundRobinRouting) Next() (models.Model, error) { modelLen := len(r.models) // in order to avoid infinite loop in case of no healthy model is available, diff --git a/pkg/routers/routing/round_robin_test.go b/pkg/routers/routing/round_robin_test.go index fc34ec48..2a6e579b 100644 --- a/pkg/routers/routing/round_robin_test.go +++ b/pkg/routers/routing/round_robin_test.go @@ -3,9 +3,9 @@ package routing import ( "testing" - ptesting "github.com/EinStack/glide/pkg/providers/testing" + "github.com/EinStack/glide/pkg/models" - "github.com/EinStack/glide/pkg/providers" + ptesting "github.com/EinStack/glide/pkg/providers/testing" "github.com/stretchr/testify/require" ) @@ -30,13 +30,13 @@ func TestRoundRobinRouting_PickModelsSequentially(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { - models := make([]providers.Model, 0, len(tc.models)) + modelPool := make([]models.Model, 0, len(tc.models)) for _, model := range tc.models { - models = append(models, ptesting.NewLangModelMock(model.modelID, model.healthy, 100, 1)) + modelPool = append(modelPool, ptesting.NewLangModelMock(model.modelID, model.healthy, 100, 1)) } - routing := NewRoundRobinRouting(models) + routing := NewRoundRobinRouting(modelPool) iterator := routing.Iterator() for i := 0; i < 3; i++ { @@ -52,13 +52,13 @@ func TestRoundRobinRouting_PickModelsSequentially(t *testing.T) { } func TestRoundRobinRouting_NoHealthyModels(t *testing.T) { - models := []providers.Model{ + modelPool := []models.Model{ ptesting.NewLangModelMock("first", false, 0, 1), ptesting.NewLangModelMock("second", false, 0, 1), ptesting.NewLangModelMock("third", false, 0, 1), } - routing := NewRoundRobinRouting(models) + routing := NewRoundRobinRouting(modelPool) iterator := routing.Iterator() _, err := iterator.Next() diff --git a/pkg/routers/routing/strategies.go b/pkg/routers/routing/strategies.go index 56f03676..960702a4 100644 --- a/pkg/routers/routing/strategies.go +++ b/pkg/routers/routing/strategies.go @@ -3,7 +3,7 @@ package routing import ( "errors" - "github.com/EinStack/glide/pkg/providers" + "github.com/EinStack/glide/pkg/models" ) var ErrNoHealthyModels = errors.New("no healthy models found") @@ -16,5 +16,5 @@ type LangModelRouting interface { } type LangModelIterator interface { - Next() (providers.Model, error) + Next() (models.Model, error) } diff --git a/pkg/routers/routing/weighted_round_robin.go b/pkg/routers/routing/weighted_round_robin.go index 2e028408..dfbee414 100644 --- a/pkg/routers/routing/weighted_round_robin.go +++ b/pkg/routers/routing/weighted_round_robin.go @@ -3,7 +3,7 @@ package routing import ( "sync" - "github.com/EinStack/glide/pkg/providers" + "github.com/EinStack/glide/pkg/models" ) const ( @@ -11,7 +11,7 @@ const ( ) type Weighter struct { - model providers.Model + model models.Model currentWeight int } @@ -36,7 +36,7 @@ type WRoundRobinRouting struct { weights []*Weighter } -func NewWeightedRoundRobin(models []providers.Model) *WRoundRobinRouting { +func NewWeightedRoundRobin(models []models.Model) *WRoundRobinRouting { weights := make([]*Weighter, 0, len(models)) for _, model := range models { @@ -55,7 +55,7 @@ func (r *WRoundRobinRouting) Iterator() LangModelIterator { return r } -func (r *WRoundRobinRouting) Next() (providers.Model, error) { +func (r *WRoundRobinRouting) Next() (models.Model, error) { r.mu.Lock() defer r.mu.Unlock() diff --git a/pkg/routers/routing/weighted_round_robin_test.go b/pkg/routers/routing/weighted_round_robin_test.go index f4b59bb3..8e4a9ee2 100644 --- a/pkg/routers/routing/weighted_round_robin_test.go +++ b/pkg/routers/routing/weighted_round_robin_test.go @@ -3,9 +3,9 @@ package routing import ( "testing" - ptesting "github.com/EinStack/glide/pkg/providers/testing" + "github.com/EinStack/glide/pkg/models" - "github.com/EinStack/glide/pkg/providers" + ptesting "github.com/EinStack/glide/pkg/providers/testing" "github.com/stretchr/testify/require" ) @@ -116,13 +116,13 @@ func TestWRoundRobinRouting_RoutingDistribution(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { - models := make([]providers.Model, 0, len(tc.models)) + modelPool := make([]models.Model, 0, len(tc.models)) for _, model := range tc.models { - models = append(models, ptesting.NewLangModelMock(model.modelID, model.healthy, 0, model.weight)) + modelPool = append(modelPool, ptesting.NewLangModelMock(model.modelID, model.healthy, 0, model.weight)) } - routing := NewWeightedRoundRobin(models) + routing := NewWeightedRoundRobin(modelPool) iterator := routing.Iterator() actualDistribution := make(map[string]int, len(tc.models)) @@ -142,13 +142,13 @@ func TestWRoundRobinRouting_RoutingDistribution(t *testing.T) { } func TestWRoundRobinRouting_NoHealthyModels(t *testing.T) { - models := []providers.Model{ + modelPool := []models.Model{ ptesting.NewLangModelMock("first", false, 0, 1), ptesting.NewLangModelMock("second", false, 0, 2), ptesting.NewLangModelMock("third", false, 0, 3), } - routing := NewWeightedRoundRobin(models) + routing := NewWeightedRoundRobin(modelPool) iterator := routing.Iterator() _, err := iterator.Next()