Skip to content

Commit

Permalink
feat: InMemoryDocumentStore serialization (#7888)
Browse files Browse the repository at this point in the history
* Add: InMemoryDocumentStore serialization

* Add: additional chek to test if path exists

* Fix: failing test
  • Loading branch information
davidberenstein1957 committed Jun 21, 2024
1 parent 9c45203 commit 08104e0
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 0 deletions.
38 changes: 38 additions & 0 deletions haystack/document_stores/in_memory/document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
#
# SPDX-License-Identifier: Apache-2.0

import json
import math
import re
import uuid
from collections import Counter
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple

import numpy as np
Expand Down Expand Up @@ -339,6 +341,42 @@ def from_dict(cls, data: Dict[str, Any]) -> "InMemoryDocumentStore":
"""
return default_from_dict(cls, data)

def save_to_disk(self, path: str) -> None:
"""
Write the database and its' data to disk as a JSON file.
:param path: The path to the JSON file.
"""
data: Dict[str, Any] = self.to_dict()
data["documents"] = [doc.to_dict(flatten=False) for doc in self.storage.values()]
with open(path, "w") as f:
json.dump(data, f)

@classmethod
def load_from_disk(cls, path: str) -> "InMemoryDocumentStore":
"""
Load the database and its' data from disk as a JSON file.
:param path: The path to the JSON file.
:returns: The loaded InMemoryDocumentStore.
"""
if Path(path).exists():
try:
with open(path, "r") as f:
data = json.load(f)
except Exception as e:
raise Exception(f"Error loading InMemoryDocumentStore from disk. error: {e}")

documents = data.pop("documents")
cls_object = default_from_dict(cls, data)
cls_object.write_documents(
documents=[Document(**doc) for doc in documents], policy=DuplicatePolicy.OVERWRITE
)
return cls_object

else:
raise FileNotFoundError(f"File {path} not found.")

def count_documents(self) -> int:
"""
Returns the number of how many documents are present in the DocumentStore.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
enhancements:
- |
Added serialization methods save_to_disk and write_to_disk to InMemoryDocumentStore.
18 changes: 18 additions & 0 deletions test/document_stores/test_in_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import pandas as pd
import pytest
import tempfile

from haystack import Document
from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError
Expand All @@ -18,6 +19,11 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests): # pylint: disable=R0904
Test InMemoryDocumentStore's specific features
"""

@pytest.fixture
def tmp_dir(self):
with tempfile.TemporaryDirectory() as tmp_dir:
yield tmp_dir

@pytest.fixture
def document_store(self) -> InMemoryDocumentStore:
return InMemoryDocumentStore(bm25_algorithm="BM25L")
Expand Down Expand Up @@ -74,6 +80,18 @@ def test_from_dict(self, mock_regex):
assert store.bm25_parameters == {"key": "value"}
assert store.index == "my_cool_index"

def test_save_to_disk_and_load_from_disk(self, tmp_dir: str):
docs = [Document(content="Hello world"), Document(content="Haystack supports multiple languages")]
document_store = InMemoryDocumentStore()
document_store.write_documents(docs)
tmp_dir = tmp_dir + "/document_store.json"
document_store.save_to_disk(tmp_dir)
document_store_loaded = InMemoryDocumentStore.load_from_disk(tmp_dir)

assert document_store_loaded.count_documents() == 2
assert list(document_store_loaded.storage.values()) == docs
assert document_store_loaded.to_dict() == document_store.to_dict()

def test_invalid_bm25_algorithm(self):
with pytest.raises(ValueError, match="BM25 algorithm 'invalid' is not supported"):
InMemoryDocumentStore(bm25_algorithm="invalid")
Expand Down

0 comments on commit 08104e0

Please sign in to comment.