From 04d185ed9f03e0347f90ccdd11447b1c4f42ed57 Mon Sep 17 00:00:00 2001 From: Muhammad Taha Date: Tue, 30 Apr 2024 14:21:42 +0800 Subject: [PATCH] adding support for finetuning your own embedding models --- src/beyondllm/embeddings/__init__.py | 3 +- src/beyondllm/embeddings/finetune.py | 86 +++++++++++++ src/beyondllm/embeddings/utils.py | 186 +++++++++++++++++++++++++++ 3 files changed, 274 insertions(+), 1 deletion(-) create mode 100644 src/beyondllm/embeddings/finetune.py create mode 100644 src/beyondllm/embeddings/utils.py diff --git a/src/beyondllm/embeddings/__init__.py b/src/beyondllm/embeddings/__init__.py index 886dd1a..fefefb0 100644 --- a/src/beyondllm/embeddings/__init__.py +++ b/src/beyondllm/embeddings/__init__.py @@ -3,4 +3,5 @@ from .qdrantfast import FastEmbedEmbeddings from .azure import AzureAIEmbeddings from .hf_inference import HuggingFaceInferenceAPIEmbeddings -from .gemini_embed import GeminiEmbeddings \ No newline at end of file +from .gemini_embed import GeminiEmbeddings +from .finetune import FineTuneEmbeddings \ No newline at end of file diff --git a/src/beyondllm/embeddings/finetune.py b/src/beyondllm/embeddings/finetune.py new file mode 100644 index 0000000..11a9627 --- /dev/null +++ b/src/beyondllm/embeddings/finetune.py @@ -0,0 +1,86 @@ +from llama_index.core.node_parser import SimpleNodeParser +from beyondllm.embeddings.utils import generate_qa_embedding_pairs, resolve_embed_model +from llama_index.core.evaluation import EmbeddingQAFinetuneDataset +from llama_index.core import SimpleDirectoryReader +from dataclasses import dataclass, field + +try: + from llama_index.finetuning import SentenceTransformersFinetuneEngine +except ImportError: + user_agree = input("The feature you're trying to use require additional packages. Would you like to install it now? [y/N]: ") + if user_agree.lower() == 'y': + import subprocess + import sys + subprocess.check_call([sys.executable, "-m", "pip", "install", "llama-index-finetuning"]) + subprocess.check_call([sys.executable, "-m", "pip", "install", "llama-index-embeddings-huggingface"]) + + from llama_index.finetuning import SentenceTransformersFinetuneEngine + else: + raise ImportError("The required 'llama-index-finetuning' package is not installed.") + + + + + +@dataclass +class FineTuneEmbeddings: + output_path: str = None + docs: list = field(init=False, default_factory=list) + + def load_data(self, files): + print(f"Loading files {files}") + reader = SimpleDirectoryReader(input_files=files) + docs = reader.load_data() + print(f'Loaded {len(docs)} docs') + return docs + + def load_corpus(self, docs, for_training=False, verbose=False): + parser = SimpleNodeParser.from_defaults() + split_index = int(len(docs) * 0.7) + + if for_training: + nodes = parser.get_nodes_from_documents(docs[:split_index], show_progress=verbose) + else: + nodes = parser.get_nodes_from_documents(docs[split_index:], show_progress=verbose) + + if verbose: + print(f'Parsed {len(nodes)} nodes') + return nodes + + def generate_and_save_datasets(self, docs, llm): + train_nodes = self.load_corpus(docs, for_training=True, verbose=True) + val_nodes = self.load_corpus(docs, for_training=False, verbose=True) + + train_dataset = generate_qa_embedding_pairs(train_nodes, llm) + val_dataset = generate_qa_embedding_pairs(val_nodes, llm) + + train_dataset.save_json("train_dataset.json") + val_dataset.save_json("val_dataset.json") + return "train_dataset.json", "val_dataset.json" + + def finetune_model(self, train_file, val_file, model_name): + train_dataset = EmbeddingQAFinetuneDataset.from_json(train_file) + val_dataset = EmbeddingQAFinetuneDataset.from_json(val_file) + model_output_path = self.output_path or "finetuned_embedding_model" + + finetune_engine = SentenceTransformersFinetuneEngine( + train_dataset, + model_id=model_name, + model_output_path=model_output_path, + val_dataset=val_dataset + ) + finetune_engine.finetune() + return finetune_engine.get_finetuned_model() + + def train(self, files, model_name, llm, output_path=None): + self.output_path = output_path + self.docs = self.load_data(files) + train_file, val_file = self.generate_and_save_datasets(self.docs, llm) + embed_model = self.finetune_model(train_file, val_file, model_name) + return embed_model + + def load_model(self, model_path): + return resolve_embed_model("local:" + model_path) + + + diff --git a/src/beyondllm/embeddings/utils.py b/src/beyondllm/embeddings/utils.py new file mode 100644 index 0000000..44eb39f --- /dev/null +++ b/src/beyondllm/embeddings/utils.py @@ -0,0 +1,186 @@ +"""Common utils for embeddings.""" +import re +import uuid + +from llama_index.core.schema import MetadataMode, TextNode +from llama_index.core.evaluation import EmbeddingQAFinetuneDataset +import os +from typing import TYPE_CHECKING, List, Optional, Union + +if TYPE_CHECKING: + from llama_index.core.bridge.langchain import Embeddings as LCEmbeddings +from llama_index.core.base.embeddings.base import BaseEmbedding +from llama_index.core.callbacks import CallbackManager +from llama_index.core.embeddings.mock_embed_model import MockEmbedding +from llama_index.core.utils import get_cache_dir + +EmbedType = Union[BaseEmbedding, "LCEmbeddings", str] +from tqdm import tqdm + + + + +DEFAULT_QA_GENERATE_PROMPT_TMPL = """\ +Context information is below. + +--------------------- +{context_str} +--------------------- + +Given the context information and no prior knowledge. +generate only questions based on the below query. + +You are a Teacher/ Professor. Your task is to setup \ +{num_questions_per_chunk} questions for an upcoming \ +quiz/examination. The questions should be diverse in nature \ +across the document. Restrict the questions to the \ +context information provided." +""" + + +# generate queries as a convenience function +def generate_qa_embedding_pairs( + nodes: List[TextNode], + llm: any, + qa_generate_prompt_tmpl: str = DEFAULT_QA_GENERATE_PROMPT_TMPL, + num_questions_per_chunk: int = 2, +) -> EmbeddingQAFinetuneDataset: + """Generate examples given a set of nodes.""" + node_dict = { + node.node_id: node.get_content(metadata_mode=MetadataMode.NONE) + for node in nodes + } + + queries = {} + relevant_docs = {} + for node_id, text in tqdm(node_dict.items()): + query = qa_generate_prompt_tmpl.format( + context_str=text, num_questions_per_chunk=num_questions_per_chunk + ) + response = llm.predict(query) + + result = str(response).strip().split("\n") + questions = [ + re.sub(r"^\d+[\).\s]", "", question).strip() for question in result + ] + questions = [question for question in questions if len(question) > 0] + + for question in questions: + question_id = str(uuid.uuid4()) + queries[question_id] = question + relevant_docs[question_id] = [node_id] + + # construct dataset + return EmbeddingQAFinetuneDataset( + queries=queries, corpus=node_dict, relevant_docs=relevant_docs + ) + + +def resolve_embed_model( + embed_model: Optional[EmbedType] = None, + callback_manager: Optional[CallbackManager] = None, +) -> BaseEmbedding: + """Resolve embed model.""" + from llama_index.core.settings import Settings + + try: + from llama_index.core.bridge.langchain import Embeddings as LCEmbeddings + except ImportError: + LCEmbeddings = None # type: ignore + + if embed_model == "default": + if os.getenv("IS_TESTING"): + embed_model = MockEmbedding(embed_dim=8) + embed_model.callback_manager = callback_manager or Settings.callback_manager + return embed_model + + try: + from llama_index.embeddings.openai import ( + OpenAIEmbedding, + ) # pants: no-infer-dep + + from llama_index.embeddings.openai.utils import ( + validate_openai_api_key, + ) # pants: no-infer-dep + + embed_model = OpenAIEmbedding() + validate_openai_api_key(embed_model.api_key) + except ImportError: + raise ImportError( + "`llama-index-embeddings-openai` package not found, " + "please run `pip install llama-index-embeddings-openai`" + ) + except ValueError as e: + raise ValueError( + "\n******\n" + "Could not load OpenAI embedding model. " + "If you intended to use OpenAI, please check your OPENAI_API_KEY.\n" + "Original error:\n" + f"{e!s}" + "\nConsider using embed_model='local'.\n" + "Visit our documentation for more embedding options: " + "https://docs.llamaindex.ai/en/stable/module_guides/models/" + "embeddings.html#modules" + "\n******" + ) + # for image multi-modal embeddings + elif isinstance(embed_model, str) and embed_model.startswith("clip"): + try: + from llama_index.embeddings.clip import ClipEmbedding # pants: no-infer-dep + + clip_model_name = ( + embed_model.split(":")[1] if ":" in embed_model else "ViT-B/32" + ) + embed_model = ClipEmbedding(model_name=clip_model_name) + except ImportError as e: + raise ImportError( + "`llama-index-embeddings-clip` package not found, " + "please run `pip install llama-index-embeddings-clip` and `pip install git+https://github.com/openai/CLIP.git`" + ) + + if isinstance(embed_model, str): + try: + from llama_index.embeddings.huggingface import ( + HuggingFaceEmbedding, + ) # pants: no-infer-dep + + splits = embed_model.split(":", 1) + is_local = splits[0] + model_name = splits[1] if len(splits) > 1 else None + if is_local != "local": + raise ValueError( + "embed_model must start with str 'local' or of type BaseEmbedding" + ) + + cache_folder = os.path.join(get_cache_dir(), "models") + os.makedirs(cache_folder, exist_ok=True) + + embed_model = HuggingFaceEmbedding( + model_name=model_name, cache_folder=cache_folder + ) + except ImportError: + raise ImportError( + "`llama-index-embeddings-huggingface` package not found, " + "please run `pip install llama-index-embeddings-huggingface`" + ) + + if LCEmbeddings is not None and isinstance(embed_model, LCEmbeddings): + try: + from llama_index.embeddings.langchain import ( + LangchainEmbedding, + ) # pants: no-infer-dep + + embed_model = LangchainEmbedding(embed_model) + except ImportError as e: + raise ImportError( + "`llama-index-embeddings-langchain` package not found, " + "please run `pip install llama-index-embeddings-langchain`" + ) + + if embed_model is None: + print("Embeddings have been explicitly disabled. Using MockEmbedding.") + embed_model = MockEmbedding(embed_dim=1) + + embed_model.callback_manager = callback_manager or Settings.callback_manager + + return embed_model \ No newline at end of file