From 6bba980e02f37620030ba95a83f4e6ada1c8ae4f Mon Sep 17 00:00:00 2001 From: TomuHirata Date: Wed, 17 Sep 2025 15:58:18 +0900 Subject: [PATCH 1/3] Add save/load to Embeddings --- dspy/retrievers/embeddings.py | 93 ++++++++++++++++++++++++++++- tests/retrievers/test_embeddings.py | 62 +++++++++++++++++++ 2 files changed, 153 insertions(+), 2 deletions(-) diff --git a/dspy/retrievers/embeddings.py b/dspy/retrievers/embeddings.py index 9b9b31be72..15a4dda498 100644 --- a/dspy/retrievers/embeddings.py +++ b/dspy/retrievers/embeddings.py @@ -1,11 +1,11 @@ +import os +import pickle from typing import Any import numpy as np from dspy.utils.unbatchify import Unbatchify -# TODO: Add .save and .load methods! - class Embeddings: def __init__( @@ -87,3 +87,92 @@ def _rerank_and_predict(self, q_embeds: np.ndarray, candidate_indices: np.ndarra def _normalize(self, embeddings: np.ndarray): norms = np.linalg.norm(embeddings, axis=1, keepdims=True) return embeddings / np.maximum(norms, 1e-10) + + def save(self, path: str): + """ + Save the embeddings index to disk. + + This saves the corpus, embeddings, FAISS index (if present), and configuration + to allow for fast loading without recomputing embeddings. + + Args: + path: Directory path where the embeddings will be saved + """ + os.makedirs(path, exist_ok=True) + + # Save configuration and corpus + config = { + "k": self.k, + "normalize": self.normalize, + "corpus": self.corpus, + "has_faiss_index": self.index is not None, + } + + with open(os.path.join(path, "config.pkl"), "wb") as f: + pickle.dump(config, f) + + # Save embeddings + np.save(os.path.join(path, "corpus_embeddings.npy"), self.corpus_embeddings) + + # Save FAISS index if it exists + if self.index is not None: + try: + import faiss + faiss.write_index(self.index, os.path.join(path, "faiss_index.bin")) + except ImportError: + # If FAISS is not available, we can't save the index + # but we can still save the embeddings for brute force search + pass + + def load(self, path: str, embedder): + """ + Load the embeddings index from disk. + + Args: + path: Directory path where the embeddings were saved + embedder: The embedder function to use for new queries + """ + # Load configuration and corpus + with open(os.path.join(path, "config.pkl"), "rb") as f: + config = pickle.load(f) + + # Restore configuration + self.k = config["k"] + self.normalize = config["normalize"] + self.corpus = config["corpus"] + self.embedder = embedder + + # Load embeddings + self.corpus_embeddings = np.load(os.path.join(path, "corpus_embeddings.npy")) + + # Load FAISS index if it was saved and FAISS is available + faiss_index_path = os.path.join(path, "faiss_index.bin") + if config["has_faiss_index"] and os.path.exists(faiss_index_path): + try: + import faiss + self.index = faiss.read_index(faiss_index_path) + except ImportError: + # If FAISS is not available, fall back to brute force + self.index = None + else: + self.index = None + + # Reinitialize the search function + self.search_fn = Unbatchify(self._batch_forward) + + @classmethod + def from_saved(cls, path: str, embedder): + """ + Create an Embeddings instance from a saved index. + + Args: + path: Directory path where the embeddings were saved + embedder: The embedder function to use for new queries + + Returns: + Embeddings instance loaded from disk + """ + # Create a minimal instance without triggering embedding computation + instance = cls.__new__(cls) + instance.load(path, embedder) + return instance diff --git a/tests/retrievers/test_embeddings.py b/tests/retrievers/test_embeddings.py index 588dfa3a58..f30c0c9c0e 100644 --- a/tests/retrievers/test_embeddings.py +++ b/tests/retrievers/test_embeddings.py @@ -1,6 +1,9 @@ +import os +import tempfile from concurrent.futures import ThreadPoolExecutor import numpy as np +import pytest from dspy.retrievers.embeddings import Embeddings @@ -70,3 +73,62 @@ def worker(query_text, expected_passage): assert results[0] == "The cat sat on the mat." assert results[1] == "The dog barked at the mailman." assert results[2] == "Birds fly in the sky." + + +def test_embeddings_save_load(): + corpus = dummy_corpus() + embedder = dummy_embedder + + original_retriever = Embeddings(corpus=corpus, embedder=embedder, k=2, normalize=False, brute_force_threshold=1000) + + with tempfile.TemporaryDirectory() as temp_dir: + save_path = os.path.join(temp_dir, "test_embeddings") + + # Save original + original_retriever.save(save_path) + + # Verify files were created + assert os.path.exists(os.path.join(save_path, "config.pkl")) + assert os.path.exists(os.path.join(save_path, "corpus_embeddings.npy")) + assert not os.path.exists(os.path.join(save_path, "faiss_index.bin")) # No FAISS for small corpus + + # Load into new instance + new_retriever = Embeddings(corpus=["dummy"], embedder=embedder, k=1, normalize=True, brute_force_threshold=500) + new_retriever.load(save_path, embedder) + + # Verify configuration was loaded correctly + assert new_retriever.corpus == corpus + assert new_retriever.k == 2 + assert new_retriever.normalize is False + assert new_retriever.embedder == embedder + assert new_retriever.index is None + + # Verify search results are preserved + query = "cat sitting" + original_result = original_retriever(query) + loaded_result = new_retriever(query) + assert loaded_result.passages == original_result.passages + assert loaded_result.indices == original_result.indices + + +def test_embeddings_from_saved(): + corpus = dummy_corpus() + embedder = dummy_embedder + + original_retriever = Embeddings(corpus=corpus, embedder=embedder, k=3, normalize=True, brute_force_threshold=1000) + + with tempfile.TemporaryDirectory() as temp_dir: + save_path = os.path.join(temp_dir, "test_embeddings") + + original_retriever.save(save_path) + loaded_retriever = Embeddings.from_saved(save_path, embedder) + + assert loaded_retriever.k == original_retriever.k + assert loaded_retriever.normalize == original_retriever.normalize + assert loaded_retriever.corpus == original_retriever.corpus + + + +def test_embeddings_load_nonexistent_path(): + with pytest.raises((FileNotFoundError, OSError)): + Embeddings.from_saved("/nonexistent/path", dummy_embedder) From 409e007cfc11feb658c8e97ba9bd95fa7fef71ea Mon Sep 17 00:00:00 2001 From: TomuHirata Date: Sat, 20 Sep 2025 00:25:18 +0900 Subject: [PATCH 2/3] comment --- dspy/retrievers/embeddings.py | 55 ++++++++++++++++++++++++----- tests/retrievers/test_embeddings.py | 2 +- 2 files changed, 48 insertions(+), 9 deletions(-) diff --git a/dspy/retrievers/embeddings.py b/dspy/retrievers/embeddings.py index 15a4dda498..27714bd75f 100644 --- a/dspy/retrievers/embeddings.py +++ b/dspy/retrievers/embeddings.py @@ -1,5 +1,5 @@ +import json import os -import pickle from typing import Any import numpy as np @@ -108,8 +108,8 @@ def save(self, path: str): "has_faiss_index": self.index is not None, } - with open(os.path.join(path, "config.pkl"), "wb") as f: - pickle.dump(config, f) + with open(os.path.join(path, "config.json"), "w") as f: + json.dump(config, f, indent=2) # Save embeddings np.save(os.path.join(path, "corpus_embeddings.npy"), self.corpus_embeddings) @@ -126,15 +126,39 @@ def save(self, path: str): def load(self, path: str, embedder): """ - Load the embeddings index from disk. + Load the embeddings index from disk into the current instance. Args: path: Directory path where the embeddings were saved embedder: The embedder function to use for new queries + + Returns: + self: Returns self for method chaining + + Raises: + FileNotFoundError: If the save directory or required files don't exist + ValueError: If the saved config is invalid or incompatible """ + if not os.path.exists(path): + raise FileNotFoundError(f"Save directory not found: {path}") + + config_path = os.path.join(path, "config.json") + embeddings_path = os.path.join(path, "corpus_embeddings.npy") + + if not os.path.exists(config_path): + raise FileNotFoundError(f"Config file not found: {config_path}") + if not os.path.exists(embeddings_path): + raise FileNotFoundError(f"Embeddings file not found: {embeddings_path}") + # Load configuration and corpus - with open(os.path.join(path, "config.pkl"), "rb") as f: - config = pickle.load(f) + with open(config_path) as f: + config = json.load(f) + + # Validate required config fields + required_fields = ["k", "normalize", "corpus", "has_faiss_index"] + for field in required_fields: + if field not in config: + raise ValueError(f"Invalid config: missing required field '{field}'") # Restore configuration self.k = config["k"] @@ -143,7 +167,7 @@ def load(self, path: str, embedder): self.embedder = embedder # Load embeddings - self.corpus_embeddings = np.load(os.path.join(path, "corpus_embeddings.npy")) + self.corpus_embeddings = np.load(embeddings_path) # Load FAISS index if it was saved and FAISS is available faiss_index_path = os.path.join(path, "faiss_index.bin") @@ -157,20 +181,35 @@ def load(self, path: str, embedder): else: self.index = None - # Reinitialize the search function + # Initialize the search function (required since we bypassed __init__) self.search_fn = Unbatchify(self._batch_forward) + return self + @classmethod def from_saved(cls, path: str, embedder): """ Create an Embeddings instance from a saved index. + This is the recommended way to load saved embeddings as it creates a new + instance without unnecessarily computing embeddings. + Args: path: Directory path where the embeddings were saved embedder: The embedder function to use for new queries Returns: Embeddings instance loaded from disk + + Example: + ```python + # Save embeddings + embeddings = Embeddings(corpus, embedder) + embeddings.save("./saved_embeddings") + + # Load embeddings later + loaded_embeddings = Embeddings.from_saved("./saved_embeddings", embedder) + ``` """ # Create a minimal instance without triggering embedding computation instance = cls.__new__(cls) diff --git a/tests/retrievers/test_embeddings.py b/tests/retrievers/test_embeddings.py index f30c0c9c0e..fe1996562b 100644 --- a/tests/retrievers/test_embeddings.py +++ b/tests/retrievers/test_embeddings.py @@ -88,7 +88,7 @@ def test_embeddings_save_load(): original_retriever.save(save_path) # Verify files were created - assert os.path.exists(os.path.join(save_path, "config.pkl")) + assert os.path.exists(os.path.join(save_path, "config.json")) assert os.path.exists(os.path.join(save_path, "corpus_embeddings.npy")) assert not os.path.exists(os.path.join(save_path, "faiss_index.bin")) # No FAISS for small corpus From 9d795e1aa9ec58d431b753f6309b3bda667c6c70 Mon Sep 17 00:00:00 2001 From: TomuHirata Date: Wed, 24 Sep 2025 10:40:52 +0900 Subject: [PATCH 3/3] comment --- dspy/retrievers/embeddings.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/dspy/retrievers/embeddings.py b/dspy/retrievers/embeddings.py index 27714bd75f..02da760078 100644 --- a/dspy/retrievers/embeddings.py +++ b/dspy/retrievers/embeddings.py @@ -181,9 +181,6 @@ def load(self, path: str, embedder): else: self.index = None - # Initialize the search function (required since we bypassed __init__) - self.search_fn = Unbatchify(self._batch_forward) - return self @classmethod @@ -213,5 +210,7 @@ def from_saved(cls, path: str, embedder): """ # Create a minimal instance without triggering embedding computation instance = cls.__new__(cls) + # Initialize the search function (required since we bypassed __init__) + instance.search_fn = Unbatchify(instance._batch_forward) instance.load(path, embedder) return instance