Skip to content

Commit

Permalink
#67: Fixed broken code after the first wave of refactoring/restructuring
Browse files Browse the repository at this point in the history
roma-glushko committed Jul 5, 2024
1 parent d842fa0 commit c924140
Showing 22 changed files with 286 additions and 263 deletions.
5 changes: 3 additions & 2 deletions config.dev.yaml
Original file line number Diff line number Diff line change
@@ -8,5 +8,6 @@ routers:
- id: default
models:
- id: openai
openai:
api_key: "${env:OPENAI_API_KEY}"
provider:
openai:
api_key: "${env:OPENAI_API_KEY}"
9 changes: 9 additions & 0 deletions config.sample.yaml
Original file line number Diff line number Diff line change
@@ -6,3 +6,12 @@ telemetry:
#api:
# http:
# ...

routers:
language:
- id: default
models:
- id: openai
provider:
openai:
api_key: "${env:OPENAI_API_KEY}"
17 changes: 14 additions & 3 deletions pkg/models/config.go
Original file line number Diff line number Diff line change
@@ -3,13 +3,16 @@ package models
import (
"fmt"

"github.com/EinStack/glide/pkg/providers"

"github.com/EinStack/glide/pkg/clients"
"github.com/EinStack/glide/pkg/resiliency/health"
"github.com/EinStack/glide/pkg/routers/latency"
"github.com/EinStack/glide/pkg/telemetry"
)

type Config[P any] struct {
// Config defines an extra configuration for a model wrapper around a provider
type Config[P providers.ProviderFactory] struct {
ID string `yaml:"id" json:"id" validate:"required"` // Model instance ID (unique in scope of the router)
Enabled bool `yaml:"enabled" json:"enabled" validate:"required"` // Is the model enabled?
ErrorBudget *health.ErrorBudget `yaml:"error_budget" json:"error_budget" swaggertype:"primitive,string"`
@@ -20,7 +23,15 @@ type Config[P any] struct {
Provider P `yaml:"provider" json:"provider"`
}

func DefaultConfig[P any]() Config[P] {
func NewConfig[P providers.ProviderFactory](ID string) *Config[P] {
config := DefaultConfig[P]()

config.ID = ID

return &config
}

func DefaultConfig[P providers.ProviderFactory]() Config[P] {
return Config[P]{
Enabled: true,
Client: clients.DefaultClientConfig(),
@@ -30,7 +41,7 @@ func DefaultConfig[P any]() Config[P] {
}
}

func (c *Config) ToModel(tel *telemetry.Telemetry) (*LanguageModel, error) {
func (c *Config[P]) ToModel(tel *telemetry.Telemetry) (*LanguageModel, error) {
client, err := c.Provider.ToClient(tel, c.Client)
if err != nil {
return nil, fmt.Errorf("error initializing client: %w", err)
21 changes: 15 additions & 6 deletions pkg/providers/config.go
Original file line number Diff line number Diff line change
@@ -16,7 +16,16 @@ import (
"github.com/EinStack/glide/pkg/telemetry"
)

var ErrProviderNotFound = errors.New("provider not found")
// TODO: ProviderFactory should be more generic, not tied to LangProviders

var ErrNoProviderConfigured = errors.New("exactly one provider must be configured, none is configured")

type ProviderFactory interface {
ToClient(tel *telemetry.Telemetry, clientConfig *clients.ClientConfig) (LangProvider, error)
}

// TODO: LangProviders should be decoupled and
// represented as a registry where providers can add their factories dynamically

type LangProviders struct {
// Add other providers like
@@ -29,9 +38,11 @@ type LangProviders struct {
Ollama *ollama.Config `yaml:"ollama,omitempty" json:"ollama,omitempty"`
}

var _ ProviderFactory = (*LangProviders)(nil)

// ToClient initializes the language model client based on the provided configuration.
// It takes a telemetry object as input and returns a LangModelProvider and an error.
func (c *LangProviders) ToClient(tel *telemetry.Telemetry, clientConfig *clients.ClientConfig) (LangProvider, error) {
func (c LangProviders) ToClient(tel *telemetry.Telemetry, clientConfig *clients.ClientConfig) (LangProvider, error) {
switch {
case c.OpenAI != nil:
return openai.NewClient(c.OpenAI, clientConfig, tel)
@@ -83,7 +94,7 @@ func (c *LangProviders) validateOneProvider() error {

// check other providers here
if providersConfigured == 0 {
return fmt.Errorf("exactly one provider must be configured, none is configured")
return ErrNoProviderConfigured
}

if providersConfigured > 1 {
@@ -97,9 +108,7 @@ func (c *LangProviders) validateOneProvider() error {
}

func (c *LangProviders) UnmarshalYAML(unmarshal func(interface{}) error) error {
*c = DefaultConfig()

type plain LangModelConfig // to avoid recursion
type plain LangProviders // to avoid recursion

if err := unmarshal((*plain)(c)); err != nil {
return err
2 changes: 1 addition & 1 deletion pkg/providers/openai/embed.go
Original file line number Diff line number Diff line change
@@ -7,7 +7,7 @@ import (
)

// Embed sends an embedding request to the specified OpenAI model.
func (c *Client) Embed(ctx context.Context, params *schemas.ChatParams) (*schemas.ChatResponse, error) {
func (c *Client) Embed(_ context.Context, _ *schemas.ChatParams) (*schemas.ChatResponse, error) {
// TODO: implement
return nil, nil
}
3 changes: 3 additions & 0 deletions pkg/providers/provider.go
Original file line number Diff line number Diff line change
@@ -2,11 +2,14 @@ package providers

import (
"context"
"errors"

"github.com/EinStack/glide/pkg/api/schemas"
"github.com/EinStack/glide/pkg/clients"
)

var ErrProviderNotFound = errors.New("provider not found")

// ModelProvider exposes provider context
type ModelProvider interface {
Provider() string
6 changes: 2 additions & 4 deletions pkg/providers/testing/models.go
Original file line number Diff line number Diff line change
@@ -4,10 +4,8 @@ import (
"time"

"github.com/EinStack/glide/pkg/config/fields"

"github.com/EinStack/glide/pkg/models"
"github.com/EinStack/glide/pkg/routers/latency"

"github.com/EinStack/glide/pkg/providers"
)

// LangModelMock
@@ -55,6 +53,6 @@ func (m LangModelMock) Weight() int {
return m.weight
}

func ChatMockLatency(model providers.Model) *latency.MovingAverage {
func ChatMockLatency(model models.Model) *latency.MovingAverage {
return model.(LangModelMock).chatLatency
}
11 changes: 11 additions & 0 deletions pkg/routers/config.go
Original file line number Diff line number Diff line change
@@ -5,9 +5,20 @@ import (
"github.com/EinStack/glide/pkg/routers/routing"
)

// TODO: how to specify other backoff strategies?
// TODO: Had to keep RoutingStrategy because of https://github.com/swaggo/swag/issues/1738

type RouterConfig struct {
ID string `yaml:"id" json:"routers" validate:"required"` // Unique router ID
Enabled bool `yaml:"enabled" json:"enabled" validate:"required"` // Is router enabled?
Retry *retry.ExpRetryConfig `yaml:"retry" json:"retry" validate:"required"` // retry when no healthy model is available to router
RoutingStrategy routing.Strategy `yaml:"strategy" json:"strategy" swaggertype:"primitive,string" validate:"required"` // strategy on picking the next model to serve the request
}

func DefaultConfig() RouterConfig {
return RouterConfig{
Enabled: true,
RoutingStrategy: routing.Priority,
Retry: retry.DefaultExpRetryConfig(),
}
}
3 changes: 1 addition & 2 deletions pkg/routers/embed/config.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
package embed

import (
"github.com/EinStack/glide/pkg/providers"
"github.com/EinStack/glide/pkg/routers"
)

type EmbeddingRouterConfig struct {
routers.RouterConfig
Models []providers.LangModelConfig `yaml:"models" json:"models" validate:"required,min=1,dive"` // the list of models that could handle requests
// Models []providers.LangModelConfig `yaml:"models" json:"models" validate:"required,min=1,dive"` // the list of models that could handle requests
}
24 changes: 7 additions & 17 deletions pkg/routers/embed/router.go
Original file line number Diff line number Diff line change
@@ -1,22 +1,12 @@
package embed

import (
"context"

"github.com/EinStack/glide/pkg/api/schemas"
"github.com/EinStack/glide/pkg/resiliency/retry"
"github.com/EinStack/glide/pkg/routers/lang"
"github.com/EinStack/glide/pkg/telemetry"
"go.uber.org/zap"
)

type EmbeddingRouter struct {
routerID lang.RouterID
Config *LangRouterConfig
retry *retry.ExpRetry
tel *telemetry.Telemetry
logger *zap.Logger
// routerID lang.RouterID
// Config *LangRouterConfig
// retry *retry.ExpRetry
// tel *telemetry.Telemetry
// logger *zap.Logger
}

func (r *EmbeddingRouter) Embed(ctx context.Context, req *schemas.EmbedRequest) (*schemas.EmbedResponse, error) {
}
//func (r *EmbeddingRouter) Embed(ctx context.Context, req *schemas.EmbedRequest) (*schemas.EmbedResponse, error) {
//}
37 changes: 29 additions & 8 deletions pkg/routers/lang/config.go
Original file line number Diff line number Diff line change
@@ -15,12 +15,37 @@ import (
"go.uber.org/zap"
)

// TODO: how to specify other backoff strategies?
// TODO: Had to keep RoutingStrategy because of https://github.com/swaggo/swag/issues/1738
type (
ModelConfig = models.Config[providers.LangProviders]
ModelPoolConfig = []ModelConfig
)

// RouterConfig
type RouterConfig struct {
routers.RouterConfig
Models []providers.LangModelConfig `yaml:"models" json:"models" validate:"required,min=1,dive"` // the list of models that could handle requests
Models ModelPoolConfig `yaml:"models" json:"models" validate:"required,min=1,dive"` // the list of models that could handle requests
}

type RouterConfigOption = func(*RouterConfig)

func WithModels(models ModelPoolConfig) RouterConfigOption {
return func(c *RouterConfig) {
c.Models = models
}
}

func NewRouterConfig(RouterID string, opt ...RouterConfigOption) *RouterConfig {
config := &RouterConfig{
RouterConfig: routers.DefaultConfig(),
}

config.ID = RouterID

for _, o := range opt {
o(config)
}

return config
}

// BuildModels creates LanguageModel slice out of the given config
@@ -165,11 +190,7 @@ func (c *RouterConfig) BuildRouting(

func DefaultRouterConfig() *RouterConfig {
return &RouterConfig{
RouterConfig: routers.RouterConfig{
Enabled: true,
RoutingStrategy: routing.Priority,
Retry: retry.DefaultExpRetryConfig(),
},
RouterConfig: routers.DefaultConfig(),
}
}

216 changes: 98 additions & 118 deletions pkg/routers/lang/config_test.go
Original file line number Diff line number Diff line change
@@ -4,70 +4,59 @@ import (
"testing"

"github.com/EinStack/glide/pkg/clients"
"github.com/EinStack/glide/pkg/providers"
"github.com/EinStack/glide/pkg/providers/cohere"
"github.com/EinStack/glide/pkg/providers/openai"
"github.com/EinStack/glide/pkg/resiliency/health"
"github.com/EinStack/glide/pkg/resiliency/retry"
routers2 "github.com/EinStack/glide/pkg/routers"
"github.com/EinStack/glide/pkg/telemetry"

"github.com/EinStack/glide/pkg/routers/routing"

"github.com/EinStack/glide/pkg/routers/latency"

"github.com/EinStack/glide/pkg/providers/openai"

"github.com/EinStack/glide/pkg/providers"

"github.com/EinStack/glide/pkg/routers/routing"
"github.com/EinStack/glide/pkg/telemetry"
"github.com/stretchr/testify/require"
)

func TestRouterConfig_BuildModels(t *testing.T) {
defaultParams := openai.DefaultParams()

cfg := routers2.Config{
LanguageRouters: []RouterConfig{
{
ID: "first_router",
Enabled: true,
RoutingStrategy: routing.Priority,
Retry: retry.DefaultExpRetryConfig(),
Models: []providers.LangModelConfig{
{
ID: "first_model",
Enabled: true,
Client: clients.DefaultClientConfig(),
ErrorBudget: health.DefaultErrorBudget(),
Latency: latency.DefaultConfig(),
cfg := RoutersConfig{
*NewRouterConfig(
"first_router",
WithModels(ModelPoolConfig{
{
ID: "first_model",
Enabled: true,
Client: clients.DefaultClientConfig(),
ErrorBudget: health.DefaultErrorBudget(),
Latency: latency.DefaultConfig(),
Provider: providers.LangProviders{
OpenAI: &openai.Config{
APIKey: "ABC",
DefaultParams: &defaultParams,
},
},
},
},
{
ID: "second_router",
Enabled: true,
RoutingStrategy: routing.LeastLatency,
Retry: retry.DefaultExpRetryConfig(),
Models: []providers.LangModelConfig{
{
ID: "first_model",
Enabled: true,
Client: clients.DefaultClientConfig(),
ErrorBudget: health.DefaultErrorBudget(),
Latency: latency.DefaultConfig(),
}),
),
*NewRouterConfig(
"second_router",
WithModels(ModelPoolConfig{
{
ID: "first_model",
Enabled: true,
Client: clients.DefaultClientConfig(),
ErrorBudget: health.DefaultErrorBudget(),
Latency: latency.DefaultConfig(),
Provider: providers.LangProviders{
OpenAI: &openai.Config{
APIKey: "ABC",
DefaultParams: &defaultParams,
},
},
},
},
},
}),
),
}

routers, err := cfg.BuildLangRouters(telemetry.NewTelemetryMock())
routers, err := cfg.Build(telemetry.NewTelemetryMock())

require.NoError(t, err)
require.Len(t, routers, 2)
@@ -82,21 +71,20 @@ func TestRouterConfig_BuildModelsPerType(t *testing.T) {
openAIParams := openai.DefaultParams()
cohereParams := cohere.DefaultParams()

cfg := LangRouterConfig{
ID: "first_router",
Enabled: true,
RoutingStrategy: routing.Priority,
Retry: retry.DefaultExpRetryConfig(),
Models: []providers.LangModelConfig{
cfg := NewRouterConfig(
"first_router",
WithModels(ModelPoolConfig{
{
ID: "first_model",
Enabled: true,
Client: clients.DefaultClientConfig(),
ErrorBudget: health.DefaultErrorBudget(),
Latency: latency.DefaultConfig(),
OpenAI: &openai.Config{
APIKey: "ABC",
DefaultParams: &openAIParams,
Provider: providers.LangProviders{
OpenAI: &openai.Config{
APIKey: "ABC",
DefaultParams: &openAIParams,
},
},
},
{
@@ -105,13 +93,15 @@ func TestRouterConfig_BuildModelsPerType(t *testing.T) {
Client: clients.DefaultClientConfig(),
ErrorBudget: health.DefaultErrorBudget(),
Latency: latency.DefaultConfig(),
Cohere: &cohere.Config{
APIKey: "ABC",
DefaultParams: &cohereParams,
Provider: providers.LangProviders{
Cohere: &cohere.Config{
APIKey: "ABC",
DefaultParams: &cohereParams,
},
},
},
},
}
}),
)

chatModels, streamChatModels, err := cfg.BuildModels(tel)

@@ -125,108 +115,98 @@ func TestRouterConfig_InvalidSetups(t *testing.T) {

tests := []struct {
name string
config routers2.Config
config RoutersConfig
}{
{
"duplicated router IDs",
routers2.Config{
LanguageRouters: []LangRouterConfig{
{
ID: "first_router",
Enabled: true,
RoutingStrategy: routing.Priority,
Retry: retry.DefaultExpRetryConfig(),
Models: []providers.LangModelConfig{
{
ID: "first_model",
Enabled: true,
Client: clients.DefaultClientConfig(),
ErrorBudget: health.DefaultErrorBudget(),
Latency: latency.DefaultConfig(),
RoutersConfig{
*NewRouterConfig(
"first_router",
WithModels(ModelPoolConfig{
{
ID: "first_model",
Enabled: true,
Client: clients.DefaultClientConfig(),
ErrorBudget: health.DefaultErrorBudget(),
Latency: latency.DefaultConfig(),
Provider: providers.LangProviders{
OpenAI: &openai.Config{
APIKey: "ABC",
DefaultParams: &defaultParams,
},
},
},
},
{
ID: "first_router",
Enabled: true,
RoutingStrategy: routing.LeastLatency,
Retry: retry.DefaultExpRetryConfig(),
Models: []providers.LangModelConfig{
{
ID: "first_model",
Enabled: true,
Client: clients.DefaultClientConfig(),
ErrorBudget: health.DefaultErrorBudget(),
Latency: latency.DefaultConfig(),
}),
),
*NewRouterConfig(
"first_router",
WithModels(ModelPoolConfig{
{
ID: "first_model",
Enabled: true,
Client: clients.DefaultClientConfig(),
ErrorBudget: health.DefaultErrorBudget(),
Latency: latency.DefaultConfig(),
Provider: providers.LangProviders{
OpenAI: &openai.Config{
APIKey: "ABC",
DefaultParams: &defaultParams,
},
},
},
},
},
}),
),
},
},
{
"duplicated model IDs",
routers2.Config{
LanguageRouters: []LangRouterConfig{
{
ID: "first_router",
Enabled: true,
RoutingStrategy: routing.Priority,
Retry: retry.DefaultExpRetryConfig(),
Models: []providers.LangModelConfig{
{
ID: "first_model",
Enabled: true,
Client: clients.DefaultClientConfig(),
ErrorBudget: health.DefaultErrorBudget(),
Latency: latency.DefaultConfig(),
RoutersConfig{
*NewRouterConfig(
"first_router",
WithModels(ModelPoolConfig{
{
ID: "first_model",
Enabled: true,
Client: clients.DefaultClientConfig(),
ErrorBudget: health.DefaultErrorBudget(),
Latency: latency.DefaultConfig(),
Provider: providers.LangProviders{
OpenAI: &openai.Config{
APIKey: "ABC",
DefaultParams: &defaultParams,
},
},
{
ID: "first_model",
Enabled: true,
Client: clients.DefaultClientConfig(),
ErrorBudget: health.DefaultErrorBudget(),
Latency: latency.DefaultConfig(),
},
{
ID: "first_model",
Enabled: true,
Client: clients.DefaultClientConfig(),
ErrorBudget: health.DefaultErrorBudget(),
Latency: latency.DefaultConfig(),
Provider: providers.LangProviders{
OpenAI: &openai.Config{
APIKey: "ABC",
DefaultParams: &defaultParams,
},
},
},
},
},
}),
),
},
},
{
"no models",
routers2.Config{
LanguageRouters: []LangRouterConfig{
{
ID: "first_router",
Enabled: true,
RoutingStrategy: routing.Priority,
Retry: retry.DefaultExpRetryConfig(),
Models: []providers.LangModelConfig{},
},
},
RoutersConfig{
*NewRouterConfig(
"first_router",
WithModels(ModelPoolConfig{}),
),
},
},
}

for _, test := range tests {
_, err := test.config.BuildLangRouters(telemetry.NewTelemetryMock())
_, err := test.config.Build(telemetry.NewTelemetryMock())

require.Error(t, err)
}
87 changes: 39 additions & 48 deletions pkg/routers/lang/router_test.go
Original file line number Diff line number Diff line change
@@ -12,7 +12,6 @@ import (
"github.com/EinStack/glide/pkg/resiliency/retry"

"github.com/EinStack/glide/pkg/api/schemas"
"github.com/EinStack/glide/pkg/providers"
ptesting "github.com/EinStack/glide/pkg/providers/testing"
"github.com/EinStack/glide/pkg/routers/latency"
"github.com/EinStack/glide/pkg/routers/routing"
@@ -41,16 +40,15 @@ func TestLangRouter_Chat_PickFistHealthy(t *testing.T) {
),
}

models := make([]providers.Model, 0, len(langModels))
modelPool := make([]models.Model, 0, len(langModels))
for _, model := range langModels {
models = append(models, model)
modelPool = append(modelPool, model)
}

router := LangRouter{
router := Router{
routerID: "test_router",
Config: &LangRouterConfig{},
retry: retry.NewExpRetry(3, 2, 1*time.Second, nil),
chatRouting: routing.NewPriority(models),
chatRouting: routing.NewPriority(modelPool),
chatModels: langModels,
chatStreamModels: langModels,
tel: telemetry.NewTelemetryMock(),
@@ -95,19 +93,18 @@ func TestLangRouter_Chat_PickThirdHealthy(t *testing.T) {
),
}

models := make([]providers.Model, 0, len(langModels))
modelPool := make([]models.Model, 0, len(langModels))
for _, model := range langModels {
models = append(models, model)
modelPool = append(modelPool, model)
}

expectedModels := []string{"third", "third"}

router := LangRouter{
router := Router{
routerID: "test_router",
Config: &LangRouterConfig{},
retry: retry.NewExpRetry(3, 2, 1*time.Second, nil),
chatRouting: routing.NewPriority(models),
chatStreamRouting: routing.NewPriority(models),
chatRouting: routing.NewPriority(modelPool),
chatStreamRouting: routing.NewPriority(modelPool),
chatModels: langModels,
chatStreamModels: langModels,
tel: telemetry.NewTelemetryMock(),
@@ -146,17 +143,16 @@ func TestLangRouter_Chat_SuccessOnRetry(t *testing.T) {
),
}

models := make([]providers.Model, 0, len(langModels))
modelPool := make([]models.Model, 0, len(langModels))
for _, model := range langModels {
models = append(models, model)
modelPool = append(modelPool, model)
}

router := LangRouter{
router := Router{
routerID: "test_router",
Config: &LangRouterConfig{},
retry: retry.NewExpRetry(3, 2, 1*time.Millisecond, nil),
chatRouting: routing.NewPriority(models),
chatStreamRouting: routing.NewPriority(models),
chatRouting: routing.NewPriority(modelPool),
chatStreamRouting: routing.NewPriority(modelPool),
chatModels: langModels,
chatStreamModels: langModels,
tel: telemetry.NewTelemetryMock(),
@@ -190,19 +186,18 @@ func TestLangRouter_Chat_UnhealthyModelInThePool(t *testing.T) {
),
}

models := make([]providers.Model, 0, len(langModels))
modelPool := make([]models.Model, 0, len(langModels))
for _, model := range langModels {
models = append(models, model)
modelPool = append(modelPool, model)
}

router := LangRouter{
router := Router{
routerID: "test_router",
Config: &LangRouterConfig{},
retry: retry.NewExpRetry(3, 2, 1*time.Millisecond, nil),
chatRouting: routing.NewPriority(models),
chatRouting: routing.NewPriority(modelPool),
chatModels: langModels,
chatStreamModels: langModels,
chatStreamRouting: routing.NewPriority(models),
chatStreamRouting: routing.NewPriority(modelPool),
tel: telemetry.NewTelemetryMock(),
logger: telemetry.NewLoggerMock(),
}
@@ -236,19 +231,18 @@ func TestLangRouter_Chat_AllModelsUnavailable(t *testing.T) {
),
}

models := make([]providers.Model, 0, len(langModels))
modelPool := make([]models.Model, 0, len(langModels))
for _, model := range langModels {
models = append(models, model)
modelPool = append(modelPool, model)
}

router := LangRouter{
router := Router{
routerID: "test_router",
Config: &LangRouterConfig{},
retry: retry.NewExpRetry(1, 2, 1*time.Millisecond, nil),
chatRouting: routing.NewPriority(models),
chatRouting: routing.NewPriority(modelPool),
chatModels: langModels,
chatStreamModels: langModels,
chatStreamRouting: routing.NewPriority(models),
chatStreamRouting: routing.NewPriority(modelPool),
tel: telemetry.NewTelemetryMock(),
logger: telemetry.NewLoggerMock(),
}
@@ -293,18 +287,17 @@ func TestLangRouter_ChatStream(t *testing.T) {
),
}

models := make([]providers.Model, 0, len(langModels))
modelPool := make([]models.Model, 0, len(langModels))
for _, model := range langModels {
models = append(models, model)
modelPool = append(modelPool, model)
}

router := LangRouter{
router := Router{
routerID: "test_stream_router",
Config: &LangRouterConfig{},
retry: retry.NewExpRetry(3, 2, 1*time.Second, nil),
chatRouting: routing.NewPriority(models),
chatRouting: routing.NewPriority(modelPool),
chatModels: langModels,
chatStreamRouting: routing.NewPriority(models),
chatStreamRouting: routing.NewPriority(modelPool),
chatStreamModels: langModels,
tel: telemetry.NewTelemetryMock(),
logger: telemetry.NewLoggerMock(),
@@ -363,18 +356,17 @@ func TestLangRouter_ChatStream_FailOnFirst(t *testing.T) {
),
}

models := make([]providers.Model, 0, len(langModels))
modelPool := make([]models.Model, 0, len(langModels))
for _, model := range langModels {
models = append(models, model)
modelPool = append(modelPool, model)
}

router := LangRouter{
router := Router{
routerID: "test_stream_router",
Config: &LangRouterConfig{},
retry: retry.NewExpRetry(3, 2, 1*time.Second, nil),
chatRouting: routing.NewPriority(models),
chatRouting: routing.NewPriority(modelPool),
chatModels: langModels,
chatStreamRouting: routing.NewPriority(models),
chatStreamRouting: routing.NewPriority(modelPool),
chatStreamModels: langModels,
tel: telemetry.NewTelemetryMock(),
logger: telemetry.NewLoggerMock(),
@@ -433,19 +425,18 @@ func TestLangRouter_ChatStream_AllModelsUnavailable(t *testing.T) {
),
}

models := make([]providers.Model, 0, len(langModels))
modelPool := make([]models.Model, 0, len(langModels))
for _, model := range langModels {
models = append(models, model)
modelPool = append(modelPool, model)
}

router := LangRouter{
router := Router{
routerID: "test_router",
Config: &LangRouterConfig{},
retry: retry.NewExpRetry(1, 2, 1*time.Millisecond, nil),
chatRouting: routing.NewPriority(models),
chatRouting: routing.NewPriority(modelPool),
chatModels: langModels,
chatStreamModels: langModels,
chatStreamRouting: routing.NewPriority(models),
chatStreamRouting: routing.NewPriority(modelPool),
tel: telemetry.NewTelemetryMock(),
logger: telemetry.NewLoggerMock(),
}
14 changes: 7 additions & 7 deletions pkg/routers/routing/least_latency.go
Original file line number Diff line number Diff line change
@@ -5,26 +5,26 @@ import (
"sync/atomic"
"time"

"github.com/EinStack/glide/pkg/routers/latency"
"github.com/EinStack/glide/pkg/models"

"github.com/EinStack/glide/pkg/providers"
"github.com/EinStack/glide/pkg/routers/latency"
)

const (
LeastLatency Strategy = "least_latency"
)

// LatencyGetter defines where to find latency for the specific model action
type LatencyGetter = func(model providers.Model) *latency.MovingAverage
type LatencyGetter = func(model models.Model) *latency.MovingAverage

// ModelSchedule defines latency update schedule for models
type ModelSchedule struct {
mu sync.RWMutex
model providers.Model
model models.Model
expireAt time.Time
}

func NewSchedule(model providers.Model) *ModelSchedule {
func NewSchedule(model models.Model) *ModelSchedule {
schedule := &ModelSchedule{
model: model,
}
@@ -67,7 +67,7 @@ type LeastLatencyRouting struct {
schedules []*ModelSchedule
}

func NewLeastLatencyRouting(latencyGetter LatencyGetter, models []providers.Model) *LeastLatencyRouting {
func NewLeastLatencyRouting(latencyGetter LatencyGetter, models []models.Model) *LeastLatencyRouting {
schedules := make([]*ModelSchedule, 0, len(models))

for _, model := range models {
@@ -95,7 +95,7 @@ func (r *LeastLatencyRouting) Iterator() LangModelIterator {
// other model latencies that might have improved over time).
// For that, we introduced expiration time after which the model receives a request
// even if it was not the fastest to respond
func (r *LeastLatencyRouting) Next() (providers.Model, error) { //nolint:cyclop
func (r *LeastLatencyRouting) Next() (models.Model, error) { //nolint:cyclop
coldSchedules := r.getColdModelSchedules()

if len(coldSchedules) > 0 {
16 changes: 8 additions & 8 deletions pkg/routers/routing/least_latency_test.go
Original file line number Diff line number Diff line change
@@ -5,9 +5,9 @@ import (
"testing"
"time"

ptesting "github.com/EinStack/glide/pkg/providers/testing"
"github.com/EinStack/glide/pkg/models"

"github.com/EinStack/glide/pkg/providers"
ptesting "github.com/EinStack/glide/pkg/providers/testing"

"github.com/stretchr/testify/require"
)
@@ -33,13 +33,13 @@ func TestLeastLatencyRouting_Warmup(t *testing.T) {

for name, tc := range tests {
t.Run(name, func(t *testing.T) {
models := make([]providers.Model, 0, len(tc.models))
modelPool := make([]models.Model, 0, len(tc.models))

for _, model := range tc.models {
models = append(models, ptesting.NewLangModelMock(model.modelID, model.healthy, model.latency, 1))
modelPool = append(modelPool, ptesting.NewLangModelMock(model.modelID, model.healthy, model.latency, 1))
}

routing := NewLeastLatencyRouting(ptesting.ChatMockLatency, models)
routing := NewLeastLatencyRouting(ptesting.ChatMockLatency, modelPool)
iterator := routing.Iterator()

// loop three times over the whole pool to check if we return back to the begging of the list
@@ -144,13 +144,13 @@ func TestLeastLatencyRouting_NoHealthyModels(t *testing.T) {

for name, latencies := range tests {
t.Run(name, func(t *testing.T) {
models := make([]providers.Model, 0, len(latencies))
modelPool := make([]models.Model, 0, len(latencies))

for idx, latency := range latencies {
models = append(models, ptesting.NewLangModelMock(strconv.Itoa(idx), false, latency, 1))
modelPool = append(modelPool, ptesting.NewLangModelMock(strconv.Itoa(idx), false, latency, 1))
}

routing := NewLeastLatencyRouting(models.ChatLatency, models)
routing := NewLeastLatencyRouting(models.ChatLatency, modelPool)
iterator := routing.Iterator()

_, err := iterator.Next()
16 changes: 8 additions & 8 deletions pkg/routers/routing/priority.go
Original file line number Diff line number Diff line change
@@ -3,7 +3,7 @@ package routing
import (
"sync/atomic"

"github.com/EinStack/glide/pkg/providers"
"github.com/EinStack/glide/pkg/models"
)

const (
@@ -15,10 +15,10 @@ const (
// Priority of models are defined as position of the model on the list
// (e.g. the first model definition has the highest priority, then the second model definition and so on)
type PriorityRouting struct {
models []providers.Model
models []models.Model
}

func NewPriority(models []providers.Model) *PriorityRouting {
func NewPriority(models []models.Model) *PriorityRouting {
return &PriorityRouting{
models: models,
}
@@ -35,14 +35,14 @@ func (r *PriorityRouting) Iterator() LangModelIterator {

type PriorityIterator struct {
idx *atomic.Uint64
models []providers.Model
models []models.Model
}

func (r PriorityIterator) Next() (providers.Model, error) {
models := r.models
func (r PriorityIterator) Next() (models.Model, error) {
modelPool := r.models

for idx := int(r.idx.Load()); idx < len(models); idx = int(r.idx.Add(1)) {
model := models[idx]
for idx := int(r.idx.Load()); idx < len(modelPool); idx = int(r.idx.Add(1)) {
model := modelPool[idx]

if !model.Healthy() {
continue
14 changes: 7 additions & 7 deletions pkg/routers/routing/priority_test.go
Original file line number Diff line number Diff line change
@@ -3,9 +3,9 @@ package routing
import (
"testing"

ptesting "github.com/EinStack/glide/pkg/providers/testing"
"github.com/EinStack/glide/pkg/models"

"github.com/EinStack/glide/pkg/providers"
ptesting "github.com/EinStack/glide/pkg/providers/testing"

"github.com/stretchr/testify/require"
)
@@ -29,13 +29,13 @@ func TestPriorityRouting_PickModelsInOrder(t *testing.T) {

for name, tc := range tests {
t.Run(name, func(t *testing.T) {
models := make([]providers.Model, 0, len(tc.models))
modelPool := make([]models.Model, 0, len(tc.models))

for _, model := range tc.models {
models = append(models, ptesting.NewLangModelMock(model.modelID, model.healthy, 100, 1))
modelPool = append(modelPool, ptesting.NewLangModelMock(model.modelID, model.healthy, 100, 1))
}

routing := NewPriority(models)
routing := NewPriority(modelPool)
iterator := routing.Iterator()

// loop three times over the whole pool to check if we return back to the begging of the list
@@ -49,13 +49,13 @@ func TestPriorityRouting_PickModelsInOrder(t *testing.T) {
}

func TestPriorityRouting_NoHealthyModels(t *testing.T) {
models := []providers.Model{
modelPool := []models.Model{
ptesting.NewLangModelMock("first", false, 0, 1),
ptesting.NewLangModelMock("second", false, 0, 1),
ptesting.NewLangModelMock("third", false, 0, 1),
}

routing := NewPriority(models)
routing := NewPriority(modelPool)
iterator := routing.Iterator()

_, err := iterator.Next()
8 changes: 4 additions & 4 deletions pkg/routers/routing/round_robin.go
Original file line number Diff line number Diff line change
@@ -3,7 +3,7 @@ package routing
import (
"sync/atomic"

"github.com/EinStack/glide/pkg/providers"
"github.com/EinStack/glide/pkg/models"
)

const (
@@ -13,10 +13,10 @@ const (
// RoundRobinRouting routes request to the next model in the list in cycle
type RoundRobinRouting struct {
idx atomic.Uint64
models []providers.Model
models []models.Model
}

func NewRoundRobinRouting(models []providers.Model) *RoundRobinRouting {
func NewRoundRobinRouting(models []models.Model) *RoundRobinRouting {
return &RoundRobinRouting{
models: models,
}
@@ -26,7 +26,7 @@ func (r *RoundRobinRouting) Iterator() LangModelIterator {
return r
}

func (r *RoundRobinRouting) Next() (providers.Model, error) {
func (r *RoundRobinRouting) Next() (models.Model, error) {
modelLen := len(r.models)

// in order to avoid infinite loop in case of no healthy model is available,
14 changes: 7 additions & 7 deletions pkg/routers/routing/round_robin_test.go
Original file line number Diff line number Diff line change
@@ -3,9 +3,9 @@ package routing
import (
"testing"

ptesting "github.com/EinStack/glide/pkg/providers/testing"
"github.com/EinStack/glide/pkg/models"

"github.com/EinStack/glide/pkg/providers"
ptesting "github.com/EinStack/glide/pkg/providers/testing"

"github.com/stretchr/testify/require"
)
@@ -30,13 +30,13 @@ func TestRoundRobinRouting_PickModelsSequentially(t *testing.T) {

for name, tc := range tests {
t.Run(name, func(t *testing.T) {
models := make([]providers.Model, 0, len(tc.models))
modelPool := make([]models.Model, 0, len(tc.models))

for _, model := range tc.models {
models = append(models, ptesting.NewLangModelMock(model.modelID, model.healthy, 100, 1))
modelPool = append(modelPool, ptesting.NewLangModelMock(model.modelID, model.healthy, 100, 1))
}

routing := NewRoundRobinRouting(models)
routing := NewRoundRobinRouting(modelPool)
iterator := routing.Iterator()

for i := 0; i < 3; i++ {
@@ -52,13 +52,13 @@ func TestRoundRobinRouting_PickModelsSequentially(t *testing.T) {
}

func TestRoundRobinRouting_NoHealthyModels(t *testing.T) {
models := []providers.Model{
modelPool := []models.Model{
ptesting.NewLangModelMock("first", false, 0, 1),
ptesting.NewLangModelMock("second", false, 0, 1),
ptesting.NewLangModelMock("third", false, 0, 1),
}

routing := NewRoundRobinRouting(models)
routing := NewRoundRobinRouting(modelPool)
iterator := routing.Iterator()

_, err := iterator.Next()
4 changes: 2 additions & 2 deletions pkg/routers/routing/strategies.go
Original file line number Diff line number Diff line change
@@ -3,7 +3,7 @@ package routing
import (
"errors"

"github.com/EinStack/glide/pkg/providers"
"github.com/EinStack/glide/pkg/models"
)

var ErrNoHealthyModels = errors.New("no healthy models found")
@@ -16,5 +16,5 @@ type LangModelRouting interface {
}

type LangModelIterator interface {
Next() (providers.Model, error)
Next() (models.Model, error)
}
8 changes: 4 additions & 4 deletions pkg/routers/routing/weighted_round_robin.go
Original file line number Diff line number Diff line change
@@ -3,15 +3,15 @@ package routing
import (
"sync"

"github.com/EinStack/glide/pkg/providers"
"github.com/EinStack/glide/pkg/models"
)

const (
WeightedRoundRobin Strategy = "weighted_round_robin"
)

type Weighter struct {
model providers.Model
model models.Model
currentWeight int
}

@@ -36,7 +36,7 @@ type WRoundRobinRouting struct {
weights []*Weighter
}

func NewWeightedRoundRobin(models []providers.Model) *WRoundRobinRouting {
func NewWeightedRoundRobin(models []models.Model) *WRoundRobinRouting {
weights := make([]*Weighter, 0, len(models))

for _, model := range models {
@@ -55,7 +55,7 @@ func (r *WRoundRobinRouting) Iterator() LangModelIterator {
return r
}

func (r *WRoundRobinRouting) Next() (providers.Model, error) {
func (r *WRoundRobinRouting) Next() (models.Model, error) {
r.mu.Lock()
defer r.mu.Unlock()

14 changes: 7 additions & 7 deletions pkg/routers/routing/weighted_round_robin_test.go
Original file line number Diff line number Diff line change
@@ -3,9 +3,9 @@ package routing
import (
"testing"

ptesting "github.com/EinStack/glide/pkg/providers/testing"
"github.com/EinStack/glide/pkg/models"

"github.com/EinStack/glide/pkg/providers"
ptesting "github.com/EinStack/glide/pkg/providers/testing"

"github.com/stretchr/testify/require"
)
@@ -116,13 +116,13 @@ func TestWRoundRobinRouting_RoutingDistribution(t *testing.T) {

for name, tc := range tests {
t.Run(name, func(t *testing.T) {
models := make([]providers.Model, 0, len(tc.models))
modelPool := make([]models.Model, 0, len(tc.models))

for _, model := range tc.models {
models = append(models, ptesting.NewLangModelMock(model.modelID, model.healthy, 0, model.weight))
modelPool = append(modelPool, ptesting.NewLangModelMock(model.modelID, model.healthy, 0, model.weight))
}

routing := NewWeightedRoundRobin(models)
routing := NewWeightedRoundRobin(modelPool)
iterator := routing.Iterator()

actualDistribution := make(map[string]int, len(tc.models))
@@ -142,13 +142,13 @@ func TestWRoundRobinRouting_RoutingDistribution(t *testing.T) {
}

func TestWRoundRobinRouting_NoHealthyModels(t *testing.T) {
models := []providers.Model{
modelPool := []models.Model{
ptesting.NewLangModelMock("first", false, 0, 1),
ptesting.NewLangModelMock("second", false, 0, 2),
ptesting.NewLangModelMock("third", false, 0, 3),
}

routing := NewWeightedRoundRobin(models)
routing := NewWeightedRoundRobin(modelPool)
iterator := routing.Iterator()

_, err := iterator.Next()

0 comments on commit c924140

Please sign in to comment.