From 97d245d3e8cdf7522bceb17db815448e623962ce Mon Sep 17 00:00:00 2001 From: Rico Hermans Date: Tue, 2 Apr 2024 18:44:34 +0200 Subject: [PATCH] =?UTF-8?q?=F0=9F=AA=B2=20Fix=20accidental=20object=20shar?= =?UTF-8?q?ing=20in=20Dynamo=20layer=20(#5353)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit As discovered in #5350 by @jpelay, the Dynamo emulation layer has a bug that can lead to accidental data sharing between different requests, because the DB layer hands out references to long-lived objects. If those get mutated in-place by a reader, those mutations accidentally affect the stored in-memory copy of the database. How I did this: - Write the tests first to encode the contract that mutations to retrieved (or inserted) objects shouldn't persist to objects retrieved in other queries. - Replace `copy.copy()` with `copy.deepcopy()` everywhere 😳 oops! **How to test** Follow the steps in #5350. --- tests/test_dynamo.py | 54 ++++++++++++++++++++++++++++++++++++++++++++ website/dynamo.py | 14 ++++++------ 2 files changed, 61 insertions(+), 7 deletions(-) diff --git a/tests/test_dynamo.py b/tests/test_dynamo.py index d7bfaadd21a..c050a1afb24 100644 --- a/tests/test_dynamo.py +++ b/tests/test_dynamo.py @@ -122,6 +122,60 @@ def test_batch_get(self): 'z': None, }) + def test_no_memory_sharing_direct(self): + """Ensure that changes to objects retrieved from the database do not leak into other operations.""" + self.table.put({'id': 'key', 'x': 1}) + + # WHEN + retrieved = self.table.get({'id': 'key'}) + retrieved['x'] = 666 + + # THEN + retrieved2 = self.table.get({'id': 'key'}) + self.assertEqual(retrieved2['x'], 1) + + def test_no_memory_sharing_sublists(self): + """Ensure that changes to objects retrieved from the database do not leak into other operations.""" + # GIVEN + self.table.put(dict( + id='with_list', + sublist=[1] + )) + + # WHEN + retrieved = self.table.get(dict(id='with_list')) + retrieved['sublist'].append(2) + + # THEN + retrieved2 = self.table.get(dict(id='with_list')) + self.assertEqual(retrieved2['sublist'], [1]) + + def test_no_memory_sharing_insert_direct(self): + """Ensure that changes to objects we insert aren't accidentally seen by other viewers.""" + # GIVEN + obj = dict(id='lookatme', x=0) + + # WHEN + self.table.put(obj) + obj['x'] = 1 + + # THEN + obj2 = self.table.get(dict(id='lookatme')) + self.assertEqual(obj2['x'], 0) + + def test_no_memory_sharing_insert_sublist(self): + """Ensure that changes to objects we insert aren't accidentally seen by other viewers.""" + # GIVEN + obj = dict(id='lookatme', x=[0]) + + # WHEN + self.table.put(obj) + obj['x'].append(1) + + # THEN + obj2 = self.table.get(dict(id='lookatme')) + self.assertEqual(obj2['x'], [0]) + class TestSortKeysInMemory(unittest.TestCase): """Test that the operations work on an in-memory table with a sort key.""" diff --git a/website/dynamo.py b/website/dynamo.py index 1b965f8ea41..3452b603418 100644 --- a/website/dynamo.py +++ b/website/dynamo.py @@ -777,10 +777,10 @@ def before_or_equal(key0, key1): next_page_key = with_keys[-1][0] # Do a final filtering to mimic DynamoDB FilterExpression - return copy.copy([record - for _, record in with_keys - if self._query_matches(record, filter_eq_conditions, filter_special_conditions) - ]), next_page_key + return copy.deepcopy([record + for _, record in with_keys + if self._query_matches(record, filter_eq_conditions, filter_special_conditions) + ]), next_page_key # NOTE: on purpose not @synchronized here def query_index(self, table_name, index_name, keys, sort_key, reverse=False, limit=None, pagination_token=None, @@ -810,9 +810,9 @@ def put(self, table_name, key, data): records = self.tables.setdefault(table_name, []) index = self._find_index(records, key) if index is None: - records.append(copy.copy(data)) + records.append(copy.deepcopy(data)) else: - records[index] = copy.copy(data) + records[index] = copy.deepcopy(data) self._flush() @lock.synchronized @@ -888,7 +888,7 @@ def scan(self, table_name, limit, pagination_token): next_page_token = {"offset": start_index + limit} items = items[:limit] - items = copy.copy(items) + items = copy.deepcopy(items) return items, next_page_token def _find_index(self, records, key):