10
10
from .session_timeseries import TimeSeries
11
11
from .session_counters import DocumentCounters
12
12
from typing import Dict , List
13
+ from collections import MutableSet
13
14
14
15
15
16
class _SaveChangesData (object ):
@@ -19,6 +20,129 @@ def __init__(self, commands, deferred_command_count, entities=None):
19
20
self .deferred_command_count = deferred_command_count
20
21
21
22
23
+ class _RefEq :
24
+ def __init__ (self , ref ):
25
+ if isinstance (ref , _RefEq ):
26
+ self .ref = ref .ref
27
+ return
28
+ self .ref = ref
29
+
30
+ # As we split the hashable and unhashable items into separate collections, we only compare _RefEq to other _RefEq
31
+ def __eq__ (self , other ):
32
+ if isinstance (other , _RefEq ):
33
+ return id (self .ref ) == id (other .ref )
34
+ raise TypeError ("Expected _RefEq type object" )
35
+
36
+ def __hash__ (self ):
37
+ return id (self .ref )
38
+
39
+
40
+ class _RefEqEntityHolder (object ):
41
+ def __init__ (self ):
42
+ self .unhashable_items = dict ()
43
+
44
+ def __len__ (self ):
45
+ return len (self .unhashable_items )
46
+
47
+ def __contains__ (self , item ):
48
+ return _RefEq (item ) in self .unhashable_items
49
+
50
+ def __delitem__ (self , key ):
51
+ del self .unhashable_items [_RefEq (key )]
52
+
53
+ def __setitem__ (self , key , value ):
54
+ self .unhashable_items [_RefEq (key )] = value
55
+
56
+ def __getitem__ (self , key ):
57
+ return self .unhashable_items [_RefEq (key )]
58
+
59
+ def __getattribute__ (self , item ):
60
+ if item == "unhashable_items" :
61
+ return super ().__getattribute__ (item )
62
+ return self .unhashable_items .__getattribute__ (item )
63
+
64
+
65
+ class _DocumentsByEntityHolder (object ):
66
+ def __init__ (self ):
67
+ self ._hashable_items = dict ()
68
+ self ._unhashable_items = _RefEqEntityHolder ()
69
+
70
+ def __repr__ (self ):
71
+ return f"{ self .__class__ .__name__ } : { [item for item in self .__iter__ ()]} "
72
+
73
+ def __len__ (self ):
74
+ return len (self ._hashable_items ) + len (self ._unhashable_items )
75
+
76
+ def __contains__ (self , item ):
77
+ try :
78
+ return item in self ._hashable_items
79
+ except TypeError as e :
80
+ if str (e .args [0 ]).startswith ("unhashable type" ):
81
+ return item in self ._unhashable_items
82
+ raise e
83
+
84
+ def __setitem__ (self , key , value ):
85
+ try :
86
+ self ._hashable_items [key ] = value
87
+ except TypeError as e :
88
+ if str (e .args [0 ]).startswith ("unhashable type" ):
89
+ self ._unhashable_items [key ] = value
90
+ return
91
+ raise e
92
+
93
+ def __getitem__ (self , key ):
94
+ try :
95
+ return self ._hashable_items [key ]
96
+ except (TypeError , KeyError ):
97
+ return self ._unhashable_items [key ]
98
+
99
+ def __iter__ (self ):
100
+ d = list (map (lambda x : x .ref , self ._unhashable_items .keys ()))
101
+ if len (self ._hashable_items ) > 0 :
102
+ d .extend (self ._hashable_items .keys ())
103
+ return (item for item in d )
104
+
105
+ def get (self , key , default = None ):
106
+ return self [key ] if key in self else default
107
+
108
+ def pop (self , key , default_value = None ):
109
+ result = self ._hashable_items .pop (key , None )
110
+ if result is not None :
111
+ return result
112
+ return self ._unhashable_items .pop (_RefEq (key ), default_value )
113
+
114
+ def clear (self ):
115
+ self ._hashable_items .clear ()
116
+ self ._unhashable_items .clear ()
117
+
118
+
119
+ class _DeletedEntitiesHolder (MutableSet ):
120
+ def __init__ (self , items = None ):
121
+ if items is None :
122
+ items = []
123
+ self .items = set (map (_RefEq , items ))
124
+
125
+ def __getattribute__ (self , item ):
126
+ if item in ["add" , "discard" , "items" ]:
127
+ return super ().__getattribute__ (item )
128
+ return self .items .__getattribute__ (item )
129
+
130
+ def __contains__ (self , item : object ) -> bool :
131
+ return _RefEq (item ) in self .items
132
+
133
+ def __len__ (self ) -> int :
134
+ return len (self .items )
135
+
136
+ def __iter__ (self ):
137
+ return (item .ref for item in self .items )
138
+
139
+ def add (self , element : object ) -> None :
140
+ return self .items .add (_RefEq (element ))
141
+
142
+ def discard (self , element : object ) -> None :
143
+ return self .items .discard (_RefEq (element ))
144
+
145
+
22
146
class DocumentSession (object ):
23
147
def __init__ (self , database , document_store , requests_executor , session_id , ** kwargs ):
24
148
"""
@@ -33,8 +157,8 @@ def __init__(self, database, document_store, requests_executor, session_id, **kw
33
157
self ._requests_executor = requests_executor
34
158
self ._documents_by_id = {}
35
159
self ._included_documents_by_id = {}
36
- self ._deleted_entities = set ()
37
- self ._documents_by_entity = {}
160
+ self ._deleted_entities = _DeletedEntitiesHolder ()
161
+ self ._documents_by_entity = _DocumentsByEntityHolder ()
38
162
self ._timeseries_defer_commands = {}
39
163
self ._time_series_by_document_id = {}
40
164
self ._counters_defer_commands = {}
0 commit comments