Skip to content

Commit

Permalink
🐛 fix: claude api restricts channel type (lobehub#318)
Browse files Browse the repository at this point in the history
  • Loading branch information
MartialBE committed Dec 3, 2024
1 parent 8ed1a9f commit d9b9ae9
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 0 deletions.
6 changes: 6 additions & 0 deletions model/balancer.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ func FilterChannelId(skipChannelIds []int) ChannelsFilterFunc {
}
}

func FilterChannelTypes(channelTypes []int) ChannelsFilterFunc {
return func(_ int, choice *ChannelChoice) bool {
return !utils.Contains(choice.Channel.Type, channelTypes)
}
}

func FilterOnlyChat() ChannelsFilterFunc {
return func(channelId int, choice *ChannelChoice) bool {
return choice.Channel.OnlyChat
Expand Down
3 changes: 3 additions & 0 deletions relay/claude.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,16 @@ import (
"github.com/gin-gonic/gin"
)

var AllowChannelType = []int{config.ChannelTypeAnthropic, config.ChannelTypeVertexAI}

func RelaycClaudeOnly(c *gin.Context) {
request := &claude.ClaudeRequest{}

if err := common.UnmarshalBodyReusable(c, request); err != nil {
common.AbortWithErr(c, http.StatusBadRequest, claude.ErrorToClaudeErr(err))
return
}
c.Set("allow_channel_type", AllowChannelType)

cacheProps := relay_util.NewChatCacheProps(c, true)
cacheProps.SetHash(request)
Expand Down
6 changes: 6 additions & 0 deletions relay/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,12 @@ func fetchChannelByModel(c *gin.Context, modelName string) (*model.Channel, erro
filters = append(filters, model.FilterChannelId(skipChannelIds))
}

if types, exists := c.Get("allow_channel_type"); exists {
if allowTypes, ok := types.([]int); ok {
filters = append(filters, model.FilterChannelTypes(allowTypes))
}
}

channel, err := model.ChannelGroup.Next(group, modelName, filters...)
if err != nil {
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", group, modelName)
Expand Down

0 comments on commit d9b9ae9

Please sign in to comment.