Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions docs/src/piccolo/schema/m2m.rst
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,36 @@ given genre:
{"name": "Classical", "bands": ["C-Sharps"]},
]

Bidirectional select queries
----------------------------

The ``bidirectional`` argument is **only** used for self-referencing tables
in many to many relationships. If set to ``True``, a bidirectional
query is performed to obtain the correct result in a symmetric
many to many relationships on self-referencing tables.

.. code-block:: python

class Member(Table):
name = Varchar()
# self-reference many to many
followers = M2M(
LazyTableReference("MemberToFollower", module_path=__name__)
)
followings = M2M(
LazyTableReference("MemberToFollower", module_path=__name__)
)


class MemberToFollower(Table):
follower_id = ForeignKey(Member)
following_id = ForeignKey(Member)

>>> await Member.select(
Member.followers(Member.name, as_list=True, bidirectional=True)
).where(Member.name == "Bob")
[{"followers": ["Fred", "John", "Mia"]}]

-------------------------------------------------------------------------------

Objects queries
Expand Down
116 changes: 90 additions & 26 deletions piccolo/columns/m2m.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(
m2m: M2M,
as_list: bool = False,
load_json: bool = False,
bidirectional: Optional[bool] = False,
):
"""
:param columns:
Expand All @@ -40,12 +41,16 @@ def __init__(
flattened list will be returned, rather than a list of objects.
:param load_json:
If ``True``, any JSON strings are loaded as Python objects.
:param bidirectional:
Only used for self-referencing tables. If ``True``, a
bidirectional query is performed against self-referencing tables.

"""
self.as_list = as_list
self.columns = columns
self.m2m = m2m
self.load_json = load_json
self.bidirectional = bidirectional

safe_types = (int, str)

Expand Down Expand Up @@ -75,20 +80,50 @@ def get_select_string(
fk_2 = self.m2m._meta.secondary_foreign_key
fk_2_name = fk_2._meta.db_column_name
table_2 = fk_2._foreign_key_meta.resolved_references
table_2_name = table_2._meta.tablename
table_2_name_with_schema = table_2._meta.get_formatted_tablename()
table_2_pk_name = table_2._meta.primary_key._meta.db_column_name

inner_select = f"""
{m2m_table_name_with_schema}
JOIN {table_1_name_with_schema} "inner_{table_1_name}" ON (
{m2m_table_name_with_schema}."{fk_1_name}" = "inner_{table_1_name}"."{table_1_pk_name}"
)
JOIN {table_2_name_with_schema} "inner_{table_2_name}" ON (
{m2m_table_name_with_schema}."{fk_2_name}" = "inner_{table_2_name}"."{table_2_pk_name}"
)
WHERE {m2m_table_name_with_schema}."{fk_1_name}" = "{table_1_name}"."{table_1_pk_name}"
""" # noqa: E501
# self-reference table (if primary and secondary table are the same)
if table_1 == table_2:
table_2_name = table_1._meta.tablename
table_2_name_with_schema = table_1._meta.get_formatted_tablename()
table_2_pk_name = table_1._meta.primary_key._meta.db_column_name
# check bidirectional argument. If True change direction in query
if self.bidirectional:
inner_select = f"""
{m2m_table_name_with_schema}
JOIN {table_1_name_with_schema} "inner_{table_1_name}" ON (
{m2m_table_name_with_schema}."{fk_1_name}" = "inner_{table_1_name}"."{table_1_pk_name}"
)
WHERE {m2m_table_name_with_schema}."{fk_2_name}" = "{table_2_name}"."{table_2_pk_name}"
""" # noqa: E501
else:
inner_select = f"""
{m2m_table_name_with_schema}
JOIN {table_2_name_with_schema} "inner_{table_2_name}" ON (
{m2m_table_name_with_schema}."{fk_2_name}" = "inner_{table_2_name}"."{table_2_pk_name}"
)
WHERE {m2m_table_name_with_schema}."{fk_1_name}" = "{table_1_name}"."{table_1_pk_name}"
""" # noqa: E501
else:
table_1_name = table_1._meta.tablename
table_1_name_with_schema = table_1._meta.get_formatted_tablename()
table_1_pk_name = table_1._meta.primary_key._meta.db_column_name

fk_2 = self.m2m._meta.secondary_foreign_key
fk_2_name = fk_2._meta.db_column_name
table_2 = fk_2._foreign_key_meta.resolved_references
table_2_name = table_2._meta.tablename
table_2_name_with_schema = table_2._meta.get_formatted_tablename()
table_2_pk_name = table_2._meta.primary_key._meta.db_column_name

inner_select = f"""
{m2m_table_name_with_schema}
JOIN {table_1_name_with_schema} "inner_{table_1_name}" ON (
{m2m_table_name_with_schema}."{fk_1_name}" = "inner_{table_1_name}"."{table_1_pk_name}"
)
JOIN {table_2_name_with_schema} "inner_{table_2_name}" ON (
{m2m_table_name_with_schema}."{fk_2_name}" = "inner_{table_2_name}"."{table_2_pk_name}"
)
WHERE {m2m_table_name_with_schema}."{fk_1_name}" = "{table_1_name}"."{table_1_pk_name}"
""" # noqa: E501

if engine_type in ("postgres", "cockroach"):
if self.as_list:
Expand Down Expand Up @@ -248,7 +283,11 @@ def secondary_foreign_key(self) -> ForeignKey:
for fk_column in self.foreign_key_columns:
if fk_column._foreign_key_meta.resolved_references != self.table:
return fk_column

if (
fk_column._foreign_key_meta.resolved_references
== self.primary_table
):
return self.foreign_key_columns[-1]
raise ValueError("No matching foreign key column found!")

@property
Expand Down Expand Up @@ -367,23 +406,39 @@ def __await__(self):
class M2MGetRelated:
row: Table
m2m: M2M
bidirectional: Optional[bool] = False

async def run(self):
joining_table = self.m2m._meta.resolved_joining_table

secondary_table = self.m2m._meta.secondary_table

# use a subquery to make only one db query
results = await secondary_table.objects().where(
secondary_table._meta.primary_key.is_in(
joining_table.select(
getattr(
self.m2m._meta.secondary_foreign_key,
secondary_table._meta.primary_key._meta.name,
)
).where(self.m2m._meta.primary_foreign_key == self.row)
# bidirectional argument which is used to distinguish
# the direction in which we execute queries in the
# self-reference table (reference the same table)
if self.bidirectional:
results = await secondary_table.objects().where(
secondary_table._meta.primary_key.is_in(
joining_table.select(
getattr(
self.m2m._meta.primary_foreign_key,
secondary_table._meta.primary_key._meta.name,
)
).where(self.m2m._meta.secondary_foreign_key == self.row)
)
)
else:
# use a subquery to make only one db query
results = await secondary_table.objects().where(
secondary_table._meta.primary_key.is_in(
joining_table.select(
getattr(
self.m2m._meta.secondary_foreign_key,
secondary_table._meta.primary_key._meta.name,
)
).where(self.m2m._meta.primary_foreign_key == self.row)
)
)
)

return results

Expand Down Expand Up @@ -424,6 +479,7 @@ def __call__(
*columns: Union[Column, list[Column]],
as_list: bool = False,
load_json: bool = False,
bidirectional: Optional[bool] = False,
) -> M2MSelect:
"""
:param columns:
Expand All @@ -434,6 +490,10 @@ def __call__(
flattened list will be returned, rather than a list of objects.
:param load_json:
If ``True``, any JSON strings are loaded as Python objects.
:param bidirectional:
Only used for self-referencing tables. If ``True``, a
bidirectional query is performed against self-referencing tables.

"""
columns_ = flatten(columns)

Expand All @@ -446,5 +506,9 @@ def __call__(
)

return M2MSelect(
*columns_, m2m=self, as_list=as_list, load_json=load_json
*columns_,
m2m=self,
as_list=as_list,
load_json=load_json,
bidirectional=bidirectional,
)
34 changes: 31 additions & 3 deletions piccolo/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,7 +647,9 @@ def get_related(

return GetRelated(foreign_key=foreign_key, row=self)

def get_m2m(self, m2m: M2M) -> M2MGetRelated:
def get_m2m(
self, m2m: M2M, bidirectional: Optional[bool] = None
) -> M2MGetRelated:
"""
Get all matching rows via the join table.

Expand All @@ -657,8 +659,34 @@ def get_m2m(self, m2m: M2M) -> M2MGetRelated:
>>> await band.get_m2m(Band.genres)
[<Genre: 1>, <Genre: 2>]

"""
return M2MGetRelated(row=self, m2m=m2m)
The ``bidirectional`` argument is only used for self-referencing tables
in many to many relationships. If set to ``True``, a bidirectional
query is performed to obtain the correct result in a symmetric
many-to-many relationships on self-referencing tables.

.. code-block:: python

class Member(Table):
name = Varchar()
# self-reference many to many
followers = M2M(
LazyTableReference("MemberToFollower", module_path=__name__)
)
followings = M2M(
LazyTableReference("MemberToFollower", module_path=__name__)
)


class MemberToFollower(Table):
follower_id = ForeignKey(Member)
following_id = ForeignKey(Member)

>>> member = await Member.objects().get(Member.name == "Bob")
>>> await member.get_m2m(Member.followers, bidirectional=True)
[<Member: 3>, <Member: 4>, <Member: 5>]

""" # noqa: E501
return M2MGetRelated(row=self, m2m=m2m, bidirectional=bidirectional)

def add_m2m(
self,
Expand Down
103 changes: 103 additions & 0 deletions tests/columns/m2m/test_m2m.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,3 +420,106 @@ def test_select_single(self):
returned_value,
msg=f"{column_name} doesn't match",
)


###############################################################################

# A schema using self-reference tables


class Member(Table):
name = Varchar()
# self-reference many to many
followers = M2M(
LazyTableReference("MemberToFollower", module_path=__name__)
)
followings = M2M(
LazyTableReference("MemberToFollower", module_path=__name__)
)


class MemberToFollower(Table):
follower_id = ForeignKey(Member)
following_id = ForeignKey(Member)


SELF_REFERENCE_SCHEMA = [Member, MemberToFollower]


class TestM2MSelfReference(TestCase):
"""
Make sure the M2M functionality works correctly when the tables is
the same (self-reference tables).
"""

def setUp(self):
create_db_tables_sync(*SELF_REFERENCE_SCHEMA, if_not_exists=True)

bob = Member.objects().create(name="Bob").run_sync()
sally = Member.objects().create(name="Sally").run_sync()
fred = Member.objects().create(name="Fred").run_sync()
john = Member.objects().create(name="John").run_sync()
mia = Member.objects().create(name="Mia").run_sync()

MemberToFollower.insert(
MemberToFollower(follower_id=fred, following_id=bob),
MemberToFollower(follower_id=bob, following_id=sally),
MemberToFollower(follower_id=fred, following_id=sally),
MemberToFollower(follower_id=john, following_id=bob),
MemberToFollower(follower_id=mia, following_id=bob),
MemberToFollower(follower_id=bob, following_id=john),
).run_sync()

def tearDown(self):
drop_db_tables_sync(*SELF_REFERENCE_SCHEMA)

def test_select_bidirectional(self):
"""
Make sure we can select related items for self-reference table.
"""
followings = (
Member.select(Member.followings(Member.name, as_list=True))
.where(Member.name == "Bob")
.run_sync()
)

self.assertEqual(followings, [{"followings": ["Sally", "John"]}])

# Now we use the bidirectional argument to get the correct result.
# Without it, we cannot get the correct result for symmetric
# self-referencing many to many relations.
followers = (
Member.select(
Member.followers(Member.name, as_list=True, bidirectional=True)
)
.where(Member.name == "Bob")
.run_sync()
)

self.assertEqual(followers, [{"followers": ["Fred", "John", "Mia"]}])

def test_get_m2m_bidirectional(self):
"""
Make sure we can get related items for self-reference table.
"""
member = Member.objects().get(Member.name == "Bob").run_sync()
assert member is not None

followings = member.get_m2m(Member.followings).run_sync()

self.assertTrue(all(isinstance(i, Table) for i in followings))

self.assertCountEqual([i.name for i in followings], ["Sally", "John"])

# Now we use the bidirectional argument to get the correct result.
# Without it, we cannot get the correct result for symmetric
# self-referencing many to many relations.
followers = member.get_m2m(
Member.followers, bidirectional=True
).run_sync()

self.assertTrue(all(isinstance(i, Table) for i in followers))

self.assertCountEqual(
[i.name for i in followers], ["Fred", "John", "Mia"]
)