Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding support for finetuning your own embedding models #39

Merged
merged 1 commit into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/beyondllm/embeddings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@
from .qdrantfast import FastEmbedEmbeddings
from .azure import AzureAIEmbeddings
from .hf_inference import HuggingFaceInferenceAPIEmbeddings
from .gemini_embed import GeminiEmbeddings
from .gemini_embed import GeminiEmbeddings
from .finetune import FineTuneEmbeddings
86 changes: 86 additions & 0 deletions src/beyondllm/embeddings/finetune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from llama_index.core.node_parser import SimpleNodeParser
from beyondllm.embeddings.utils import generate_qa_embedding_pairs, resolve_embed_model
from llama_index.core.evaluation import EmbeddingQAFinetuneDataset
from llama_index.core import SimpleDirectoryReader
from dataclasses import dataclass, field

try:
from llama_index.finetuning import SentenceTransformersFinetuneEngine
except ImportError:
user_agree = input("The feature you're trying to use require additional packages. Would you like to install it now? [y/N]: ")
if user_agree.lower() == 'y':
import subprocess
import sys
subprocess.check_call([sys.executable, "-m", "pip", "install", "llama-index-finetuning"])
subprocess.check_call([sys.executable, "-m", "pip", "install", "llama-index-embeddings-huggingface"])

from llama_index.finetuning import SentenceTransformersFinetuneEngine
else:
raise ImportError("The required 'llama-index-finetuning' package is not installed.")





@dataclass
class FineTuneEmbeddings:
output_path: str = None
docs: list = field(init=False, default_factory=list)

def load_data(self, files):
print(f"Loading files {files}")
reader = SimpleDirectoryReader(input_files=files)
docs = reader.load_data()
print(f'Loaded {len(docs)} docs')
return docs

def load_corpus(self, docs, for_training=False, verbose=False):
parser = SimpleNodeParser.from_defaults()
split_index = int(len(docs) * 0.7)

if for_training:
nodes = parser.get_nodes_from_documents(docs[:split_index], show_progress=verbose)
else:
nodes = parser.get_nodes_from_documents(docs[split_index:], show_progress=verbose)

if verbose:
print(f'Parsed {len(nodes)} nodes')
return nodes

def generate_and_save_datasets(self, docs, llm):
train_nodes = self.load_corpus(docs, for_training=True, verbose=True)
val_nodes = self.load_corpus(docs, for_training=False, verbose=True)

train_dataset = generate_qa_embedding_pairs(train_nodes, llm)
val_dataset = generate_qa_embedding_pairs(val_nodes, llm)

train_dataset.save_json("train_dataset.json")
val_dataset.save_json("val_dataset.json")
return "train_dataset.json", "val_dataset.json"

def finetune_model(self, train_file, val_file, model_name):
train_dataset = EmbeddingQAFinetuneDataset.from_json(train_file)
val_dataset = EmbeddingQAFinetuneDataset.from_json(val_file)
model_output_path = self.output_path or "finetuned_embedding_model"

finetune_engine = SentenceTransformersFinetuneEngine(
train_dataset,
model_id=model_name,
model_output_path=model_output_path,
val_dataset=val_dataset
)
finetune_engine.finetune()
return finetune_engine.get_finetuned_model()

def train(self, files, model_name, llm, output_path=None):
self.output_path = output_path
self.docs = self.load_data(files)
train_file, val_file = self.generate_and_save_datasets(self.docs, llm)
embed_model = self.finetune_model(train_file, val_file, model_name)
return embed_model

def load_model(self, model_path):
return resolve_embed_model("local:" + model_path)



186 changes: 186 additions & 0 deletions src/beyondllm/embeddings/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
"""Common utils for embeddings."""
import re
import uuid

from llama_index.core.schema import MetadataMode, TextNode
from llama_index.core.evaluation import EmbeddingQAFinetuneDataset
import os
from typing import TYPE_CHECKING, List, Optional, Union

if TYPE_CHECKING:
from llama_index.core.bridge.langchain import Embeddings as LCEmbeddings
from llama_index.core.base.embeddings.base import BaseEmbedding
from llama_index.core.callbacks import CallbackManager
from llama_index.core.embeddings.mock_embed_model import MockEmbedding
from llama_index.core.utils import get_cache_dir

EmbedType = Union[BaseEmbedding, "LCEmbeddings", str]
from tqdm import tqdm




DEFAULT_QA_GENERATE_PROMPT_TMPL = """\
Context information is below.

---------------------
{context_str}
---------------------

Given the context information and no prior knowledge.
generate only questions based on the below query.

You are a Teacher/ Professor. Your task is to setup \
{num_questions_per_chunk} questions for an upcoming \
quiz/examination. The questions should be diverse in nature \
across the document. Restrict the questions to the \
context information provided."
"""


# generate queries as a convenience function
def generate_qa_embedding_pairs(
nodes: List[TextNode],
llm: any,
qa_generate_prompt_tmpl: str = DEFAULT_QA_GENERATE_PROMPT_TMPL,
num_questions_per_chunk: int = 2,
) -> EmbeddingQAFinetuneDataset:
"""Generate examples given a set of nodes."""
node_dict = {
node.node_id: node.get_content(metadata_mode=MetadataMode.NONE)
for node in nodes
}

queries = {}
relevant_docs = {}
for node_id, text in tqdm(node_dict.items()):
query = qa_generate_prompt_tmpl.format(
context_str=text, num_questions_per_chunk=num_questions_per_chunk
)
response = llm.predict(query)

result = str(response).strip().split("\n")
questions = [
re.sub(r"^\d+[\).\s]", "", question).strip() for question in result
]
questions = [question for question in questions if len(question) > 0]

for question in questions:
question_id = str(uuid.uuid4())
queries[question_id] = question
relevant_docs[question_id] = [node_id]

# construct dataset
return EmbeddingQAFinetuneDataset(
queries=queries, corpus=node_dict, relevant_docs=relevant_docs
)


def resolve_embed_model(
embed_model: Optional[EmbedType] = None,
callback_manager: Optional[CallbackManager] = None,
) -> BaseEmbedding:
"""Resolve embed model."""
from llama_index.core.settings import Settings

try:
from llama_index.core.bridge.langchain import Embeddings as LCEmbeddings
except ImportError:
LCEmbeddings = None # type: ignore

if embed_model == "default":
if os.getenv("IS_TESTING"):
embed_model = MockEmbedding(embed_dim=8)
embed_model.callback_manager = callback_manager or Settings.callback_manager
return embed_model

try:
from llama_index.embeddings.openai import (
OpenAIEmbedding,
) # pants: no-infer-dep

from llama_index.embeddings.openai.utils import (
validate_openai_api_key,
) # pants: no-infer-dep

embed_model = OpenAIEmbedding()
validate_openai_api_key(embed_model.api_key)
except ImportError:
raise ImportError(
"`llama-index-embeddings-openai` package not found, "
"please run `pip install llama-index-embeddings-openai`"
)
except ValueError as e:
raise ValueError(
"\n******\n"
"Could not load OpenAI embedding model. "
"If you intended to use OpenAI, please check your OPENAI_API_KEY.\n"
"Original error:\n"
f"{e!s}"
"\nConsider using embed_model='local'.\n"
"Visit our documentation for more embedding options: "
"https://docs.llamaindex.ai/en/stable/module_guides/models/"
"embeddings.html#modules"
"\n******"
)
# for image multi-modal embeddings
elif isinstance(embed_model, str) and embed_model.startswith("clip"):
try:
from llama_index.embeddings.clip import ClipEmbedding # pants: no-infer-dep

clip_model_name = (
embed_model.split(":")[1] if ":" in embed_model else "ViT-B/32"
)
embed_model = ClipEmbedding(model_name=clip_model_name)
except ImportError as e:
raise ImportError(
"`llama-index-embeddings-clip` package not found, "
"please run `pip install llama-index-embeddings-clip` and `pip install git+https://github.com/openai/CLIP.git`"
)

if isinstance(embed_model, str):
try:
from llama_index.embeddings.huggingface import (
HuggingFaceEmbedding,
) # pants: no-infer-dep

splits = embed_model.split(":", 1)
is_local = splits[0]
model_name = splits[1] if len(splits) > 1 else None
if is_local != "local":
raise ValueError(
"embed_model must start with str 'local' or of type BaseEmbedding"
)

cache_folder = os.path.join(get_cache_dir(), "models")
os.makedirs(cache_folder, exist_ok=True)

embed_model = HuggingFaceEmbedding(
model_name=model_name, cache_folder=cache_folder
)
except ImportError:
raise ImportError(
"`llama-index-embeddings-huggingface` package not found, "
"please run `pip install llama-index-embeddings-huggingface`"
)

if LCEmbeddings is not None and isinstance(embed_model, LCEmbeddings):
try:
from llama_index.embeddings.langchain import (
LangchainEmbedding,
) # pants: no-infer-dep

embed_model = LangchainEmbedding(embed_model)
except ImportError as e:
raise ImportError(
"`llama-index-embeddings-langchain` package not found, "
"please run `pip install llama-index-embeddings-langchain`"
)

if embed_model is None:
print("Embeddings have been explicitly disabled. Using MockEmbedding.")
embed_model = MockEmbedding(embed_dim=1)

embed_model.callback_manager = callback_manager or Settings.callback_manager

return embed_model
Loading