From 08104e0042cff827231e30a37f53cac4d60cbc02 Mon Sep 17 00:00:00 2001 From: David Berenstein Date: Fri, 21 Jun 2024 16:45:25 +0200 Subject: [PATCH] feat: InMemoryDocumentStore serialization (#7888) * Add: InMemoryDocumentStore serialization * Add: additional chek to test if path exists * Fix: failing test --- .../in_memory/document_store.py | 38 +++++++++++++++++++ ...nmemorydocumentstore-2aa4d9ac85b961c5.yaml | 4 ++ test/document_stores/test_in_memory.py | 18 +++++++++ 3 files changed, 60 insertions(+) create mode 100644 releasenotes/notes/add-serialization-to-inmemorydocumentstore-2aa4d9ac85b961c5.yaml diff --git a/haystack/document_stores/in_memory/document_store.py b/haystack/document_stores/in_memory/document_store.py index 5ff6cb1fe6..4fd10e1cd1 100644 --- a/haystack/document_stores/in_memory/document_store.py +++ b/haystack/document_stores/in_memory/document_store.py @@ -2,11 +2,13 @@ # # SPDX-License-Identifier: Apache-2.0 +import json import math import re import uuid from collections import Counter from dataclasses import dataclass +from pathlib import Path from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple import numpy as np @@ -339,6 +341,42 @@ def from_dict(cls, data: Dict[str, Any]) -> "InMemoryDocumentStore": """ return default_from_dict(cls, data) + def save_to_disk(self, path: str) -> None: + """ + Write the database and its' data to disk as a JSON file. + + :param path: The path to the JSON file. + """ + data: Dict[str, Any] = self.to_dict() + data["documents"] = [doc.to_dict(flatten=False) for doc in self.storage.values()] + with open(path, "w") as f: + json.dump(data, f) + + @classmethod + def load_from_disk(cls, path: str) -> "InMemoryDocumentStore": + """ + Load the database and its' data from disk as a JSON file. + + :param path: The path to the JSON file. + :returns: The loaded InMemoryDocumentStore. + """ + if Path(path).exists(): + try: + with open(path, "r") as f: + data = json.load(f) + except Exception as e: + raise Exception(f"Error loading InMemoryDocumentStore from disk. error: {e}") + + documents = data.pop("documents") + cls_object = default_from_dict(cls, data) + cls_object.write_documents( + documents=[Document(**doc) for doc in documents], policy=DuplicatePolicy.OVERWRITE + ) + return cls_object + + else: + raise FileNotFoundError(f"File {path} not found.") + def count_documents(self) -> int: """ Returns the number of how many documents are present in the DocumentStore. diff --git a/releasenotes/notes/add-serialization-to-inmemorydocumentstore-2aa4d9ac85b961c5.yaml b/releasenotes/notes/add-serialization-to-inmemorydocumentstore-2aa4d9ac85b961c5.yaml new file mode 100644 index 0000000000..48e3c8e427 --- /dev/null +++ b/releasenotes/notes/add-serialization-to-inmemorydocumentstore-2aa4d9ac85b961c5.yaml @@ -0,0 +1,4 @@ +--- +enhancements: + - | + Added serialization methods save_to_disk and write_to_disk to InMemoryDocumentStore. diff --git a/test/document_stores/test_in_memory.py b/test/document_stores/test_in_memory.py index 2a8679502b..8b8ed0e5fa 100644 --- a/test/document_stores/test_in_memory.py +++ b/test/document_stores/test_in_memory.py @@ -6,6 +6,7 @@ import pandas as pd import pytest +import tempfile from haystack import Document from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError @@ -18,6 +19,11 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests): # pylint: disable=R0904 Test InMemoryDocumentStore's specific features """ + @pytest.fixture + def tmp_dir(self): + with tempfile.TemporaryDirectory() as tmp_dir: + yield tmp_dir + @pytest.fixture def document_store(self) -> InMemoryDocumentStore: return InMemoryDocumentStore(bm25_algorithm="BM25L") @@ -74,6 +80,18 @@ def test_from_dict(self, mock_regex): assert store.bm25_parameters == {"key": "value"} assert store.index == "my_cool_index" + def test_save_to_disk_and_load_from_disk(self, tmp_dir: str): + docs = [Document(content="Hello world"), Document(content="Haystack supports multiple languages")] + document_store = InMemoryDocumentStore() + document_store.write_documents(docs) + tmp_dir = tmp_dir + "/document_store.json" + document_store.save_to_disk(tmp_dir) + document_store_loaded = InMemoryDocumentStore.load_from_disk(tmp_dir) + + assert document_store_loaded.count_documents() == 2 + assert list(document_store_loaded.storage.values()) == docs + assert document_store_loaded.to_dict() == document_store.to_dict() + def test_invalid_bm25_algorithm(self): with pytest.raises(ValueError, match="BM25 algorithm 'invalid' is not supported"): InMemoryDocumentStore(bm25_algorithm="invalid")