Skip to content

Commit

Permalink
adding support for finetuning your own embedding models
Browse files Browse the repository at this point in the history
  • Loading branch information
taha-aiplanet authored and tarun-aiplanet committed May 6, 2024
1 parent 29c598a commit 04d185e
Show file tree
Hide file tree
Showing 3 changed files with 274 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/beyondllm/embeddings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@
from .qdrantfast import FastEmbedEmbeddings
from .azure import AzureAIEmbeddings
from .hf_inference import HuggingFaceInferenceAPIEmbeddings
from .gemini_embed import GeminiEmbeddings
from .gemini_embed import GeminiEmbeddings
from .finetune import FineTuneEmbeddings
86 changes: 86 additions & 0 deletions src/beyondllm/embeddings/finetune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from llama_index.core.node_parser import SimpleNodeParser
from beyondllm.embeddings.utils import generate_qa_embedding_pairs, resolve_embed_model
from llama_index.core.evaluation import EmbeddingQAFinetuneDataset
from llama_index.core import SimpleDirectoryReader
from dataclasses import dataclass, field

try:
from llama_index.finetuning import SentenceTransformersFinetuneEngine
except ImportError:
user_agree = input("The feature you're trying to use require additional packages. Would you like to install it now? [y/N]: ")
if user_agree.lower() == 'y':
import subprocess
import sys
subprocess.check_call([sys.executable, "-m", "pip", "install", "llama-index-finetuning"])
subprocess.check_call([sys.executable, "-m", "pip", "install", "llama-index-embeddings-huggingface"])

from llama_index.finetuning import SentenceTransformersFinetuneEngine
else:
raise ImportError("The required 'llama-index-finetuning' package is not installed.")





@dataclass
class FineTuneEmbeddings:
output_path: str = None
docs: list = field(init=False, default_factory=list)

def load_data(self, files):
print(f"Loading files {files}")
reader = SimpleDirectoryReader(input_files=files)
docs = reader.load_data()
print(f'Loaded {len(docs)} docs')
return docs

def load_corpus(self, docs, for_training=False, verbose=False):
parser = SimpleNodeParser.from_defaults()
split_index = int(len(docs) * 0.7)

if for_training:
nodes = parser.get_nodes_from_documents(docs[:split_index], show_progress=verbose)
else:
nodes = parser.get_nodes_from_documents(docs[split_index:], show_progress=verbose)

if verbose:
print(f'Parsed {len(nodes)} nodes')
return nodes

def generate_and_save_datasets(self, docs, llm):
train_nodes = self.load_corpus(docs, for_training=True, verbose=True)
val_nodes = self.load_corpus(docs, for_training=False, verbose=True)

train_dataset = generate_qa_embedding_pairs(train_nodes, llm)
val_dataset = generate_qa_embedding_pairs(val_nodes, llm)

train_dataset.save_json("train_dataset.json")
val_dataset.save_json("val_dataset.json")
return "train_dataset.json", "val_dataset.json"

def finetune_model(self, train_file, val_file, model_name):
train_dataset = EmbeddingQAFinetuneDataset.from_json(train_file)
val_dataset = EmbeddingQAFinetuneDataset.from_json(val_file)
model_output_path = self.output_path or "finetuned_embedding_model"

finetune_engine = SentenceTransformersFinetuneEngine(
train_dataset,
model_id=model_name,
model_output_path=model_output_path,
val_dataset=val_dataset
)
finetune_engine.finetune()
return finetune_engine.get_finetuned_model()

def train(self, files, model_name, llm, output_path=None):
self.output_path = output_path
self.docs = self.load_data(files)
train_file, val_file = self.generate_and_save_datasets(self.docs, llm)
embed_model = self.finetune_model(train_file, val_file, model_name)
return embed_model

def load_model(self, model_path):
return resolve_embed_model("local:" + model_path)



186 changes: 186 additions & 0 deletions src/beyondllm/embeddings/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
"""Common utils for embeddings."""
import re
import uuid

from llama_index.core.schema import MetadataMode, TextNode
from llama_index.core.evaluation import EmbeddingQAFinetuneDataset
import os
from typing import TYPE_CHECKING, List, Optional, Union

if TYPE_CHECKING:
from llama_index.core.bridge.langchain import Embeddings as LCEmbeddings
from llama_index.core.base.embeddings.base import BaseEmbedding
from llama_index.core.callbacks import CallbackManager
from llama_index.core.embeddings.mock_embed_model import MockEmbedding
from llama_index.core.utils import get_cache_dir

EmbedType = Union[BaseEmbedding, "LCEmbeddings", str]
from tqdm import tqdm




DEFAULT_QA_GENERATE_PROMPT_TMPL = """\
Context information is below.
---------------------
{context_str}
---------------------
Given the context information and no prior knowledge.
generate only questions based on the below query.
You are a Teacher/ Professor. Your task is to setup \
{num_questions_per_chunk} questions for an upcoming \
quiz/examination. The questions should be diverse in nature \
across the document. Restrict the questions to the \
context information provided."
"""


# generate queries as a convenience function
def generate_qa_embedding_pairs(
nodes: List[TextNode],
llm: any,
qa_generate_prompt_tmpl: str = DEFAULT_QA_GENERATE_PROMPT_TMPL,
num_questions_per_chunk: int = 2,
) -> EmbeddingQAFinetuneDataset:
"""Generate examples given a set of nodes."""
node_dict = {
node.node_id: node.get_content(metadata_mode=MetadataMode.NONE)
for node in nodes
}

queries = {}
relevant_docs = {}
for node_id, text in tqdm(node_dict.items()):
query = qa_generate_prompt_tmpl.format(
context_str=text, num_questions_per_chunk=num_questions_per_chunk
)
response = llm.predict(query)

result = str(response).strip().split("\n")
questions = [
re.sub(r"^\d+[\).\s]", "", question).strip() for question in result
]
questions = [question for question in questions if len(question) > 0]

for question in questions:
question_id = str(uuid.uuid4())
queries[question_id] = question
relevant_docs[question_id] = [node_id]

# construct dataset
return EmbeddingQAFinetuneDataset(
queries=queries, corpus=node_dict, relevant_docs=relevant_docs
)


def resolve_embed_model(
embed_model: Optional[EmbedType] = None,
callback_manager: Optional[CallbackManager] = None,
) -> BaseEmbedding:
"""Resolve embed model."""
from llama_index.core.settings import Settings

try:
from llama_index.core.bridge.langchain import Embeddings as LCEmbeddings
except ImportError:
LCEmbeddings = None # type: ignore

if embed_model == "default":
if os.getenv("IS_TESTING"):
embed_model = MockEmbedding(embed_dim=8)
embed_model.callback_manager = callback_manager or Settings.callback_manager
return embed_model

try:
from llama_index.embeddings.openai import (
OpenAIEmbedding,
) # pants: no-infer-dep

from llama_index.embeddings.openai.utils import (
validate_openai_api_key,
) # pants: no-infer-dep

embed_model = OpenAIEmbedding()
validate_openai_api_key(embed_model.api_key)
except ImportError:
raise ImportError(
"`llama-index-embeddings-openai` package not found, "
"please run `pip install llama-index-embeddings-openai`"
)
except ValueError as e:
raise ValueError(
"\n******\n"
"Could not load OpenAI embedding model. "
"If you intended to use OpenAI, please check your OPENAI_API_KEY.\n"
"Original error:\n"
f"{e!s}"
"\nConsider using embed_model='local'.\n"
"Visit our documentation for more embedding options: "
"https://docs.llamaindex.ai/en/stable/module_guides/models/"
"embeddings.html#modules"
"\n******"
)
# for image multi-modal embeddings
elif isinstance(embed_model, str) and embed_model.startswith("clip"):
try:
from llama_index.embeddings.clip import ClipEmbedding # pants: no-infer-dep

clip_model_name = (
embed_model.split(":")[1] if ":" in embed_model else "ViT-B/32"
)
embed_model = ClipEmbedding(model_name=clip_model_name)
except ImportError as e:
raise ImportError(
"`llama-index-embeddings-clip` package not found, "
"please run `pip install llama-index-embeddings-clip` and `pip install git+https://github.com/openai/CLIP.git`"
)

if isinstance(embed_model, str):
try:
from llama_index.embeddings.huggingface import (
HuggingFaceEmbedding,
) # pants: no-infer-dep

splits = embed_model.split(":", 1)
is_local = splits[0]
model_name = splits[1] if len(splits) > 1 else None
if is_local != "local":
raise ValueError(
"embed_model must start with str 'local' or of type BaseEmbedding"
)

cache_folder = os.path.join(get_cache_dir(), "models")
os.makedirs(cache_folder, exist_ok=True)

embed_model = HuggingFaceEmbedding(
model_name=model_name, cache_folder=cache_folder
)
except ImportError:
raise ImportError(
"`llama-index-embeddings-huggingface` package not found, "
"please run `pip install llama-index-embeddings-huggingface`"
)

if LCEmbeddings is not None and isinstance(embed_model, LCEmbeddings):
try:
from llama_index.embeddings.langchain import (
LangchainEmbedding,
) # pants: no-infer-dep

embed_model = LangchainEmbedding(embed_model)
except ImportError as e:
raise ImportError(
"`llama-index-embeddings-langchain` package not found, "
"please run `pip install llama-index-embeddings-langchain`"
)

if embed_model is None:
print("Embeddings have been explicitly disabled. Using MockEmbedding.")
embed_model = MockEmbedding(embed_dim=1)

embed_model.callback_manager = callback_manager or Settings.callback_manager

return embed_model

0 comments on commit 04d185e

Please sign in to comment.