diff --git a/tests/test_dynamo.py b/tests/test_dynamo.py index d7bfaadd21a..c050a1afb24 100644 --- a/tests/test_dynamo.py +++ b/tests/test_dynamo.py @@ -122,6 +122,60 @@ def test_batch_get(self): 'z': None, }) + def test_no_memory_sharing_direct(self): + """Ensure that changes to objects retrieved from the database do not leak into other operations.""" + self.table.put({'id': 'key', 'x': 1}) + + # WHEN + retrieved = self.table.get({'id': 'key'}) + retrieved['x'] = 666 + + # THEN + retrieved2 = self.table.get({'id': 'key'}) + self.assertEqual(retrieved2['x'], 1) + + def test_no_memory_sharing_sublists(self): + """Ensure that changes to objects retrieved from the database do not leak into other operations.""" + # GIVEN + self.table.put(dict( + id='with_list', + sublist=[1] + )) + + # WHEN + retrieved = self.table.get(dict(id='with_list')) + retrieved['sublist'].append(2) + + # THEN + retrieved2 = self.table.get(dict(id='with_list')) + self.assertEqual(retrieved2['sublist'], [1]) + + def test_no_memory_sharing_insert_direct(self): + """Ensure that changes to objects we insert aren't accidentally seen by other viewers.""" + # GIVEN + obj = dict(id='lookatme', x=0) + + # WHEN + self.table.put(obj) + obj['x'] = 1 + + # THEN + obj2 = self.table.get(dict(id='lookatme')) + self.assertEqual(obj2['x'], 0) + + def test_no_memory_sharing_insert_sublist(self): + """Ensure that changes to objects we insert aren't accidentally seen by other viewers.""" + # GIVEN + obj = dict(id='lookatme', x=[0]) + + # WHEN + self.table.put(obj) + obj['x'].append(1) + + # THEN + obj2 = self.table.get(dict(id='lookatme')) + self.assertEqual(obj2['x'], [0]) + class TestSortKeysInMemory(unittest.TestCase): """Test that the operations work on an in-memory table with a sort key.""" diff --git a/website/dynamo.py b/website/dynamo.py index 1b965f8ea41..3452b603418 100644 --- a/website/dynamo.py +++ b/website/dynamo.py @@ -777,10 +777,10 @@ def before_or_equal(key0, key1): next_page_key = with_keys[-1][0] # Do a final filtering to mimic DynamoDB FilterExpression - return copy.copy([record - for _, record in with_keys - if self._query_matches(record, filter_eq_conditions, filter_special_conditions) - ]), next_page_key + return copy.deepcopy([record + for _, record in with_keys + if self._query_matches(record, filter_eq_conditions, filter_special_conditions) + ]), next_page_key # NOTE: on purpose not @synchronized here def query_index(self, table_name, index_name, keys, sort_key, reverse=False, limit=None, pagination_token=None, @@ -810,9 +810,9 @@ def put(self, table_name, key, data): records = self.tables.setdefault(table_name, []) index = self._find_index(records, key) if index is None: - records.append(copy.copy(data)) + records.append(copy.deepcopy(data)) else: - records[index] = copy.copy(data) + records[index] = copy.deepcopy(data) self._flush() @lock.synchronized @@ -888,7 +888,7 @@ def scan(self, table_name, limit, pagination_token): next_page_token = {"offset": start_index + limit} items = items[:limit] - items = copy.copy(items) + items = copy.deepcopy(items) return items, next_page_token def _find_index(self, records, key):