diff --git a/dspy/retrievers/embeddings.py b/dspy/retrievers/embeddings.py index 9b9b31be72..02da760078 100644 --- a/dspy/retrievers/embeddings.py +++ b/dspy/retrievers/embeddings.py @@ -1,11 +1,11 @@ +import json +import os from typing import Any import numpy as np from dspy.utils.unbatchify import Unbatchify -# TODO: Add .save and .load methods! - class Embeddings: def __init__( @@ -87,3 +87,130 @@ def _rerank_and_predict(self, q_embeds: np.ndarray, candidate_indices: np.ndarra def _normalize(self, embeddings: np.ndarray): norms = np.linalg.norm(embeddings, axis=1, keepdims=True) return embeddings / np.maximum(norms, 1e-10) + + def save(self, path: str): + """ + Save the embeddings index to disk. + + This saves the corpus, embeddings, FAISS index (if present), and configuration + to allow for fast loading without recomputing embeddings. + + Args: + path: Directory path where the embeddings will be saved + """ + os.makedirs(path, exist_ok=True) + + # Save configuration and corpus + config = { + "k": self.k, + "normalize": self.normalize, + "corpus": self.corpus, + "has_faiss_index": self.index is not None, + } + + with open(os.path.join(path, "config.json"), "w") as f: + json.dump(config, f, indent=2) + + # Save embeddings + np.save(os.path.join(path, "corpus_embeddings.npy"), self.corpus_embeddings) + + # Save FAISS index if it exists + if self.index is not None: + try: + import faiss + faiss.write_index(self.index, os.path.join(path, "faiss_index.bin")) + except ImportError: + # If FAISS is not available, we can't save the index + # but we can still save the embeddings for brute force search + pass + + def load(self, path: str, embedder): + """ + Load the embeddings index from disk into the current instance. + + Args: + path: Directory path where the embeddings were saved + embedder: The embedder function to use for new queries + + Returns: + self: Returns self for method chaining + + Raises: + FileNotFoundError: If the save directory or required files don't exist + ValueError: If the saved config is invalid or incompatible + """ + if not os.path.exists(path): + raise FileNotFoundError(f"Save directory not found: {path}") + + config_path = os.path.join(path, "config.json") + embeddings_path = os.path.join(path, "corpus_embeddings.npy") + + if not os.path.exists(config_path): + raise FileNotFoundError(f"Config file not found: {config_path}") + if not os.path.exists(embeddings_path): + raise FileNotFoundError(f"Embeddings file not found: {embeddings_path}") + + # Load configuration and corpus + with open(config_path) as f: + config = json.load(f) + + # Validate required config fields + required_fields = ["k", "normalize", "corpus", "has_faiss_index"] + for field in required_fields: + if field not in config: + raise ValueError(f"Invalid config: missing required field '{field}'") + + # Restore configuration + self.k = config["k"] + self.normalize = config["normalize"] + self.corpus = config["corpus"] + self.embedder = embedder + + # Load embeddings + self.corpus_embeddings = np.load(embeddings_path) + + # Load FAISS index if it was saved and FAISS is available + faiss_index_path = os.path.join(path, "faiss_index.bin") + if config["has_faiss_index"] and os.path.exists(faiss_index_path): + try: + import faiss + self.index = faiss.read_index(faiss_index_path) + except ImportError: + # If FAISS is not available, fall back to brute force + self.index = None + else: + self.index = None + + return self + + @classmethod + def from_saved(cls, path: str, embedder): + """ + Create an Embeddings instance from a saved index. + + This is the recommended way to load saved embeddings as it creates a new + instance without unnecessarily computing embeddings. + + Args: + path: Directory path where the embeddings were saved + embedder: The embedder function to use for new queries + + Returns: + Embeddings instance loaded from disk + + Example: + ```python + # Save embeddings + embeddings = Embeddings(corpus, embedder) + embeddings.save("./saved_embeddings") + + # Load embeddings later + loaded_embeddings = Embeddings.from_saved("./saved_embeddings", embedder) + ``` + """ + # Create a minimal instance without triggering embedding computation + instance = cls.__new__(cls) + # Initialize the search function (required since we bypassed __init__) + instance.search_fn = Unbatchify(instance._batch_forward) + instance.load(path, embedder) + return instance diff --git a/tests/retrievers/test_embeddings.py b/tests/retrievers/test_embeddings.py index 588dfa3a58..fe1996562b 100644 --- a/tests/retrievers/test_embeddings.py +++ b/tests/retrievers/test_embeddings.py @@ -1,6 +1,9 @@ +import os +import tempfile from concurrent.futures import ThreadPoolExecutor import numpy as np +import pytest from dspy.retrievers.embeddings import Embeddings @@ -70,3 +73,62 @@ def worker(query_text, expected_passage): assert results[0] == "The cat sat on the mat." assert results[1] == "The dog barked at the mailman." assert results[2] == "Birds fly in the sky." + + +def test_embeddings_save_load(): + corpus = dummy_corpus() + embedder = dummy_embedder + + original_retriever = Embeddings(corpus=corpus, embedder=embedder, k=2, normalize=False, brute_force_threshold=1000) + + with tempfile.TemporaryDirectory() as temp_dir: + save_path = os.path.join(temp_dir, "test_embeddings") + + # Save original + original_retriever.save(save_path) + + # Verify files were created + assert os.path.exists(os.path.join(save_path, "config.json")) + assert os.path.exists(os.path.join(save_path, "corpus_embeddings.npy")) + assert not os.path.exists(os.path.join(save_path, "faiss_index.bin")) # No FAISS for small corpus + + # Load into new instance + new_retriever = Embeddings(corpus=["dummy"], embedder=embedder, k=1, normalize=True, brute_force_threshold=500) + new_retriever.load(save_path, embedder) + + # Verify configuration was loaded correctly + assert new_retriever.corpus == corpus + assert new_retriever.k == 2 + assert new_retriever.normalize is False + assert new_retriever.embedder == embedder + assert new_retriever.index is None + + # Verify search results are preserved + query = "cat sitting" + original_result = original_retriever(query) + loaded_result = new_retriever(query) + assert loaded_result.passages == original_result.passages + assert loaded_result.indices == original_result.indices + + +def test_embeddings_from_saved(): + corpus = dummy_corpus() + embedder = dummy_embedder + + original_retriever = Embeddings(corpus=corpus, embedder=embedder, k=3, normalize=True, brute_force_threshold=1000) + + with tempfile.TemporaryDirectory() as temp_dir: + save_path = os.path.join(temp_dir, "test_embeddings") + + original_retriever.save(save_path) + loaded_retriever = Embeddings.from_saved(save_path, embedder) + + assert loaded_retriever.k == original_retriever.k + assert loaded_retriever.normalize == original_retriever.normalize + assert loaded_retriever.corpus == original_retriever.corpus + + + +def test_embeddings_load_nonexistent_path(): + with pytest.raises((FileNotFoundError, OSError)): + Embeddings.from_saved("/nonexistent/path", dummy_embedder)