1515
1616class  OpenAIEmbeddingModel (MaxKBBaseModel ):
1717    model_name : str 
18+     optional_params : dict 
1819
19-     def  __init__ (self , api_key , base_url , model_name : str ):
20+     def  __init__ (self , api_key , base_url , model_name : str ,  optional_params :  dict ):
2021        self .client  =  openai .OpenAI (api_key = api_key , base_url = base_url ).embeddings 
2122        self .model_name  =  model_name 
23+         self .optional_params  =  optional_params 
2224
2325    @staticmethod  
2426    def  new_instance (model_type , model_name , model_credential : Dict [str , object ], ** model_kwargs ):
27+         optional_params  =  MaxKBBaseModel .filter_optional_params (model_kwargs )
2528        return  OpenAIEmbeddingModel (
2629            api_key = model_credential .get ('api_key' ),
2730            model_name = model_name ,
2831            base_url = model_credential .get ('api_base' ),
32+             optional_params = optional_params 
2933        )
3034
3135    def  embed_query (self , text : str ):
@@ -35,5 +39,12 @@ def embed_query(self, text: str):
3539    def  embed_documents (
3640            self , texts : List [str ], chunk_size : int  |  None  =  None 
3741    ) ->  List [List [float ]]:
38-         res  =  self .client .create (input = texts , model = self .model_name , encoding_format = "float" )
42+         print (self .optional_params )
43+         if  len (self .optional_params ) >  0 :
44+             res  =  self .client .create (
45+                 input = texts , model = self .model_name , encoding_format = "float" ,
46+                 ** self .optional_params 
47+             )
48+         else :
49+             res  =  self .client .create (input = texts , model = self .model_name , encoding_format = "float" )
3950        return  [e .embedding  for  e  in  res .data ]
0 commit comments