1515
1616class OpenAIEmbeddingModel (MaxKBBaseModel ):
1717 model_name : str
18+ optional_params : dict
1819
19- def __init__ (self , api_key , base_url , model_name : str ):
20+ def __init__ (self , api_key , base_url , model_name : str , optional_params : dict ):
2021 self .client = openai .OpenAI (api_key = api_key , base_url = base_url ).embeddings
2122 self .model_name = model_name
23+ self .optional_params = optional_params
2224
2325 @staticmethod
2426 def new_instance (model_type , model_name , model_credential : Dict [str , object ], ** model_kwargs ):
27+ optional_params = MaxKBBaseModel .filter_optional_params (model_kwargs )
2528 return OpenAIEmbeddingModel (
2629 api_key = model_credential .get ('api_key' ),
2730 model_name = model_name ,
2831 base_url = model_credential .get ('api_base' ),
32+ optional_params = optional_params
2933 )
3034
3135 def embed_query (self , text : str ):
@@ -35,5 +39,11 @@ def embed_query(self, text: str):
3539 def embed_documents (
3640 self , texts : List [str ], chunk_size : int | None = None
3741 ) -> List [List [float ]]:
38- res = self .client .create (input = texts , model = self .model_name , encoding_format = "float" )
42+ if len (self .optional_params ) > 0 :
43+ res = self .client .create (
44+ input = texts , model = self .model_name , encoding_format = "float" ,
45+ ** self .optional_params
46+ )
47+ else :
48+ res = self .client .create (input = texts , model = self .model_name , encoding_format = "float" )
3949 return [e .embedding for e in res .data ]
0 commit comments