Skip to content

Commit 1484398

Browse files
authored
Merge pull request #101 from poissoncorp/RDBC-471-5.0
RDBC-471 Add support for storing @DataClass() objects (for 5.0)
2 parents 9233bc5 + e36398a commit 1484398

File tree

5 files changed

+216
-2
lines changed

5 files changed

+216
-2
lines changed

pyravendb/store/document_session.py

Lines changed: 126 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from .session_timeseries import TimeSeries
1111
from .session_counters import DocumentCounters
1212
from typing import Dict, List
13+
from collections import MutableSet
1314

1415

1516
class _SaveChangesData(object):
@@ -19,6 +20,129 @@ def __init__(self, commands, deferred_command_count, entities=None):
1920
self.deferred_command_count = deferred_command_count
2021

2122

23+
class _RefEq:
24+
def __init__(self, ref):
25+
if isinstance(ref, _RefEq):
26+
self.ref = ref.ref
27+
return
28+
self.ref = ref
29+
30+
# As we split the hashable and unhashable items into separate collections, we only compare _RefEq to other _RefEq
31+
def __eq__(self, other):
32+
if isinstance(other, _RefEq):
33+
return id(self.ref) == id(other.ref)
34+
raise TypeError("Expected _RefEq type object")
35+
36+
def __hash__(self):
37+
return id(self.ref)
38+
39+
40+
class _RefEqEntityHolder(object):
41+
def __init__(self):
42+
self.unhashable_items = dict()
43+
44+
def __len__(self):
45+
return len(self.unhashable_items)
46+
47+
def __contains__(self, item):
48+
return _RefEq(item) in self.unhashable_items
49+
50+
def __delitem__(self, key):
51+
del self.unhashable_items[_RefEq(key)]
52+
53+
def __setitem__(self, key, value):
54+
self.unhashable_items[_RefEq(key)] = value
55+
56+
def __getitem__(self, key):
57+
return self.unhashable_items[_RefEq(key)]
58+
59+
def __getattribute__(self, item):
60+
if item == "unhashable_items":
61+
return super().__getattribute__(item)
62+
return self.unhashable_items.__getattribute__(item)
63+
64+
65+
class _DocumentsByEntityHolder(object):
66+
def __init__(self):
67+
self._hashable_items = dict()
68+
self._unhashable_items = _RefEqEntityHolder()
69+
70+
def __repr__(self):
71+
return f"{self.__class__.__name__}: {[item for item in self.__iter__()]}"
72+
73+
def __len__(self):
74+
return len(self._hashable_items) + len(self._unhashable_items)
75+
76+
def __contains__(self, item):
77+
try:
78+
return item in self._hashable_items
79+
except TypeError as e:
80+
if str(e.args[0]).startswith("unhashable type"):
81+
return item in self._unhashable_items
82+
raise e
83+
84+
def __setitem__(self, key, value):
85+
try:
86+
self._hashable_items[key] = value
87+
except TypeError as e:
88+
if str(e.args[0]).startswith("unhashable type"):
89+
self._unhashable_items[key] = value
90+
return
91+
raise e
92+
93+
def __getitem__(self, key):
94+
try:
95+
return self._hashable_items[key]
96+
except (TypeError, KeyError):
97+
return self._unhashable_items[key]
98+
99+
def __iter__(self):
100+
d = list(map(lambda x: x.ref, self._unhashable_items.keys()))
101+
if len(self._hashable_items) > 0:
102+
d.extend(self._hashable_items.keys())
103+
return (item for item in d)
104+
105+
def get(self, key, default=None):
106+
return self[key] if key in self else default
107+
108+
def pop(self, key, default_value=None):
109+
result = self._hashable_items.pop(key, None)
110+
if result is not None:
111+
return result
112+
return self._unhashable_items.pop(_RefEq(key), default_value)
113+
114+
def clear(self):
115+
self._hashable_items.clear()
116+
self._unhashable_items.clear()
117+
118+
119+
class _DeletedEntitiesHolder(MutableSet):
120+
def __init__(self, items=None):
121+
if items is None:
122+
items = []
123+
self.items = set(map(_RefEq, items))
124+
125+
def __getattribute__(self, item):
126+
if item in ["add", "discard", "items"]:
127+
return super().__getattribute__(item)
128+
return self.items.__getattribute__(item)
129+
130+
def __contains__(self, item: object) -> bool:
131+
return _RefEq(item) in self.items
132+
133+
def __len__(self) -> int:
134+
return len(self.items)
135+
136+
def __iter__(self):
137+
return (item.ref for item in self.items)
138+
139+
def add(self, element: object) -> None:
140+
return self.items.add(_RefEq(element))
141+
142+
def discard(self, element: object) -> None:
143+
return self.items.discard(_RefEq(element))
144+
145+
22146
class DocumentSession(object):
23147
def __init__(self, database, document_store, requests_executor, session_id, **kwargs):
24148
"""
@@ -33,8 +157,8 @@ def __init__(self, database, document_store, requests_executor, session_id, **kw
33157
self._requests_executor = requests_executor
34158
self._documents_by_id = {}
35159
self._included_documents_by_id = {}
36-
self._deleted_entities = set()
37-
self._documents_by_entity = {}
160+
self._deleted_entities = _DeletedEntitiesHolder()
161+
self._documents_by_entity = _DocumentsByEntityHolder()
38162
self._timeseries_defer_commands = {}
39163
self._time_series_by_document_id = {}
40164
self._counters_defer_commands = {}

pyravendb/tests/session_tests/test_delete.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,19 @@
33

44
sys.path.append(os.path.abspath(__file__ + "/../"))
55

6+
from dataclasses import dataclass
67
from pyravendb.tests.test_base import TestBase
78
from pyravendb.custom_exceptions import exceptions
89
import unittest
910

1011

12+
@dataclass()
13+
class Fish:
14+
name: str
15+
weight: int
16+
Id: str
17+
18+
1119
class Product(object):
1220
def __init__(self, Id=None, name=None):
1321
self.Id = Id
@@ -68,6 +76,19 @@ def test_delete_with_entity(self):
6876

6977
self.assertIsNone(session.load(product.Id))
7078

79+
def test_delete_dataclass(self):
80+
fishie = Fish(name="Tuna", weight=100, Id=None)
81+
with self.store.open_session() as session:
82+
session.store(fishie)
83+
session.save_changes()
84+
85+
with self.store.open_session() as session:
86+
fishie = session.load(fishie.Id)
87+
session.delete(fishie)
88+
session.save_changes()
89+
90+
self.assertIsNone(session.load(fishie.Id))
91+
7192

7293
if __name__ == "__main__":
7394
unittest.main()
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from pyravendb.tests.test_base import TestBase
2+
from pyravendb.store.document_session import _RefEq
3+
from dataclasses import dataclass
4+
5+
6+
@dataclass()
7+
class TestData:
8+
entry: str
9+
10+
11+
class TestDocumentsByEntity(TestBase):
12+
def setUp(self):
13+
super(TestDocumentsByEntity, self).setUp()
14+
15+
def test_ref_eq(self):
16+
data = TestData(entry="classified")
17+
18+
wrapped_data_1 = _RefEq(data)
19+
wrapped_data_2 = _RefEq(wrapped_data_1)
20+
21+
self.assertTrue(wrapped_data_2 == wrapped_data_1)
22+
with self.assertRaises(TypeError):
23+
wrapped_data_1 == data

pyravendb/tests/session_tests/test_load.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,24 @@
11
from pyravendb.commands.raven_commands import PutDocumentCommand
22
from pyravendb.tests.test_base import TestBase
33
from pyravendb.custom_exceptions import exceptions
4+
from dataclasses import dataclass
45
import unittest
56

67

8+
@dataclass()
9+
class OrderData(object):
10+
Id: str
11+
name: str
12+
key: str
13+
product_id: str
14+
15+
16+
@dataclass()
17+
class ProductData(object):
18+
Id: str
19+
name: str
20+
21+
722
class Product(object):
823
def __init__(self, Id=None, name=""):
924
self.Id = Id
@@ -116,6 +131,18 @@ def test_load_with_include(self):
116131
session.load("products/101")
117132
self.assertEqual(session.number_of_requests_in_session, 1)
118133

134+
def test_load_with_include_dataclass(self):
135+
with self.store.open_session() as session:
136+
session.store(OrderData("orderd/1", "some_order", "some_key", "productd/1"))
137+
session.store(ProductData("productd/1", "some_product"))
138+
session.save_changes()
139+
140+
with self.store.open_session() as session:
141+
session.load("orderd/1", includes="product_id")
142+
product = session.load("productd/1")
143+
self.assertEqual(1, session.number_of_requests_in_session)
144+
self.assertEqual("some_product", product.name)
145+
119146

120147
if __name__ == "__main__":
121148
unittest.main()

pyravendb/tests/session_tests/test_store_entities.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,16 @@
11
from pyravendb.tests.test_base import TestBase
22
from pyravendb.store.document_store import DocumentStore
33
from pyravendb.custom_exceptions import exceptions
4+
from dataclasses import dataclass
45
import unittest
56

67

8+
@dataclass()
9+
class Fish:
10+
name: str
11+
weight: int
12+
13+
714
class Foo(object):
815
def __init__(self, name, key):
916
self.name = name
@@ -91,6 +98,18 @@ def test_store_the_same_documents_should_work(self):
9198
results = list(session.query().raw_query("From Foos"))
9299
self.assertEqual(len(results), 40 * 4)
93100

101+
def test_store_dataclass(self):
102+
with self.store.open_session() as session:
103+
fishie = Fish(name="Tuna", weight=100)
104+
session.store(fishie, "fish/1")
105+
session.save_changes()
106+
107+
with self.store.open_session() as session:
108+
fishie = session.load("fish/1")
109+
self.assertIsNotNone(fishie)
110+
self.assertEqual(100, fishie.weight)
111+
self.assertEqual("Tuna", fishie.name)
112+
94113

95114
if __name__ == "__main__":
96115
unittest.main()

0 commit comments

Comments
 (0)