-
Notifications
You must be signed in to change notification settings - Fork 0
/
embedding_models.py
115 lines (91 loc) · 3.12 KB
/
embedding_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from dataclasses import dataclass, field
from typing import List
import numpy as np
from langchain_community.embeddings import OllamaEmbeddings
from langchain_core.embeddings import Embeddings
from langchain_openai import OpenAIEmbeddings
from sentence_transformers import SentenceTransformer
from torch import Tensor
from utils import download_models
@dataclass
class SentenceTransformerEmbeddingModel(Embeddings):
"""
Embedding model using Sentence Transformers
"""
model: str
embedding_model: SentenceTransformer = field(init=False)
def __post_init__(self):
"""
Post initialization
"""
models_info = download_models(sent_embedding_model=self.model)
self.st_embedding_model = SentenceTransformer(model_name_or_path=models_info["model_path"])
def embed_documents(self, texts: list) -> np.ndarray:
"""
Generate embeddings for a list of documents
"""
v_representation = self.st_embedding_model.encode(texts)
return v_representation.tolist()
def embed_query(self, text: str) -> np.ndarray:
"""
Generate embedding for a piece of text
"""
v_representation = self.st_embedding_model.encode(text)
return v_representation.tolist()
def check_similarity(self, embeddings_1: np.ndarray, embeddings_2: np.ndarray) -> Tensor:
"""
Computes the cosine similarity between two embeddings
"""
return self.st_embedding_model.similarity(embeddings_1, embeddings_2)
def get_model(self):
"""Returns the model"""
return self.st_embedding_model
@dataclass
class OllamaEmbeddingModel(Embeddings):
"""
Embedding model using Ollama (locally deployed)
"""
model: str
base_url: str
def __post_init__(self):
"""
Post initialization
"""
self.ollama_embed_model = OllamaEmbeddings(model=self.model, base_url=self.base_url)
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""
Generate embeddings for a list of documents.
"""
return self.ollama_embed_model.embed_documents(texts)
def embed_query(self, text: str) -> List[float]:
"""
Generate embedding for a piece of text
"""
return self.ollama_embed_model.embed_query(text=text)
def get_model(self):
"""Returns the model"""
return self.ollama_embed_model
@dataclass
class OpenAIEmbeddingModel(Embeddings):
"""
Embedding Model using OpenAI
"""
model: str
def __post_init__(self):
"""
Post initialization
"""
self.openai_embed_model = OpenAIEmbeddings(model=self.model)
def embed_documents(self, texts: List[str]):
"""
Generate embeddings for a list of documents.
"""
return self.openai_embed_model.embed_documents(texts=texts)
def embed_query(self, text: str):
"""
Generate embedding for a piece of text
"""
return self.openai_embed_model.embed_query(text=text)
def get_model(self):
"""Returns the model"""
return self.openai_embed_model