-
Notifications
You must be signed in to change notification settings - Fork 43
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
adding support for finetuning your own embedding models
- Loading branch information
1 parent
29c598a
commit 04d185e
Showing
3 changed files
with
274 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
from llama_index.core.node_parser import SimpleNodeParser | ||
from beyondllm.embeddings.utils import generate_qa_embedding_pairs, resolve_embed_model | ||
from llama_index.core.evaluation import EmbeddingQAFinetuneDataset | ||
from llama_index.core import SimpleDirectoryReader | ||
from dataclasses import dataclass, field | ||
|
||
try: | ||
from llama_index.finetuning import SentenceTransformersFinetuneEngine | ||
except ImportError: | ||
user_agree = input("The feature you're trying to use require additional packages. Would you like to install it now? [y/N]: ") | ||
if user_agree.lower() == 'y': | ||
import subprocess | ||
import sys | ||
subprocess.check_call([sys.executable, "-m", "pip", "install", "llama-index-finetuning"]) | ||
subprocess.check_call([sys.executable, "-m", "pip", "install", "llama-index-embeddings-huggingface"]) | ||
|
||
from llama_index.finetuning import SentenceTransformersFinetuneEngine | ||
else: | ||
raise ImportError("The required 'llama-index-finetuning' package is not installed.") | ||
|
||
|
||
|
||
|
||
|
||
@dataclass | ||
class FineTuneEmbeddings: | ||
output_path: str = None | ||
docs: list = field(init=False, default_factory=list) | ||
|
||
def load_data(self, files): | ||
print(f"Loading files {files}") | ||
reader = SimpleDirectoryReader(input_files=files) | ||
docs = reader.load_data() | ||
print(f'Loaded {len(docs)} docs') | ||
return docs | ||
|
||
def load_corpus(self, docs, for_training=False, verbose=False): | ||
parser = SimpleNodeParser.from_defaults() | ||
split_index = int(len(docs) * 0.7) | ||
|
||
if for_training: | ||
nodes = parser.get_nodes_from_documents(docs[:split_index], show_progress=verbose) | ||
else: | ||
nodes = parser.get_nodes_from_documents(docs[split_index:], show_progress=verbose) | ||
|
||
if verbose: | ||
print(f'Parsed {len(nodes)} nodes') | ||
return nodes | ||
|
||
def generate_and_save_datasets(self, docs, llm): | ||
train_nodes = self.load_corpus(docs, for_training=True, verbose=True) | ||
val_nodes = self.load_corpus(docs, for_training=False, verbose=True) | ||
|
||
train_dataset = generate_qa_embedding_pairs(train_nodes, llm) | ||
val_dataset = generate_qa_embedding_pairs(val_nodes, llm) | ||
|
||
train_dataset.save_json("train_dataset.json") | ||
val_dataset.save_json("val_dataset.json") | ||
return "train_dataset.json", "val_dataset.json" | ||
|
||
def finetune_model(self, train_file, val_file, model_name): | ||
train_dataset = EmbeddingQAFinetuneDataset.from_json(train_file) | ||
val_dataset = EmbeddingQAFinetuneDataset.from_json(val_file) | ||
model_output_path = self.output_path or "finetuned_embedding_model" | ||
|
||
finetune_engine = SentenceTransformersFinetuneEngine( | ||
train_dataset, | ||
model_id=model_name, | ||
model_output_path=model_output_path, | ||
val_dataset=val_dataset | ||
) | ||
finetune_engine.finetune() | ||
return finetune_engine.get_finetuned_model() | ||
|
||
def train(self, files, model_name, llm, output_path=None): | ||
self.output_path = output_path | ||
self.docs = self.load_data(files) | ||
train_file, val_file = self.generate_and_save_datasets(self.docs, llm) | ||
embed_model = self.finetune_model(train_file, val_file, model_name) | ||
return embed_model | ||
|
||
def load_model(self, model_path): | ||
return resolve_embed_model("local:" + model_path) | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,186 @@ | ||
"""Common utils for embeddings.""" | ||
import re | ||
import uuid | ||
|
||
from llama_index.core.schema import MetadataMode, TextNode | ||
from llama_index.core.evaluation import EmbeddingQAFinetuneDataset | ||
import os | ||
from typing import TYPE_CHECKING, List, Optional, Union | ||
|
||
if TYPE_CHECKING: | ||
from llama_index.core.bridge.langchain import Embeddings as LCEmbeddings | ||
from llama_index.core.base.embeddings.base import BaseEmbedding | ||
from llama_index.core.callbacks import CallbackManager | ||
from llama_index.core.embeddings.mock_embed_model import MockEmbedding | ||
from llama_index.core.utils import get_cache_dir | ||
|
||
EmbedType = Union[BaseEmbedding, "LCEmbeddings", str] | ||
from tqdm import tqdm | ||
|
||
|
||
|
||
|
||
DEFAULT_QA_GENERATE_PROMPT_TMPL = """\ | ||
Context information is below. | ||
--------------------- | ||
{context_str} | ||
--------------------- | ||
Given the context information and no prior knowledge. | ||
generate only questions based on the below query. | ||
You are a Teacher/ Professor. Your task is to setup \ | ||
{num_questions_per_chunk} questions for an upcoming \ | ||
quiz/examination. The questions should be diverse in nature \ | ||
across the document. Restrict the questions to the \ | ||
context information provided." | ||
""" | ||
|
||
|
||
# generate queries as a convenience function | ||
def generate_qa_embedding_pairs( | ||
nodes: List[TextNode], | ||
llm: any, | ||
qa_generate_prompt_tmpl: str = DEFAULT_QA_GENERATE_PROMPT_TMPL, | ||
num_questions_per_chunk: int = 2, | ||
) -> EmbeddingQAFinetuneDataset: | ||
"""Generate examples given a set of nodes.""" | ||
node_dict = { | ||
node.node_id: node.get_content(metadata_mode=MetadataMode.NONE) | ||
for node in nodes | ||
} | ||
|
||
queries = {} | ||
relevant_docs = {} | ||
for node_id, text in tqdm(node_dict.items()): | ||
query = qa_generate_prompt_tmpl.format( | ||
context_str=text, num_questions_per_chunk=num_questions_per_chunk | ||
) | ||
response = llm.predict(query) | ||
|
||
result = str(response).strip().split("\n") | ||
questions = [ | ||
re.sub(r"^\d+[\).\s]", "", question).strip() for question in result | ||
] | ||
questions = [question for question in questions if len(question) > 0] | ||
|
||
for question in questions: | ||
question_id = str(uuid.uuid4()) | ||
queries[question_id] = question | ||
relevant_docs[question_id] = [node_id] | ||
|
||
# construct dataset | ||
return EmbeddingQAFinetuneDataset( | ||
queries=queries, corpus=node_dict, relevant_docs=relevant_docs | ||
) | ||
|
||
|
||
def resolve_embed_model( | ||
embed_model: Optional[EmbedType] = None, | ||
callback_manager: Optional[CallbackManager] = None, | ||
) -> BaseEmbedding: | ||
"""Resolve embed model.""" | ||
from llama_index.core.settings import Settings | ||
|
||
try: | ||
from llama_index.core.bridge.langchain import Embeddings as LCEmbeddings | ||
except ImportError: | ||
LCEmbeddings = None # type: ignore | ||
|
||
if embed_model == "default": | ||
if os.getenv("IS_TESTING"): | ||
embed_model = MockEmbedding(embed_dim=8) | ||
embed_model.callback_manager = callback_manager or Settings.callback_manager | ||
return embed_model | ||
|
||
try: | ||
from llama_index.embeddings.openai import ( | ||
OpenAIEmbedding, | ||
) # pants: no-infer-dep | ||
|
||
from llama_index.embeddings.openai.utils import ( | ||
validate_openai_api_key, | ||
) # pants: no-infer-dep | ||
|
||
embed_model = OpenAIEmbedding() | ||
validate_openai_api_key(embed_model.api_key) | ||
except ImportError: | ||
raise ImportError( | ||
"`llama-index-embeddings-openai` package not found, " | ||
"please run `pip install llama-index-embeddings-openai`" | ||
) | ||
except ValueError as e: | ||
raise ValueError( | ||
"\n******\n" | ||
"Could not load OpenAI embedding model. " | ||
"If you intended to use OpenAI, please check your OPENAI_API_KEY.\n" | ||
"Original error:\n" | ||
f"{e!s}" | ||
"\nConsider using embed_model='local'.\n" | ||
"Visit our documentation for more embedding options: " | ||
"https://docs.llamaindex.ai/en/stable/module_guides/models/" | ||
"embeddings.html#modules" | ||
"\n******" | ||
) | ||
# for image multi-modal embeddings | ||
elif isinstance(embed_model, str) and embed_model.startswith("clip"): | ||
try: | ||
from llama_index.embeddings.clip import ClipEmbedding # pants: no-infer-dep | ||
|
||
clip_model_name = ( | ||
embed_model.split(":")[1] if ":" in embed_model else "ViT-B/32" | ||
) | ||
embed_model = ClipEmbedding(model_name=clip_model_name) | ||
except ImportError as e: | ||
raise ImportError( | ||
"`llama-index-embeddings-clip` package not found, " | ||
"please run `pip install llama-index-embeddings-clip` and `pip install git+https://github.com/openai/CLIP.git`" | ||
) | ||
|
||
if isinstance(embed_model, str): | ||
try: | ||
from llama_index.embeddings.huggingface import ( | ||
HuggingFaceEmbedding, | ||
) # pants: no-infer-dep | ||
|
||
splits = embed_model.split(":", 1) | ||
is_local = splits[0] | ||
model_name = splits[1] if len(splits) > 1 else None | ||
if is_local != "local": | ||
raise ValueError( | ||
"embed_model must start with str 'local' or of type BaseEmbedding" | ||
) | ||
|
||
cache_folder = os.path.join(get_cache_dir(), "models") | ||
os.makedirs(cache_folder, exist_ok=True) | ||
|
||
embed_model = HuggingFaceEmbedding( | ||
model_name=model_name, cache_folder=cache_folder | ||
) | ||
except ImportError: | ||
raise ImportError( | ||
"`llama-index-embeddings-huggingface` package not found, " | ||
"please run `pip install llama-index-embeddings-huggingface`" | ||
) | ||
|
||
if LCEmbeddings is not None and isinstance(embed_model, LCEmbeddings): | ||
try: | ||
from llama_index.embeddings.langchain import ( | ||
LangchainEmbedding, | ||
) # pants: no-infer-dep | ||
|
||
embed_model = LangchainEmbedding(embed_model) | ||
except ImportError as e: | ||
raise ImportError( | ||
"`llama-index-embeddings-langchain` package not found, " | ||
"please run `pip install llama-index-embeddings-langchain`" | ||
) | ||
|
||
if embed_model is None: | ||
print("Embeddings have been explicitly disabled. Using MockEmbedding.") | ||
embed_model = MockEmbedding(embed_dim=1) | ||
|
||
embed_model.callback_manager = callback_manager or Settings.callback_manager | ||
|
||
return embed_model |