From 06bc04bdcf4453f96aa17e00ede6d554fcdbf25b Mon Sep 17 00:00:00 2001 From: davidlatwe Date: Tue, 27 Apr 2021 03:03:16 +0800 Subject: [PATCH 1/5] add tests --- tests/test_engine/test_update/test_update.py | 47 ++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/tests/test_engine/test_update/test_update.py b/tests/test_engine/test_update/test_update.py index b5e8cad..aa9d8cc 100644 --- a/tests/test_engine/test_update/test_update.py +++ b/tests/test_engine/test_update/test_update.py @@ -249,6 +249,53 @@ def test_update_id2(monty_update, mongo_update): assert next(monty_c) == {"a": {"b": {"_id": 3}}} +def test_update_id3(monty_update, mongo_update): + docs = [] + spec = {"$set": {"_id": "some-id", "foo": "baby"}} + find = {"_id": "some-id"} + + monty_c = monty_update(docs, spec, find, upsert=True) + mongo_c = mongo_update(docs, spec, find, upsert=True) + + assert next(mongo_c) == next(monty_c) + + +def test_update_id4(monty_update, mongo_update): + docs = [] + spec = {"$set": {"_id": "some-id", "foo": "baby"}} + find = {"_id": "other-id"} + + with pytest.raises(mongo_write_err) as mongo_err: + mongo_update(docs, spec, find, upsert=True) + + with pytest.raises(monty_write_err) as monty_err: + monty_update(docs, spec, find, upsert=True) + + +def test_update_id5(monty_update, mongo_update): + docs = [{"_id": "other-id", "foo": "bar"}] + spec = {"$set": {"_id": "some-id", "foo": "baby"}} + find = {"foo": "bar"} + + with pytest.raises(mongo_write_err) as mongo_err: + mongo_update(docs, spec, find, upsert=True) + + with pytest.raises(monty_write_err) as monty_err: + monty_update(docs, spec, find, upsert=True) + + +def test_update_id6(monty_update, mongo_update): + docs = [{"_id": "some-id", "foo": "bar"}] + spec = {"$set": {"_id": "some-id", "foo": "baby"}} + find = {"foo": "bar"} + + with pytest.raises(mongo_write_err) as mongo_err: + mongo_update(docs, spec, find, upsert=True) + + with pytest.raises(monty_write_err) as monty_err: + monty_update(docs, spec, find, upsert=True) + + def test_update_positional(monty_update, mongo_update): docs = [ {"a": [{"b": 3, "c": 1}, {"b": 4, "c": 0}]} From d79cefee2fcfa2840fc4bd3e7f0c2863c03dd4d5 Mon Sep 17 00:00:00 2001 From: davidlatwe Date: Tue, 27 Apr 2021 03:04:44 +0800 Subject: [PATCH 2/5] wip --- montydb/collection.py | 7 +++++++ montydb/engine/update.py | 5 ----- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/montydb/collection.py b/montydb/collection.py index f05a1af..69dae69 100644 --- a/montydb/collection.py +++ b/montydb/collection.py @@ -28,6 +28,7 @@ from .errors import ( DuplicateKeyError, BulkWriteError, + WriteError, ) from .results import (BulkWriteResult, @@ -290,6 +291,12 @@ def update_one(self, if upsert: self._internal_upsert(filter, updator, raw_result) else: + filter_id = filter.get("_id") + if filter_id and filter_id != fw.doc["_id"]: + msg = ("Performing an update on the path '_id' would " + "modify the immutable field '_id'") + raise WriteError(msg, code=66) + raw_result["n"] = 1 if updator(fw): self._storage.update_one(self, fw.doc) diff --git a/montydb/engine/update.py b/montydb/engine/update.py index 4aae79b..0b5266a 100644 --- a/montydb/engine/update.py +++ b/montydb/engine/update.py @@ -143,11 +143,6 @@ def parser(self, spec): raise WriteError(msg, code=9) for field, value in cmd_doc.items(): - if field == "_id": - msg = ("Performing an update on the path '_id' would " - "modify the immutable field '_id'") - raise WriteError(msg, code=66) - for top in list(idnt_tops): if "$[{}]".format(top) in field: idnt_tops.remove(top) From e0048c04d735a46ea3ad1d8175e21f152b9931f6 Mon Sep 17 00:00:00 2001 From: davidlatwe Date: Tue, 27 Apr 2021 05:47:43 +0800 Subject: [PATCH 3/5] compare error code in test --- tests/test_engine/test_update/test_update.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/tests/test_engine/test_update/test_update.py b/tests/test_engine/test_update/test_update.py index aa9d8cc..7db36ed 100644 --- a/tests/test_engine/test_update/test_update.py +++ b/tests/test_engine/test_update/test_update.py @@ -231,8 +231,8 @@ def test_update_id(monty_update, mongo_update): with pytest.raises(monty_write_err) as monty_err: monty_update(docs, spec, find) - # ignore comparing error code - # assert mongo_err.value.code == monty_err.value.code + assert mongo_err.value.code == 66 + assert mongo_err.value.code == monty_err.value.code def test_update_id2(monty_update, mongo_update): @@ -271,6 +271,9 @@ def test_update_id4(monty_update, mongo_update): with pytest.raises(monty_write_err) as monty_err: monty_update(docs, spec, find, upsert=True) + assert mongo_err.value.code == 66 + assert mongo_err.value.code == monty_err.value.code + def test_update_id5(monty_update, mongo_update): docs = [{"_id": "other-id", "foo": "bar"}] @@ -283,6 +286,9 @@ def test_update_id5(monty_update, mongo_update): with pytest.raises(monty_write_err) as monty_err: monty_update(docs, spec, find, upsert=True) + assert mongo_err.value.code == 66 + assert mongo_err.value.code == monty_err.value.code + def test_update_id6(monty_update, mongo_update): docs = [{"_id": "some-id", "foo": "bar"}] @@ -295,6 +301,9 @@ def test_update_id6(monty_update, mongo_update): with pytest.raises(monty_write_err) as monty_err: monty_update(docs, spec, find, upsert=True) + assert mongo_err.value.code == 66 + assert mongo_err.value.code == monty_err.value.code + def test_update_positional(monty_update, mongo_update): docs = [ From cd9da70bdfac7a0037c3d67e3b533683ec723f3b Mon Sep 17 00:00:00 2001 From: davidlatwe Date: Tue, 27 Apr 2021 05:47:58 +0800 Subject: [PATCH 4/5] fix #20 --- montydb/collection.py | 18 +++++++++++++----- montydb/engine/update.py | 19 +++++++++++++++---- 2 files changed, 28 insertions(+), 9 deletions(-) diff --git a/montydb/collection.py b/montydb/collection.py index 69dae69..fe9eed8 100644 --- a/montydb/collection.py +++ b/montydb/collection.py @@ -267,6 +267,14 @@ def _remove_dollar_key(doc): updator(fieldwalker, do_insert=True) self._storage.write_one(self, fieldwalker.doc) + def _no_id_update(self, updator, filter=None): + id_operator = updator.operations.get("_id") + doc_id = (filter or {}).get("_id") + if id_operator and id_operator._keep() != doc_id: + msg = ("Performing an update on the path '_id' would " + "modify the immutable field '_id'") + raise WriteError(msg, code=66) + def update_one(self, filter, update, @@ -285,17 +293,14 @@ def update_one(self, raw_result = {"n": 0, "nModified": 0} updator = Updator(update, array_filters) + self._no_id_update(updator, filter) try: fw = next(self._internal_scan_query(filter)) except StopIteration: if upsert: self._internal_upsert(filter, updator, raw_result) else: - filter_id = filter.get("_id") - if filter_id and filter_id != fw.doc["_id"]: - msg = ("Performing an update on the path '_id' would " - "modify the immutable field '_id'") - raise WriteError(msg, code=66) + self._no_id_update(updator) raw_result["n"] = 1 if updator(fw): @@ -323,12 +328,15 @@ def update_many(self, raw_result = {"n": 0, "nModified": 0} updator = Updator(update, array_filters) scanner = self._internal_scan_query(filter) + self._no_id_update(updator, filter) try: next(scanner) except StopIteration: if upsert: self._internal_upsert(filter, updator, raw_result) else: + self._no_id_update(updator) + @on_err_close(scanner) def update_docs(): n, m = 0, 0 diff --git a/montydb/engine/update.py b/montydb/engine/update.py index 0b5266a..d579b16 100644 --- a/montydb/engine/update.py +++ b/montydb/engine/update.py @@ -13,6 +13,7 @@ is_numeric_type, is_duckument_type, is_integer_type, + keep, ) @@ -171,6 +172,7 @@ def check_conflict(self, field): self.fields_to_update.append(field) def parse_set_on_insert(self, field, value, array_filters): + @keep(value) def _set_on_insert(fieldwalker): if self.__insert: parse_set(field, value, array_filters)(fieldwalker) @@ -186,6 +188,7 @@ def parse_inc(field, value, array_filters): "{{{0}: {1}}}".format(field, val_repr_)) raise WriteError(msg, code=14) + @keep(value) def _inc(fieldwalker): def evaluator(node, inc_val): @@ -217,7 +220,7 @@ def evaluator(node, inc_val): def parse_min(field, value, array_filters): - + @keep(value) def _min(fieldwalker): def evaluator(node, min_val): @@ -235,7 +238,7 @@ def evaluator(node, min_val): def parse_max(field, value, array_filters): - + @keep(value) def _max(fieldwalker): def evaluator(node, max_val): @@ -260,6 +263,7 @@ def parse_mul(field, value, array_filters): "{{{0}: {1}}}".format(field, val_repr_)) raise WriteError(msg, code=14) + @keep(value) def _mul(fieldwalker): def evaluator(node, mul_val): @@ -312,6 +316,7 @@ def parse_rename(field, new_field, array_filters): "same path: {0}: {1!r}".format(field, new_field)) raise WriteError(msg, code=2) + @keep(new_field) def _rename(fieldwalker): probe = FieldWalker(fieldwalker.doc) @@ -352,7 +357,7 @@ def _rename(fieldwalker): def parse_set(field, value, array_filters): - + @keep(value) def _set(fieldwalker): _update(fieldwalker, field, value, None, array_filters) @@ -361,7 +366,7 @@ def _set(fieldwalker): def parse_unset(field, _, array_filters): - + @keep(field) def _unset(fieldwalker): _drop(fieldwalker, field, array_filters) @@ -396,6 +401,7 @@ def parse_currentDate(field, value, array_filters): else: value = date_type["date"] + @keep(value) def _currentDate(fieldwalker): parse_set(field, value, array_filters)(fieldwalker) @@ -411,6 +417,7 @@ def parse_add_to_set(field, value_or_each, array_filters): value = value_or_each run_each = False + @keep(value) def _add_to_set(fieldwalker): def evaluator(node, new_elem): @@ -452,6 +459,7 @@ def parse_pop(field, value, array_filters): if value not in (1.0, -1.0): raise WriteError(msg_raw.format(field, value), code=9) + @keep(value) def _pop(fieldwalker): def evaluator(node, pop_ind): @@ -488,6 +496,7 @@ def parse_pull(field, value_or_conditions, array_filters): else: queryfilter = QueryFilter({field: value_or_conditions}) + @keep(queryfilter) def _pull(fieldwalker): def evaluator(node, _): @@ -521,6 +530,7 @@ def parse_push(field, value_or_each, array_filters): value = value_or_each run_each = False + @keep(value) def _push(fieldwalker): def evaluator(node, new_elem): @@ -554,6 +564,7 @@ def parse_pull_all(field, value, array_filters): "".format(value_type)) raise WriteError(msg, code=2) + @keep(value) def _pull_all(fieldwalker): def evaluator(node, pull_list): From 678d885faf261cc38e4645c5143e2a7169ea5b84 Mon Sep 17 00:00:00 2001 From: David Lai Date: Tue, 27 Apr 2021 11:12:51 +0800 Subject: [PATCH 5/5] fix test --- tests/test_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_utils.py b/tests/test_utils.py index 77bf6f0..d4aace9 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -168,6 +168,7 @@ def test_utils_montydump(storage, tmp_monty_repo): def test_MongoQueryRecorder(mongo_client): + mongo_client.drop_database("recordTarget") # ensure clean db db = mongo_client["recordTarget"] _docs_ = [