Skip to content

Commit ae2e08d

Browse files
committed
feat: add optional parameters to OpenAIEmbeddingModel for enhanced embedding functionality
1 parent 6b23469 commit ae2e08d

File tree

5 files changed

+64
-51
lines changed

5 files changed

+64
-51
lines changed

apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/embedding.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,27 @@
1313

1414
from common import forms
1515
from common.exception.app_exception import AppApiException
16-
from common.forms import BaseForm
16+
from common.forms import BaseForm, TooltipLabel
1717
from models_provider.base_model_provider import BaseModelCredential, ValidCode
1818
from models_provider.impl.aliyun_bai_lian_model_provider.model.embedding import AliyunBaiLianEmbedding
1919

20+
class BaiLianEmbeddingModelParams(BaseForm):
21+
dimensions = forms.SingleSelect(
22+
TooltipLabel(
23+
_('Dimensions'),
24+
_('')
25+
),
26+
required=True,
27+
default_value=1024,
28+
value_field='value',
29+
text_field='label',
30+
option_list=[
31+
{'label': '1024', 'value': '1024'},
32+
{'label': '768', 'value': '768'},
33+
{'label': '512', 'value': '512'},
34+
]
35+
)
36+
2037

2138
class AliyunBaiLianEmbeddingCredential(BaseForm, BaseModelCredential):
2239

@@ -71,4 +88,8 @@ def encryption_dict(self, model: Dict[str, Any]) -> Dict[str, Any]:
7188
api_key = model.get('dashscope_api_key', '')
7289
return {**model, 'dashscope_api_key': super().encryption(api_key)}
7390

91+
92+
def get_model_params_setting_form(self, model_name):
93+
return BaiLianEmbeddingModelParams()
94+
7495
dashscope_api_key = forms.PasswordInputField('API Key', required=True)

apps/models_provider/impl/aliyun_bai_lian_model_provider/model/embedding.py

Lines changed: 27 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -6,61 +6,43 @@
66
@date:2024/10/16 16:34
77
@desc:
88
"""
9-
from functools import reduce
109
from typing import Dict, List
1110

12-
from langchain_community.embeddings import DashScopeEmbeddings
13-
from langchain_community.embeddings.dashscope import embed_with_retry
11+
from openai import OpenAI
1412

1513
from models_provider.base_model_provider import MaxKBBaseModel
1614

1715

18-
def proxy_embed_documents(texts: List[str], step_size, embed_documents):
19-
value = [embed_documents(texts[start_index:start_index + step_size]) for start_index in
20-
range(0, len(texts), step_size)]
21-
return reduce(lambda x, y: [*x, *y], value, [])
16+
class AliyunBaiLianEmbedding(MaxKBBaseModel):
17+
model_name: str
18+
optional_params: dict
2219

20+
def __init__(self, api_key, model_name: str, optional_params: dict):
21+
self.client = OpenAI(api_key=api_key, base_url='https://dashscope.aliyuncs.com/compatible-mode/v1').embeddings
22+
self.model_name = model_name
23+
self.optional_params = optional_params
2324

24-
class AliyunBaiLianEmbedding(MaxKBBaseModel, DashScopeEmbeddings):
2525
@staticmethod
2626
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
27+
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
2728
return AliyunBaiLianEmbedding(
28-
model=model_name,
29-
dashscope_api_key=model_credential.get('dashscope_api_key')
29+
api_key=model_credential.get('dashscope_api_key'),
30+
model_name=model_name,
31+
optional_params=optional_params
3032
)
3133

32-
def embed_documents(self, texts: List[str]) -> List[List[float]]:
33-
if self.model == 'text-embedding-v3':
34-
return proxy_embed_documents(texts, 6, self._embed_documents)
35-
return self._embed_documents(texts)
36-
37-
def _embed_documents(self, texts: List[str]) -> List[List[float]]:
38-
"""Call out to DashScope's embedding endpoint for embedding search docs.
39-
40-
Args:
41-
texts: The list of texts to embed.
42-
chunk_size: The chunk size of embeddings. If None, will use the chunk size
43-
specified by the class.
44-
45-
Returns:
46-
List of embeddings, one for each text.
47-
"""
48-
embeddings = embed_with_retry(
49-
self, input=texts, text_type="document", model=self.model
50-
)
51-
embedding_list = [item["embedding"] for item in embeddings]
52-
return embedding_list
53-
54-
def embed_query(self, text: str) -> List[float]:
55-
"""Call out to DashScope's embedding endpoint for embedding query text.
56-
57-
Args:
58-
text: The text to embed.
59-
60-
Returns:
61-
Embedding for the text.
62-
"""
63-
embedding = embed_with_retry(
64-
self, input=[text], text_type="document", model=self.model
65-
)[0]["embedding"]
66-
return embedding
34+
def embed_query(self, text: str):
35+
res = self.embed_documents([text])
36+
return res[0]
37+
38+
def embed_documents(
39+
self, texts: List[str], chunk_size: int | None = None
40+
) -> List[List[float]]:
41+
if len(self.optional_params) > 0:
42+
res = self.client.create(
43+
input=texts, model=self.model_name, encoding_format="float",
44+
**self.optional_params
45+
)
46+
else:
47+
res = self.client.create(input=texts, model=self.model_name, encoding_format="float")
48+
return [e.embedding for e in res.data]

apps/models_provider/impl/openai_model_provider/model/embedding.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,21 @@
1515

1616
class OpenAIEmbeddingModel(MaxKBBaseModel):
1717
model_name: str
18+
optional_params: dict
1819

19-
def __init__(self, api_key, base_url, model_name: str):
20+
def __init__(self, api_key, base_url, model_name: str, optional_params: dict):
2021
self.client = openai.OpenAI(api_key=api_key, base_url=base_url).embeddings
2122
self.model_name = model_name
23+
self.optional_params = optional_params
2224

2325
@staticmethod
2426
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
27+
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
2528
return OpenAIEmbeddingModel(
2629
api_key=model_credential.get('api_key'),
2730
model_name=model_name,
2831
base_url=model_credential.get('api_base'),
32+
optional_params=optional_params
2933
)
3034

3135
def embed_query(self, text: str):
@@ -35,5 +39,11 @@ def embed_query(self, text: str):
3539
def embed_documents(
3640
self, texts: List[str], chunk_size: int | None = None
3741
) -> List[List[float]]:
38-
res = self.client.create(input=texts, model=self.model_name, encoding_format="float")
42+
if len(self.optional_params) > 0:
43+
res = self.client.create(
44+
input=texts, model=self.model_name, encoding_format="float",
45+
**self.optional_params
46+
)
47+
else:
48+
res = self.client.create(input=texts, model=self.model_name, encoding_format="float")
3949
return [e.embedding for e in res.data]

ui/src/views/model/component/CreateModelDialog.vue

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,7 @@
140140
/>
141141
<el-empty
142142
v-else-if="
143-
base_form_data.model_type === 'RERANKER' ||
144-
base_form_data.model_type === 'EMBEDDING'
143+
base_form_data.model_type === 'RERANKER'
145144
"
146145
:description="$t('views.model.tip.emptyMessage2')"
147146
/>
@@ -150,7 +149,7 @@
150149
<el-button
151150
type="text"
152151
@click.stop="openAddDrawer()"
153-
:disabled="!['TTS', 'LLM', 'IMAGE', 'TTI', 'TTV', 'ITV','STT'].includes(base_form_data.model_type)"
152+
:disabled="!['TTS', 'LLM', 'IMAGE', 'TTI', 'TTV', 'ITV','STT', 'EMBEDDING'].includes(base_form_data.model_type)"
154153
>
155154
<AppIcon iconName="app-add-outlined" class="mr-4"/> {{ $t('common.add') }}
156155
</el-button>

ui/src/views/model/component/ModelCard.vue

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@
9595
currentModel.model_type === 'IMAGE' ||
9696
currentModel.model_type === 'TTI' ||
9797
currentModel.model_type === 'ITV' ||
98+
currentModel.model_type === 'EMBEDDING' ||
9899
currentModel.model_type === 'TTV') &&
99100
permissionPrecise.paramSetting(model.id)
100101
"

0 commit comments

Comments
 (0)