Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a decorator for creating table via class-style. #823

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,21 @@ referenced as attributes on instances of ``pypika.Table``.
customers = Table('customers')
q = Query.from_(customers).select(customers.id, customers.fname, customers.lname, customers.phone)

The table also can create via class-style:

.. code-block:: python

from pypika import table_class, Table, Field, Query

@table_class('customers')
class Customer(Table):
id = Field('id')
first_name = Field('fname')
last_name = Field('lname')
phone = Field('phone')

q = Query.from_(Customer).select(Customer.id, Customer.first_name, Customer.last_name, Customer.phone)

Both of the above examples result in the following SQL:

.. code-block:: sql
Expand Down
2 changes: 2 additions & 0 deletions pypika/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
Database,
make_tables as Tables,
make_columns as Columns,
table_class,
)

# noinspection PyUnresolvedReferences
Expand Down Expand Up @@ -141,6 +142,7 @@
'Database',
'Tables',
'Columns',
'table_class',
'Array',
'Bracket',
'Case',
Expand Down
31 changes: 30 additions & 1 deletion pypika/queries.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from copy import copy
from functools import reduce
from typing import Any, List, Optional, Sequence, Tuple as TypedTuple, Type, Union
from typing import Any, List, Optional, Sequence, Tuple as TypedTuple, Type, Union, TypeVar

from pypika.enums import Dialects, JoinType, ReferenceOption, SetOperation
from pypika.terms import (
Expand Down Expand Up @@ -241,6 +241,35 @@ def insert(self, *terms: Union[int, float, str, bool, Term, Field]) -> "QueryBui
return self._query_cls.into(self).insert(*terms)


T = TypeVar("T", bound=Table)


def table_class(
name: str,
schema: Optional[Union[Schema, str]] = None,
alias: Optional[str] = None,
query_cls: Optional[Type["Query"]] = None,
):
"""
A decorator for creating a new table via class-style syntax.

>>> @table_class("user")
... class User(Table):
... id = Field("_id")
... name = Field("name", alias="username")
"""
def builder(cls: Type[T]) -> T:
if not issubclass(cls, Table):
raise TypeError(f"{cls.__name__} must be a subclass of Table.")
table = cls(name=name, schema=schema, alias=alias, query_cls=query_cls)
for field in cls.__dict__.values():
if isinstance(field, Field):
field.table = table
return table

return builder


def make_tables(*names: Union[TypedTuple[str, str], str], **kwargs: Any) -> List[Table]:
"""
Shortcut to create many tables. If `names` param is a tuple, the first
Expand Down
245 changes: 245 additions & 0 deletions pypika/tests/test_table_class.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
# Most of test cases are copied and modified from test_tables.py

import unittest

from pypika import Database, Dialects, Schema, SQLLiteQuery, Table, SYSTEM_TIME, table_class, Field


class TableStructureTests(unittest.TestCase):
def test_table_sql(self):
@table_class("test_table")
class T(Table):
pass

self.assertEqual('"test_table"', str(T))

def test_table_with_no_superclass(self):
with self.assertRaises(TypeError):
@table_class("test_table")
class T:
pass

def test_table_with_bad_superclass(self):
with self.assertRaises(TypeError):
@table_class("test_table")
class T(object):
pass

def test_table_with_alias(self):
@table_class("test_table")
class T(Table):
pass

table = T.as_("my_table")

self.assertEqual('"test_table" "my_table"', table.get_sql(with_alias=True, quote_char='"'))

def test_table_with_schema_arg(self):
@table_class("test_table", schema=Schema("x_schema"))
class T(Table):
pass

self.assertEqual('"x_schema"."test_table"', str(T))

def test_table_with_field(self):
@table_class("test_table")
class T(Table):
f = Field('f')

self.assertEqual('"f"', T.f.get_sql(with_alias=True, quote_char='"'))
self.assertEqual(id(T), id(T.f.table))

def test_table_with_field_and_ailas(self):
@table_class("test_table")
class T(Table):
f = Field('f', alias='my_f')

self.assertEqual('"f" "my_f"', T.f.get_sql(with_alias=True, quote_char='"'))
self.assertEqual(id(T), id(T.f.table))

def test_table_with_unset_field(self):
@table_class("test_table")
class T(Table):
pass

self.assertEqual('"f"', T.f.get_sql(with_alias=True, quote_char='"'))
self.assertEqual(id(T), id(T.f.table))

def test_table_with_schema_and_schema_parent_arg(self):
@table_class("test_table", schema=Schema("x_schema", parent=Database("x_db")))
class T(Table):
pass

self.assertEqual('"x_db"."x_schema"."test_table"', str(T))

def test_table_for_system_time_sql(self):
with self.subTest("with between criterion"):
@table_class("test_table")
class T(Table):
pass

table = T.for_(SYSTEM_TIME.between('2020-01-01', '2020-02-01'))

self.assertEqual('"test_table" FOR SYSTEM_TIME BETWEEN \'2020-01-01\' AND \'2020-02-01\'', str(table))

with self.subTest("with as of criterion"):
@table_class("test_table")
class T(Table):
pass

table = T.for_(SYSTEM_TIME.as_of('2020-01-01'))

self.assertEqual('"test_table" FOR SYSTEM_TIME AS OF \'2020-01-01\'', str(table))

with self.subTest("with from to criterion"):
@table_class("test_table")
class T(Table):
pass

table = T.for_(SYSTEM_TIME.from_to('2020-01-01', '2020-02-01'))

self.assertEqual('"test_table" FOR SYSTEM_TIME FROM \'2020-01-01\' TO \'2020-02-01\'', str(table))

def test_table_for_period_sql(self):
with self.subTest("with between criterion"):
@table_class("test_table")
class T(Table):
pass

table = T.for_(T.valid_period.between('2020-01-01', '2020-02-01'))

self.assertEqual('"test_table" FOR "valid_period" BETWEEN \'2020-01-01\' AND \'2020-02-01\'', str(table))

with self.subTest("with as of criterion"):
@table_class("test_table")
class T(Table):
pass

table = T.for_(T.valid_period.as_of('2020-01-01'))

self.assertEqual('"test_table" FOR "valid_period" AS OF \'2020-01-01\'', str(table))

with self.subTest("with from to criterion"):
@table_class("test_table")
class T(Table):
pass

table = T.for_(T.valid_period.from_to('2020-01-01', '2020-02-01'))

self.assertEqual('"test_table" FOR "valid_period" FROM \'2020-01-01\' TO \'2020-02-01\'', str(table))


class TableEqualityTests(unittest.TestCase):
def test_tables_equal_by_name(self):
@table_class("test_table")
class T1(Table):
pass

@table_class("test_table")
class T2(Table):
pass

self.assertEqual(T1, T2)

def test_tables_equal_by_schema_and_name(self):
@table_class("test_table", schema="a")
class T1(Table):
pass

@table_class("test_table", schema="a")
class T2(Table):
pass

self.assertEqual(T1, T2)

def test_tables_equal_by_schema_and_name_using_schema(self):
a = Schema("a")

@table_class("test_table", schema=a)
class T1(Table):
pass

@table_class("test_table", schema=a)
class T2(Table):
pass

self.assertEqual(T1, T2)

def test_tables_equal_by_schema_and_name_using_schema_with_parent(self):
parent = Schema("parent")
a = Schema("a", parent=parent)

@table_class("test_table", schema=a)
class T1(Table):
pass

@table_class("test_table", schema=a)
class T2(Table):
pass

self.assertEqual(T1, T2)

def test_tables_not_equal_by_schema_and_name_using_schema_with_different_parents(
self,
):
parent = Schema("parent")
a = Schema("a", parent=parent)

@table_class("test_table", schema=a)
class T1(Table):
pass

@table_class("test_table", schema=Schema("a"))
class T2(Table):
pass

self.assertNotEqual(T1, T2)

def test_tables_not_equal_with_different_schemas(self):

@table_class("test_table", schema="a")
class T1(Table):
pass

@table_class("test_table", schema="b")
class T2(Table):
pass

self.assertNotEqual(T1, T2)

def test_tables_not_equal_with_different_names(self):

@table_class("t", schema="a")
class T1(Table):
pass

@table_class("q", schema="a")
class T2(Table):
pass

self.assertNotEqual(T1, T2)


class TableDialectTests(unittest.TestCase):
def test_table_with_default_query_cls(self):
@table_class("test_table")
class T(Table):
pass

q = T.select("1")
self.assertIs(q.dialect, None)

def test_table_with_dialect_query_cls(self):

@table_class("test_table", query_cls=SQLLiteQuery)
class T(Table):
pass

q = T.select("1")
self.assertIs(q.dialect, Dialects.SQLLITE)

def test_table_with_bad_query_cls(self):
with self.assertRaises(TypeError):
@table_class("test_table", query_cls=object)
class T(Table):
pass