Skip to content

Commit

Permalink
Fix ORM self references (#454)
Browse files Browse the repository at this point in the history
  • Loading branch information
gi0baro committed Jan 30, 2023
1 parent fee3932 commit a78571c
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 14 deletions.
2 changes: 1 addition & 1 deletion emmett/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "2.4.13"
__version__ = "2.4.14"
7 changes: 4 additions & 3 deletions emmett/orm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,6 @@ def define_models(self, *models):
obj._define_props_()
obj._define_relations_()
obj._define_virtuals_()
obj._build_rowclass_()
# define table and store in model
args = dict(
migrate=obj.migrate,
Expand All @@ -204,10 +203,12 @@ def define_models(self, *models):
obj.tablename, *obj.fields, **args
)
model.table._model_ = obj
# load user's definitions
obj._define_()
# set reference in db for model name
self.__setattr__(model.__name__, obj.table)
# configure structured rows
obj._build_rowclass_()
# load user's definitions
obj._define_()
if self._auto_migrate and not self._do_connect:
self.connection_close()

Expand Down
26 changes: 18 additions & 8 deletions emmett/orm/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from typing import Any, Callable

from ..datastructures import sdict
from ..utils import cachedprop
from .apis import (
compute,
rowattr,
Expand Down Expand Up @@ -244,6 +243,7 @@ def __init__(self):
self.format = None
if not hasattr(self, 'primary_keys'):
self.primary_keys = []
self._fieldset_pk = set(self.primary_keys or ['id'])

@property
def config(self):
Expand Down Expand Up @@ -401,7 +401,11 @@ def _define_relations_(self):
raise RuntimeError(bad_args_error)
reference = self.__parse_belongs_relation(item, ondelete)
reference.is_refers = not is_belongs
refmodel = self.db[reference.model]._model_
refmodel = (
self.db[reference.model]._model_
if reference.model != self.__class__.__name__
else self
)
ref_multi_pk = len(refmodel._fieldset_pk) > 1
fk_def_key, fks_data, multi_fk = None, {}, []
if ref_multi_pk and reference.fk:
Expand All @@ -418,15 +422,15 @@ def _define_relations_(self):
elif ref_multi_pk and not reference.fk:
multi_fk = list(refmodel.primary_keys)
elif not reference.fk:
reference.fk = refmodel.table._id.name
reference.fk = list(refmodel._fieldset_pk)[0]
if multi_fk:
references = []
fks_data["fields"] = []
fks_data["foreign_fields"] = []
for fk in multi_fk:
refclone = sdict(reference)
refclone.fk = fk
refclone.ftype = refmodel.table[refclone.fk].type
refclone.ftype = getattr(refmodel, refclone.fk).type
refclone.name = f"{refclone.name}_{refclone.fk}"
refclone.compound = reference.name
references.append(refclone)
Expand All @@ -450,7 +454,7 @@ def _define_relations_(self):
coupled_fields=belongs_fks[reference.name].coupled_fields,
)
else:
reference.ftype = refmodel.table[reference.fk].type
reference.ftype = getattr(refmodel, reference.fk).type
references = [reference]
belongs_fks[reference.name] = sdict(
model=reference.model,
Expand Down Expand Up @@ -531,8 +535,15 @@ def __define_fks(self):
implicit_defs = {}
grouped_rels = {}
for rname, rel in self._belongs_ref_.items():
rmodel = self.db[rel.model]._model_
if not rmodel.primary_keys and rmodel.table._id.type == 'id':
rmodel = (
self.db[rel.model]._model_
if rel.model != self.__class__.__name__
else self
)
if (
not rmodel.primary_keys and
getattr(rmodel, list(rmodel._fieldset_pk)[0]).type == 'id'
):
continue
if len(rmodel._fieldset_pk) > 1:
match = self.__find_matching_fk_definition([rel.fk], rmodel)
Expand Down Expand Up @@ -619,7 +630,6 @@ def _unset_row_persistence(self, row):

def _build_rowclass_(self):
#: build helpers for rows
self._fieldset_pk = set(self.primary_keys or ['id'])
save_excluded_fields = (
set(
field.name for field in self.fields if
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "Emmett"
version = "2.4.13"
version = "2.4.14"
description = "The web framework for inventors"
authors = ["Giovanni Barillari <[email protected]>"]
license = "BSD-3-Clause"
Expand Down
9 changes: 8 additions & 1 deletion tests/test_orm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
before_destroy, after_destroy,
before_commit, after_commit,
rowattr, rowmethod,
has_one, has_many, belongs_to,
has_one, has_many, belongs_to, refers_to,
scope
)
from emmett.orm.migrations.utils import generate_runtime_migration
Expand Down Expand Up @@ -361,6 +361,12 @@ def _compute_price(self, row):
return row.quantity * row.product.price


class SelfRef(Model):
refers_to({'parent': 'self'})

name = Field.string()


class CustomPKType(Model):
id = Field.string()

Expand Down Expand Up @@ -421,6 +427,7 @@ def _db():
User, Organization, Membership,
House, Mouse, NeedSplit, Zoo, Animal, Elephant,
Product, Cart, CartElement,
SelfRef,
CustomPKType, CustomPKName, CustomPKMulti,
CommitWatcher
)
Expand Down

0 comments on commit a78571c

Please sign in to comment.