Skip to content

Commit

Permalink
Merge pull request #23 from davidlatwe/fix-#20
Browse files Browse the repository at this point in the history
Fix #20
  • Loading branch information
davidlatwe authored Apr 27, 2021
2 parents 993b1b2 + 678d885 commit 5207284
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 11 deletions.
15 changes: 15 additions & 0 deletions montydb/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from .errors import (
DuplicateKeyError,
BulkWriteError,
WriteError,
)

from .results import (BulkWriteResult,
Expand Down Expand Up @@ -266,6 +267,14 @@ def _remove_dollar_key(doc):
updator(fieldwalker, do_insert=True)
self._storage.write_one(self, fieldwalker.doc)

def _no_id_update(self, updator, filter=None):
id_operator = updator.operations.get("_id")
doc_id = (filter or {}).get("_id")
if id_operator and id_operator._keep() != doc_id:
msg = ("Performing an update on the path '_id' would "
"modify the immutable field '_id'")
raise WriteError(msg, code=66)

def update_one(self,
filter,
update,
Expand All @@ -284,12 +293,15 @@ def update_one(self,

raw_result = {"n": 0, "nModified": 0}
updator = Updator(update, array_filters)
self._no_id_update(updator, filter)
try:
fw = next(self._internal_scan_query(filter))
except StopIteration:
if upsert:
self._internal_upsert(filter, updator, raw_result)
else:
self._no_id_update(updator)

raw_result["n"] = 1
if updator(fw):
self._storage.update_one(self, fw.doc)
Expand All @@ -316,12 +328,15 @@ def update_many(self,
raw_result = {"n": 0, "nModified": 0}
updator = Updator(update, array_filters)
scanner = self._internal_scan_query(filter)
self._no_id_update(updator, filter)
try:
next(scanner)
except StopIteration:
if upsert:
self._internal_upsert(filter, updator, raw_result)
else:
self._no_id_update(updator)

@on_err_close(scanner)
def update_docs():
n, m = 0, 0
Expand Down
24 changes: 15 additions & 9 deletions montydb/engine/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
is_numeric_type,
is_duckument_type,
is_integer_type,
keep,
)


Expand Down Expand Up @@ -143,11 +144,6 @@ def parser(self, spec):
raise WriteError(msg, code=9)

for field, value in cmd_doc.items():
if field == "_id":
msg = ("Performing an update on the path '_id' would "
"modify the immutable field '_id'")
raise WriteError(msg, code=66)

for top in list(idnt_tops):
if "$[{}]".format(top) in field:
idnt_tops.remove(top)
Expand Down Expand Up @@ -176,6 +172,7 @@ def check_conflict(self, field):
self.fields_to_update.append(field)

def parse_set_on_insert(self, field, value, array_filters):
@keep(value)
def _set_on_insert(fieldwalker):
if self.__insert:
parse_set(field, value, array_filters)(fieldwalker)
Expand All @@ -191,6 +188,7 @@ def parse_inc(field, value, array_filters):
"{{{0}: {1}}}".format(field, val_repr_))
raise WriteError(msg, code=14)

@keep(value)
def _inc(fieldwalker):

def evaluator(node, inc_val):
Expand Down Expand Up @@ -222,7 +220,7 @@ def evaluator(node, inc_val):


def parse_min(field, value, array_filters):

@keep(value)
def _min(fieldwalker):

def evaluator(node, min_val):
Expand All @@ -240,7 +238,7 @@ def evaluator(node, min_val):


def parse_max(field, value, array_filters):

@keep(value)
def _max(fieldwalker):

def evaluator(node, max_val):
Expand All @@ -265,6 +263,7 @@ def parse_mul(field, value, array_filters):
"{{{0}: {1}}}".format(field, val_repr_))
raise WriteError(msg, code=14)

@keep(value)
def _mul(fieldwalker):

def evaluator(node, mul_val):
Expand Down Expand Up @@ -317,6 +316,7 @@ def parse_rename(field, new_field, array_filters):
"same path: {0}: {1!r}".format(field, new_field))
raise WriteError(msg, code=2)

@keep(new_field)
def _rename(fieldwalker):

probe = FieldWalker(fieldwalker.doc)
Expand Down Expand Up @@ -357,7 +357,7 @@ def _rename(fieldwalker):


def parse_set(field, value, array_filters):

@keep(value)
def _set(fieldwalker):

_update(fieldwalker, field, value, None, array_filters)
Expand All @@ -366,7 +366,7 @@ def _set(fieldwalker):


def parse_unset(field, _, array_filters):

@keep(field)
def _unset(fieldwalker):

_drop(fieldwalker, field, array_filters)
Expand Down Expand Up @@ -401,6 +401,7 @@ def parse_currentDate(field, value, array_filters):
else:
value = date_type["date"]

@keep(value)
def _currentDate(fieldwalker):
parse_set(field, value, array_filters)(fieldwalker)

Expand All @@ -416,6 +417,7 @@ def parse_add_to_set(field, value_or_each, array_filters):
value = value_or_each
run_each = False

@keep(value)
def _add_to_set(fieldwalker):

def evaluator(node, new_elem):
Expand Down Expand Up @@ -457,6 +459,7 @@ def parse_pop(field, value, array_filters):
if value not in (1.0, -1.0):
raise WriteError(msg_raw.format(field, value), code=9)

@keep(value)
def _pop(fieldwalker):

def evaluator(node, pop_ind):
Expand Down Expand Up @@ -493,6 +496,7 @@ def parse_pull(field, value_or_conditions, array_filters):
else:
queryfilter = QueryFilter({field: value_or_conditions})

@keep(queryfilter)
def _pull(fieldwalker):

def evaluator(node, _):
Expand Down Expand Up @@ -526,6 +530,7 @@ def parse_push(field, value_or_each, array_filters):
value = value_or_each
run_each = False

@keep(value)
def _push(fieldwalker):

def evaluator(node, new_elem):
Expand Down Expand Up @@ -559,6 +564,7 @@ def parse_pull_all(field, value, array_filters):
"".format(value_type))
raise WriteError(msg, code=2)

@keep(value)
def _pull_all(fieldwalker):

def evaluator(node, pull_list):
Expand Down
60 changes: 58 additions & 2 deletions tests/test_engine/test_update/test_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,8 +231,8 @@ def test_update_id(monty_update, mongo_update):
with pytest.raises(monty_write_err) as monty_err:
monty_update(docs, spec, find)

# ignore comparing error code
# assert mongo_err.value.code == monty_err.value.code
assert mongo_err.value.code == 66
assert mongo_err.value.code == monty_err.value.code


def test_update_id2(monty_update, mongo_update):
Expand All @@ -249,6 +249,62 @@ def test_update_id2(monty_update, mongo_update):
assert next(monty_c) == {"a": {"b": {"_id": 3}}}


def test_update_id3(monty_update, mongo_update):
docs = []
spec = {"$set": {"_id": "some-id", "foo": "baby"}}
find = {"_id": "some-id"}

monty_c = monty_update(docs, spec, find, upsert=True)
mongo_c = mongo_update(docs, spec, find, upsert=True)

assert next(mongo_c) == next(monty_c)


def test_update_id4(monty_update, mongo_update):
docs = []
spec = {"$set": {"_id": "some-id", "foo": "baby"}}
find = {"_id": "other-id"}

with pytest.raises(mongo_write_err) as mongo_err:
mongo_update(docs, spec, find, upsert=True)

with pytest.raises(monty_write_err) as monty_err:
monty_update(docs, spec, find, upsert=True)

assert mongo_err.value.code == 66
assert mongo_err.value.code == monty_err.value.code


def test_update_id5(monty_update, mongo_update):
docs = [{"_id": "other-id", "foo": "bar"}]
spec = {"$set": {"_id": "some-id", "foo": "baby"}}
find = {"foo": "bar"}

with pytest.raises(mongo_write_err) as mongo_err:
mongo_update(docs, spec, find, upsert=True)

with pytest.raises(monty_write_err) as monty_err:
monty_update(docs, spec, find, upsert=True)

assert mongo_err.value.code == 66
assert mongo_err.value.code == monty_err.value.code


def test_update_id6(monty_update, mongo_update):
docs = [{"_id": "some-id", "foo": "bar"}]
spec = {"$set": {"_id": "some-id", "foo": "baby"}}
find = {"foo": "bar"}

with pytest.raises(mongo_write_err) as mongo_err:
mongo_update(docs, spec, find, upsert=True)

with pytest.raises(monty_write_err) as monty_err:
monty_update(docs, spec, find, upsert=True)

assert mongo_err.value.code == 66
assert mongo_err.value.code == monty_err.value.code


def test_update_positional(monty_update, mongo_update):
docs = [
{"a": [{"b": 3, "c": 1}, {"b": 4, "c": 0}]}
Expand Down
1 change: 1 addition & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ def test_utils_montydump(storage, tmp_monty_repo):


def test_MongoQueryRecorder(mongo_client):
mongo_client.drop_database("recordTarget") # ensure clean db
db = mongo_client["recordTarget"]

_docs_ = [
Expand Down

0 comments on commit 5207284

Please sign in to comment.