diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml
index d4be6053..b72c4ff7 100644
--- a/.github/workflows/publish.yml
+++ b/.github/workflows/publish.yml
@@ -13,18 +13,14 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v2
with:
- python-version: 3.8
- - name: Install and set up Poetry
- run: |
- curl -fsS -o get-poetry.py https://raw.githubusercontent.com/sdispater/poetry/master/get-poetry.py
- python get-poetry.py -y
- source $HOME/.poetry/env
- poetry config virtualenvs.in-project true
+ python-version: 3.9
+ - name: Install and configure Poetry
+ uses: gi0baro/setup-poetry-bin@v1
+ with:
+ virtualenvs-in-project: true
- name: Publish
run: |
- source $HOME/.poetry/env
- poetry config pypi-token.pypi $PUBLISH_TOKEN
poetry build
poetry publish
env:
- PUBLISH_TOKEN: ${{ secrets.PYPI_TOKEN }}
+ POETRY_PYPI_TOKEN_PYPI: ${{ secrets.PYPI_TOKEN }}
diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml
index c2d4eaa2..49f0393e 100644
--- a/.github/workflows/tests.yml
+++ b/.github/workflows/tests.yml
@@ -13,7 +13,16 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
- python-version: [3.7, 3.8, 3.9]
+ python-version: [3.7, 3.8, 3.9, '3.10']
+
+ services:
+ postgres:
+ image: postgis/postgis:12-3.2
+ env:
+ POSTGRES_PASSWORD: postgres
+ POSTGRES_DB: test
+ ports:
+ - 5432:5432
steps:
- uses: actions/checkout@v2
@@ -21,26 +30,24 @@ jobs:
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- - name: Install and set up Poetry
- run: |
- curl -fsS -o get-poetry.py https://raw.githubusercontent.com/sdispater/poetry/master/get-poetry.py
- python get-poetry.py -y
- source $HOME/.poetry/env
- poetry config virtualenvs.in-project true
+ - name: Install and configure Poetry
+ uses: gi0baro/setup-poetry-bin@v1
+ with:
+ virtualenvs-in-project: true
- name: Install dependencies
run: |
- source $HOME/.poetry/env
- poetry install -v
+ poetry install -v --extras crypto
- name: Test
+ env:
+ POSTGRES_URI: postgres:postgres@localhost:5432/test
run: |
- source $HOME/.poetry/env
poetry run pytest -v tests
MacOS:
runs-on: macos-latest
strategy:
matrix:
- python-version: [3.7, 3.8, 3.9]
+ python-version: [3.7, 3.8, 3.9, '3.10']
steps:
- uses: actions/checkout@v2
@@ -48,26 +55,22 @@ jobs:
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- - name: Install and set up Poetry
- run: |
- curl -fsS -o get-poetry.py https://raw.githubusercontent.com/sdispater/poetry/master/get-poetry.py
- python get-poetry.py -y
- source $HOME/.poetry/env
- poetry config virtualenvs.in-project true
+ - name: Install and configure Poetry
+ uses: gi0baro/setup-poetry-bin@v1
+ with:
+ virtualenvs-in-project: true
- name: Install dependencies
run: |
- source $HOME/.poetry/env
- poetry install -v
+ poetry install -v --extras crypto
- name: Test
run: |
- source $HOME/.poetry/env
poetry run pytest -v tests
Windows:
runs-on: windows-latest
strategy:
matrix:
- python-version: [3.7, 3.8, 3.9]
+ python-version: [3.7, 3.8, 3.9, '3.10']
steps:
- uses: actions/checkout@v2
@@ -75,17 +78,15 @@ jobs:
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- - name: Install and setup Poetry
- run: |
- curl -fsS -o get-poetry.py https://raw.githubusercontent.com/sdispater/poetry/master/get-poetry.py
- python get-poetry.py -y
- $env:Path += ";$env:Userprofile\.poetry\bin"
- poetry config virtualenvs.in-project true
+ - name: Install and configure Poetry
+ uses: gi0baro/setup-poetry-bin@v1
+ with:
+ virtualenvs-in-project: true
- name: Install dependencies
+ shell: bash
run: |
- $env:Path += ";$env:Userprofile\.poetry\bin"
- poetry install -v
+ poetry install -v --extras crypto
- name: Test
+ shell: bash
run: |
- $env:Path += ";$env:Userprofile\.poetry\bin"
poetry run pytest -v tests
diff --git a/CHANGES.md b/CHANGES.md
index 38190c92..f2518b89 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -1,6 +1,28 @@
Emmett changelog
================
+Version 2.4
+-----------
+
+Released on January 10th 2022, codename Edison
+
+- Added official Python 3.10 support
+- Added relative path support in templates
+- Added support for spatial columns in ORM
+- Added support for custom/multiple primary keys in ORM
+- Added support for custom/multiple foreign keys in ORM
+- Added support for custom and multiple primary keys relations in ORM
+- Added `watch` parameter to ORM's `compute` decorator
+- Added `save` method to ORM's rows and relevant callbacks
+- Added `destroy` method to ORM's rows and relevant callbacks
+- Added `refresh` method to ORM's rows
+- Added `before_commit` and `after_commit` ORM callbacks
+- Added changes tracking to ORM's rows
+- Added support to call super `rowmethod` decorated methods in ORM models
+- Added `migrations set` command to CLI
+- Added `skip_callbacks` parameter to relevant methods in ORM
+- ORM now automatically adds appropriate indexes for `unique` fields
+
Version 2.3
-----------
@@ -12,8 +34,8 @@ Released on August 12th 2021, codename Da Vinci
- Added `dict` values support for `in` validations
- Use optional `emmett-crypto` package for cryptographic functions
- Deprecated `security.secure_dumps` and `security.secure_loads` in favour of new crypto package
-- Add `on_delete` option to `belongs_to` and `refers_to`
-- Add `--dry-run` option to migrations `up` and `down` commands
+- Added `on_delete` option to `belongs_to` and `refers_to`
+- Added `--dry-run` option to migrations `up` and `down` commands
Version 2.2
-----------
diff --git a/docs/orm/advanced.md b/docs/orm/advanced.md
index 0ed3b104..6b1570ae 100644
--- a/docs/orm/advanced.md
+++ b/docs/orm/advanced.md
@@ -257,3 +257,43 @@ class Post(Model):
```
The `where` option accepts a lambda function that should accept the model as first parameter and should return any valid query using the Emmett query language.
+
+Custom primary keys
+-------------------
+
+*New in version 2.4*
+
+### Customise primary key type
+
+Under default behaviour, all models in Emmett have an integer `id` primary key. In case you need to change the type of this field, just define your own `id` field.
+
+For example, we might want to have a model with UUID strings as primary key:
+
+```python
+from uuid import uuid4
+
+class Ticket(Model):
+ id = Field.string(default=lambda: uuid4().hex)
+```
+
+### Using custom primary keys
+
+Sometimes you need to have models without an `id` field, as a specific field can be used as primary key. Other times, you need to have compound primary keys, where you have multiple fields producing your records identifier.
+
+Under these circumstances, using the `primary_keys` attribute of the `Model` class will be enough:
+
+```python
+class Ticket(Model):
+ primary_keys = ["code"]
+
+ code = Field.string(default=lambda: uuid4().hex)
+
+
+class MultiPK(Model):
+ primary_keys = ["key1", "key2"]
+
+ key1 = Field.int()
+ key2 = Field.int()
+```
+
+> **Note:** Emmett [relations](./relations) system is fully-compatible with custom and multiple primary keys.
diff --git a/docs/orm/callbacks.md b/docs/orm/callbacks.md
index 6872d0ac..b6d6877b 100644
--- a/docs/orm/callbacks.md
+++ b/docs/orm/callbacks.md
@@ -3,7 +3,7 @@ Callbacks
Callbacks are methods that get called when specific database operations are performed on your data.
-When you need to perform actions on one of these specific conditions, Emmett helps you with six different callbacks decorators that you can use inside your models, corresponding to the moments before and after a database insert, update and delete operations. The methods you decorate using these helpers will be invoked automatically when the database operation is performed.
+When you need to perform actions on one of these specific conditions, Emmett helps you with several different callbacks decorators that you can use inside your models, corresponding to different moments before and after certain database operations. The methods you decorate using these helpers will be invoked automatically when the database operation is performed.
All the callbacks method should return `None` or `False` (not returning anything in python is the same of returning `None`) otherwise returning `True` will abort the current operation.
@@ -12,7 +12,7 @@ Let's see these decorators in detail.
before\_insert
--------------
-The `before_insert` decorator is called just before the insertion of a new record will be performed. The methods decorated with this helper should accept just one parameter that will be the dictionary mapping the fields and the values to be inserted in the table.
+The `before_insert` callback is called just before the insertion of a new record will be performed. The methods decorated with this helper should accept just one parameter that will be the dictionary mapping the fields and the values to be inserted in the table.
Here is a quick example:
```python
@@ -36,7 +36,7 @@ Now, if you insert a new record, you will see the printed values:
after\_insert
-------------
-The `after_insert` decorator is called just after the insertion of a new record happened. The methods decorated with this helper should accept the dictionary mapping the fields and the values that had been used for the insertion as the first parameter, and the id of the newly created record as the second one.
+The `after_insert` callback is called just after the insertion of a new record happened. The methods decorated with this helper should accept the dictionary mapping the fields and the values that had been used for the insertion as the first parameter, and the id of the newly created record as the second one.
Here is a quick example:
```python
@@ -77,7 +77,7 @@ class Profile(Model):
before\_update
--------------
-As the `before_insert` callbacks is called just before a record insertion, the `before_update` one is called just before a set of records is updated. The methods decorated with this helper should accept the database set on which the update operation will be performed, and the dictionary mapping the fields and the values to use for the update as the second one.
+As the `before_insert` callback gets called just before a record insertion, the `before_update` one is called just before a set of records is updated. The methods decorated with this helper should accept the database set on which the update operation will be performed, and the dictionary mapping the fields and the values to use for the update as the second one.
Here is a quick example:
```python
@@ -103,7 +103,7 @@ Notice that, since the first parameter is a database set, you can have more than
after\_update
-------------
-The `after_update` decorator is called just after the update of the set of records has happened. As for the `before_update` decorator, the methods decorated with this helper should accept the database set on which the update operation was performed as the first parameter, and the dictionary mapping the fields and the values used for the update as the second one.
+The `after_update` callback is called just after the update of the set of records has happened. As for the `before_update` decorator, the methods decorated with this helper should accept the database set on which the update operation was performed as the first parameter, and the dictionary mapping the fields and the values used for the update as the second one.
Here is a quick example:
```python
@@ -127,7 +127,7 @@ Now, if you update a set of records, you will see the printed values:
before\_delete
--------------
-The `before_delete` decorator is called just before the deletion of a set of records will be performed. The methods decorated with this helper should accept just one parameter that will be the database set on which the delete operation should be performed.
+The `before_delete` callback is called just before the deletion of a set of records will be performed. The methods decorated with this helper should accept just one parameter that will be the database set on which the delete operation should be performed.
Here is a quick example:
```python
@@ -151,7 +151,7 @@ Now, if you delete a set of records, you will see the printed values:
after\_delete
-------------
-The `after_delete` decorator is called just after the deletion of a set of records has happened. As for the `before_delete` decorator, the methods decorated with this helper should accept just one parameter that will be the database set on which the delete operation was performed.
+The `after_delete` callback is called just after the deletion of a set of records has happened. As for the `before_delete` decorator, the methods decorated with this helper should accept just one parameter that will be the database set on which the delete operation was performed.
Here is a quick example:
```python
@@ -174,38 +174,216 @@ Now, if you delete a set of records, you will see the printed values:
Notice that in the `after_delete` callbacks you will have the database set parameter, but the records corresponding to the query have been just deleted and won't be accessible anymore.
-Skip update callbacks
----------------------
+before\_save
+------------
-Sometimes you would need to skip the invocation of the update callbacks, for example when you want to mutually *touch* related entities during the update of one of the sides of the relation. In these cases, you can use the `update_naive` method on the database sets that won't trigger the callbacks invocation.
+*New in version 2.4*
+
+The `before_save` callback is invoked just before the execution of the `save` operation from the relevant record. The methods decorated with this helper should accept just one parameter that will be the record getting saved.
+Here is a quick example:
+
+```python
+class Product(Model):
+ name = Field.string()
+ price = Field.float(default=0.0)
+
+
+class CartElement(Model):
+ belongs_to("product")
+
+ quantity = Field.int(default=1)
+ price_denorm = Field.float(default=0.0)
+
+ @before_save
+ def _rebuild_price(self, row):
+ row.price_denorm = row.quantity * row.product.price
+```
+
+> **Note**: `save` triggers both `before_save` and the relevant insert or update callbacks. During the operation `before_save` will be invoked before the `before_insert` or `before_update` callbacks.
+
+after\_save
+-----------
+
+*New in version 2.4*
+
+The `after_save` callback is invoked just after the execution of the `save` operation from the relevant record. The methods decorated with this helper should accept just one parameter that will be the saved record.
+Here is a quick example:
+
+```python
+class User(Model):
+ email = Field()
+
+ @after_save
+ def _send_welcome_email(self, row):
+ # if is a new user, send a welcome email
+ if row.has_changed_value("id"):
+ send_welcome_email(row.email)
+```
+
+> **Note**: `save` triggers both `after_save` and the relevant insert or update callbacks. During the operation `after_save` will be invoked after the `after_insert` or `after_update` callbacks.
+
+before\_destroy
+---------------
+
+*New in version 2.4*
+
+The `before_destroy` callback is invoked just before the execution of the `destroy` operation from the relevant record. The methods decorated with this helper should accept just one parameter that will be the record getting destroyed.
+Here is a quick example:
+
+```python
+class Product(Model):
+ name = Field.string()
+ price = Field.float(default=0.0)
+
+
+class CartElement(Model):
+ belongs_to("product")
+
+ quantity = Field.int(default=1)
+ price_denorm = Field.float(default=0.0)
+
+ @before_destroy
+ def _clear_element(self, row):
+ row.quantity = 0
+ row.price_denorm = 0
+```
+
+> **Note**: `destroy` triggers both `before_destroy` and `before_delete` callbacks. During the operation `before_destroy` will be invoked before the `before_delete` callback.
+
+after\_destroy
+--------------
+
+*New in version 2.4*
+
+The `after_destroy` callback is invoked just after the execution of the `destroy` operation from the relevant record. The methods decorated with this helper should accept just one parameter that will be the destroyed record.
+Here is a quick example:
+
+```python
+class Cart(Model):
+ has_many({"elements": "CartElement"})
+ updated_at = Field.datetime(default=now, update=now)
+
+
+class CartElement(Model):
+ belongs_to("cart")
+
+ @after_destroy
+ def _update_cart(self, row):
+ row.cart.save()
+```
+
+> **Note**: `destroy` triggers both `after_destroy` and `after_delete` callbacks. During the operation `after_destroy` will be invoked after the `after_delete` callback.
+
+before\_commit and after\_commit
+--------------------------------
+
+*New in version 2.4*
+
+Emmett also provides callbacks to watch `commit` events on [transactions](./connecting#transactions). Due to their nature, these callbacks behave differently from the other ones, and thus we need to make some observations:
+
+- code encapsuled in these callbacks **should not make any database operation**, as it might breaks the current transaction stack
+- these callbacks will be invoked in bulk once the trasaction is getting committed, thus the callback for the operation and the commit one won't probably be called one after another, and the commit one will receive all the operations happened during the transaction itself, not just the last one
+
+> **Note:** commit callbacks get triggered only on the top transaction, not in the nested ones (savepoints).
+
+The methods decorated with these helpers should accept two parameters: the operation type and the operation context:
+
+```python
+@after_commit
+def my_method(self, op_type, ctx):
+ ...
+```
+
+The operation type is one of the values provided by the `TransactionOps` enum, and will be one of the following:
+
+- insert
+- update
+- delete
+- save
+- destroy
+
+Now, since `before_commit` and `after_commit`, as we saw, catch all the operations happening on the relevant model, these methods offers additional filtering in order to watch only the relevant events. In order to listen only particular operations, you can use the `TransactionOps` enum in combination with the `operation` method:
+
+```python
+from emmett.orm import TransactionOps
+
+@after_commit.operation(TransactionOps.insert)
+def my_method(self, ctx):
+ ...
+```
+
+as you can see, filtered operation callbacks won't need the operation type parameter.
+
+The operation context is represented by an object with the following attributes:
+
+| attribute | description |
+| --- | --- |
+| values | fields and values involved |
+| return\_value | return value of the operation |
+| dbset | query set involved (for update and delete operations) |
+| row | row involved (for save and destroy operations) |
+| changes | row changes occurred (for save and destroy operations) |
+
+Now, let's see all of this with some examples.
+
+We might want to send a welcome email to a newly registered user, and we want to be sure the operation commited:
+
+```python
+class User(Model):
+ email = Field()
+
+ @after_commit.operation(TransactionOps.insert)
+ def _send_welcome_email(self, ctx):
+ my_queue_system.send_welcome_email(ctx.return_value)
+```
+
+or we might track activities over the records:
+
+```python
+class Todo(Model):
+ belongs_to("owner")
+
+ description = Field.text()
+ completed_at = Field.datetime()
+
+ @after_commit.operation(TransactionOps.save)
+ def _store_save_activity(self, ctx):
+ activity_type = "creation" if "id" in ctx.changes else "edit"
+ my_queue_system.store_activity(activity_type, ctx.row, ctx.changes)
+
+ @after_commit.operation(TransactionOps.destroy)
+ def _store_save_activity(self, ctx):
+ my_queue_system.store_activity("deletion", ctx.row, ctx.changes)
+```
+
+Skip callbacks
+--------------
+
+*Changed in version 2.4*
+
+Sometimes you would need to skip the invocation of callbacks, for example when you want to mutually *touch* related entities during the update of one of the sides of the relation. In these cases, you can use the `skip_callbacks` parameter in the method you're calling.
Let's see this with an example:
```python
class User(Model):
+ has_one('profile')
+
email = Field()
changed_at = Field.datetime()
- has_one('profile')
- @after_update
- def touch_profile(self, dbset, fields):
- row = dbset.select().first()
- self.db(
- db.Profile.user == row.id
- ).update_naive(
- changed_at=row.changed_at
- )
+ @after_save
+ def touch_profile(self, row):
+ profile = row.profile()
+ profile.changed_at = row.changed_at
+ profile.save(skip_callbacks=True)
class Profile(Model):
belongs_to('user')
language = Field()
changed_at = Field.datetime()
- @after_update
- def touch_user(self, dbset, fields):
- row = dbset.select().first()
- self.db(
- db.User.id == row.user
- ).update_naive(
- changed_at=row.changed_at
- )
+ @after_save
+ def touch_user(self, row):
+ row.user.changed_at = row.changed_at
+ row.user.save(skip_callbacks=True)
```
diff --git a/docs/orm/migrations.md b/docs/orm/migrations.md
index 97413179..63f8cd11 100644
--- a/docs/orm/migrations.md
+++ b/docs/orm/migrations.md
@@ -428,6 +428,35 @@ db1 = Database(app, app.config.db1)
db2 = Database(app, app.config.db2)
```
+Set migration status manually
+-----------------------------
+
+*New in version 2.4*
+
+Sometimes you need to manually set the current revision schema in your database, without actually applying the involved migration(s). Some examples might be:
+
+- you handled migrations using a different system in the past and want to start using Emmett
+- you want to rewrite migrations, or *condense* several of them into a single one
+
+In such cases, you can use the `set` command of the migrations engine:
+
+```
+$ emmett migrations status
+> Current revision(s) for sqlite://dummy.db
+8422706ae767
+
+$ emmett migrations history
+> Migrations history:
+8422706ae767 -> 69a284b840cf (head), Generated migration
+9d6518b3cdc2 -> 8422706ae767, Generated migration
+ -> 9d6518b3cdc2, First migration
+
+$ emmett migrations set -r 69a284b840cf
+> Setting revision to 69a284b840cf against sqlite://dummy.db
+Do you want to continue? [y/N]: y
+> Updating schema revision from 8422706ae767 to 69a284b840cf
+> Succesfully set revision to 69a284b840cf: Generated migration
+```
DBMS support
------------
diff --git a/docs/orm/models.md b/docs/orm/models.md
index 6099e9c5..7fd269e3 100644
--- a/docs/orm/models.md
+++ b/docs/orm/models.md
@@ -41,8 +41,8 @@ db.posts
> **Note:**
> Accessing `Model` refers to the model itself, while `db.Model` refers to the table instance you created with your model. While these two classes shares the fields of your models, so accessing `Model.fieldname` and `db.Model.fieldname` or `db.tablename.fieldname` will produce the same result, they have different properties and methods, and you should remember this difference.
-
### Tables naming
+
Under default behavior, Emmett will create the table using the name of the class and making it plural, so that the class `Post` will create the table *posts*, `Comment` will create table *comments* and so on.
If you want to customize the name of the table, you can use the `tablename` attribute inside your model:
@@ -59,6 +59,7 @@ just ensure the name is valid for the DBMS you're using.
Fields
------
+
`Field` objects define your entity's properties, and will map the appropriate columns inside your tables, so in general you would write the name of the property and its type:
```python
@@ -75,25 +76,28 @@ Available type methods for Field definition are:
| bool | `bool` |
| int | `int` |
| float | `float` |
-| decimal(n,m) | `decimal.Decimal` |
+| decimal(precision,scale) | `decimal.Decimal` |
| date | `datetime.date` |
| time | `datetime.time` |
| datetime | `datetime.datetime` |
| password | `str` |
| upload | `str` |
-| list:string | `list` of `str` |
-| list:int | `list` of `int` |
-| json | `dict` or `list` |
-| jsonb | `dict` or `list` |
+| int\_list | `List[str]` |
+| string\_list | `List[int]` |
+| json | `Union[Dict,List]` |
+| jsonb | `Union[Dict,List]` |
+| geography(type,srid,dimension) | `str` |
+| geometry(type,srid,dimension) | `str` |
If you don't specify a type for the `Field` class, and create an instance directly, it will be set as *string* as default value.
Using the right field type ensure the right columns types inside your tables, and allows you to benefit from the default validation implemented by Emmett.
-> **Note:** some fields' types are engine specific, for instance `jsonb` field is valid only with PostgreSQL engine.
+> **Note:** some fields' types are engine specific, for instance `jsonb` field is valid only with PostgreSQL engine, and `geometry` and `geography` fields require spatial APIs.
Validation
----------
+
To implement a validation mechanism for your fields, you can use the `validation` parameter of the `Field` class, or the mapping `dict` with the name of the fields at the `validation` attribute inside your Model. Both method will produce the same result, just pick the one you prefer:
```python
@@ -131,6 +135,7 @@ While you can find the complete list of available validators in the [appropriate
> `{'allow': 'blank'}` or `{'allow': 'empty'}`
### Disable default validation
+
Sometimes you may want to disable the default validation implemented by Emmett. Depending on your needs, you have two different ways.
When you need to disable the default validation on a single `Field`, you can use the `auto_validation` parameter:
@@ -147,6 +152,7 @@ class MyModel(Model):
Default values
--------------
+
Emmett models have a `default_values` attribute that helps you to set the default value for the field on record insertions:
```python
@@ -165,6 +171,7 @@ The values defined in this way will be used on the insertion of new records in t
Update values
-------------
+
As for the `default_values` attribute we've seen before, `update_values` helps you to set the default value for the field on record updates:
```python
@@ -239,6 +246,7 @@ Emmett supports some advanced options on defining indexes, see the [advanced cha
Values representation
---------------------
+
Sometimes you need to give a better representation for the value of your entity, for example rendering dates or shows only a portion of a text field. In these cases, the `repr_values` attribute of your models will help:
```python
@@ -261,9 +269,11 @@ started = Field.datetime(representation=lambda row, value: prettydate(value))
Forms helpers
-------------
+
The `Model` attributes listed in this section are intended to be used for forms generation.
### Form labels
+
Labels are useful to produce good titles for your fields in forms:
```python
@@ -281,6 +291,7 @@ started = Field.datetime(label=T("Opening date:"))
```
### Form info
+
As for the labels, `form_info` attribute is useful to produce hints or helping blocks for your fields in forms:
```python
@@ -296,6 +307,7 @@ started = Field.datetime(info=T("some description here"))
```
### Widgets
+
Widgets are used to produce the relevant input part in the form produced from your model. Every `Field` object has a default widget depending on the type you defined, for example the *datetime* has an `` html tag of type *text*. When you need to customize the look of your input blocks in the form, you can use your own widgets and pass them to the model with the appropriate attribute:
```python
@@ -330,6 +342,7 @@ form_widgets = {
The setup helper
----------------
+
Sometimes you need to access your model attributes when defining other features, but, until now, we couldn't access the class or the instance itself. To avoid this problem, you can use the `setup` method of the model:
```python
@@ -340,9 +353,9 @@ def setup(self):
field = self.table.fieldname
```
-
Model methods
-------------
+
You can also define methods that will be available on the Model class itself. For instance, every Emmett model comes with some pre-defined methods, for example:
```python
@@ -350,12 +363,14 @@ MyModel.form()
```
will create the form for the entity defined in your model.
-Other methods pre-defined in Emmett are:
+Here is the list of all pre-defined methods in Emmett:
| method | description |
| --- | --- |
-| validate | validates the values passed as parameters (field=value) and return an `sdict` of errors (that would be empty if the validation passed) |
-| create | insert a new record with the values passed (field=value) if they pass the validation |
+| new | returns a new record instance with specified parameters (field=value) |
+| create | insert a new record with specified parameters (field=value) if validation succeed |
+| validate | validates the specified parameters (field=value) and returns a `sdict` of errors (that would be empty if the validation passed) |
+| form | returns a new form instance for the current model |
But how can you define additional methods?
Let's say, for example that you want a shortcut in your `Notification` model to set all the records to be *read* for a specific user, without writing down the query manually every time:
@@ -372,6 +387,7 @@ class Notification(Model):
lambda n: n.user == user
).update(read=True)
```
+
now you can easily set user's notification as read:
```python
diff --git a/docs/orm/operations.md b/docs/orm/operations.md
index edc8b89d..5e4c941f 100644
--- a/docs/orm/operations.md
+++ b/docs/orm/operations.md
@@ -75,6 +75,48 @@ In fact, you can access the attributes of the record you just created:
We will see more about the `as_dict` method in the next paragraphs.
+### Spatial fields helpers
+
+Emmett provides some helpers on GIS fields (`geography` and `geometry` types) in order to simplify the workflow regarding these values.
+
+Whenever you have a GIS column:
+
+```python
+class City(Model):
+ name = Field.string()
+ location = Field.geography("POINT")
+```
+
+you can use the provided helpers from `emmett.orm.geo` module to produce fields' values:
+
+```python
+from emmett.orm import geo
+
+rv = City.create(
+ name="Hill Valley",
+ location=geo.Point(44, 12)
+)
+```
+
+Also, GIS fields values are sub-class of `str`, but they provides some additional attributes:
+
+```python
+>>> rv.id.location
+'POINT(44 12)'
+>>> rv.id.location.geometry
+'POINT'
+>>> rv.id.location.coordinates
+(44.0, 12.0)
+```
+
+On geometries representing collections, you also have the `groups` attribute:
+
+```python
+>>> mp = geo.MultiPoint((1, 1), (2, 2))
+>>> mp.groups
+('POINT(1.000000 1.000000)', 'POINT(2.000000 2.000000)')
+```
+
Making queries
--------------
@@ -274,6 +316,27 @@ db(Event.happens_at.year() == 1985)
### Engine specific operators
+#### GIS operators
+
+Emmett provides additional query operators specific to spatial extensions. Engines providing this kind of APIs can be Spatialite or PostGIS. The following table describes Emmett's ORM methods:
+
+| operator | description |
+| --- | --- |
+| st\_asgeojson | returns a geometry as a GeoJSON element |
+| st\_astext | returns WKT representation of the geometry/geography |
+| st\_x | returns the X coordinate of a Point |
+| st\_y | returns the Y coordinate of a Point |
+| st\_distance | returns the distance between two geometry/geography values |
+| st\_simplify | returns a simplified version of a geometry (Douglas-Peucker) |
+| st\_simplifypreservetopology | returns a simplified and valid version of a geometry (Douglas-Peucker) |
+| st\_contains | returns true if no points of B lie in the exterior of A |
+| st\_equals | returns true if two geometries include the same set of points |
+| st\_intersects | returns true if two geometries intersect |
+| st\_overlaps | returns true if two geometries intersect and have the same dimension |
+| st\_touches | returns true true if two geometries have at least one point in common, but their interiors do not intersect |
+| st\_within | returns true if no points of A lie in the exterior of B |
+| st\_dwithin | returns true if two geometries are within a given distance |
+
#### PostgreSQL json operators
Emmett provides additional query operators specific to PostgreSQL engine. The following table describes the mapping between Emmett's ORM methods and the relevant PostgreSQL json/jsonb operators:
@@ -641,3 +704,149 @@ Here are two examples:
As you can see both of these methods return the number of record removed.
> **Note:** just like the `update_record`, the `delete_record` method requires you to select the `id` field in the rows.
+
+Using model objects
+-------------------
+
+Emmett also provides the ability to work directly with records, in addition to models' operations.
+
+We will use the same `Event` model we presented in the above sections for the examples:
+
+```python
+class Event(Model):
+ name = Field(notnull=True)
+ location = Field(notnull=True)
+ participants = Field.int(default=0)
+ happens_at = Field.datetime()
+```
+
+Let's see all the available methods and steps in details.
+
+### Model new method
+
+Every `Model` in Emmett provides a `new` method, which produces a clean record:
+
+```python
+event = Event.new(
+ name="Lightning",
+ location="Hill Valley"
+)
+event.happens_at = datetime(
+ 1955, 11, 12,
+ 22, 4, 0
+)
+```
+
+Records produced from the `new` method won't have primary key(s), and will have valued fields with defaults and passed parameters.
+
+### Record save method
+
+*New in version 2.4*
+
+Records produced with `Model.new` or selected with all the model fields will have a `save` method.
+
+The `save` methods performs an `insert` or an `update` accordingly to the record contents:
+
+```python
+# save() will produce an insert
+event = Event.new()
+event.save()
+# save() will produce an update
+event = Event.first()
+event.location = "New York"
+event.save()
+```
+
+> **Note:** differently from `update` or `update_record` methods, where you specify which fields should be updated, the `save` method will overwrite all the fields with the current record contents
+
+The `save` method will return a boolean representing the operation fulfillment, unless you call `save(raise_on_error=True)` which will produce an exception.
+
+> **Note:** `save` will trigger both save callbacks and insert/update ones
+
+### Record destroy method
+
+*New in version 2.4*
+
+Records selected with all the model fields will have a `destroy` method.
+
+The `destroy` methods performs a `delete` operation using the records' primary key(s):
+
+```python
+event = Event.first()
+event.destroy()
+```
+
+The `destroy` method will return a boolean representing the operation fulfillment, unless you call `destroy(raise_on_error=True)` which will produce an exception.
+
+> **Note:** `destroy` will trigger both destroy callbacks and delete ones
+
+### Record refresh method
+
+*New in version 2.4*
+
+Records selected with all the model fields will have a `refresh` method.
+
+The `refresh` methods performs a new selection of the record from the database and update the current object accordingly.
+
+```python
+event = Event.first()
+event.refresh()
+```
+
+The `refresh` method will return a boolean representing the operation fulfillment.
+
+### Record changes
+
+*New in version 2.4*
+
+Records produced with `Model.new` or selected with all the model fields will track changes to their attributes between saves.
+
+Here we list attributes and methods provided by Emmett to deal with row changes:
+
+| name | type | description |
+| --- | --- | --- |
+| has\_changed | attribute | boolean which states if record has changed |
+| has\_changed\_value | method | returns a boolean wich states if the specified attribute has changed |
+| get\_value\_change | method | returns a tuple containing the original and new values for the specified attribute or `None` |
+| changes | attribute | returns a `sdict` with all the changed attributes and their values |
+
+And here is an example:
+
+```python
+>>> event = Event.first()
+>>> event.location = "New York"
+>>> event.has_changed
+True
+>>> event.has_changed_value("location")
+True
+>>> event.get_value_change("location")
+('Hill Valley', 'New York')
+>>> event.changes
+
+```
+
+### Record validations
+
+*New in version 2.4*
+
+Records produced with `Model.new` or selected with all the model fields will provide helpers for validation.
+
+Here we list attributes and methods provided by Emmett:
+
+| name | type | description |
+| --- | --- | --- |
+| is\_valid | attribute | boolean which states if record passes validations |
+| validation\_errors | attribute | returns a `sdict` with all the validation errors |
+
+And here is an example:
+
+```python
+>>> event = Event.new(name="Lightning", location="Hill Valley")
+>>> event.is_valid
+False
+>>> event.validation_errors
+
+>>> event.happens_at = datetime(1955, 11, 12, 22, 4, 0)
+>>> event.is_valid
+True
+```
diff --git a/docs/orm/virtuals.md b/docs/orm/virtuals.md
index 49eec8c9..dd965a97 100644
--- a/docs/orm/virtuals.md
+++ b/docs/orm/virtuals.md
@@ -7,7 +7,8 @@ Emmett provides different apis that can help you in these cases: let's see them
Computed fields
---------------
-*Changed in version 0.7*
+
+*Changed in version 2.4*
Sometimes you need some field values to be *computed* using other fields' values. Let's say, for example, that you have a table of items where you store the quantity and price for each of them. You often need the total value of the items you have in your store, and you don't want to compute this value every time in your application code.
@@ -21,19 +22,44 @@ class Item(Model):
quantity = Field.int()
total = Field.float()
- @compute('total')
+ @compute('total', watch=['price', 'quantity'])
def compute_total(self, fields):
return fields.price * fields.quantity
```
-As you can see, the `compute` decorator needs and accepts just one parameter: the name of the field where to store the result of the computation.
+As you can see, the `compute` decorator accepts the name of the field where to store the result of the computation and and optional `watch` list of fields.
The function that performs the computation has to accept the operation fields as its first parameter, and it will be called both on inserts and updates.
> **Note:** `compute` decorated methods receives **only the fields' values involved in the operation**. This means that fields will contain only the values passed by the insert/update operation and the relative default values.
+### Operation fields and watch parameter
+
+Since computations will be triggered on every update operation that might happen on your model, and thus such operation might involve several records, you might end up in conditions where the operation doesn't include all the fields required for the computation. For example issuing this update instruction on the upper model:
+
+```python
+Item.where(lambda i: i.quantity == 1).update(quantity=2)
+```
+
+would make impossible to re-compute the `total` value for the involved records.
+
+The `watch` parameter is designed to avoid these conditions, since – under default behaviour – in an insert or update operation involving computations Emmett will:
+
+- execute all the computations without `watch` fields, ignoring the ones failing
+- execute all the computations with where the `watch` fields presence is completely satisfied
+- raise an exception, preventing the operation to continue, for those computations where the `watch` fields presence is not completely satisfied
+
+Considering our upper example:
+
+- the `total` computation will be executed for all the operations including both `price` and `quantity` fields
+- an operation with only one of the `price` or `quantity` fields cannot be executed
+- operations not involving `price` or `quantity` fields won't trigger the computation
+
+> **Note:** to handle complex cases where you need to access the single record fields we suggest to use records' `save` method and relevant callbacks.
+
Virtual attributes
------------------
+
*Changed in version 1.0*
Virtual attributes are values returned by functions that will be injected to the involved rows every time you select them.
@@ -67,6 +93,7 @@ You can access the values as the common fields:
Virtual methods
---------------
+
*Changed in version 1.0*
Similarly to virtual attributes, these methods are helpers injected to the rows when you select them. Differently from virtual attributes, however, they will be methods indeed, and you should invoke them to access the value you're looking for.
diff --git a/docs/testing.md b/docs/testing.md
index 8d89e063..26b25962 100644
--- a/docs/testing.md
+++ b/docs/testing.md
@@ -1,5 +1,6 @@
Testing Emmett applications
===========================
+
*New in version 0.6*
> Untested code is broken code
@@ -144,3 +145,30 @@ and you also need to remember to commit or rollback changes:
db.commit()
db.rollback()
```
+
+Using migrations in tests
+-------------------------
+
+*New in version 2.0*
+
+Whenever you want to test databases interactions within your applications, the relative [migrations](./orm/migrations) should be performed before your tests.
+
+Emmett approach on this is quite simple: it provides utilities to generate and apply a single composed runtime migration during tests. Let's see how it works with a *pytest* example:
+
+```python
+import pytest
+
+from emmett.orm.migrations.utils import generate_runtime_migration
+
+from myapp import db as _db
+
+
+@pytest.fixture(scope='function')
+def db():
+ migration = generate_runtime_migration(_db)
+ migration.up()
+ yield _db
+ migration.down()
+```
+
+As you can see, we called the `generate_runtime_migration` method with our application database instance, applied the generated migration before yielding the database instace, and reverted the migration immediately after. Every test function we'll write using this fixture, will have a migrated database to test.
diff --git a/docs/upgrading.md b/docs/upgrading.md
index d8142cec..c0134a06 100644
--- a/docs/upgrading.md
+++ b/docs/upgrading.md
@@ -13,6 +13,28 @@ Just as a remind, you can update Emmett using *pip*:
$ pip install -U emmett
```
+Version 2.4
+-----------
+
+Emmett 2.4 release is highly focused on the included ORM.
+
+This release doesn't introduce any deprecation or breaking change, but some new features you might be interested into.
+
+### New features
+
+- The ability to use relative paths in [templates](./templates) `extend` and `include` blocks
+- GIS/spatial [fields](./orm/models#fields) and [operators]((./orm/operations#engine-specific-operators)) support in ORM
+- [Primary keys customisation](./orm/advanced#custom-primary-keys) support in ORM
+- A `watch` parameter to ORM [compute decorator](./orm/virtuals#computed-fields)
+- A [save method](./orm/operations#record-save-method) to rows and [relevant callbacks](./orm/callbacks#before_save) in ORM
+- A [destroy method](./orm/operations#record-destroy-method) to rows and [relevant callbacks](./orm/callbacks#before_destroy) in ORM
+- [Commit callbacks](./orm/callbacks#before_commit-and_after_commit) in ORM
+- [Changes tracking](./orm/operations#record-changes) to rows in ORM
+- The `set` command to [migrations](./orm/migrations) CLI
+- The ability to [skip callbacks](./orm/callbacks#skip-callbacks) in ORM
+
+Emmett 2.4 also introduces support for Python 3.10.
+
Version 2.3
-----------
diff --git a/emmett/__version__.py b/emmett/__version__.py
index ef6497d0..3d67cd6b 100644
--- a/emmett/__version__.py
+++ b/emmett/__version__.py
@@ -1 +1 @@
-__version__ = "2.3.2"
+__version__ = "2.4.0"
diff --git a/emmett/asgi/handlers.py b/emmett/asgi/handlers.py
index 996ffd6c..6a3373b7 100644
--- a/emmett/asgi/handlers.py
+++ b/emmett/asgi/handlers.py
@@ -14,13 +14,18 @@
import asyncio
import os
import re
+import time
from collections import OrderedDict
+from email.utils import formatdate
+from hashlib import md5
+from importlib import resources
from typing import Any, Awaitable, Callable, Optional, Tuple, Union
from ..ctx import RequestContext, WSContext, current
from ..debug import smart_traceback, debug_handler
-from ..http import HTTPResponse, HTTPFile, HTTP
+from ..http import HTTPBytes, HTTPResponse, HTTPFile, HTTP
+from ..libs.contenttype import contenttype
from ..utils import cachedprop
from ..wrappers.helpers import RequestCancelled
from ..wrappers.request import Request
@@ -160,6 +165,10 @@ class HTTPHandler(RequestHandler):
def _bind_router(self):
self.router = self.app._router_http
+ self._internal_assets_md = (
+ str(int(time.time())),
+ formatdate(time.time(), usegmt=True)
+ )
def _configure_methods(self):
self.static_matcher = (
@@ -253,6 +262,21 @@ def _static_nolang_matcher(
async def _static_response(self, file_path: str) -> HTTPFile:
return HTTPFile(file_path)
+ async def _static_content(self, content: bytes, content_type: str) -> HTTPBytes:
+ content_len = str(len(content))
+ return HTTPBytes(
+ 200,
+ content,
+ headers={
+ 'content-type': content_type,
+ 'content-length': content_len,
+ 'last-modified': self._internal_assets_md[1],
+ 'etag': md5(
+ f"{self._internal_assets_md[0]}_{content_len}".encode("utf8")
+ ).hexdigest()
+ }
+ )
+
def _static_handler(
self,
scope: Scope,
@@ -263,11 +287,19 @@ def _static_handler(
#: handle internal assets
if path.startswith('/__emmett__'):
file_name = path[12:]
- static_file = os.path.join(
- os.path.dirname(__file__), '..', 'assets', file_name)
- if os.path.splitext(static_file)[1] == 'html':
+ if file_name.endswith(".html"):
return self._http_response(404)
- return self._static_response(static_file)
+ pkg = None
+ if '/' in file_name:
+ pkg, file_name = file_name.split('/', 1)
+ try:
+ file_contents = resources.read_binary(
+ f'emmett.assets.{pkg}' if pkg else 'emmett.assets',
+ file_name
+ )
+ except FileNotFoundError:
+ return self._http_response(404)
+ return self._static_content(file_contents, contenttype(file_name))
#: handle app assets
static_file, _ = self.static_matcher(path)
if static_file:
diff --git a/emmett/assets/__init__.py b/emmett/assets/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/emmett/assets/debug/__init__.py b/emmett/assets/debug/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/emmett/cli.py b/emmett/cli.py
index 08d6855a..c85c2a78 100644
--- a/emmett/cli.py
+++ b/emmett/cli.py
@@ -428,14 +428,13 @@ def set_db_value(ctx, param, value):
@cli.group('migrations', short_help='Runs migration operations.')
@click.option(
- '--db', help='The db instance to use', callback=set_db_value,
- is_eager=True)
+ '--db', help='The db instance to use', callback=set_db_value, is_eager=True
+)
def migrations_cli(db):
pass
-@migrations_cli.command(
- 'status', short_help='Shows current database revision.')
+@migrations_cli.command('status', short_help='Shows current database revision.')
@click.option('--verbose', '-v', default=False, is_flag=True)
@pass_script_info
def migrations_status(info, verbose):
@@ -457,9 +456,12 @@ def migrations_history(info, range, verbose):
@migrations_cli.command(
- 'generate', short_help='Generate a new migration from application models.')
-@click.option('--message', '-m', default='Generated migration',
- help='The description for the new migration.')
+ 'generate', short_help='Generates a new migration from application models.'
+)
+@click.option(
+ '--message', '-m', default='Generated migration',
+ help='The description for the new migration.'
+)
@click.option('-head', default='head', help='The migration to generate from')
@pass_script_info
def migrations_generate(info, message, head):
@@ -469,9 +471,11 @@ def migrations_generate(info, message, head):
generate(app, dbs, message, head)
-@migrations_cli.command('new', short_help='Generate a new empty migration.')
-@click.option('--message', '-m', default='New migration',
- help='The description for the new migration.')
+@migrations_cli.command('new', short_help='Generates a new empty migration.')
+@click.option(
+ '--message', '-m', default='New migration',
+ help='The description for the new migration.'
+)
@click.option('-head', default='head', help='The migration to generate from')
@pass_script_info
def migrations_new(info, message, head):
@@ -517,6 +521,24 @@ def migrations_down(info, revision, dry_run):
down(app, dbs, revision, dry_run)
+@migrations_cli.command(
+ 'set', short_help='Overrides database revision with selected migration.'
+)
+@click.option('--revision', '-r', default='head', help='The migration to set.')
+@click.option(
+ '--auto-confirm',
+ default=False,
+ is_flag=True,
+ help='Skip asking confirmation.'
+)
+@pass_script_info
+def migrations_set(info, revision, auto_confirm):
+ from .orm.migrations.commands import set_revision
+ app = info.load_app()
+ dbs = info.load_db()
+ set_revision(app, dbs, revision, auto_confirm)
+
+
def main(as_module=False):
cli.main(prog_name="python -m emmett" if as_module else None)
diff --git a/emmett/forms.py b/emmett/forms.py
index 0d0ecf10..fe13e067 100644
--- a/emmett/forms.py
+++ b/emmett/forms.py
@@ -150,17 +150,19 @@ async def _load_input_files(self):
rv = sdict()
return rv
- async def _process(self):
+ def _validate_input(self):
+ for field in self.writable_fields:
+ value = self._get_input_val(field)
+ self._validate_value(field, value)
+
+ async def _process(self, write_defaults=True):
self._load_csrf()
self.input_params = await self._load_input_params()
self.input_files = await self._load_input_files()
# run processing if needed
if self._submitted:
self.processed = True
- # validate input
- for field in self.writable_fields:
- value = self._get_input_val(field)
- self._validate_value(field, value)
+ self._validate_input()
# custom validation
if not self.errors and callable(self.onvalidation):
self.onvalidation(self)
@@ -173,11 +175,16 @@ async def _process(self):
if self.csrf and not self.accepted:
self.formkey = current.session._csrf.gen_token()
# reset default values in form
- if not self.processed or (self.accepted and not self.keepvalues):
+ if (
+ write_defaults and (
+ not self.processed or (self.accepted and not self.keepvalues)
+ )
+ ):
for field in self.fields:
- default_value = field.default() if callable(field.default) \
- else field.default
- self.input_params[field.name] = default_value
+ self.input_params[field.name] = (
+ field.default() if callable(field.default) else
+ field.default
+ )
return self
def _render(self):
@@ -292,9 +299,9 @@ def __init__(
class ModelForm(BaseForm):
def __init__(
self,
- table: Table,
+ model: Type[Model],
record: Optional[Row] = None,
- record_id: Optional[int] = None,
+ record_id: Any = None,
fields: Union[Dict[str, List[str]], List[str]] = None,
exclude_fields: List[str] = [],
csrf: Union[str, bool] = "auto",
@@ -308,8 +315,12 @@ def __init__(
_method: str = "POST",
**attributes
):
- self.table = table
- self.record = record or table(record_id)
+ self.model = model._instance_()
+ self.table: Table = self.model.table
+ self.record = record or (
+ self.model.get(record_id) if record_id else
+ self.model.new()
+ )
#: build fields for form
fields_list_all = []
fields_list_writable = []
@@ -317,7 +328,7 @@ def __init__(
#: developer has selected specific fields
if not isinstance(fields, dict):
fields = {'writable': fields, 'readable': fields}
- for field in table:
+ for field in self.table:
if field.name not in fields['readable']:
continue
fields_list_all.append(field)
@@ -325,7 +336,7 @@ def __init__(
fields_list_writable.append(field)
else:
#: use table fields
- for field in table:
+ for field in self.table:
if field.name in exclude_fields:
continue
if not field.readable:
@@ -336,14 +347,11 @@ def __init__(
fields_list_all.append(field)
if field.writable:
fields_list_writable.append(field)
- # if field.type != 'id' and field.writable and \
- # field.name not in exclude_fields:
- # self.fields.append(field)
super().__init__(
fields=fields_list_all,
writable_fields=fields_list_writable,
csrf=csrf,
- id_prefix=table._tablename + "_",
+ id_prefix=self.table._tablename + "_",
formstyle=formstyle,
keepvalues=keepvalues,
onvalidation=onvalidation,
@@ -354,80 +362,80 @@ def __init__(
_method=_method
)
- def _validate_value(self, field, value):
- if field.type == "upload" and self.record:
+ def _get_id_value(self):
+ if len(self.model._fieldset_pk) > 1:
+ return tuple(self.record[pk] for pk in self.model.primary_keys)
+ return self.record[self.table._id.name]
+
+ def _validate_input(self):
+ record, fields = self.record.clone(), {
+ field.name: self._get_input_val(field)
+ for field in self.writable_fields
+ }
+ for field in filter(lambda f: f.type == "upload", self.writable_fields):
+ val = fields[field.name]
if (
- (value == b"" or value is None) and
+ (val == b"" or val is None) and
not self.input_params.get(field.name + "__del", False) and
self.record[field.name]
):
- return
- super()._validate_value(field, value)
+ fields.pop(field.name)
+ record.update(fields)
+ errors = record.validation_errors
+ for field in self.writable_fields:
+ if field.name in errors:
+ self.errors[field.name] = errors[field.name]
+ elif field.type == "upload":
+ self.files[field.name] = fields[field.name]
+ else:
+ self.params[field.name] = fields[field.name]
- async def _process(self):
+ async def _process(self, **kwargs):
#: send record id to validators if needed
current._dbvalidation_record_id_ = None
- if self.record:
- current._dbvalidation_record_id_ = self.record.id
+ if self.record._concrete:
+ current._dbvalidation_record_id_ = self._get_id_value()
#: load super `_process`
- # await Form._process(self)
- await super()._process()
- #: clear current and run additional operations for DAL
- del current._dbvalidation_record_id_
+ await super()._process(write_defaults=False)
+ #: additional record logic
if self.accepted:
- for field in self.writable_fields:
- #: handle uploads
- if field.type == "upload":
- upload = self.files[field.name]
- del_field = field.name + "__del"
- if not upload.filename:
- if self.input_params.get(del_field, False):
- self.params[field.name] = (
- self.table[field.name].default or ""
- )
- # TODO: we want to physically delete file?
- else:
- if self.record and self.record[field.name]:
- self.params[field.name] = self.record[field.name]
- continue
+ #: handle uploads
+ for field in filter(lambda f: f.type == "upload", self.writable_fields):
+ upload = self.files[field.name]
+ del_field = field.name + "__del"
+ if not upload.filename:
+ if self.input_params.get(del_field, False):
+ self.params[field.name] = self.table[field.name].default or ""
+ # TODO: do we want to physically delete file?
else:
- source_file, original_filename = (
- upload.stream, upload.filename
- )
- newfilename = field.store(
- source_file, original_filename, field.uploadfolder
- )
- if isinstance(field.uploadfield, str):
- self.params[field.uploadfield] = source_file.read()
- self.params[field.name] = newfilename
- #: add default values to hidden fields if needed
- if not self.record:
- fieldnames = [field.name for field in self.writable_fields]
- for field in self.table:
- if field.name not in fieldnames and field.compute is None:
- if field.default is not None:
- def_val = (
- field.default() if callable(field.default) else
- field.default
- )
- self.params[field.name] = def_val
- if self.record:
- self.record.update_record(**self.params)
- else:
- self.params.id = self.table.insert(**self.params)
+ if self.record._concrete and self.record[field.name]:
+ self.params[field.name] = self.record[field.name]
+ continue
+ else:
+ source_file, original_filename = upload.stream, upload.filename
+ newfilename = field.store(
+ source_file, original_filename, field.uploadfolder
+ )
+ if isinstance(field.uploadfield, str):
+ self.params[field.uploadfield] = source_file.read()
+ self.params[field.name] = newfilename
+ #: perform save
+ self.record.update(self.params)
+ if self.record.save():
+ self.params.id = self._get_id_value()
+ #: clear current from validation data
+ del current._dbvalidation_record_id_
+ #: cleanup inputs
if not self.processed or (self.accepted and not self.keepvalues):
for field in self.fields:
- if self.record:
- self.input_params[field.name] = self.record[field.name]
self.input_params[field.name] = field.formatter(
- self.input_params[field.name]
+ self.record[field.name]
)
- elif self.processed and not self.accepted and self.record:
+ elif self.processed and not self.accepted and self.record._concrete:
for field in self.writable_fields:
if field.type == "upload" and field.name not in self.params:
- self.input_params[field.name] = self.record[field.name]
self.input_params[field.name] = field.formatter(
- self.input_params[field.name]
+ self.record[field.name]
)
return self
@@ -716,7 +724,7 @@ def render(self):
def add_form_on_model(cls):
@wraps(cls)
def wrapped(model, *args, **kwargs):
- return cls(model.table, *args, **kwargs)
+ return cls(model, *args, **kwargs)
return wrapped
diff --git a/emmett/orm/__init__.py b/emmett/orm/__init__.py
index bf4d839e..8a8ffa65 100644
--- a/emmett/orm/__init__.py
+++ b/emmett/orm/__init__.py
@@ -1,12 +1,16 @@
from . import _patches
from .adapters import adapters as adapters_registry
from .base import Database
-from .objects import Field
+from .objects import Field, TransactionOps
from .models import Model
from .apis import (
belongs_to, refers_to, has_one, has_many,
compute, rowattr, rowmethod,
before_insert, before_update, before_delete,
+ before_save, before_destroy,
+ before_commit,
after_insert, after_update, after_delete,
+ after_save, after_destroy,
+ after_commit,
scope
)
diff --git a/emmett/orm/adapters.py b/emmett/orm/adapters.py
index 885ea3fb..924c42d5 100644
--- a/emmett/orm/adapters.py
+++ b/emmett/orm/adapters.py
@@ -9,7 +9,10 @@
:license: BSD-3-Clause
"""
+import sys
+
from functools import wraps
+
from pydal.adapters import adapters
from pydal.adapters.mssql import (
MSSQL1,
@@ -27,9 +30,14 @@
PostgrePsycoNew,
PostgrePG8000New
)
+from pydal.helpers.classes import SQLALL
+from pydal.helpers.regex import REGEX_TABLE_DOT_FIELD
+from pydal.parsers import ParserMethodWrapper, for_type as _parser_for_type
+from pydal.representers import TReprMethodWrapper, for_type as _representer_for_type
from .engines import adapters
-from .objects import Field
+from .helpers import GeoFieldWrapper, typed_row_reference
+from .objects import Expression, Field, Row, IterRows
adapters._registry_.update({
@@ -56,11 +64,21 @@ def wrapped(*args, **kwargs):
def patch_adapter(adapter):
+ adapter.insert = _wrap_on_obj(insert, adapter)
+ adapter.iterselect = _wrap_on_obj(iterselect, adapter)
adapter.parse = _wrap_on_obj(parse, adapter)
- adapter._parse_expand_colnames = _wrap_on_obj(
- _parse_expand_colnames, adapter)
+ adapter.iterparse = _wrap_on_obj(iterparse, adapter)
+ adapter._parse_expand_colnames = _wrap_on_obj(_parse_expand_colnames, adapter)
adapter._parse = _wrap_on_obj(_parse, adapter)
+ adapter._expand_all_with_concrete_tables = _wrap_on_obj(
+ _expand_all_with_concrete_tables, adapter
+ )
+ adapter._select_wcols_inner = adapter._select_wcols
+ adapter._select_wcols = _wrap_on_obj(_select_wcols, adapter)
+ adapter._select_aux = _wrap_on_obj(_select_aux, adapter)
patch_dialect(adapter.dialect)
+ patch_parser(adapter.parser)
+ patch_representer(adapter.representer)
def patch_dialect(dialect):
@@ -69,20 +87,161 @@ def patch_dialect(dialect):
'firebird': _create_table_firebird
}
dialect.create_table = _wrap_on_obj(
- _create_table_map.get(dialect.adapter.dbengine, _create_table),
- dialect)
-
-
-def parse(adapter, rows, fields, colnames, blob_decode=True, cacheable=False):
+ _create_table_map.get(dialect.adapter.dbengine, _create_table), dialect
+ )
+ dialect.add_foreign_key_constraint = _wrap_on_obj(_add_fk_constraint, dialect)
+ dialect.drop_constraint = _wrap_on_obj(_drop_constraint, dialect)
+
+
+def patch_parser(parser):
+ parser.registered['reference'] = ParserMethodWrapper(
+ parser,
+ _parser_for_type('reference')(_parser_reference).f,
+ parser._before_registry_['reference']
+ )
+ parser.registered['geography'] = ParserMethodWrapper(
+ parser,
+ _parser_for_type('geography')(_parser_geo).f
+ )
+ parser.registered['geometry'] = ParserMethodWrapper(
+ parser,
+ _parser_for_type('geometry')(_parser_geo).f
+ )
+
+
+def patch_representer(representer):
+ representer.registered_t['reference'] = TReprMethodWrapper(
+ representer,
+ _representer_for_type('reference')(_representer_reference),
+ representer._tbefore_registry_['reference']
+ )
+
+
+def insert(adapter, table, fields):
+ query = adapter._insert(table, fields)
+ try:
+ adapter.execute(query)
+ except:
+ e = sys.exc_info()[1]
+ if hasattr(table, '_on_insert_error'):
+ return table._on_insert_error(table, fields, e)
+ raise e
+ if not table._id:
+ id = {
+ field.name: val for field, val in fields
+ if field.name in table._primarykey
+ } or None
+ elif table._id.type == 'id':
+ id = adapter.lastrowid(table)
+ else:
+ id = {field.name: val for field, val in fields}.get(table._id.name)
+ rid = typed_row_reference(id, table)
+ return rid
+
+
+def iterselect(adapter, query, fields, attributes):
+ colnames, sql = adapter._select_wcols(query, fields, **attributes)
+ return adapter.iterparse(sql, fields, colnames, **attributes)
+
+
+def _expand_all_with_concrete_tables(adapter, fields, tabledict):
+ new_fields, concrete_tables = [], []
+ for item in fields:
+ if isinstance(item, SQLALL):
+ new_fields += item._table
+ concrete_tables.append(item._table)
+ elif isinstance(item, str):
+ m = REGEX_TABLE_DOT_FIELD.match(item)
+ if m:
+ tablename, fieldname = m.groups()
+ new_fields.append(adapter.db[tablename][fieldname])
+ else:
+ new_fields.append(Expression(adapter.db, lambda item=item: item))
+ else:
+ new_fields.append(item)
+ # ## if no fields specified take them all from the requested tables
+ if not new_fields:
+ for table in tabledict.values():
+ for field in table:
+ new_fields.append(field)
+ concrete_tables.append(table)
+ return new_fields, concrete_tables
+
+
+def _select_wcols(
+ adapter,
+ query,
+ fields,
+ left=False,
+ join=False,
+ distinct=False,
+ orderby=False,
+ groupby=False,
+ having=False,
+ limitby=False,
+ orderby_on_limitby=True,
+ for_update=False,
+ outer_scoped=[],
+ **kwargs
+):
+ return adapter._select_wcols_inner(
+ query,
+ fields,
+ left=left,
+ join=join,
+ distinct=distinct,
+ orderby=orderby,
+ groupby=groupby,
+ having=having,
+ limitby=limitby,
+ orderby_on_limitby=orderby_on_limitby,
+ for_update=for_update,
+ outer_scoped=outer_scoped
+ )
+
+
+def _select_aux(adapter, sql, fields, attributes, colnames):
+ rows = adapter._select_aux_execute(sql)
+ if isinstance(rows, tuple):
+ rows = list(rows)
+ limitby = attributes.get('limitby', None) or (0,)
+ rows = adapter.rowslice(rows, limitby[0], None)
+ return adapter.parse(
+ rows,
+ fields,
+ colnames,
+ concrete_tables=attributes.get('_concrete_tables', [])
+ )
+
+
+def parse(adapter, rows, fields, colnames, **options):
fdata, tables = _parse_expand_colnames(adapter, fields)
new_rows = [
- _parse(adapter, row, fdata, tables, fields, colnames, blob_decode)
- for row in rows
+ _parse(
+ adapter,
+ row,
+ fdata,
+ tables,
+ options['concrete_tables'],
+ fields,
+ colnames,
+ options.get('blob_decode', True)
+ ) for row in rows
]
rowsobj = adapter.db.Rows(adapter.db, new_rows, colnames, rawrows=rows)
return rowsobj
+def iterparse(adapter, sql, fields, colnames, **options):
+ return IterRows(
+ adapter.db,
+ sql,
+ fields,
+ options.get('_concrete_tables', []),
+ colnames
+ )
+
+
def _parse_expand_colnames(adapter, fieldlist):
rv, tables = [], {}
for field in fieldlist:
@@ -98,8 +257,10 @@ def _parse_expand_colnames(adapter, fieldlist):
return rv, tables
-def _parse(adapter, row, fdata, tables, fields, colnames, blob_decode):
- new_row = _build_newrow_wtables(adapter, tables)
+def _parse(adapter, row, fdata, tables, concrete_tables, fields, colnames, blob_decode):
+ new_row, rows_cls, rows_accum = _build_newrow_wtables(
+ adapter, tables, concrete_tables
+ )
extras = adapter.db.Row()
#: let's loop over columns
for (idx, colname) in enumerate(colnames):
@@ -109,7 +270,7 @@ def _parse(adapter, row, fdata, tables, fields, colnames, blob_decode):
#: do we have a real column?
if fd:
(tablename, fieldname, table, field, ft, fit) = fd
- colset = new_row[tablename]
+ colset = rows_accum[tablename]
#: parse value
value = adapter.parse_value(value, fit, ft, blob_decode)
if field.filter_out:
@@ -118,23 +279,30 @@ def _parse(adapter, row, fdata, tables, fields, colnames, blob_decode):
#: otherwise we set the value in extras
else:
value = adapter.parse_value(
- value, fields[idx]._itype, fields[idx].type, blob_decode)
+ value, fields[idx]._itype, fields[idx].type, blob_decode
+ )
extras[colname] = value
new_column_name = adapter._regex_select_as_parser(colname)
if new_column_name is not None:
column_name = new_column_name.groups(0)
new_row[column_name[0]] = value
+ for key, val in rows_cls.items():
+ new_row[key] = val(rows_accum[key])
#: add extras if needed (eg. operations results)
if extras:
new_row['_extra'] = extras
return new_row
-def _build_newrow_wtables(adapter, tables):
- rv = adapter.db.Row()
+def _build_newrow_wtables(adapter, tables, concrete_tables):
+ row, cls_map, accum = adapter.db.Row(), {}, {}
for name, table in tables.items():
- rv[name] = table._model_._rowclass_()
- return rv
+ cls_map[name] = adapter.db.Row
+ accum[name] = {}
+ for table in concrete_tables:
+ cls_map[table._tablename] = table._model_._rowclass_
+ accum[table._tablename] = {}
+ return row, cls_map, accum
def _create_table(dialect, tablename, fields):
@@ -168,6 +336,56 @@ def _create_table_firebird(dialect, tablename, fields):
return rv
+def _add_fk_constraint(
+ dialect,
+ name,
+ table_local,
+ table_foreign,
+ columns_local,
+ columns_foreign,
+ on_delete
+):
+ return (
+ f"ALTER TABLE {dialect.quote(table_local)} "
+ f"ADD CONSTRAINT {dialect.quote(name)} "
+ f"FOREIGN KEY ({','.join([dialect.quote(v) for v in columns_local])}) "
+ f"REFERENCES {dialect.quote(table_foreign)}"
+ f"({','.join([dialect.quote(v) for v in columns_foreign])}) "
+ f"ON DELETE {on_delete};"
+ )
+
+
+def _drop_constraint(dialect, name, table):
+ return f"ALTER TABLE {dialect.quote(table)} DROP CONSTRAINT {dialect.quote(name)};"
+
+
+def _parser_reference(parser, value, referee):
+ if '.' not in referee:
+ value = typed_row_reference(value, parser.adapter.db[referee])
+ return value
+
+
+def _parser_geo(parser, value):
+ return GeoFieldWrapper(value)
+
+
+def _representer_reference(representer, value, referenced):
+ rtname, _, rfname = referenced.partition('.')
+ rtable = representer.adapter.db[rtname]
+ if not rfname and rtable._id:
+ rfname = rtable._id.name
+ if not rfname:
+ return value
+ rtype = rtable[rfname].type
+ if isinstance(value, Row) and getattr(value, "_concrete", False):
+ value = value[(value._model.primary_keys or ["id"])[0]]
+ if rtype in ('id', 'integer'):
+ return str(int(value))
+ if rtype == 'string':
+ return str(value)
+ return representer.adapter.represent(value, rtype)
+
+
def _initialize(adapter, *args, **kwargs):
adapter._find_work_folder()
adapter._connection_manager.configure(
diff --git a/emmett/orm/apis.py b/emmett/orm/apis.py
index cfab25a0..b0c934ac 100644
--- a/emmett/orm/apis.py
+++ b/emmett/orm/apis.py
@@ -10,6 +10,9 @@
"""
from collections import OrderedDict
+from typing import List
+
+from .errors import MissingFieldsForCompute
from .helpers import Reference, Callback
@@ -48,8 +51,9 @@ def refobj(self):
class compute(object):
_inst_count_ = 0
- def __init__(self, field_name):
+ def __init__(self, field_name: str, watch: List[str] = []):
self.field_name = field_name
+ self.watch_fields = set(watch)
self._inst_count_ = compute._inst_count_
compute._inst_count_ += 1
@@ -57,6 +61,19 @@ def __call__(self, f):
self.f = f
return self
+ def compute(self, model, op_row):
+ if self.watch_fields:
+ row_keyset = set(op_row.keys())
+ if row_keyset & self.watch_fields:
+ if not self.watch_fields.issubset(row_keyset):
+ raise MissingFieldsForCompute(
+ f"Compute field '{self.field_name}' missing required "
+ f"({','.join(self.watch_fields - row_keyset)})"
+ )
+ else:
+ return
+ return self.f(model, op_row)
+
class rowattr(object):
_inst_count_ = 0
@@ -99,6 +116,40 @@ def after_delete(f):
return Callback(f, '_after_delete')
+def before_save(f):
+ return Callback(f, '_before_save')
+
+
+def after_save(f):
+ return Callback(f, '_after_save')
+
+
+def before_destroy(f):
+ return Callback(f, '_before_destroy')
+
+
+def after_destroy(f):
+ return Callback(f, '_after_destroy')
+
+
+def before_commit(f):
+ return Callback(f, '_before_commit')
+
+
+def after_commit(f):
+ return Callback(f, '_after_commit')
+
+
+def _commit_callback_op(kind, op):
+ def _deco(f):
+ return Callback(f, f'_{kind}_commit_{op}')
+ return _deco
+
+
+before_commit.operation = lambda op: _commit_callback_op('before', op)
+after_commit.operation = lambda op: _commit_callback_op('after', op)
+
+
class scope(object):
def __init__(self, name):
self.name = name
diff --git a/emmett/orm/base.py b/emmett/orm/base.py
index 5e51385c..006f2a3c 100644
--- a/emmett/orm/base.py
+++ b/emmett/orm/base.py
@@ -197,7 +197,8 @@ def define_models(self, *models):
args = dict(
migrate=obj.migrate,
format=obj.format,
- table_class=Table
+ table_class=Table,
+ primarykey=obj.primary_keys or ['id']
)
model.table = self.define_table(
obj.tablename, *obj.fields, **args
diff --git a/emmett/orm/engines/postgres.py b/emmett/orm/engines/postgres.py
index b1804a51..dcc47f88 100644
--- a/emmett/orm/engines/postgres.py
+++ b/emmett/orm/engines/postgres.py
@@ -123,8 +123,7 @@ def _jsonb(self, value):
return serializers.json(value)
-@adapters.register_for('postgres')
-class PostgresAdapter(PostgreBoolean):
+class PostgresAdapterMixin:
def _load_dependencies(self):
super()._load_dependencies()
self.dialect = JSONBPostgreDialect(self)
@@ -137,32 +136,39 @@ def _config_json(self):
def _mock_reconnect(self):
pass
+ def _insert(self, table, fields):
+ self._last_insert = None
+ if fields:
+ retval = None
+ if getattr(table, "_id", None):
+ self._last_insert = (table._id, 1)
+ retval = table._id._rname
+ return self.dialect.insert(
+ table._rname,
+ ','.join(el[0]._rname for el in fields),
+ ','.join(self.expand(v, f.type) for f, v in fields),
+ retval
+ )
+ return self.dialect.insert_empty(table._rname)
+
+ def lastrowid(self, table):
+ if self._last_insert:
+ return self.cursor.fetchone()[0]
+ sequence_name = table._sequence_name
+ self.execute("SELECT currval(%s);" % self.adapt(sequence_name))
+ return self.cursor.fetchone()[0]
-@adapters.register_for('postgres:psycopg2')
-class PostgresPsycoPG2Adapter(PostgrePsycoBoolean):
- def _load_dependencies(self):
- super()._load_dependencies()
- self.dialect = JSONBPostgreDialect(self)
- self.parser = JSONBPostgreParser(self)
- self.representer = JSONBPostgreRepresenter(self)
- def _config_json(self):
- pass
-
- def _mock_reconnect(self):
- pass
+@adapters.register_for('postgres')
+class PostgresAdapter(PostgresAdapterMixin, PostgreBoolean):
+ pass
-@adapters.register_for('postgres:pg8000')
-class PostgresPG8000Adapter(PostgrePG8000Boolean):
- def _load_dependencies(self):
- super()._load_dependencies()
- self.dialect = JSONBPostgreDialect(self)
- self.parser = JSONBPostgreParser(self)
- self.representer = JSONBPostgreRepresenter(self)
+@adapters.register_for('postgres:psycopg2')
+class PostgresPsycoPG2Adapter(PostgresAdapterMixin, PostgrePsycoBoolean):
+ pass
- def _config_json(self):
- pass
- def _mock_reconnect(self):
- pass
+@adapters.register_for('postgres:pg8000')
+class PostgresPG8000Adapter(PostgresAdapterMixin, PostgrePG8000Boolean):
+ pass
diff --git a/emmett/orm/engines/sqlite.py b/emmett/orm/engines/sqlite.py
index d33ec045..a979ce76 100644
--- a/emmett/orm/engines/sqlite.py
+++ b/emmett/orm/engines/sqlite.py
@@ -17,9 +17,24 @@
@adapters.register_for('sqlite', 'sqlite:memory')
class SQLite(_SQLite):
def _initialize_(self, do_connect):
- super(SQLite, self)._initialize_(do_connect)
+ super()._initialize_(do_connect)
self.driver_args['isolation_level'] = None
def begin(self, lock_type=None):
statement = 'BEGIN %s;' % lock_type if lock_type else 'BEGIN;'
self.execute(statement)
+
+ def delete(self, table, query):
+ deleted = (
+ [x[table._id.name] for x in self.db(query).select(table._id)]
+ if table._id else []
+ )
+ counter = super(_SQLite, self).delete(table, query)
+ if table._id and counter:
+ for field in table._referenced_by:
+ if (
+ field.type == 'reference ' + table._dalname and
+ field.ondelete == 'CASCADE'
+ ):
+ self.db(field.belongs(deleted)).delete()
+ return counter
diff --git a/emmett/orm/errors.py b/emmett/orm/errors.py
index d3077a51..691bf876 100644
--- a/emmett/orm/errors.py
+++ b/emmett/orm/errors.py
@@ -13,3 +13,27 @@
class MaxConnectionsExceeded(RuntimeError):
def __init__(self):
super().__init__('Exceeded maximum connections')
+
+
+class MissingFieldsForCompute(RuntimeError):
+ ...
+
+
+class SaveException(RuntimeError):
+ ...
+
+
+class InsertFailureOnSave(SaveException):
+ ...
+
+
+class UpdateFailureOnSave(SaveException):
+ ...
+
+
+class DestroyException(RuntimeError):
+ ...
+
+
+class ValidationError(RuntimeError):
+ ...
diff --git a/emmett/orm/geo.py b/emmett/orm/geo.py
new file mode 100644
index 00000000..8ecb3e6a
--- /dev/null
+++ b/emmett/orm/geo.py
@@ -0,0 +1,74 @@
+# -*- coding: utf-8 -*-
+"""
+ emmett.orm.geo
+ --------------
+
+ Provides geographic facilities.
+
+ :copyright: 2014 Giovanni Barillari
+ :license: BSD-3-Clause
+"""
+
+from .helpers import GeoFieldWrapper
+
+
+def Point(x, y):
+ return GeoFieldWrapper("POINT(%f %f)" % (x, y))
+
+
+def Line(*coordinates):
+ return GeoFieldWrapper(
+ "LINESTRING(%s)" % ','.join("%f %f" % point for point in coordinates)
+ )
+
+
+def Polygon(*coordinates_groups):
+ try:
+ if not isinstance(coordinates_groups[0][0], (tuple, list)):
+ coordinates_groups = (coordinates_groups,)
+ except Exception:
+ pass
+ return GeoFieldWrapper(
+ "POLYGON(%s)" % (
+ ",".join([
+ "(%s)" % ",".join("%f %f" % point for point in group)
+ for group in coordinates_groups
+ ])
+ )
+ )
+
+
+def MultiPoint(*points):
+ return GeoFieldWrapper(
+ "MULTIPOINT(%s)" % (
+ ",".join([
+ "(%f %f)" % point for point in points
+ ])
+ )
+ )
+
+
+def MultiLine(*lines):
+ return GeoFieldWrapper(
+ "MULTILINESTRING(%s)" % (
+ ",".join([
+ "(%s)" % ",".join("%f %f" % point for point in line)
+ for line in lines
+ ])
+ )
+ )
+
+
+def MultiPolygon(*polygons):
+ return GeoFieldWrapper(
+ "MULTIPOLYGON(%s)" % (
+ ",".join([
+ "(%s)" % (
+ ",".join([
+ "(%s)" % ",".join("%f %f" % point for point in group)
+ for group in polygon
+ ])
+ ) for polygon in polygons
+ ])
+ )
+ )
diff --git a/emmett/orm/helpers.py b/emmett/orm/helpers.py
index 171fd2f0..a3fbead5 100644
--- a/emmett/orm/helpers.py
+++ b/emmett/orm/helpers.py
@@ -9,17 +9,233 @@
:license: BSD-3-Clause
"""
+from __future__ import annotations
+
+import operator
import re
import time
-from functools import wraps
+from functools import reduce, wraps
+from typing import TYPE_CHECKING, Any, Callable
+
from pydal._globals import THREAD_LOCAL
-from pydal.helpers.classes import Reference as _IDReference, ExecutionHandler
+from pydal.helpers.classes import ExecutionHandler
from pydal.objects import Field as _Field
from ..datastructures import sdict
from ..utils import cachedprop
+if TYPE_CHECKING:
+ from .objects import Table
+
+
+class RowReferenceMeta:
+ __slots__ = ['table', 'pk', 'caster']
+
+ def __init__(self, table: Table, caster: Callable[[Any], Any]):
+ self.table = table
+ self.pk = table._id.name
+ self.caster = caster
+
+ def fetch(self, val):
+ return self.table._db(self.table._id == self.caster(val)).select(
+ limitby=(0, 1),
+ orderby_on_limitby=False
+ ).first()
+
+
+class RowReferenceMultiMeta:
+ __slots__ = ['table', 'pks', 'pks_idx', 'caster', 'casters']
+ _casters = {'integer': int, 'string': str}
+
+ def __init__(self, table: Table) -> None:
+ self.table = table
+ self.pks = list(table._primarykey)
+ self.pks_idx = {key: idx for idx, key in enumerate(self.pks)}
+ self.caster = tuple
+ self.casters = {pk: self._casters[table[pk].type] for pk in self.pks}
+
+ def fetch(self, val):
+ query = reduce(
+ operator.and_, [
+ self.table[pk] == self.casters[pk](self.caster.__getitem__(val, idx))
+ for pk, idx in self.pks_idx.items()
+ ]
+ )
+ return self.table._db(query).select(
+ limitby=(0, 1),
+ orderby_on_limitby=False
+ ).first()
+
+
+class RowReferenceMixin:
+ def _allocate_(self):
+ if not self._refrecord:
+ self._refrecord = self._refmeta.fetch(self)
+ if not self._refrecord:
+ raise RuntimeError(
+ "Using a recursive select but encountered a broken " +
+ "reference: %s %r" % (self._table, self)
+ )
+
+ def __getattr__(self, key: str) -> Any:
+ if key == self._refmeta.pk:
+ return self._refmeta.caster(self)
+ if key in self._refmeta.table:
+ self._allocate_()
+ if self._refrecord:
+ return self._refrecord.get(key, None)
+ return None
+
+ def get(self, key: str, default: Any = None) -> Any:
+ return self.__getattr__(key, default)
+
+ def __setattr__(self, key: str, value: Any):
+ if key.startswith('_'):
+ self._refmeta.caster.__setattr__(self, key, value)
+ return
+ self._allocate_()
+ self._refrecord[key] = value
+
+ def __getitem__(self, key):
+ if key == self._refmeta.pk:
+ return self._refmeta.caster(self)
+ self._allocate_()
+ return self._refrecord.get(key, None)
+
+ def __setitem__(self, key, value):
+ self._allocate_()
+ self._refrecord[key] = value
+
+ def __pure__(self):
+ return self._refmeta.caster(self)
+
+ def __repr__(self) -> str:
+ return repr(self._refmeta.caster(self))
+
+
+class RowReferenceInt(RowReferenceMixin, int):
+ def __new__(cls, id, table: Table, *args: Any, **kwargs: Any):
+ rv = super().__new__(cls, id, *args, **kwargs)
+ int.__setattr__(rv, '_refmeta', RowReferenceMeta(table, int))
+ int.__setattr__(rv, '_refrecord', None)
+ return rv
+
+
+class RowReferenceStr(RowReferenceMixin, str):
+ def __new__(cls, id, table: Table, *args: Any, **kwargs: Any):
+ rv = super().__new__(cls, id, *args, **kwargs)
+ str.__setattr__(rv, '_refmeta', RowReferenceMeta(table, str))
+ str.__setattr__(rv, '_refrecord', None)
+ return rv
+
+
+class RowReferenceMulti(RowReferenceMixin, tuple):
+ def __new__(cls, id, table: Table, *args: Any, **kwargs: Any):
+ tupid = tuple(id[key] for key in table._primarykey)
+ rv = super().__new__(cls, tupid, *args, **kwargs)
+ tuple.__setattr__(rv, '_refmeta', RowReferenceMultiMeta(table))
+ tuple.__setattr__(rv, '_refrecord', None)
+ return rv
+
+ def __getattr__(self, key: str) -> Any:
+ if key in self._refmeta.pks:
+ return self._refmeta.casters[key](
+ tuple.__getitem__(self, self._refmeta.pks_idx[key])
+ )
+ if key in self._refmeta.table:
+ self._allocate_()
+ if self._refrecord:
+ return self._refrecord.get(key, None)
+ return None
+
+ def __getitem__(self, key):
+ if key in self._refmeta.pks:
+ return self._refmeta.casters[key](
+ tuple.__getitem__(self, self._refmeta.pks_idx[key])
+ )
+ self._allocate_()
+ return self._refrecord.get(key, None)
+
+
+class GeoFieldWrapper(str):
+ _rule_parens = re.compile(r"^(\(+)(?:.+)$")
+ _json_geom_map = {
+ "POINT": "Point",
+ "LINESTRING": "LineString",
+ "POLYGON": "Polygon",
+ "MULTIPOINT": "MultiPoint",
+ "MULTILINESTRING": "MultiLineString",
+ "MULTIPOLYGON": "MultiPolygon"
+ }
+
+ def __new__(cls, value, *args: Any, **kwargs: Any):
+ geometry, raw_coords = value.strip()[:-1].split("(", 1)
+ rv = super().__new__(cls, value, *args, **kwargs)
+ coords = cls._parse_coords_block(raw_coords)
+ str.__setattr__(rv, '_geometry', geometry.strip())
+ str.__setattr__(rv, '_coordinates', coords)
+ return rv
+
+ @classmethod
+ def _parse_coords_block(cls, v):
+ groups = []
+ parens_match = cls._rule_parens.match(v)
+ parens = parens_match.group(1) if parens_match else ''
+ if parens:
+ for element in v.split(parens):
+ if not element:
+ continue
+ element = element.strip()
+ shift = -2 if element.endswith(",") else -1
+ groups.append(f"{parens}{element}"[1:shift])
+ if not groups:
+ return cls._parse_coords_group(v)
+ return tuple(
+ cls._parse_coords_block(group) for group in groups
+ )
+
+ @staticmethod
+ def _parse_coords_group(v):
+ accum = []
+ for element in v.split(","):
+ accum.append(tuple(float(v) for v in element.split(" ")))
+ return tuple(accum) if len(accum) > 1 else accum[0]
+
+ def _repr_coords(self, val=None):
+ val = val or self._coordinates
+ if isinstance(val[0], tuple):
+ accum = []
+ for el in val:
+ inner, plevel = self._repr_coords(el)
+ inner = f"({inner})" if not plevel else inner
+ accum.append(inner)
+ return ",".join(accum), False
+ return "%f %f" % val, True
+
+ @property
+ def geometry(self):
+ return self._geometry
+
+ @property
+ def coordinates(self):
+ return self._coordinates
+
+ @property
+ def groups(self):
+ if not self._geometry.startswith("MULTI"):
+ return tuple()
+ return tuple(
+ self.__class__(f"{self._geometry[5:]}({self._repr_coords(coords)[0]})")
+ for coords in self._coordinates
+ )
+
+ def __json__(self):
+ return {
+ "type": self._json_geom_map[self._geometry],
+ "coordinates": self._coordinates
+ }
+
class Reference(object):
def __init__(self, *args, **params):
@@ -78,8 +294,11 @@ def table_name(self):
return self.model_instance.tablename
@property
- def field_instance(self):
- return self.table[self.field]
+ def fields_instances(self):
+ return tuple(
+ self.table[field]
+ for field in self.model_instance._belongs_fks_[self.reverse].local_fields
+ )
class RelationBuilder(object):
@@ -88,7 +307,10 @@ def __init__(self, ref, model_instance):
self.model = model_instance
def _make_refid(self, row):
- return row.id if row is not None else self.model.id
+ pks = self.model.primary_keys or ["id"]
+ if row:
+ return tuple(row[pk] for pk in pks)
+ return tuple(self.model.table[pk] for pk in pks)
def _extra_scopes(self, ref, model_instance=None):
model_instance = model_instance or ref.model_instance
@@ -112,16 +334,32 @@ def _patch_query_with_scopes_on(self, ref, query, model_name):
return query
def _get_belongs(self, modelname, value):
- return self.model.db[modelname]._model_._belongs_ref_.get(value)
+ return self.model.db[modelname]._model_._belongs_fks_.get(value)
def belongs_query(self):
- return (self.model.table[self.ref[1]] == self.model.db[self.ref[0]].id)
+ return reduce(
+ operator.and_, [
+ self.model.table[local] == self.model.db[self.ref.model][foreign]
+ for local, foreign in self.ref.coupled_fields
+ ]
+ )
@staticmethod
def many_query(ref, rid):
- if ref.cast and isinstance(rid, _Field):
- rid = rid.cast(ref.cast)
- return ref.model_instance.table[ref.field] == rid
+ components = rid
+ if ref.cast:
+ components = []
+ for element in rid:
+ if isinstance(rid, _Field):
+ components.append(element.cast(ref.cast))
+ else:
+ components.append(element)
+ return reduce(
+ operator.and_, [
+ field == components[idx]
+ for idx, field in enumerate(ref.fields_instances)
+ ]
+ )
def _many(self, ref, rid):
return ref.dbset.where(
@@ -153,20 +391,35 @@ def via(self, row=None):
#: join table way
last_belongs = step_model
last_via = via
- _query = (db[belongs_model].id == db[step_model][rname])
- sel_field = db[belongs_model].ALL
- step_model = belongs_model
+ _query = reduce(
+ operator.and_, [
+ (
+ db[belongs_model.model][foreign] ==
+ db[step_model][local]
+ ) for local, foreign in belongs_model.coupled_fields
+ ]
+ )
+ sel_field = db[belongs_model.model].ALL
+ step_model = belongs_model.model
else:
#: shortcut way
last_belongs = None
rname = via.field or via.name
midrel = db[step_model]._model_._hasmany_ref_[rname]
- _query = self._many(midrel, db[step_model].id)
+ _query = self._many(
+ midrel, [
+ db[step_model][step_field]
+ for step_field in (
+ db[step_model]._model_.primary_keys or ["id"]
+ )
+ ]
+ )
step_model = midrel.table_name
sel_field = db[step_model].ALL
query = query & _query
query = via.dbset.where(
- self._patch_query_with_scopes_on(via, query, step_model)).query
+ self._patch_query_with_scopes_on(via, query, step_model)
+ ).query
return query, sel_field, sname, rid, last_belongs, last_via
@@ -187,18 +440,6 @@ def __call__(self):
return None
-class JoinedIDReference(_IDReference):
- @classmethod
- def _from_record(cls, record, table=None):
- rv = cls(record.id)
- rv._table = table
- rv._record = record
- return rv
-
- def as_dict(self, datetime_to_str=False, custom_types=None):
- return self._record.as_dict()
-
-
class TimingHandler(ExecutionHandler):
def _timings(self):
THREAD_LOCAL._emtdal_timings_ = getattr(
@@ -285,3 +526,30 @@ def wrap_virtual_on_model(model, virtual):
def wrapped(row, *args, **kwargs):
return virtual(model, row, *args, **kwargs)
return wrapped
+
+
+def typed_row_reference(id: Any, table: Table):
+ field_type = table._id.type if table._id else None
+ return {
+ 'id': RowReferenceInt,
+ 'integer': RowReferenceInt,
+ 'string': RowReferenceStr,
+ None: RowReferenceMulti
+ }[field_type](id, table)
+
+
+def typed_row_reference_from_record(record: Any, model: Any):
+ field_type = model.table._id.type if model.table._id else None
+ refcls = {
+ 'id': RowReferenceInt,
+ 'integer': RowReferenceInt,
+ 'string': RowReferenceStr,
+ None: RowReferenceMulti
+ }[field_type]
+ if len(model._fieldset_pk) > 1:
+ id = {pk: record[pk] for pk in model._fieldset_pk}
+ else:
+ id = record[tuple(model._fieldset_pk)[0]]
+ rv = refcls(id, model.table)
+ rv._refrecord = record
+ return rv
diff --git a/emmett/orm/migrations/base.py b/emmett/orm/migrations/base.py
index 06192d24..62719df2 100644
--- a/emmett/orm/migrations/base.py
+++ b/emmett/orm/migrations/base.py
@@ -9,92 +9,94 @@
:license: BSD-3-Clause
"""
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any, Callable, Dict, Type
+
from ...datastructures import sdict
-from .. import Model, Field
+from .. import Database, Model, Field
from .engine import MetaEngine, Engine
from .helpers import WrappedOperation, _feasible_as_dbms_default
+if TYPE_CHECKING:
+ from .operations import Operation
+
class Schema(Model):
tablename = "emmett_schema"
version = Field()
-class Migration(object):
- _registered_ops_ = {}
+class Migration:
+ _registered_ops_: Dict[str, Type[Operation]] = {}
@classmethod
- def register_operation(cls, name):
- def wrap(op_cls):
+ def register_operation(
+ cls,
+ name: str
+ ) -> Callable[[Type[Operation]], Type[Operation]]:
+ def wrap(op_cls: Type[Operation]) -> Type[Operation]:
cls._registered_ops_[name] = op_cls
return op_cls
return wrap
- def __init__(self, app, db, is_meta=False):
+ def __init__(self, app: Any, db: Database, is_meta: bool = False):
self.db = db
if is_meta:
self.engine = MetaEngine(db)
else:
self.engine = Engine(db)
- def __getattr__(self, name):
+ def __getattr__(self, name: str) -> WrappedOperation:
registered = self._registered_ops_.get(name)
if registered is not None:
return WrappedOperation(registered, name, self.engine)
- else:
- raise NotImplementedError
+ raise NotImplementedError
class Column(sdict):
- def __init__(self, name, type='string', unique=False, notnull=False,
- **kwargs):
+ def __init__(
+ self,
+ name: str,
+ type: str = 'string',
+ unique: bool = False,
+ notnull: bool = False,
+ **kwargs: Any
+ ):
self.name = name
self.type = type
self.unique = unique
self.notnull = notnull
for key, val in kwargs.items():
self[key] = val
- self.length = self.length or 255
-
- def _build_fks(self, db, tablename):
- if self.type.startswith('reference'):
- referenced = self.type[10:].strip()
- try:
- rtablename, rfieldname = referenced.split('.')
- except Exception:
- rtablename = referenced
- rfieldname = 'id'
- if not rtablename:
- rtablename = tablename
- rtable = db[rtablename]
- rfield = rtable[rfieldname]
- if getattr(rtable, '_primarykey', None) and rfieldname in \
- rtable._primarykey or rfield.unique:
- if not rfield.unique and len(rtable._primarykey) > 1:
- # self.tfk = [pk for pk in rtable._primarykey]
- raise NotImplementedError(
- 'Column of type reference pointing to multiple ' +
- 'columns are currently not supported.'
- )
- else:
- self.fk = True
+ self.length: int = self.length or 255
+
+ def _fk_type(self, db: Database, tablename: str):
+ if self.name not in db[tablename]._model_._belongs_ref_:
+ return
+ ref = db[tablename]._model_._belongs_ref_[self.name]
+ if ref.ftype != 'id':
+ self.type = ref.ftype
+ self.length = db[ref.model][ref.fk].length
+ self.on_delete = None
@classmethod
- def from_field(cls, field):
+ def from_field(cls, field: Field) -> Column:
rv = cls(
field.name,
- field.type,
+ field._pydal_types.get(field._type, field._type),
field.unique,
field.notnull,
length=field.length,
- ondelete=field.ondelete
+ ondelete=field.ondelete,
+ **field._ormkw
)
if _feasible_as_dbms_default(field.default):
rv.default = field.default
- rv._build_fks(field.db, field.tablename)
+ rv._fk_type(field.db, field.tablename)
return rv
- def __repr__(self):
+ def __repr__(self) -> str:
return "%s(%s)" % (
self.__class__.__name__,
", ".join(["%s=%r" % (k, v) for k, v in self.items()])
diff --git a/emmett/orm/migrations/commands.py b/emmett/orm/migrations/commands.py
index 2f855f9b..471e685b 100644
--- a/emmett/orm/migrations/commands.py
+++ b/emmett/orm/migrations/commands.py
@@ -9,19 +9,23 @@
:license: BSD-3-Clause
"""
+from __future__ import annotations
+
+from typing import Any, List
+
import click
from ...datastructures import sdict
-from .base import Schema, Column
+from .base import Database, Schema, Column
from .helpers import DryRunDatabase, make_migration_id, to_tuple
from .operations import MigrationOp, UpgradeOps, DowngradeOps
from .scripts import ScriptDir
-class Command(object):
- def __init__(self, app, dals):
+class Command:
+ def __init__(self, app: Any, dals: List[Database]):
self.app = app
- self.envs = []
+ self.envs: List[sdict] = []
self._load_envs(dals)
def _load_envs(self, dals):
@@ -360,6 +364,43 @@ def down(self, rev_id, dry_run=False):
)
raise
+ def set(self, rev_id, auto_confirm=False):
+ for ctx in self.envs:
+ self.load_schema(ctx)
+ current_revision = ctx._current_revision_
+ target_revision = ctx.scriptdir.get_revision(rev_id)
+ if not target_revision:
+ click.secho("> No matching revision found", fg="red")
+ return
+ click.echo(
+ " ".join([
+ click.style("> Setting revision to", fg="yellow"),
+ click.style(target_revision.revision, bold=True, fg="yellow"),
+ click.style("against", fg="yellow"),
+ click.style(ctx.db._uri, bold=True, fg="yellow")
+ ])
+ )
+ if not auto_confirm:
+ if not click.confirm("Do you want to continue?"):
+ click.echo("Aborting")
+ return
+ with ctx.db.connection():
+ self._store_current_revision_(
+ ctx, current_revision, target_revision.revision
+ )
+ click.echo(
+ "".join([
+ click.style(
+ "> Succesfully set revision to ",
+ fg="green"
+ ),
+ click.style(
+ target_revision.revision, fg="cyan", bold=True
+ ),
+ click.style(f": {target_revision.doc}", fg="green")
+ ])
+ )
+
def generate(app, dals, message, head):
Command(app, dals).generate(message, head)
@@ -391,3 +432,7 @@ def up(app, dals, revision, dry_run):
def down(app, dals, revision, dry_run):
Command(app, dals).down(revision, dry_run)
+
+
+def set_revision(app, dals, revision, auto_confirm):
+ Command(app, dals).set(revision, auto_confirm)
diff --git a/emmett/orm/migrations/engine.py b/emmett/orm/migrations/engine.py
index 4221cda7..c52f2df4 100644
--- a/emmett/orm/migrations/engine.py
+++ b/emmett/orm/migrations/engine.py
@@ -9,15 +9,23 @@
:license: BSD-3-Clause
"""
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any, Dict, List, Tuple
+
from ...datastructures import sdict
+if TYPE_CHECKING:
+ from .base import Column, Database
+ from .generation import MetaData
-class MetaEngine(object):
- def __init__(self, db):
+
+class MetaEngine:
+ def __init__(self, db: MetaData):
self.db = db
- def create_table(self, name, columns, **kwargs):
- self.db.create_table(name, columns)
+ def create_table(self, name, columns, primary_keys, **kwargs):
+ self.db.create_table(name, columns, primary_keys, **kwargs)
def drop_table(self, name):
self.db.drop_table(name)
@@ -39,28 +47,50 @@ def create_index(self, name, table_name, fields, expr, unique, **kw):
def drop_index(self, name, table_name):
self.db.drop_index(table_name, name)
+ def create_foreign_key_constraint(
+ self,
+ name,
+ table_name,
+ column_names,
+ foreign_table_name,
+ foreign_keys,
+ on_delete
+ ):
+ self.db.create_foreign_key_constraint(
+ table_name,
+ name,
+ column_names,
+ foreign_table_name,
+ foreign_keys,
+ on_delete
+ )
+
+ def drop_foreign_key_constraint(self, name, table_name):
+ self.db.drop_foreign_key_constraint(table_name, name)
+
@staticmethod
- def _parse_column_changes(changes):
+ def _parse_column_changes(
+ changes: List[Tuple[str, str, str, Dict[str, Any], Any, Any]]
+ ) -> Dict[str, List[Any]]:
rv = {}
for change in changes:
if change[0] == "modify_type":
- rv['type'] = [
- change[4], change[5], change[3]['existing_length']
- ]
+ rv['type'] = [change[4], change[5], change[3]['existing_length']]
elif change[0] == "modify_length":
- rv['length'] = [
- change[4], change[5], change[3]['existing_type']
- ]
+ rv['length'] = [change[4], change[5], change[3]['existing_type']]
elif change[0] == "modify_notnull":
rv['notnull'] = [change[4], change[5]]
elif change[0] == "modify_default":
- rv['default'] = [
- change[4], change[5], change[3]['existing_type']
- ]
+ rv['default'] = [change[4], change[5], change[3]['existing_type']]
+ else:
+ rv[change[0].split("modify_")[-1]] = [change[4], change[5], change[3]]
return rv
class Engine(MetaEngine):
+ def __init__(self, db: Database):
+ self.db = db
+
@property
def adapter(self):
return self.db._adapter
@@ -73,12 +103,8 @@ def _log_and_exec(self, sql):
self.db.logger.debug("executing SQL:\n%s" % sql)
self.adapter.execute(sql)
- def create_table(self, name, columns, **kwargs):
- params = {}
- for key in ['primary_keys', 'id_col']:
- if kwargs.get(key) is not None:
- params[key] = kwargs[key]
- sql_list = self._new_table_sql(name, columns, **params)
+ def create_table(self, name, columns, primary_keys, **kwargs):
+ sql_list = self._new_table_sql(name, columns, primary_keys, **kwargs)
for sql in sql_list:
self._log_and_exec(sql)
@@ -98,7 +124,10 @@ def drop_column(self, tablename, colname):
def alter_column(self, table_name, column_name, changes):
sql = self._alter_column_sql(
- table_name, column_name, self._parse_column_changes(changes))
+ table_name,
+ column_name,
+ self._parse_column_changes(changes)
+ )
if sql is not None:
self._log_and_exec(sql)
@@ -106,8 +135,7 @@ def create_index(self, name, table_name, fields, expr, unique, **kw):
adapt_t = sdict(_rname=self.dialect.quote(table_name))
components = [self.dialect.quote(field) for field in fields]
components += expr
- sql = self.dialect.create_index(
- name, adapt_t, components, unique, **kw)
+ sql = self.dialect.create_index(name, adapt_t, components, unique, **kw)
self._log_and_exec(sql)
def drop_index(self, name, table_name):
@@ -115,7 +143,25 @@ def drop_index(self, name, table_name):
sql = self.dialect.drop_index(name, adapt_t)
self._log_and_exec(sql)
- def _gen_reference(self, tablename, column, tfks):
+ def create_foreign_key_constraint(
+ self,
+ name: str,
+ table_name: str,
+ column_names: List[str],
+ foreign_table_name: str,
+ foreign_keys: List[str],
+ on_delete: str
+ ):
+ sql = self.dialect.add_foreign_key_constraint(
+ name, table_name, foreign_table_name, column_names, foreign_keys, on_delete
+ )
+ self._log_and_exec(sql)
+
+ def drop_foreign_key_constraint(self, name, table_name):
+ sql = self.dialect.drop_constraint(name, table_name)
+ self._log_and_exec(sql)
+
+ def _gen_reference(self, tablename, column):
referenced = column.type[10:].strip()
constraint_name = self.dialect.constraint_name(tablename, column.name)
try:
@@ -125,127 +171,125 @@ def _gen_reference(self, tablename, column, tfks):
rfieldname = 'id'
if not rtablename:
rtablename = tablename
- if column.fk or column.tfk:
- csql = self.adapter.types[column.type[:9]] % \
- dict(length=column.length)
- if column.fk:
- csql = csql + self.adapter.types['reference FK'] % dict(
- constraint_name=self.dialect.quote(constraint_name),
- foreign_key='%s (%s)' % (
- self.dialect.quote(rtablename),
- self.dialect.quote(rfieldname)),
- table_name=self.dialect.quote(tablename),
- field_name=self.dialect.quote(column.name),
- on_delete_action=column.ondelete)
- if column.tfk:
- # TODO
- raise NotImplementedError(
- 'Migrating tables containing multiple columns references '
- 'is currently not supported.'
- )
- else:
- csql_info = dict(
- index_name=self.dialect.quote(column.name + '__idx'),
- field_name=self.dialect.quote(column.name),
- constraint_name=self.dialect.quote(constraint_name),
- foreign_key='%s (%s)' % (
- self.dialect.quote(rtablename),
- self.dialect.quote(rfieldname)),
- on_delete_action=column.ondelete)
- csql_info['null'] = ' NOT NULL' if column.notnull else \
- self.dialect.allow_null
- csql_info['unique'] = ' UNIQUE' if column.unique else ''
- csql = self.adapter.types['reference'] % csql_info
+ csql_info = dict(
+ index_name=self.dialect.quote(column.name + '__idx'),
+ field_name=self.dialect.quote(column.name),
+ constraint_name=self.dialect.quote(constraint_name),
+ foreign_key='%s (%s)' % (
+ self.dialect.quote(rtablename),
+ self.dialect.quote(rfieldname)
+ ),
+ on_delete_action=column.ondelete)
+ csql_info['null'] = (
+ ' NOT NULL' if column.notnull else
+ self.dialect.allow_null
+ )
+ csql_info['unique'] = ' UNIQUE' if column.unique else ''
+ csql = self.adapter.types['reference'] % csql_info
return csql
def _gen_primary_key(self, fields, primary_keys=[]):
if primary_keys:
- fields.append(self.dialect.primary_key(
- ', '.join([
- self.dialect.quote(pk) for pk in primary_keys])))
+ fields.append(
+ self.dialect.primary_key(
+ ', '.join([self.dialect.quote(pk) for pk in primary_keys])
+ )
+ )
- def _gen_geo(self, tablename, column):
+ def _gen_geo(self, column_type, geometry_type, srid, dimension):
if not hasattr(self.adapter, 'srid'):
raise RuntimeError('Adapter does not support geometry')
- geotype, parms = column.type[:-1].split('(')
- if geotype not in self.adapter.types:
+ if column_type not in self.adapter.types:
raise SyntaxError(
- 'Field: unknown field type: %s for %s' %
- (column.type, column.name))
- if self.adaper.dbengine == 'postgres' and geotype == 'geometry':
- # TODO
- raise NotImplementedError(
- 'Migration with PostgreSQL and %s columns are not supported.' %
- column.type
+ f'Field: unknown field type: {column_type}'
)
- return self.adapter.types[geotype]
-
- def _new_column_sql(self, tablename, column, tfks):
+ return "{ctype}({gtype},{srid},{dimension})".format(
+ ctype=self.adapter.types[column_type],
+ gtype=geometry_type,
+ srid=srid or self.adapter.srid,
+ dimension=dimension or 2
+ )
+
+ def _new_column_sql(
+ self,
+ tablename: str,
+ column: Column,
+ primary_key: bool = False
+ ) -> str:
if column.type.startswith('reference'):
- csql = self._gen_reference(tablename, column, tfks)
+ csql = self._gen_reference(tablename, column)
elif column.type.startswith('list:reference'):
csql = self.adapter.types[column.type[:14]]
elif column.type.startswith('decimal'):
precision, scale = map(int, column.type[8:-1].split(','))
- csql = self.adapter.types[column.type[:7]] % \
- dict(precision=precision, scale=scale)
+ csql = self.adapter.types[column.type[:7]] % dict(
+ precision=precision, scale=scale
+ )
elif column.type.startswith('geo'):
- csql = self._gen_geo()
+ csql = self._gen_geo(
+ column.type,
+ column.geometry_type,
+ column.srid,
+ column.dimension
+ )
elif column.type not in self.adapter.types:
raise SyntaxError(
- 'Field: unknown field type: %s for %s' %
- (column.type, column.name))
+ f'Field: unknown field type: {column.type} for {column.nmae}'
+ )
else:
- csql = self.adapter.types[column.type] % \
- {'length': column.length}
+ csql = self.adapter.types[column.type] % {'length': column.length}
if self.adapter.dbengine not in ('firebird', 'informix', 'oracle'):
- cprops = "%(notnull)s%(default)s%(unique)s%(qualifier)s"
+ cprops = "%(notnull)s%(default)s%(unique)s%(pk)s%(qualifier)s"
else:
- cprops = "%(default)s%(notnull)s%(unique)s%(qualifier)s"
+ cprops = "%(default)s%(notnull)s%(unique)s%(pk)s%(qualifier)s"
if not column.type.startswith(('id', 'reference')):
csql += cprops % {
- 'notnull': ' NOT NULL' if column.notnull
- else self.dialect.allow_null,
- 'default': ' DEFAULT %s' %
- self.adapter.represent(column.default, column.type)
- if column.default is not None else '',
+ 'notnull': ' NOT NULL' if column.notnull else self.dialect.allow_null,
+ 'default': (
+ ' DEFAULT %s' % self.adapter.represent(column.default, column.type)
+ if column.default is not None else ''
+ ),
'unique': ' UNIQUE' if column.unique else '',
- 'qualifier': ' %s' % column.custom_qualifier
- if column.custom_qualifier else ''
+ 'pk': ' PRIMARY KEY' if primary_key else '',
+ 'qualifier': (
+ ' %s' % column.custom_qualifier if column.custom_qualifier else ''
+ )
}
- # if column.notnull:
- # csql += ' NOT NULL'
- # else:
- # csql += self.adapter.ALLOW_NULL()
- # if column.unique:
- # csql += ' UNIQUE'
- # if column.custom_qualifier:
- # csql += ' %s' % column.custom_qualifier
- # if column.notnull and column.default is not None:
- # not_null = self.adapter.NOT_NULL(column.default, column.type)
- # csql = csql.replace('NOT NULL', not_null)
return csql
- def _new_table_sql(self, tablename, columns, primary_keys=[], id_col='id'):
+ def _new_table_sql(
+ self,
+ tablename: str,
+ columns: List[Column],
+ primary_keys: List[str] = [],
+ id_col: str ='id'
+ ) -> str:
# TODO:
- # - postgres geometry
# - SQLCustomType
+ composed_primary_key = len(primary_keys) > 1
fields = []
- tfks = {}
- for sortable, column in enumerate(columns, start=1):
- csql = self._new_column_sql(tablename, column, tfks)
+ for column in columns:
+ csql = self._new_column_sql(
+ tablename,
+ column,
+ primary_key=(
+ column.name in primary_keys if not composed_primary_key else False
+ )
+ )
fields.append('%s %s' % (self.dialect.quote(column.name), csql))
# backend-specific extensions to fields
if self.adapter.dbengine == 'mysql':
if not primary_keys:
primary_keys.append(id_col)
+ elif not composed_primary_key:
+ primary_keys.clear()
self._gen_primary_key(fields, primary_keys)
fields = ',\n '.join(fields)
return self.dialect.create_table(tablename, fields)
def _add_column_sql(self, tablename, column):
- csql = self._new_column_sql(tablename, column, {})
+ csql = self._new_column_sql(tablename, column)
return 'ALTER TABLE %(tname)s ADD %(cname)s %(sql)s;' % {
'tname': self.dialect.quote(tablename),
'cname': self.dialect.quote(column.name),
@@ -261,6 +305,12 @@ def _drop_column_sql(self, table_name, column_name):
self.dialect.quote(table_name), self.dialect.quote(column_name))
def _represent_changes(self, changes, field):
+ geo_attrs = ("geometry_type", "srid", "dimension")
+ geo_changes, geo_data = {}, {}
+ for key in set(changes.keys()) & set(geo_attrs):
+ geo_changes[key] = changes.pop(key)
+ geo_data.update(geo_changes[key][3])
+
if 'default' in changes and changes['default'][1] is not None:
ftype = changes['default'][2] or field.type
if 'type' in changes:
@@ -279,7 +329,14 @@ def _represent_changes(self, changes, field):
csql = self.adapter.types[coltype[:7]] % \
dict(precision=precision, scale=scale)
elif coltype.startswith('geo'):
- csql = self._gen_geo()
+ gen_attrs = []
+ for key in geo_attrs:
+ val = (
+ geo_changes.get(f"{key}", (None, None))[1] or
+ geo_data[f"existing_{key}"]
+ )
+ gen_attrs.append(val)
+ csql = self._gen_geo(coltype, gen_attrs)
else:
csql = self.adapter.types[coltype] % {
'length': changes['type'][2] or field.length
@@ -288,10 +345,18 @@ def _represent_changes(self, changes, field):
elif 'length' in changes:
change = changes.pop('length')
ftype = change[2] or field.type
- changes['type'] = [
- None,
- self.adapter.types[ftype] % {'length': change[1]}
- ]
+ changes['type'] = [None, self.adapter.types[ftype] % {'length': change[1]}]
+ elif geo_changes:
+ coltype = geo_data["existing_type"] or field.type
+ gen_attrs = []
+ for key in geo_attrs:
+ val = (
+ geo_changes.get(f"{key}", (None, None))[1] or
+ geo_data[f"existing_{key}"]
+ )
+ gen_attrs.append(val)
+ changes['type'] = [None, self._gen_geo(coltype, *gen_attrs)]
+
def _alter_column_sql(self, table_name, column_name, changes):
sql = 'ALTER TABLE %(tname)s ALTER COLUMN %(cname)s %(changes)s;'
@@ -312,8 +377,8 @@ def _alter_column_sql(self, table_name, column_name, changes):
sql_changes.append(change_sql[change_val[1]])
elif isinstance(change_sql, list):
sql_changes.append(
- change_sql[0] % change_val[1] if change_val[1] is not None
- else change_sql[1]
+ change_sql[0] % change_val[1] if change_val[1] is not None else
+ change_sql[1]
)
else:
sql_changes.append(change_sql % change_val[1])
diff --git a/emmett/orm/migrations/generation.py b/emmett/orm/migrations/generation.py
index 865d65e4..54e91ecd 100644
--- a/emmett/orm/migrations/generation.py
+++ b/emmett/orm/migrations/generation.py
@@ -13,46 +13,81 @@
:license: BSD-3-Clause
"""
+from __future__ import annotations
+
from collections import OrderedDict
+from typing import Any, Dict, List, Optional
+from ..._shortcuts import hashlib_sha1
from ...datastructures import OrderedSet
-from .base import Column
-from .helpers import Dispatcher, DEFAULT_VALUE, _feasible_as_dbms_default
-from .operations import UpgradeOps, CreateTableOp, DropTableOp, \
- AddColumnOp, DropColumnOp, AlterColumnOp, CreateIndexOp, DropIndexOp
-
-
-class MetaTable(object):
- def __init__(self, name, columns=[], indexes=[], **kw):
+from ..objects import Table
+from .base import Column, Database
+from .helpers import Dispatcher, DEFAULT_VALUE
+from .operations import (
+ AddColumnOp,
+ AlterColumnOp,
+ CreateForeignKeyConstraintOp,
+ CreateIndexOp,
+ CreateTableOp,
+ DropColumnOp,
+ DropForeignKeyConstraintOp,
+ DropIndexOp,
+ DropTableOp,
+ MigrationOp,
+ OpContainer,
+ Operation,
+ UpgradeOps
+)
+from .scripts import ScriptDir
+
+
+class MetaTable:
+ def __init__(
+ self,
+ name: str,
+ columns: List[Column] = [],
+ primary_keys: List[str] = [],
+ **kw: Any
+ ):
self.name = name
self.columns = OrderedDict()
for column in columns:
self.columns[column.name] = column
- self.indexes = {}
+ self.primary_keys = primary_keys
+ self.indexes: Dict[str, MetaIndex] = {}
+ self.foreign_keys: Dict[str, MetaForeignKey] = {}
self.kw = kw
@property
- def fields(self):
+ def fields(self) -> List[str]:
return list(self.columns)
- def __getitem__(self, name):
+ def __getitem__(self, name: str) -> Column:
return self.columns[name]
- def __setitem__(self, name, value):
+ def __setitem__(self, name: str, value: Column):
self.columns[name] = value
- def __delitem__(self, name):
+ def __delitem__(self, name: str):
del self.columns[name]
- def __repr__(self):
+ def __repr__(self) -> str:
return "Table(%r, %s)" % (
self.name,
", ".join(["%s" % column for column in self.columns.values()])
)
-class MetaIndex(object):
- def __init__(self, table_name, name, fields, expressions, unique, **kw):
+class MetaIndex:
+ def __init__(
+ self,
+ table_name: str,
+ name: str,
+ fields: List[str],
+ expressions: List[str],
+ unique: bool,
+ **kw: Any
+ ):
self.table_name = table_name
self.name = name
self.fields = fields
@@ -61,10 +96,10 @@ def __init__(self, table_name, name, fields, expressions, unique, **kw):
self.kw = kw
@property
- def where(self):
+ def where(self) -> Optional[str]:
return self.kw.get('where')
- def __repr__(self):
+ def __repr__(self) -> str:
opts = [('expressions', self.expressions), ('unique', self.unique)]
for key, val in self.kw.items():
opts.append((key, val))
@@ -74,56 +109,131 @@ def __repr__(self):
)
-class MetaData(object):
- def __init__(self):
- self.tables = {}
+class MetaForeignKey:
+ def __init__(
+ self,
+ table_name: str,
+ name: str,
+ column_names: List[str],
+ foreign_table_name: str,
+ foreign_keys: List[str],
+ on_delete: str,
+ **kw
+ ):
+ self.table_name = table_name
+ self.name = name
+ self.column_names = column_names
+ self.foreign_table_name = foreign_table_name
+ self.foreign_keys = foreign_keys
+ self.on_delete = on_delete
+ self.kw = kw
+
+ @property
+ def _hash(self) -> str:
+ return hashlib_sha1(
+ f"{self.table_name}:{self.name}:{self.on_delete}:"
+ f"{repr(sorted(self.column_names))}:{repr(sorted(self.foreign_keys))}"
+ ).hexdigest()
+
+ def __eq__(self, obj: Any) -> bool:
+ if isinstance(obj, MetaForeignKey):
+ return self._hash == obj._hash
+ return False
+
+ def __repr__(self) -> str:
+ return "ForeignKey(%r, %r, %r, %r, %r, on_delete=%r)" % (
+ self.name,
+ self.table_name,
+ self.foreign_table_name,
+ self.column_names,
+ self.foreign_keys,
+ self.on_delete
+ )
+
- def create_table(self, name, columns, **kw):
- self.tables[name] = MetaTable(name, columns)
+class MetaData:
+ def __init__(self):
+ self.tables: Dict[str, MetaTable] = {}
+
+ def create_table(
+ self,
+ name: str,
+ columns: List[Column],
+ primary_keys: List[str],
+ **kw: Any
+ ):
+ self.tables[name] = MetaTable(name, columns, primary_keys, **kw)
- def drop_table(self, name):
+ def drop_table(self, name: str):
del self.tables[name]
- def add_column(self, table, column):
+ def add_column(self, table: str, column: Column):
self.tables[table][column.name] = column
- def drop_column(self, table, column):
+ def drop_column(self, table: str, column: str):
del self.tables[table][column]
- def change_column(self, table_name, column_name, changes):
+ def change_column(self, table_name: str, column_name: str, changes: Dict[str, Any]):
self.tables[table_name][column_name].update(**changes)
def create_index(
- self, table_name, index_name, fields, expressions, unique, **kw
+ self,
+ table_name: str,
+ index_name: str,
+ fields: List[str],
+ expressions: List[str],
+ unique: bool,
+ **kw: Any
):
self.tables[table_name].indexes[index_name] = MetaIndex(
table_name, index_name, fields, expressions, unique, **kw
)
- def drop_index(self, table_name, index_name):
+ def drop_index(self, table_name: str, index_name: str):
del self.tables[table_name].indexes[index_name]
+ def create_foreign_key_constraint(
+ self,
+ table_name: str,
+ constraint_name: str,
+ column_names: List[str],
+ foreign_table_name: str,
+ foreign_keys: List[str],
+ on_delete: str
+ ):
+ self.tables[table_name].foreign_keys[constraint_name] = MetaForeignKey(
+ table_name,
+ constraint_name,
+ column_names,
+ foreign_table_name,
+ foreign_keys,
+ on_delete
+ )
+
+ def drop_foreign_key_constraint(self, table_name: str, constraint_name: str):
+ del self.tables[table_name].foreign_keys[constraint_name]
+
-class Comparator(object):
- def __init__(self, db, meta):
+class Comparator:
+ def __init__(self, db: Database, meta: MetaData):
self.db = db
self.meta = meta
- def make_ops(self):
- self.ops = []
+ def make_ops(self) -> List[Operation]:
+ self.ops: List[Operation] = []
self.tables()
return self.ops
- def _build_metatable(self, dbtable):
- columns = []
- for field in list(dbtable):
- columns.append(Column.from_field(field))
+ def _build_metatable(self, dbtable: Table):
return MetaTable(
dbtable._tablename,
- columns
+ [
+ Column.from_field(field) for field in list(dbtable)
+ ],
+ primary_keys=list(dbtable._primary_keys)
)
- def _build_metaindex(self, dbtable, index_name):
+ def _build_metaindex(self, dbtable: Table, index_name: str) -> MetaIndex:
model = dbtable._model_
dbindex = model._indexes_[index_name]
kw = {}
@@ -131,13 +241,27 @@ def _build_metaindex(self, dbtable, index_name):
if 'where' in dbindex:
kw['where'] = str(dbindex['where'])
rv = MetaIndex(
- model.tablename, index_name,
+ model.tablename,
+ index_name,
[field for field in dbindex['fields']],
[str(expr) for expr in dbindex['expressions']],
- dbindex['unique'], **kw
+ dbindex['unique'],
+ **kw
)
return rv
+ def _build_metafk(self, dbtable: Table, fk_name: str) -> MetaForeignKey:
+ model = dbtable._model_
+ dbfk = model._foreign_keys_[fk_name]
+ return MetaForeignKey(
+ model.tablename,
+ fk_name,
+ dbfk['fields_local'],
+ dbfk['table'],
+ dbfk['fields_foreign'],
+ dbfk['on_delete']
+ )
+
def tables(self):
db_table_names = OrderedSet([t._tablename for t in self.db])
meta_table_names = OrderedSet(list(self.meta.tables))
@@ -146,30 +270,34 @@ def tables(self):
meta_table = self._build_metatable(self.db[table_name])
self.ops.append(CreateTableOp.from_table(meta_table))
self.indexes_and_uniques(self.db[table_name], meta_table)
+ self.foreign_keys(self.db[table_name], meta_table)
#: removed tables
for table_name in meta_table_names.difference(db_table_names):
#: remove table indexes too
metatable = self.meta.tables[table_name]
- for idx_name, idx in metatable.indexes.items():
+ for idx in metatable.indexes.values():
self.ops.append(DropIndexOp.from_index(idx))
#: remove table
- self.ops.append(
- DropTableOp.from_table(self.meta.tables[table_name]))
+ self.ops.append(DropTableOp.from_table(self.meta.tables[table_name]))
#: existing tables
for table_name in meta_table_names.intersection(db_table_names):
- self.columns(
- self.db[table_name], self.meta.tables[table_name])
- self.table(
- self.db[table_name], self.meta.tables[table_name])
+ self.columns(self.db[table_name], self.meta.tables[table_name])
+ self.table(self.db[table_name], self.meta.tables[table_name])
- def table(self, dbtable, metatable):
+ def table(self, dbtable: Table, metatable: MetaTable):
self.indexes_and_uniques(dbtable, metatable)
self.foreign_keys(dbtable, metatable)
- def indexes_and_uniques(self, dbtable, metatable, ops_stack=None):
+ def indexes_and_uniques(
+ self,
+ dbtable: Table,
+ metatable: MetaTable,
+ ops_stack: Optional[List[Operation]] = None
+ ):
ops = ops_stack if ops_stack is not None else self.ops
db_index_names = OrderedSet(
- [idxname for idxname in dbtable._model_._indexes_.keys()])
+ [idxname for idxname in dbtable._model_._indexes_.keys()]
+ )
meta_index_names = OrderedSet(list(metatable.indexes))
#: removed indexes
for index_name in meta_index_names.difference(db_index_names):
@@ -178,7 +306,9 @@ def indexes_and_uniques(self, dbtable, metatable, ops_stack=None):
for index_name in db_index_names.difference(meta_index_names):
ops.append(
CreateIndexOp.from_index(
- self._build_metaindex(dbtable, index_name)))
+ self._build_metaindex(dbtable, index_name)
+ )
+ )
#: existing indexes
for index_name in meta_index_names.intersection(db_index_names):
metaindex = metatable.indexes[index_name]
@@ -191,69 +321,105 @@ def indexes_and_uniques(self, dbtable, metatable, ops_stack=None):
ops.append(CreateIndexOp.from_index(dbindex))
# TODO: uniques
- def foreign_keys(self, dbtable, metatable, ops_stack=None):
- # TODO
- pass
-
- def columns(self, dbtable, metatable):
+ def foreign_keys(
+ self,
+ dbtable: Table,
+ metatable: MetaTable,
+ ops_stack: Optional[List[Operation]] = None
+ ):
+ ops = ops_stack if ops_stack is not None else self.ops
+ db_fk_names = OrderedSet(
+ [fkname for fkname in dbtable._model_._foreign_keys_.keys()]
+ )
+ meta_fk_names = OrderedSet(list(metatable.foreign_keys))
+ #: removed fks
+ for fk_name in meta_fk_names.difference(db_fk_names):
+ ops.append(
+ DropForeignKeyConstraintOp.from_foreign_key(
+ metatable.foreign_keys[fk_name]
+ )
+ )
+ #: new fks
+ for fk_name in db_fk_names.difference(meta_fk_names):
+ ops.append(
+ CreateForeignKeyConstraintOp.from_foreign_key(
+ self._build_metafk(dbtable, fk_name)
+ )
+ )
+ #: existing fks
+ for fk_name in meta_fk_names.intersection(db_fk_names):
+ metafk = metatable.foreign_keys[fk_name]
+ dbfk = self._build_metafk(dbtable, fk_name)
+ if metafk != dbfk:
+ ops.append(DropForeignKeyConstraintOp.from_foreign_key(metafk))
+ ops.append(CreateForeignKeyConstraintOp.from_foreign_key(dbfk))
+
+ def columns(self, dbtable: Table, metatable: MetaTable):
db_column_names = OrderedSet([fname for fname in dbtable.fields])
meta_column_names = OrderedSet(metatable.fields)
#: new columns
for column_name in db_column_names.difference(meta_column_names):
self.ops.append(AddColumnOp.from_column_and_tablename(
- dbtable._tablename, Column.from_field(dbtable[column_name])))
+ dbtable._tablename, Column.from_field(dbtable[column_name])
+ ))
#: existing columns
for column_name in meta_column_names.intersection(db_column_names):
self.ops.append(AlterColumnOp(dbtable._tablename, column_name))
self.column(
- dbtable[column_name], metatable.columns[column_name])
+ Column.from_field(dbtable[column_name]),
+ metatable.columns[column_name]
+ )
if not self.ops[-1].has_changes():
self.ops.pop()
#: removed columns
for column_name in meta_column_names.difference(db_column_names):
self.ops.append(
DropColumnOp.from_column_and_tablename(
- dbtable._tablename, metatable.columns[column_name]))
+ dbtable._tablename, metatable.columns[column_name]
+ )
+ )
- def column(self, dbcolumn, metacolumn):
+ def column(self, dbcolumn: Column, metacolumn: Column):
self.notnulls(dbcolumn, metacolumn)
self.types(dbcolumn, metacolumn)
self.lengths(dbcolumn, metacolumn)
self.defaults(dbcolumn, metacolumn)
- def types(self, dbcolumn, metacolumn):
+ def types(self, dbcolumn: Column, metacolumn: Column):
self.ops[-1].existing_type = metacolumn.type
if dbcolumn.type != metacolumn.type:
self.ops[-1].modify_type = dbcolumn.type
+ if dbcolumn.geometry_type and metacolumn.geometry_type:
+ for key in ("geometry_type", "srid", "dimension"):
+ self.ops[-1].kw[f"existing_{key}"] = metacolumn[key]
+ if dbcolumn[key] != metacolumn[key]:
+ self.ops[-1].kw[f"modify_{key}"] = dbcolumn[key]
- def lengths(self, dbcolumn, metacolumn):
+ def lengths(self, dbcolumn: Column, metacolumn: Column):
self.ops[-1].existing_length = metacolumn.length
if any(
field.type == "string" for field in [dbcolumn, metacolumn]
) and dbcolumn.length != metacolumn.length:
self.ops[-1].modify_length = dbcolumn.length
- def notnulls(self, dbcolumn, metacolumn):
+ def notnulls(self, dbcolumn: Column, metacolumn: Column):
self.ops[-1].existing_notnull = metacolumn.notnull
if dbcolumn.notnull != metacolumn.notnull:
self.ops[-1].modify_notnull = dbcolumn.notnull
- def defaults(self, dbcolumn, metacolumn):
- oldv, newv = metacolumn.default, dbcolumn.default
- self.ops[-1].existing_default = oldv
- if newv != oldv:
- if not all(callable(v) for v in [oldv, newv]):
- if _feasible_as_dbms_default(newv):
- self.ops[-1].modify_default = newv
+ def defaults(self, dbcolumn: Column, metacolumn: Column):
+ self.ops[-1].existing_default = metacolumn.default
+ if dbcolumn.default != metacolumn.default:
+ self.ops[-1].modify_default = dbcolumn.default
@classmethod
- def compare(cls, db, meta):
+ def compare(cls, db: Database, meta: MetaData) -> UpgradeOps:
ops = cls(db, meta).make_ops()
return UpgradeOps(ops)
-class Generator(object):
- def __init__(self, db, scriptdir, head):
+class Generator:
+ def __init__(self, db: Database, scriptdir: ScriptDir, head: str):
self.db = db
self.scriptdir = scriptdir
self.head = head
@@ -265,23 +431,29 @@ def _load_head_to_meta(self):
list(self.scriptdir.walk_revisions("base", self.head))
):
migration = revision.migration_class(
- None, self.meta, is_meta=True)
+ None, self.meta, is_meta=True
+ )
migration.up()
- def generate(self):
+ def generate(self) -> UpgradeOps:
return Comparator.compare(self.db, self.meta)
@classmethod
- def generate_from(cls, dal, scriptdir, head):
+ def generate_from(
+ cls,
+ dal: Database,
+ scriptdir: ScriptDir,
+ head: str
+ ) -> UpgradeOps:
return cls(dal, scriptdir, head).generate()
-class Renderer(object):
- def render_op(self, op):
+class Renderer:
+ def render_op(self, op: Operation) -> str:
op_renderer = renderers.dispatch(op)
return op_renderer(op)
- def render_opcontainer(self, op_container):
+ def render_opcontainer(self, op_container: OpContainer) -> List[str]:
rv = []
if not op_container.ops:
rv.append("pass")
@@ -291,17 +463,19 @@ def render_opcontainer(self, op_container):
return rv
@classmethod
- def render_migration(cls, migration_op):
+ def render_migration(cls, migration_op: MigrationOp):
r = cls()
- return r.render_opcontainer(migration_op.upgrade_ops), \
+ return (
+ r.render_opcontainer(migration_op.upgrade_ops),
r.render_opcontainer(migration_op.downgrade_ops)
+ )
renderers = Dispatcher()
@renderers.dispatch_for(CreateTableOp)
-def _add_table(op):
+def _add_table(op: CreateTableOp) -> str:
table = op.to_table()
args = [
@@ -321,10 +495,13 @@ def _add_table(op):
else:
args = (',\n' + indent).join(args)
- text = ("self.create_table(\n" + indent + "%(tablename)r,\n" + indent +
- "%(args)s") % {
+ text = (
+ "self.create_table(\n" + indent + "%(tablename)r,\n" + indent + "%(args)s,\n" +
+ indent + "primary_keys=%(primary_keys)r"
+ ) % {
'tablename': op.table_name,
- 'args': args
+ 'args': args,
+ 'primary_keys': table.primary_keys
}
for k in sorted(op.kw):
text += ",\n" + indent + "%s=%r" % (k.replace(" ", "_"), op.kw[k])
@@ -333,7 +510,7 @@ def _add_table(op):
@renderers.dispatch_for(DropTableOp)
-def _drop_table(op):
+def _drop_table(op: DropTableOp) -> str:
text = "self.drop_table(%(tname)r" % {
"tname": op.table_name
}
@@ -341,7 +518,7 @@ def _drop_table(op):
return text
-def _render_column(column):
+def _render_column(column: Column) -> str:
opts = []
if column.default is not None:
@@ -357,6 +534,10 @@ def _render_column(column):
opts.append(("length", column.length))
elif column.type.startswith('reference'):
opts.append(("ondelete", column.ondelete))
+ elif column.type.startswith("geo"):
+ for key in ("geometry_type", "srid", "dimension"):
+ if column[key] is not None:
+ opts.append((key, column[key]))
kw_str = ""
if opts:
@@ -370,7 +551,7 @@ def _render_column(column):
@renderers.dispatch_for(AddColumnOp)
-def _add_column(op):
+def _add_column(op: AddColumnOp) -> str:
return "self.add_column(%(tname)r, %(column)s)" % {
"tname": op.table_name,
"column": _render_column(op.column)
@@ -378,7 +559,7 @@ def _add_column(op):
@renderers.dispatch_for(DropColumnOp)
-def _drop_column(op):
+def _drop_column(op: DropTableOp) -> str:
return "self.drop_column(%(tname)r, %(cname)r)" % {
"tname": op.table_name,
"cname": op.column_name
@@ -386,7 +567,7 @@ def _drop_column(op):
@renderers.dispatch_for(AlterColumnOp)
-def _alter_column(op):
+def _alter_column(op: AlterColumnOp) -> str:
indent = " " * 12
text = "self.alter_column(%(tname)r, %(cname)r" % {
'tname': op.table_name,
@@ -408,13 +589,16 @@ def _alter_column(op):
text += ",\n%sexisting_notnull=%r" % (indent, op.existing_notnull)
if op.modify_default is DEFAULT_VALUE and op.existing_default:
text += ",\n%sexisting_default=%s" % (indent, op.existing_default)
+ for key, val in op.kw.items():
+ if key.startswith("existing_") or key.startswith("modify_"):
+ text += ",\n%s%s=%r" % (indent, key, val)
text += ")"
return text
@renderers.dispatch_for(CreateIndexOp)
-def _add_index(op):
+def _add_index(op: CreateIndexOp) -> str:
kw_str = ""
if op.kw:
kw_str = ", %s" % ", ".join(
@@ -428,8 +612,36 @@ def _add_index(op):
@renderers.dispatch_for(DropIndexOp)
-def _drop_index(op):
+def _drop_index(op: DropIndexOp) -> str:
return "self.drop_index(%(iname)r, %(tname)r)" % {
"tname": op.table_name,
"iname": op.index_name
}
+
+
+@renderers.dispatch_for(CreateForeignKeyConstraintOp)
+def _add_fk_constraint(op: CreateForeignKeyConstraintOp) -> str:
+ kw_str = ""
+ if op.kw:
+ kw_str = ", %s" % ", ".join(
+ ["%s=%r" % (key, val) for key, val in op.kw.items()]
+ )
+ return "self.create_foreign_key(%s%s)" % (
+ "%r, %r, %r, %r, %r, on_delete=%r" % (
+ op.constraint_name,
+ op.table_name,
+ op.foreign_table_name,
+ op.column_names,
+ op.foreign_keys,
+ op.on_delete
+ ),
+ kw_str
+ )
+
+
+@renderers.dispatch_for(DropForeignKeyConstraintOp)
+def _drop_fk_constraint(op: DropForeignKeyConstraintOp) -> str:
+ return "self.drop_foreign_key(%(cname)r, %(tname)r)" % {
+ "tname": op.table_name,
+ "cname": op.constraint_name
+ }
diff --git a/emmett/orm/migrations/helpers.py b/emmett/orm/migrations/helpers.py
index dd97a577..0f804b7b 100644
--- a/emmett/orm/migrations/helpers.py
+++ b/emmett/orm/migrations/helpers.py
@@ -9,11 +9,21 @@
:license: BSD-3-Clause
"""
+from __future__ import annotations
+
from collections.abc import Iterable
from contextlib import contextmanager
+from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, Type
from uuid import uuid4
+from pydal.adapters.base import BaseAdapter
+
from ...datastructures import _unique_list
+from .base import Database
+
+if TYPE_CHECKING:
+ from .engine import MetaEngine
+ from .operations import Operation
DEFAULT_VALUE = lambda: None
@@ -23,62 +33,64 @@ def make_migration_id():
return uuid4().hex[-12:]
-class WrappedOperation(object):
- def __init__(self, op_class, name, engine):
+class WrappedOperation:
+ def __init__(self, op_class: Type[Operation], name: str, engine: MetaEngine):
self.op_class = op_class
self.name = name
self.engine = engine
- def __call__(self, *args, **kwargs):
+ def __call__(self, *args: Any, **kwargs: Any) -> Any:
op = getattr(self.op_class, self.name)(*args, **kwargs)
op._env_load_(self.engine)
return op.run()
-class Dispatcher(object):
+class Dispatcher:
def __init__(self):
- self._registry = {}
+ self._registry: Dict[Type[Operation], Callable[[Operation], str]] = {}
- def dispatch_for(self, target):
- def wrap(fn):
+ def dispatch_for(
+ self,
+ target: Type[Operation]
+ ) -> Callable[[Callable[[Operation], str]], Callable[[Operation], str]]:
+ def wrap(fn: Callable[[Operation], str]) -> Callable[[Operation], str]:
self._registry[target] = fn
return fn
return wrap
- def dispatch(self, obj):
+ def dispatch(self, obj: Operation):
targets = type(obj).__mro__
for target in targets:
if target in self._registry:
return self._registry[target]
- else:
- raise ValueError("no dispatch function for object: %s" % obj)
+ raise ValueError(f"no dispatch function for object: {obj}")
class DryRunAdapter:
- def __init__(self, adapter, logger):
+ def __init__(self, adapter: BaseAdapter, logger: Any):
self.adapter = adapter
self.__dlogger = logger
- def __getattr__(self, name):
+ def __getattr__(self, name: str) -> Any:
return getattr(self.adapter, name)
- def execute(self, sql):
+ def execute(self, sql: str):
self.__dlogger(sql)
class DryRunDatabase:
- def __init__(self, db, logger):
+ def __init__(self, db: Database, logger: Any):
self.db = db
self._adapter = DryRunAdapter(db._adapter, logger)
- def __getattr__(self, name):
+ def __getattr__(self, name: str) -> Any:
return getattr(self.db, name)
- def __getitem__(self, key):
+ def __getitem__(self, key: str) -> Any:
return self.db[key]
@contextmanager
- def connection(self, *args, **kwargs):
+ def connection(self, *args: Any, **kwargs: Any) -> Generator[None, None, None]:
yield None
diff --git a/emmett/orm/migrations/operations.py b/emmett/orm/migrations/operations.py
index 4ea421db..b2558c6b 100644
--- a/emmett/orm/migrations/operations.py
+++ b/emmett/orm/migrations/operations.py
@@ -13,25 +13,37 @@
:license: BSD-3-Clause
"""
+from __future__ import annotations
+
import re
+
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
+
from .base import Migration, Column
from .helpers import DEFAULT_VALUE
+if TYPE_CHECKING:
+ from .engine import MetaEngine
+ from .generation import MetaTable, MetaIndex, MetaForeignKey
-class Operation(object):
- def _env_load_(self, engine):
+
+class Operation:
+ def _env_load_(self, engine: MetaEngine):
self.engine = engine
+ def reverse(self) -> Operation:
+ raise NotImplementedError
+
def run(self):
pass
class OpContainer(Operation):
#: represent a sequence of operations
- def __init__(self, ops=()):
+ def __init__(self, ops: List[Operation] = []):
self.ops = ops
- def is_empty(self):
+ def is_empty(self) -> bool:
return not self.ops
def as_diffs(self):
@@ -49,50 +61,51 @@ def _ops_as_diffs(cls, migrations):
class ModifyTableOps(OpContainer):
#: a sequence of operations that all apply to a single Table
- def __init__(self, table_name, ops):
- super(ModifyTableOps, self).__init__(ops)
+ def __init__(self, table_name: str, ops: List[Operation]):
+ super().__init__(ops)
self.table_name = table_name
- def reverse(self):
+ def reverse(self) -> ModifyTableOps:
return ModifyTableOps(
self.table_name,
- ops=list(reversed(
- [op.reverse() for op in self.ops]
- ))
+ ops=list(reversed([op.reverse() for op in self.ops]))
)
class UpgradeOps(OpContainer):
#: contains a sequence of operations that would apply during upgrade
- def __init__(self, ops=(), upgrade_token="upgrades"):
- super(UpgradeOps, self).__init__(ops=ops)
+ def __init__(self, ops: List[Operation] = [], upgrade_token: str = "upgrades"):
+ super().__init__(ops=ops)
self.upgrade_token = upgrade_token
- def reverse(self):
+ def reverse(self) -> DowngradeOps:
return DowngradeOps(
- ops=list(reversed(
- [op.reverse() for op in self.ops]
- ))
+ ops=list(reversed([op.reverse() for op in self.ops]))
)
class DowngradeOps(OpContainer):
#: contains a sequence of operations that would apply during downgrade
- def __init__(self, ops=(), downgrade_token="downgrades"):
- super(DowngradeOps, self).__init__(ops=ops)
+ def __init__(self, ops: List[Operation] = [], downgrade_token: str = "downgrades"):
+ super().__init__(ops=ops)
self.downgrade_token = downgrade_token
def reverse(self):
return UpgradeOps(
- ops=list(reversed(
- [op.reverse() for op in self.ops]
- ))
+ ops=list(reversed([op.reverse() for op in self.ops]))
)
class MigrationOp(Operation):
- def __init__(self, rev_id, upgrade_ops, downgrade_ops, message=None,
- head=None, splice=None):
+ def __init__(
+ self,
+ rev_id: str,
+ upgrade_ops: UpgradeOps,
+ downgrade_ops: DowngradeOps,
+ message: Optional[str] = None,
+ head: Optional[str] = None,
+ splice: Any = None
+ ):
self.rev_id = rev_id
self.message = message
self.head = head
@@ -103,73 +116,98 @@ def __init__(self, rev_id, upgrade_ops, downgrade_ops, message=None,
@Migration.register_operation("create_table")
class CreateTableOp(Operation):
- def __init__(self, table_name, columns, _orig_table=None, **kw):
+ def __init__(
+ self,
+ table_name: str,
+ columns: List[Column],
+ primary_keys: List[str] = [],
+ _orig_table: Optional[MetaTable] = None,
+ **kw: Any
+ ):
self.table_name = table_name
self.columns = columns
+ self.primary_keys = primary_keys
self.kw = kw
self._orig_table = _orig_table
- def reverse(self):
+ def reverse(self) -> DropTableOp:
return DropTableOp.from_table(self.to_table())
- def to_diff_tuple(self):
+ def to_diff_tuple(self) -> Tuple[str, MetaTable]:
return ("add_table", self.to_table())
@classmethod
- def from_table(cls, table):
+ def from_table(cls, table: MetaTable) -> CreateTableOp:
return cls(
table.name,
[table[colname] for colname in table.fields],
+ list(table.primary_keys),
_orig_table=table
)
- def to_table(self, migration_context=None):
+ def to_table(self, migration_context: Any = None) -> MetaTable:
if self._orig_table is not None:
return self._orig_table
from .generation import MetaTable
return MetaTable(
- self.table_name, self.columns, **self.kw
+ self.table_name,
+ self.columns,
+ self.primary_keys,
+ **self.kw
)
@classmethod
- def create_table(cls, table_name, *columns, **kw):
+ def create_table(
+ cls,
+ table_name: str,
+ *columns: Column,
+ **kw: Any
+ ) -> CreateTableOp:
return cls(table_name, columns, **kw)
def run(self):
- self.engine.create_table(self.table_name, self.columns, **self.kw)
+ self.engine.create_table(
+ self.table_name, self.columns, self.primary_keys, **self.kw
+ )
@Migration.register_operation("drop_table")
class DropTableOp(Operation):
- def __init__(self, table_name, table_kw=None, _orig_table=None):
+ def __init__(
+ self,
+ table_name: str,
+ table_kw: Optional[Dict[str, Any]] = None,
+ _orig_table: Optional[MetaTable] = None
+ ):
self.table_name = table_name
self.table_kw = table_kw or {}
self._orig_table = _orig_table
- def to_diff_tuple(self):
+ def to_diff_tuple(self) -> Tuple[str, MetaTable]:
return ("remove_table", self.to_table())
- def reverse(self):
+ def reverse(self) -> CreateTableOp:
if self._orig_table is None:
raise ValueError(
- "operation is not reversible; "
- "original table is not present")
+ "operation is not reversible; original table is not present"
+ )
return CreateTableOp.from_table(self._orig_table)
@classmethod
- def from_table(cls, table):
+ def from_table(cls, table: MetaTable) -> DropTableOp:
return cls(table.name, _orig_table=table)
- def to_table(self):
+ def to_table(self) -> MetaTable:
if self._orig_table is not None:
return self._orig_table
from .generation import MetaTable
return MetaTable(
self.table_name,
- **self.table_kw)
+ **self.table_kw
+ )
@classmethod
- def drop_table(cls, table_name, **kw):
+ def drop_table(cls, table_name: str, **kw: Any) -> DropTableOp:
return cls(table_name, table_kw=kw)
def run(self):
@@ -177,48 +215,45 @@ def run(self):
class AlterTableOp(Operation):
- def __init__(self, table_name):
+ def __init__(self, table_name: str):
self.table_name = table_name
@Migration.register_operation("rename_table")
class RenameTableOp(AlterTableOp):
- def __init__(self, old_table_name, new_table_name):
- super(RenameTableOp, self).__init__(old_table_name)
+ def __init__(self, old_table_name: str, new_table_name: str):
+ super().__init__(old_table_name)
self.new_table_name = new_table_name
@classmethod
- def rename_table(cls, old_table_name, new_table_name):
+ def rename_table(cls, old_table_name: str, new_table_name: str) -> RenameTableOp:
return cls(old_table_name, new_table_name)
def run(self):
- raise NotImplementedError(
- 'Table renaming is currently not supported.'
- )
+ raise NotImplementedError('Table renaming is currently not supported.')
@Migration.register_operation("add_column")
class AddColumnOp(AlterTableOp):
- def __init__(self, table_name, column):
- super(AddColumnOp, self).__init__(table_name)
+ def __init__(self, table_name: str, column: Column):
+ super().__init__(table_name)
self.column = column
- def reverse(self):
- return DropColumnOp.from_column_and_tablename(
- self.table_name, self.column)
+ def reverse(self) -> DropColumnOp:
+ return DropColumnOp.from_column_and_tablename(self.table_name, self.column)
- def to_diff_tuple(self):
+ def to_diff_tuple(self) -> Tuple[str, str, Column]:
return ("add_column", self.table_name, self.column)
- def to_column(self):
+ def to_column(self) -> Column:
return self.column
@classmethod
- def from_column_and_tablename(cls, tname, col):
+ def from_column_and_tablename(cls, tname: str, col: Column) -> AddColumnOp:
return cls(tname, col)
@classmethod
- def add_column(cls, table_name, column):
+ def add_column(cls, table_name: str, column: Column) -> AddColumnOp:
return cls(table_name, column)
def run(self):
@@ -227,35 +262,42 @@ def run(self):
@Migration.register_operation("drop_column")
class DropColumnOp(AlterTableOp):
- def __init__(self, table_name, column_name, _orig_column=None, **kw):
- super(DropColumnOp, self).__init__(table_name)
+ def __init__(
+ self,
+ table_name: str,
+ column_name: str,
+ _orig_column: Optional[Column] = None,
+ **kw: Any
+ ):
+ super().__init__(table_name)
self.column_name = column_name
self.kw = kw
self._orig_column = _orig_column
- def to_diff_tuple(self):
+ def to_diff_tuple(self) -> Tuple[str, str, Column]:
return ("remove_column", self.table_name, self.to_column())
- def reverse(self):
+ def reverse(self) -> AddColumnOp:
if self._orig_column is None:
raise ValueError(
- "operation is not reversible; "
- "original column is not present")
+ "operation is not reversible; original column is not present"
+ )
return AddColumnOp.from_column_and_tablename(
- self.table_name, self._orig_column)
+ self.table_name, self._orig_column
+ )
@classmethod
- def from_column_and_tablename(cls, tname, col):
+ def from_column_and_tablename(cls, tname: str, col: Column) -> DropColumnOp:
return cls(tname, col.name, _orig_column=col)
- def to_column(self):
+ def to_column(self) -> Column:
if self._orig_column is not None:
return self._orig_column
return Column(self.column_name, **self.kw)
@classmethod
- def drop_column(cls, table_name, column_name, **kw):
+ def drop_column(cls, table_name: str, column_name: str, **kw: Any) -> DropColumnOp:
return cls(table_name, column_name, **kw)
def run(self):
@@ -265,20 +307,21 @@ def run(self):
@Migration.register_operation("alter_column")
class AlterColumnOp(AlterTableOp):
def __init__(
- self, table_name, column_name,
- existing_type=None,
- existing_length=None,
- existing_default=None,
- existing_notnull=None,
- modify_notnull=None,
- modify_default=DEFAULT_VALUE,
- modify_name=None,
- modify_type=None,
- modify_length=None,
- **kw
-
+ self,
+ table_name: str,
+ column_name: str,
+ existing_type: Optional[str] = None,
+ existing_length: Optional[int] = None,
+ existing_default: Any = None,
+ existing_notnull: Optional[bool] = None,
+ modify_notnull: Optional[bool] = None,
+ modify_default: Any = DEFAULT_VALUE,
+ modify_name: Optional[str] = None,
+ modify_type: Optional[str] = None,
+ modify_length: Optional[int] = None,
+ **kw: Any
):
- super(AlterColumnOp, self).__init__(table_name)
+ super().__init__(table_name)
self.column_name = column_name
self.existing_type = existing_type
self.existing_length = existing_length
@@ -291,17 +334,25 @@ def __init__(
self.modify_length = modify_length
self.kw = kw
- def to_diff_tuple(self):
+ def to_diff_tuple(self) -> List[Tuple[str, str, str, Dict[str, Any], Any, Any]]:
col_diff = []
tname, cname = self.table_name, self.column_name
if self.modify_type is not None:
col_diff.append(
(
- "modify_type", tname, cname, {
+ "modify_type",
+ tname,
+ cname,
+ {
"existing_length": self.existing_length,
"existing_notnull": self.existing_notnull,
- "existing_default": self.existing_default},
+ "existing_default": self.existing_default,
+ **{
+ nkey: nval for nkey, nval in self.kw.items()
+ if nkey.startswith('existing_')
+ }
+ },
self.existing_type,
self.modify_type
)
@@ -310,10 +361,14 @@ def to_diff_tuple(self):
if self.modify_length is not None:
col_diff.append(
(
- "modify_length", tname, cname, {
+ "modify_length",
+ tname,
+ cname,
+ {
"existing_type": self.existing_type,
"existing_notnull": self.existing_notnull,
- "existing_default": self.existing_default},
+ "existing_default": self.existing_default
+ },
self.existing_length,
self.modify_length
)
@@ -322,9 +377,13 @@ def to_diff_tuple(self):
if self.modify_notnull is not None:
col_diff.append(
(
- "modify_notnull", tname, cname, {
+ "modify_notnull",
+ tname,
+ cname,
+ {
"existing_type": self.existing_type,
- "existing_default": self.existing_default},
+ "existing_default": self.existing_default
+ },
self.existing_notnull,
self.modify_notnull
)
@@ -333,17 +392,41 @@ def to_diff_tuple(self):
if self.modify_default is not DEFAULT_VALUE:
col_diff.append(
(
- "modify_default", tname, cname, {
+ "modify_default",
+ tname,
+ cname,
+ {
"existing_notnull": self.existing_notnull,
- "existing_type": self.existing_type},
+ "existing_type": self.existing_type
+ },
self.existing_default,
self.modify_default
)
)
+ for key, val in self.kw.items():
+ if key.startswith("modify_"):
+ attr = key.split("modify_")[-1]
+ col_diff.append(
+ (
+ key,
+ tname,
+ cname,
+ {
+ "existing_type": self.existing_type,
+ **{
+ nkey: nval for nkey, nval in self.kw.items()
+ if nkey.startswith('existing_')
+ }
+ },
+ self.kw.get(f"existing_{attr}"),
+ val
+ )
+ )
+
return col_diff
- def has_changes(self):
+ def has_changes(self) -> bool:
hc = (
self.modify_notnull is not None or
self.modify_default is not DEFAULT_VALUE or
@@ -357,7 +440,7 @@ def has_changes(self):
return True
return False
- def reverse(self):
+ def reverse(self) -> AlterColumnOp:
kw = self.kw.copy()
kw['existing_type'] = self.existing_type
kw['existing_length'] = self.existing_length
@@ -389,20 +472,23 @@ def reverse(self):
@classmethod
def alter_column(
- cls, table_name, column_name,
- notnull=None,
- default=DEFAULT_VALUE,
- new_column_name=None,
- type=None,
- length=None,
- existing_type=None,
- existing_length=None,
- existing_default=None,
- existing_notnull=None,
- **kw
- ):
+ cls,
+ table_name: str,
+ column_name: str,
+ notnull: Optional[bool] = None,
+ default: Any = DEFAULT_VALUE,
+ new_column_name: Optional[str] = None,
+ type: Optional[str] = None,
+ length: Optional[int] = None,
+ existing_type: Optional[str] = None,
+ existing_length: Optional[int] = None,
+ existing_default: Any = None,
+ existing_notnull: Optional[bool] = None,
+ **kw: Any
+ ) -> AlterColumnOp:
return cls(
- table_name, column_name,
+ table_name,
+ column_name,
existing_type=existing_type,
existing_length=existing_length,
existing_default=existing_default,
@@ -417,14 +503,21 @@ def alter_column(
def run(self):
self.engine.alter_column(
- self.table_name, self.column_name, self.to_diff_tuple())
+ self.table_name, self.column_name, self.to_diff_tuple()
+ )
@Migration.register_operation("create_index")
class CreateIndexOp(Operation):
def __init__(
- self, index_name, table_name, fields=[], expressions=[], unique=False,
- _orig_index=None, **kw
+ self,
+ index_name: str,
+ table_name: str,
+ fields: List[str] = [],
+ expressions: List[str] = [],
+ unique: bool = False,
+ _orig_index: Optional[MetaIndex] = None,
+ **kw: Any
):
self.index_name = index_name
self.table_name = table_name
@@ -434,75 +527,238 @@ def __init__(
self.kw = kw
self._orig_index = _orig_index
- def reverse(self):
+ def reverse(self) -> DropIndexOp:
return DropIndexOp.from_index(self.to_index())
- def to_diff_tuple(self):
+ def to_diff_tuple(self) -> Tuple[str, MetaIndex]:
return ("create_index", self.to_index())
@classmethod
- def from_index(cls, index):
+ def from_index(cls, index: MetaIndex) -> CreateIndexOp:
return cls(
index.name, index.table_name, index.fields, index.expressions,
index.unique, _orig_index=index, **index.kw
)
- def to_index(self):
+ def to_index(self) -> MetaIndex:
if self._orig_index is not None:
return self._orig_index
from .generation import MetaIndex
return MetaIndex(
- self.table_name, self.index_name, self.fields, self.expressions,
- self.unique, **self.kw)
+ self.table_name,
+ self.index_name,
+ self.fields,
+ self.expressions,
+ self.unique,
+ **self.kw
+ )
@classmethod
def create_index(
- cls, index_name, table_name, fields=[], expressions=[], unique=False,
- **kw
- ):
+ cls,
+ index_name: str,
+ table_name: str,
+ fields: List[str] = [],
+ expressions: List[str] = [],
+ unique: bool = False,
+ **kw: Any
+ ) -> CreateIndexOp:
return cls(index_name, table_name, fields, expressions, unique, **kw)
def run(self):
self.engine.create_index(
- self.index_name, self.table_name, self.fields, self.expressions,
- self.unique, **self.kw)
+ self.index_name,
+ self.table_name,
+ self.fields,
+ self.expressions,
+ self.unique,
+ **self.kw
+ )
@Migration.register_operation("drop_index")
class DropIndexOp(Operation):
- def __init__(self, index_name, table_name=None, _orig_index=None):
+ def __init__(
+ self,
+ index_name: str,
+ table_name: Optional[str] = None,
+ _orig_index: Optional[MetaIndex] = None
+ ):
self.index_name = index_name
self.table_name = table_name
self._orig_index = _orig_index
- def to_diff_tuple(self):
+ def to_diff_tuple(self) -> Tuple[str, MetaIndex]:
return ("remove_index", self.to_index())
- def reverse(self):
+ def reverse(self) -> CreateIndexOp:
if self._orig_index is None:
raise ValueError(
- "operation is not reversible; "
- "original index is not present")
+ "operation is not reversible; original index is not present"
+ )
return CreateIndexOp.from_index(self._orig_index)
@classmethod
- def from_index(cls, index):
+ def from_index(cls, index: MetaIndex) -> DropIndexOp:
return cls(index.name, index.table_name, index)
- def to_index(self):
+ def to_index(self) -> MetaIndex:
if self._orig_index is not None:
return self._orig_index
from .generation import MetaIndex
return MetaIndex(self.table_name, self.index_name, [], [], False)
@classmethod
- def drop_index(cls, index_name, table_name):
+ def drop_index(cls, index_name: str, table_name: str) -> DropIndexOp:
return cls(index_name, table_name)
def run(self):
self.engine.drop_index(self.index_name, self.table_name)
+@Migration.register_operation("create_foreign_key")
+class CreateForeignKeyConstraintOp(AlterTableOp):
+ def __init__(
+ self,
+ name: str,
+ table_name: str,
+ foreign_table_name: str,
+ column_names: List[str],
+ foreign_keys: List[str],
+ on_delete: str,
+ _orig_fk: Optional[MetaForeignKey] = None,
+ **kw: Any
+ ):
+ super().__init__(table_name)
+ self.constraint_name = name
+ self.foreign_table_name = foreign_table_name
+ self.column_names = column_names
+ self.foreign_keys = foreign_keys
+ self.on_delete = on_delete
+ self.kw = kw
+ self._orig_fk = _orig_fk
+ if len(self.column_names) != len(self.foreign_keys):
+ raise SyntaxError("local and foreign columns number should match")
+
+ def reverse(self) -> DropForeignKeyConstraintOp:
+ return DropForeignKeyConstraintOp.from_foreign_key(self.to_foreign_key())
+
+ def to_diff_tuple(self) -> Tuple[str, MetaForeignKey]:
+ return ("create_fk_constraint", self.to_foreign_key())
+
+ @classmethod
+ def from_foreign_key(
+ cls,
+ foreign_key: MetaForeignKey
+ ) -> CreateForeignKeyConstraintOp:
+ return cls(
+ foreign_key.name,
+ foreign_key.table_name,
+ foreign_key.foreign_table_name,
+ foreign_key.column_names,
+ foreign_key.foreign_keys,
+ foreign_key.on_delete,
+ _orig_fk=foreign_key
+ )
+
+ def to_foreign_key(self) -> MetaForeignKey:
+ if self._orig_fk is not None:
+ return self._orig_fk
+
+ from .generation import MetaForeignKey
+ return MetaForeignKey(
+ self.table_name,
+ self.constraint_name,
+ self.column_names,
+ self.foreign_table_name,
+ self.foreign_keys,
+ self.on_delete
+ )
+
+ @classmethod
+ def create_foreign_key(
+ cls,
+ name: str,
+ table_name: str,
+ foreign_table_name: str,
+ column_names: List[str],
+ foreign_keys: List[str],
+ on_delete: str
+ ) -> CreateForeignKeyConstraintOp:
+ return cls(
+ name=name,
+ table_name=table_name,
+ foreign_table_name=foreign_table_name,
+ column_names=column_names,
+ foreign_keys=foreign_keys,
+ on_delete=on_delete
+ )
+
+ def run(self):
+ self.engine.create_foreign_key_constraint(
+ self.constraint_name,
+ self.table_name,
+ self.column_names,
+ self.foreign_table_name,
+ self.foreign_keys,
+ self.on_delete
+ )
+
+
+@Migration.register_operation("drop_foreign_key")
+class DropForeignKeyConstraintOp(AlterTableOp):
+ def __init__(
+ self,
+ name: str,
+ table_name: str,
+ _orig_fk: Optional[MetaForeignKey] = None,
+ **kw: Any
+ ):
+ super().__init__(table_name)
+ self.constraint_name = name
+ self.kw = kw
+ self._orig_fk = _orig_fk
+
+ def reverse(self) -> CreateForeignKeyConstraintOp:
+ if self._orig_fk is None:
+ raise ValueError(
+ "operation is not reversible; original constraint is not present"
+ )
+ return CreateForeignKeyConstraintOp.from_foreign_key(self._orig_fk)
+
+ def to_diff_tuple(self) -> Tuple[str, MetaForeignKey]:
+ return ("drop_fk_constraint", self.to_foreign_key())
+
+ @classmethod
+ def from_foreign_key(
+ cls,
+ foreign_key: MetaForeignKey
+ ) -> DropForeignKeyConstraintOp:
+ return cls(
+ foreign_key.name,
+ foreign_key.table_name,
+ _orig_fk=foreign_key
+ )
+
+ def to_foreign_key(self):
+ if self._orig_fk is not None:
+ return self._orig_fk
+
+ from .generation import MetaForeignKey
+ return MetaForeignKey(self.table_name, self.constraint_name, [], '', [], '')
+
+ @classmethod
+ def drop_foreign_key(
+ cls,
+ name: str,
+ table_name: str
+ ) -> DropForeignKeyConstraintOp:
+ return DropForeignKeyConstraintOp(name, table_name)
+
+ def run(self):
+ self.engine.drop_foreign_key_constraint(self.constraint_name, self.table_name)
+
+
# @Migration.register_operation("execute")
# class ExecuteSQLOp(Operation):
# def __init__(self, sqltext):
diff --git a/emmett/orm/migrations/scripts.py b/emmett/orm/migrations/scripts.py
index 5a57a79c..eee7b6fb 100644
--- a/emmett/orm/migrations/scripts.py
+++ b/emmett/orm/migrations/scripts.py
@@ -19,9 +19,12 @@
from contextlib import contextmanager
from datetime import datetime
+from importlib import resources
+
from renoir import Renoir
from ...html import asis
+from . import __name__ as __pkg__
from .base import Migration
from .exceptions import (
RangeNotAncestorError, MultipleHeads, ResolutionError, RevisionError
@@ -146,7 +149,8 @@ def _rev_filename(self, revid, message, creation_date):
return filename
def _generate_template(self, filename, ctx):
- rendered = self.templater.render('migration.tmpl', ctx)
+ tmpl_source = resources.read_text(__pkg__, 'migration.tmpl')
+ rendered = self.templater._render(source=tmpl_source, context=ctx)
with open(os.path.join(self.path, filename), 'w') as f:
f.write(rendered)
diff --git a/emmett/orm/models.py b/emmett/orm/models.py
index 9a303cc4..9d771de0 100644
--- a/emmett/orm/models.py
+++ b/emmett/orm/models.py
@@ -9,19 +9,40 @@
:license: BSD-3-Clause
"""
+import operator
import types
from collections import OrderedDict
+from functools import reduce
from ..datastructures import sdict
from ..utils import cachedprop
from .apis import (
- compute, rowattr, rowmethod, scope, belongs_to, refers_to, has_one,
+ compute,
+ rowattr,
+ rowmethod,
+ scope,
+ belongs_to,
+ refers_to,
+ has_one,
has_many
)
+from .errors import (
+ InsertFailureOnSave,
+ SaveException,
+ UpdateFailureOnSave,
+ ValidationError,
+ DestroyException
+)
from .helpers import (
- Callback, ReferenceData, make_tablename, camelize, decamelize,
- wrap_scope_on_model, wrap_virtual_on_model
+ Callback,
+ ReferenceData,
+ make_tablename,
+ camelize,
+ decamelize,
+ wrap_scope_on_model,
+ wrap_virtual_on_model,
+ RowReferenceMulti
)
from .objects import Field, Row
from .wrappers import HasOneWrap, HasManyWrap, HasManyViaWrap
@@ -29,7 +50,7 @@
class MetaModel(type):
_inheritable_dict_attrs_ = [
- 'indexes', 'validation', ('fields_rw', {'id': False}),
+ 'indexes', 'validation', ('fields_rw', {'id': False}), 'foreign_keys',
'default_values', 'update_values', 'repr_values',
'form_labels', 'form_info', 'form_widgets'
]
@@ -112,11 +133,13 @@ def __new__(cls, name, bases, attrs):
belongs=OrderedDict(), refers=OrderedDict(),
hasone=OrderedDict(), hasmany=OrderedDict()
)
+ super_vfields = OrderedDict()
for base in reversed(new_class.__mro__[1:]):
if hasattr(base, '_declared_fields_'):
all_fields.update(base._declared_fields_)
if hasattr(base, '_declared_virtuals_'):
all_vfields.update(base._declared_virtuals_)
+ super_vfields.update(base._declared_virtuals_)
if hasattr(base, '_declared_computations_'):
all_computations.update(base._declared_computations_)
if hasattr(base, '_declared_callbacks_'):
@@ -145,6 +168,8 @@ def __new__(cls, name, bases, attrs):
new_class._all_refers_ref_ = all_relations.refers
new_class._all_hasone_ref_ = all_relations.hasone
new_class._all_hasmany_ref_ = all_relations.hasmany
+ #: store 'super' attributes on class
+ new_class._super_virtuals_ = super_vfields
return new_class
@@ -213,6 +238,8 @@ def __init__(self):
self.migrate = self.config.get('migrate', self.db._migrate)
if not hasattr(self, 'format'):
self.format = None
+ if not hasattr(self, 'primary_keys'):
+ self.primary_keys = []
@property
def config(self):
@@ -229,21 +256,25 @@ def __parse_relation_via(self, via):
return rv
def __parse_belongs_relation(self, item, on_delete):
- rv = sdict(on_delete=on_delete)
+ rv = sdict(fk=None, on_delete=on_delete, compound=None)
if isinstance(item, dict):
rv.name = list(item)[0]
rdata = item[rv.name]
+ target = None
if isinstance(rdata, dict):
if "target" in rdata:
- rv.model = rdata["target"]
- else:
- rv.model = camelize(rv.name)
+ target = rdata["target"]
if "on_delete" in rdata:
rv.on_delete = rdata["on_delete"]
else:
- rv.model = rdata
- if rv.model == "self":
- rv.model = self.__class__.__name__
+ target = rdata
+ if not target:
+ target = camelize(rv.name)
+ if "." in target:
+ target, rv.fk = target.split(".")
+ if target == "self":
+ target = self.__class__.__name__
+ rv.model = target
else:
rv.name = item
rv.model = camelize(item)
@@ -254,13 +285,19 @@ def __build_relation_modelname(self, name, relation, singularize):
if singularize:
relation.model = relation.model[:-1]
- def __build_relation_fieldname(self, relation):
+ def __build_relation_fieldnames(self, relation):
splitted = relation.model.split('.')
relation.model = splitted[0]
if len(splitted) > 1:
- relation.field = splitted[1]
+ relation.fields = [splitted[1]]
else:
- relation.field = decamelize(self.__class__.__name__)
+ if len(self.primary_keys) > 1:
+ relation.fields = [
+ f"{decamelize(self.__class__.__name__)}_{pk}"
+ for pk in self.primary_keys
+ ]
+ else:
+ relation.fields = [decamelize(self.__class__.__name__)]
def __parse_relation_dict(self, rel, singularize):
if 'scope' in rel.model:
@@ -283,9 +320,16 @@ def __parse_many_relation(self, item, singularize=True):
rv.model = item[rv.name]
if isinstance(rv.model, dict):
if 'method' in rv.model:
- rv.field = rv.model.get(
- 'field', decamelize(self.__class__.__name__)
- )
+ if 'field' in rv.model:
+ rv.fields = [rv.model.field]
+ else:
+ if len(self.primary_keys) > 1:
+ rv.fields = [
+ f"{decamelize(self.__class__.__name__)}_{pk}"
+ for pk in self.primary_keys
+ ]
+ else:
+ rv.fields = [decamelize(self.__class__.__name__)]
rv.cast = rv.model.get('cast')
rv.method = rv.model['method']
del rv.model
@@ -295,27 +339,43 @@ def __parse_many_relation(self, item, singularize=True):
rv.name = item
self.__build_relation_modelname(item, rv, singularize)
if rv.model:
- if not rv.field:
- self.__build_relation_fieldname(rv)
+ if not rv.fields:
+ self.__build_relation_fieldnames(rv)
if rv.model == "self":
rv.model = self.__class__.__name__
+ if not rv.via:
+ rv.reverse = (
+ rv.fields[0] if len(rv.fields) == 1 else
+ decamelize(self.__class__.__name__)
+ )
return rv
def _define_props_(self):
#: create pydal's Field elements
self.fields = []
- idfield = Field('id')._make_field('id', self)
- setattr(self.__class__, 'id', idfield)
- self.fields.append(idfield)
+ if not self.primary_keys and 'id' not in self._all_fields_:
+ idfield = Field('id')._make_field('id', model=self)
+ setattr(self.__class__, 'id', idfield)
+ self.fields.append(idfield)
for name, obj in self._all_fields_.items():
if obj.modelname is not None:
- obj = Field(obj._type, *obj._args, **obj._kwargs)
+ obj = Field(obj._type, *obj._args, _kw=obj._ormkw, **obj._kwargs)
setattr(self.__class__, name, obj)
self.fields.append(obj._make_field(name, self))
+ def __find_matching_fk_definition(self, fields, rmodel):
+ match = None
+ if not set(fields).issubset(set(rmodel.primary_keys)):
+ return match
+ for key, val in self.foreign_keys.items():
+ if set(val["foreign_fields"]) == set(rmodel.primary_keys):
+ match = key
+ break
+ return match
+
def _define_relations_(self):
- _ftype_builder = lambda v: 'reference {}'.format(v)
self._virtual_relations_ = OrderedDict()
+ self._compound_relations_ = {}
bad_args_error = (
"belongs_to, has_one and has_many "
"only accept strings or dicts as arguments"
@@ -324,36 +384,106 @@ def _define_relations_(self):
_references = []
_reference_keys = ['_all_belongs_ref_', '_all_refers_ref_']
belongs_references = {}
+ belongs_fks = {}
for key in _reference_keys:
if hasattr(self, key):
_references.append(list(getattr(self, key).values()))
else:
_references.append([])
- isbelongs, ondelete = True, 'cascade'
+ is_belongs, ondelete = True, 'cascade'
for _references_obj in _references:
for item in _references_obj:
if not isinstance(item, (str, dict)):
raise RuntimeError(bad_args_error)
reference = self.__parse_belongs_relation(item, ondelete)
- if reference.model != self.__class__.__name__:
- tablename = self.db[reference.model]._tablename
+ reference.is_refers = not is_belongs
+ refmodel = self.db[reference.model]._model_
+ ref_multi_pk = len(refmodel._fieldset_pk) > 1
+ fk_def_key, fks_data, multi_fk = None, {}, []
+ if ref_multi_pk and reference.fk:
+ fk_def_key = self.__find_matching_fk_definition(
+ [reference.fk], refmodel
+ )
+ if not fk_def_key:
+ raise SyntaxError(
+ f"{self.__class__.__name__}.{reference.name} relation "
+ "targets a compound primary key table. A matching foreign "
+ "key needs to be defined into `foreign_keys`"
+ )
+ fks_data = self.foreign_keys[fk_def_key]
+ elif ref_multi_pk and not reference.fk:
+ multi_fk = list(refmodel.primary_keys)
+ elif not reference.fk:
+ reference.fk = refmodel.table._id.name
+ if multi_fk:
+ references = []
+ fks_data["fields"] = []
+ fks_data["foreign_fields"] = []
+ for fk in multi_fk:
+ refclone = sdict(reference)
+ refclone.fk = fk
+ refclone.ftype = refmodel.table[refclone.fk].type
+ refclone.name = f"{refclone.name}_{refclone.fk}"
+ refclone.compound = reference.name
+ references.append(refclone)
+ fks_data["fields"].append(refclone.name)
+ fks_data["foreign_fields"].append(refclone.fk)
+ belongs_fks[reference.name] = sdict(
+ model=reference.model,
+ name=reference.name,
+ local_fields=fks_data["fields"],
+ foreign_fields=fks_data["foreign_fields"],
+ coupled_fields=[
+ (local, fks_data["foreign_fields"][idx])
+ for idx, local in enumerate(fks_data["fields"])
+ ],
+ is_refers=reference.is_refers
+ )
+ self._compound_relations_[reference.name] = sdict(
+ model=reference.model,
+ local_fields=belongs_fks[reference.name].local_fields,
+ foreign_fields=belongs_fks[reference.name].foreign_fields,
+ coupled_fields=belongs_fks[reference.name].coupled_fields,
+ )
else:
- tablename = self.tablename
- fieldobj = Field(
- _ftype_builder(tablename),
- ondelete=reference.on_delete,
- _isrefers=not isbelongs
- )
- setattr(self.__class__, reference.name, fieldobj)
- self.fields.append(
- getattr(self, reference.name)._make_field(
- reference.name, self
+ reference.ftype = refmodel.table[reference.fk].type
+ references = [reference]
+ belongs_fks[reference.name] = sdict(
+ model=reference.model,
+ name=reference.name,
+ local_fields=[reference.name],
+ foreign_fields=[reference.fk],
+ coupled_fields=[(reference.name, reference.fk)],
+ is_refers=reference.is_refers
)
- )
- belongs_references[reference.name] = reference.model
- isbelongs = False
+ if not fk_def_key and fks_data:
+ self.foreign_keys[reference.name] = self.foreign_keys.get(
+ reference.name
+ ) or fks_data
+ for reference in references:
+ if reference.model != self.__class__.__name__:
+ tablename = self.db[reference.model]._tablename
+ else:
+ tablename = self.tablename
+ fieldobj = Field(
+ (
+ f"reference {tablename}" if not ref_multi_pk else
+ f"reference {tablename}.{reference.fk}"
+ ),
+ ondelete=reference.on_delete,
+ _isrefers=not is_belongs
+ )
+ setattr(self.__class__, reference.name, fieldobj)
+ self.fields.append(
+ getattr(self, reference.name)._make_field(
+ reference.name, self
+ )
+ )
+ belongs_references[reference.name] = reference
+ is_belongs = False
ondelete = 'nullify'
setattr(self.__class__, '_belongs_ref_', belongs_references)
+ setattr(self.__class__, '_belongs_fks_', belongs_fks)
#: has_one are mapped with rowattr
hasone_references = {}
if hasattr(self, '_all_hasone_ref_'):
@@ -385,14 +515,65 @@ def _define_relations_(self):
)(wrapper(reference))
hasmany_references[reference.name] = reference
setattr(self.__class__, '_hasmany_ref_', hasmany_references)
+ self.__define_fks()
+
+ def __define_fks(self):
+ self._foreign_keys_ = {}
+ implicit_defs = {}
+ grouped_rels = {}
+ for rname, rel in self._belongs_ref_.items():
+ rmodel = self.db[rel.model]._model_
+ if not rmodel.primary_keys and rmodel.table._id.type == 'id':
+ continue
+ if len(rmodel._fieldset_pk) > 1:
+ match = self.__find_matching_fk_definition([rel.fk], rmodel)
+ if not match:
+ raise SyntaxError(
+ f"{self.__class__.__name__}.{rname} relation targets a "
+ "compound primary key table. A matching foreign key "
+ "needs to be defined into `foreign_keys`."
+ )
+ trels = grouped_rels[rmodel.tablename] = grouped_rels.get(
+ rmodel.tablename, {
+ 'rels': {},
+ 'on_delete': self.foreign_keys[match].get(
+ "on_delete", "cascade"
+ )
+ }
+ )
+ trels['rels'][rname] = rel
+ else:
+ # NOTE: we need this since pyDAL doesn't support id/refs types != int
+ implicit_defs[rname] = {
+ 'table': rmodel.tablename,
+ 'fields_local': [rname],
+ 'fields_foreign': [rel.fk],
+ 'on_delete': Field._internal_delete[rel.on_delete]
+ }
+ for rname, rel in implicit_defs.items():
+ constraint_name = self.__create_fk_contraint_name(
+ rel['table'], *rel['fields_local']
+ )
+ self._foreign_keys_[constraint_name] = {**rel}
+ for tname, rels in grouped_rels.items():
+ constraint_name = self.__create_fk_contraint_name(
+ tname, *[rel.name for rel in rels['rels'].values()]
+ )
+ self._foreign_keys_[constraint_name] = {
+ 'table': tname,
+ 'fields_local': [rel.name for rel in rels['rels'].values()],
+ 'fields_foreign': [rel.fk for rel in rels['rels'].values()],
+ 'on_delete': Field._internal_delete[rels['on_delete']]
+ }
def _define_virtuals_(self):
self._all_rowattrs_ = {}
self._all_rowmethods_ = {}
+ self._super_rowmethods_ = {}
err = 'rowattr or rowmethod cannot have the name of an existent field!'
field_names = [field.name for field in self.fields]
for attr in ['_virtual_relations_', '_all_virtuals_']:
- for _, obj in getattr(self, attr, {}).items():
+ for obj in getattr(self, attr, {}).values():
if obj.field_name in field_names:
raise RuntimeError(err)
wrapped = wrap_virtual_on_model(self, obj.f)
@@ -403,14 +584,70 @@ def _define_virtuals_(self):
self._all_rowattrs_[obj.field_name] = wrapped
f = Field.Virtual(obj.field_name, wrapped)
self.fields.append(f)
+ for obj in self._super_virtuals_.values():
+ wrapped = wrap_virtual_on_model(self, obj.f)
+ if not isinstance(obj, rowmethod):
+ continue
+ self._super_rowmethods_[obj.field_name] = wrapped
+
+ def _set_row_persistence_id(self, row, ret):
+ row.id = ret.id
+ object.__setattr__(row, '_concrete', True)
+
+ def _set_row_persistence_pk(self, row, ret):
+ row[self.primary_keys[0]] = ret[self.primary_keys[0]]
+ object.__setattr__(row, '_concrete', True)
+
+ def _set_row_persistence_pks(self, row, ret):
+ for field_name in self.primary_keys:
+ row[field_name] = ret[field_name]
+ object.__setattr__(row, '_concrete', True)
+
+ def _unset_row_persistence(self, row):
+ for field_name in self._fieldset_pk:
+ row[field_name] = None
+ object.__setattr__(row, '_concrete', False)
def _build_rowclass_(self):
+ #: build helpers for rows
+ self._fieldset_pk = set(self.primary_keys or ['id'])
+ save_excluded_fields = (
+ set(
+ field.name for field in self.fields if
+ getattr(field, "type", None) == "id"
+ ) |
+ set(self._all_rowattrs_.keys()) |
+ set(self._all_rowmethods_.keys())
+ )
+ self._fieldset_initable = set([
+ field.name for field in self.fields
+ ]) - save_excluded_fields
+ self._fieldset_editable = set([
+ field.name for field in self.fields
+ ]) - save_excluded_fields - self._fieldset_pk
+ self._fieldset_all = self._fieldset_initable | self._fieldset_pk
+ self._fieldset_update = set([
+ field.name for field in self.fields
+ if getattr(field, "update", None) is not None
+ ]) & self._fieldset_editable
+ if not self.primary_keys:
+ self._set_row_persistence = self._set_row_persistence_id
+ elif len(self.primary_keys) == 1:
+ self._set_row_persistence = self._set_row_persistence_pk
+ else:
+ self._set_row_persistence = self._set_row_persistence_pks
+ #: create dynamic row class
clsname = self.__class__.__name__ + "Row"
attrs = {
k: cachedprop(v, name=k) for k, v in self._all_rowattrs_.items()
}
attrs.update(self._all_rowmethods_)
- self._rowclass_ = type(clsname, (Row,), attrs)
+ attrs.update(_model=self)
+ attrs.update({
+ key: property(_relation_mapper_getter_(key), _relation_mapper_setter_(key))
+ for key in self._compound_relations_.keys()
+ })
+ self._rowclass_ = type(clsname, (StructuredRow,), attrs)
globals()[clsname] = self._rowclass_
def _define_(self):
@@ -423,6 +660,7 @@ def _define_(self):
self.__define_computations()
self.__define_callbacks()
self.__define_scopes()
+ self.__define_query_helpers()
self.__define_form_utils()
self.setup()
@@ -442,6 +680,8 @@ def __define_validation(self):
def __define_access(self):
for field, value in self.fields_rw.items():
+ if field == 'id' and field not in self.table:
+ continue
if isinstance(value, (tuple, list)):
readable, writable = value
else:
@@ -468,15 +708,32 @@ def __define_computations(self):
for obj in self._all_computations_.values():
if obj.field_name not in field_names:
raise RuntimeError(err)
- # TODO add check virtuals
self.table[obj.field_name].compute = (
- lambda row, obj=obj, self=self: obj.f(self, row)
+ lambda row, obj=obj, self=self: obj.compute(self, row)
)
def __define_callbacks(self):
for obj in self._all_callbacks_.values():
for t in obj.t:
- if t in ["_before_insert", "_before_delete", "_after_delete"]:
+ if t in [
+ "_before_insert",
+ "_before_delete",
+ "_after_delete",
+ "_before_save",
+ "_after_save",
+ "_before_destroy",
+ "_after_destroy",
+ "_before_commit_insert",
+ "_before_commit_update",
+ "_before_commit_delete",
+ "_before_commit_save",
+ "_before_commit_destroy",
+ "_after_commit_insert",
+ "_after_commit_update",
+ "_after_commit_delete",
+ "_after_commit_save",
+ "_after_commit_destroy"
+ ]:
getattr(self.table, t).append(
lambda a, obj=obj, self=self: obj.f(self, a)
)
@@ -495,14 +752,20 @@ def __define_scopes(self):
classmethod(wrap_scope_on_model(obj.f))
)
- def __prepend_table_on_index_name(self, name):
- return '%s_widx__%s' % (self.tablename, name)
+ def __prepend_table_name(self, name, ns):
+ return '%s_%s__%s' % (self.tablename, ns, name)
def __create_index_name(self, *values):
components = []
for value in values:
components.append(value.replace('_', ''))
- return self.__prepend_table_on_index_name("_".join(components))
+ return self.__prepend_table_name("_".join(components), 'widx')
+
+ def __create_fk_contraint_name(self, *values):
+ components = []
+ for value in values:
+ components.append(value.replace('_', ''))
+ return self.__prepend_table_name("fk__" + "_".join(components), 'ecnt')
def __parse_index_dict(self, value):
rv = {}
@@ -528,6 +791,15 @@ def __parse_index_dict(self, value):
def __define_indexes(self):
self._indexes_ = {}
+ #: auto-define indexes based on fields
+ for field in self.fields:
+ if getattr(field, 'unique', False):
+ idx_name = self.__prepend_table_name(f'{field.name}_unique', 'widx')
+ idx_dict = self.__parse_index_dict(
+ {'fields': [field.name], 'unique': True}
+ )
+ self._indexes_[idx_name] = idx_dict
+ #: parse user-defined fields
for key, value in self.indexes.items():
if isinstance(value, bool):
if not value:
@@ -539,12 +811,46 @@ def __define_indexes(self):
idx_name = self.__create_index_name(*key)
idx_dict = {'fields': key, 'expressions': [], 'unique': False}
elif isinstance(value, dict):
- idx_name = self.__prepend_table_on_index_name(key)
+ idx_name = self.__prepend_table_name(key, 'widx')
idx_dict = self.__parse_index_dict(value)
else:
raise SyntaxError('Values in indexes dict should be booleans or dicts')
self._indexes_[idx_name] = idx_dict
+ def _row_record_query_id(self, row):
+ return self.table.id == row.id
+
+ def _row_record_query_pk(self, row):
+ return self.table[self.primary_keys[0]] == row[self.primary_keys[0]]
+
+ def _row_record_query_pks(self, row):
+ return reduce(
+ operator.and_, [self.table[pk] == row[pk] for pk in self.primary_keys]
+ )
+
+ def __define_query_helpers(self):
+ if not self.primary_keys:
+ self._query_id = self.table.id != None
+ self._query_row = self._row_record_query_id
+ self._order_by_id_asc = self.table.id
+ self._order_by_id_desc = ~self.table.id
+ elif len(self.primary_keys) == 1:
+ self._query_id = self.table[self.primary_keys[0]] != None
+ self._query_row = self._row_record_query_pk
+ self._order_by_id_asc = self.table[self.primary_keys[0]]
+ self._order_by_id_desc = ~self.table[self.primary_keys[0]]
+ else:
+ self._query_id = reduce(
+ operator.and_, [self.table[key] != None for key in self.primary_keys]
+ )
+ self._query_row = self._row_record_query_pks
+ self._order_by_id_asc = reduce(
+ operator.or_, [self.table[key] for key in self.primary_keys]
+ )
+ self._order_by_id_desc = reduce(
+ operator.or_, [~self.table[key] for key in self.primary_keys]
+ )
+
def __define_form_utils(self):
#: labels
for field, value in self.form_labels.items():
@@ -559,40 +865,64 @@ def __define_form_utils(self):
def setup(self):
pass
+ def get_rowmethod(self, name: str):
+ return self._all_rowmethods_[name]
+
+ def super_rowmethod(self, name: str):
+ return self._super_rowmethods_[name]
+
@classmethod
def _instance_(cls):
return cls.table._model_
@classmethod
def new(cls, **attributes):
- row = cls._instance_()._rowclass_()
- for field in cls.table.fields:
- val = attributes.get(field, cls.table[field].default)
+ inst = cls._instance_()
+ attrset = set(attributes.keys())
+ rowattrs = {}
+ for field in inst._fieldset_initable & attrset:
+ rowattrs[field] = attributes[field]
+ for field in inst._fieldset_initable - attrset:
+ val = cls.table[field].default
if callable(val):
val = val()
- row[field] = val
- return row
+ rowattrs[field] = val
+ for field in (inst.primary_keys or ["id"]):
+ if inst.table[field].type == "id":
+ rowattrs[field] = None
+ for field in set(inst._compound_relations_.keys()) & attrset:
+ reldata = inst._compound_relations_[field]
+ for local_field, foreign_field in reldata.coupled_fields:
+ rowattrs[local_field] = attributes[field][foreign_field]
+ return inst._rowclass_(rowattrs, __concrete=False)
@classmethod
- def create(cls, *args, **kwargs):
+ def create(cls, *args, skip_callbacks=False, **kwargs):
+ inst = cls._instance_()
if args:
if isinstance(args[0], (dict, sdict)):
for key in list(args[0]):
kwargs[key] = args[0][key]
- return cls.table.validate_and_insert(**kwargs)
+ for field in set(inst._compound_relations_.keys()) & set(kwargs.keys()):
+ reldata = inst._compound_relations_[field]
+ for local_field, foreign_field in reldata.coupled_fields:
+ kwargs[local_field] = kwargs[field][foreign_field]
+ return cls.table.validate_and_insert(skip_callbacks=skip_callbacks, **kwargs)
@classmethod
- def validate(cls, row):
- row = sdict(row)
- errors = sdict()
- for field in cls.table.fields:
- default = getattr(cls.table[field], 'default')
+ def validate(cls, row, write_values: bool = False):
+ inst, errors = cls._instance_(), sdict()
+ for field_name in inst._fieldset_all:
+ field = inst.table[field_name]
+ default = getattr(field, 'default')
if callable(default):
default = default()
- value = row.get(field, default)
- _, error = cls.table[field].validate(value)
+ value = row.get(field_name, default)
+ new_value, error = field.validate(value)
if error:
- errors[field] = error
+ errors[field_name] = error
+ elif new_value is not None and write_values:
+ row[field_name] = new_value
return errors
@classmethod
@@ -603,37 +933,251 @@ def where(cls, cond):
@classmethod
def all(cls):
- return cls.db.where(cls.table, model=cls)
+ return cls.db.where(cls._instance_()._query_id, model=cls)
@classmethod
def first(cls):
- return cls.all().select(orderby=cls.id, limitby=(0, 1)).first()
+ return cls.all().select(
+ orderby=cls._instance_()._order_by_id_asc,
+ limitby=(0, 1)
+ ).first()
@classmethod
def last(cls):
- return cls.all().select(orderby=~cls.id, limitby=(0, 1)).first()
+ return cls.all().select(
+ orderby=cls._instance_()._order_by_id_desc,
+ limitby=(0, 1)
+ ).first()
@classmethod
def get(cls, *args, **kwargs):
- if len(args) == 1:
- return cls.table[args[0]]
+ if args:
+ inst = cls._instance_()
+ if len(args) == 1:
+ if isinstance(args[0], tuple):
+ args = args[0]
+ elif isinstance(args[0], dict) and not kwargs:
+ return cls.table(**args[0])
+ if len(args) != len(inst._fieldset_pk):
+ raise SyntaxError(
+ f"{cls.__name__}.get requires the same number of arguments "
+ "as its primary keys"
+ )
+ pks = inst.primary_keys or ["id"]
+ return cls.table(
+ **{pks[idx]: val for idx, val in enumerate(args)}
+ )
return cls.table(**kwargs)
@rowmethod('update_record')
- def _update_record(self, row, **fields):
+ def _update_record(self, row, skip_callbacks=False, **fields):
newfields = fields or dict(row)
- for fieldname in list(newfields.keys()):
- if (
- fieldname not in self.table.fields or
- self.table[fieldname].type == 'id'
- ):
- del newfields[fieldname]
- self.db(self.table._id == row.id, ignore_common_filters=True).update(
- **newfields
- )
- row.update(newfields)
+ for field_name in set(newfields.keys()) - self._fieldset_editable:
+ del newfields[field_name]
+ res = self.db(
+ self._query_row(row), ignore_common_filters=True
+ ).update(skip_callbacks=skip_callbacks, **newfields)
+ if res:
+ row.update(self.get(**{key: row[key] for key in self._fieldset_pk}))
return row
@rowmethod('delete_record')
- def _delete_record(self, row):
- return self.db(self.db[self.tablename]._id == row.id).delete()
+ def _delete_record(self, row, skip_callbacks=False):
+ return self.db(self._query_row(row)).delete(skip_callbacks=skip_callbacks)
+
+ @rowmethod('refresh')
+ def _row_refresh(self, row) -> bool:
+ if not row._concrete:
+ return False
+ last = self.db(self._query_row(row)).select(
+ limitby=(0, 1),
+ orderby_on_limitby=False
+ ).first()
+ if not last:
+ return False
+ row.update(last)
+ row._changes.clear()
+ return True
+
+ @rowmethod('save')
+ def _row_save(
+ self,
+ row,
+ raise_on_error: bool = False,
+ skip_callbacks: bool = False
+ ) -> bool:
+ if row._concrete:
+ if set(row._changes.keys()) & self._fieldset_pk:
+ if raise_on_error:
+ raise SaveException(
+ 'Cannot save a record with altered primary key(s)'
+ )
+ return False
+ for field_name in self._fieldset_update:
+ val = self.table[field_name].update
+ if callable(val):
+ val = val()
+ row[field_name] = val
+ errors = self.validate(row, write_values=True)
+ if errors:
+ if raise_on_error:
+ raise ValidationError
+ return False
+ if row._concrete:
+ res = self.db(
+ self._query_row(row), ignore_common_filters=True
+ )._update_from_save(self, row, skip_callbacks=skip_callbacks)
+ if not res:
+ if raise_on_error:
+ raise UpdateFailureOnSave
+ return False
+ else:
+ self.table._insert_from_save(row, skip_callbacks=skip_callbacks)
+ if not row._concrete:
+ if raise_on_error:
+ raise InsertFailureOnSave
+ return False
+ row._changes.clear()
+ return True
+
+ @rowmethod('destroy')
+ def _row_destroy(
+ self,
+ row,
+ raise_on_error: bool = False,
+ skip_callbacks: bool = False
+ ) -> bool:
+ if not row._concrete:
+ return False
+ res = self.db(
+ self._query_row(row), ignore_common_filters=True
+ )._delete_from_destroy(self, row, skip_callbacks=skip_callbacks)
+ if not res:
+ if raise_on_error:
+ raise DestroyException
+ return False
+ row._changes.clear()
+ return True
+
+
+class RowRelationMapper:
+ __slots__ = ["model", "fields", "_lastv", "_cached"]
+
+ def __init__(self, db, relation_data):
+ self.model = db[relation_data.model]._model_
+ self.fields = relation_data.coupled_fields
+ self._lastv = {}
+ self._cached = None
+
+ def __call__(self, obj):
+ pks = {fk: obj[lk] for lk, fk in self.fields}
+ if all(v is None for v in pks.values()):
+ return None
+ if not self._cached or pks != self._lastv:
+ self._lastv = pks
+ self._cached = RowReferenceMulti(pks, self.model.table)
+ return self._cached
+
+
+class StructuredRow(Row):
+ __slots__ = ["_concrete", "_changes", "_compound_rel_mappers"]
+
+ def __init__(self, *args, **kwargs):
+ object.__setattr__(self, "_concrete", kwargs.pop("__concrete", True))
+ object.__setattr__(self, "_changes", {})
+ object.__setattr__(self, "_compound_rel_mappers", {})
+ super().__init__(*args, **kwargs)
+ if self._model._compound_relations_:
+ for key, data in self._model._compound_relations_.items():
+ self._compound_rel_mappers[key] = RowRelationMapper(
+ self._model.db, data
+ )
+
+ def __setattr__(self, key, value):
+ if key in self.__slots__:
+ return
+ prev = self._changes[key][0] if key in self._changes else self.__dict__.get(key)
+ object.__setattr__(self, key, value)
+ if (prev is None and value is not None) or prev != value:
+ self._changes[key] = (prev, value)
+ else:
+ self._changes.pop(key, None)
+
+ def __setitem__(self, key, value):
+ self.__setattr__(key, value)
+
+ def __getstate__(self):
+ return {
+ "__data": self.__dict__.copy(),
+ "__struct": {
+ "_concrete": self._concrete,
+ "_changes": {},
+ "_compound_rel_mappers": {}
+ }
+ }
+
+ def __setstate__(self, state):
+ self.__dict__.update(state["__data"])
+ for key, val in state["__struct"].items():
+ object.__setattr__(self, key, val)
+ if self._model._compound_relations_:
+ for key, data in self._model._compound_relations_.items():
+ self._compound_rel_mappers[key] = RowRelationMapper(
+ self._model.db, data
+ )
+
+ def update(self, *args, **kwargs):
+ for arg in args:
+ for key, val in arg.items():
+ self.__setattr__(key, val)
+ for key, val in kwargs.items():
+ self.__setattr__(key, val)
+
+ @property
+ def changes(self):
+ return sdict(self._changes)
+
+ @property
+ def has_changed(self):
+ return bool(self._changes)
+
+ def has_changed_value(self, key):
+ return key in self._changes
+
+ def get_value_change(self, key):
+ return self._changes.get(key, None)
+
+ def clone(self):
+ fields = {}
+ for key in self._model._fieldset_all:
+ fields[key] = self._changes[key][0] if key in self._changes else self[key]
+ return self.__class__(fields, __concrete=self._concrete)
+
+ def clone_changed(self):
+ return self.__class__(
+ {key: self[key] for key in self._model._fieldset_all},
+ __concrete=self._concrete
+ )
+
+ @property
+ def validation_errors(self):
+ return self._model.validate(self)
+
+ @property
+ def is_valid(self):
+ return not bool(self._model.validate(self))
+
+
+def _relation_mapper_getter_(key):
+ def wrap(obj):
+ return obj._compound_rel_mappers[key](obj)
+ return wrap
+
+
+def _relation_mapper_setter_(key):
+ def wrap(obj, val):
+ if not isinstance(val, (StructuredRow, RowReferenceMulti)):
+ return
+ for local_field, foreign_field in obj._compound_rel_mappers[key].fields:
+ obj[local_field] = val[foreign_field]
+ return wrap
diff --git a/emmett/orm/objects.py b/emmett/orm/objects.py
index 78360d36..f6fadb3f 100644
--- a/emmett/orm/objects.py
+++ b/emmett/orm/objects.py
@@ -12,12 +12,24 @@
import copy
import datetime
import decimal
+import operator
import types
from collections import OrderedDict, defaultdict
+from enum import Enum
+from functools import reduce
+from typing import Any, Optional
+
from pydal.objects import (
- Table as _Table, Field as _Field, Set as _Set,
- Row as _Row, Rows as _Rows, IterRows as _IterRows, Query, Expression)
+ Table as _Table,
+ Field as _Field,
+ Set as _Set,
+ Row as _Row,
+ Rows as _Rows,
+ IterRows as _IterRows,
+ Query,
+ Expression
+)
from ..ctx import current
from ..datastructures import sdict
@@ -26,19 +38,160 @@
from ..utils import cachedprop
from ..validators import ValidateFromDict
from .helpers import (
- _IDReference, JoinedIDReference, RelationBuilder, wrap_scope_on_set)
+ RelationBuilder,
+ GeoFieldWrapper,
+ wrap_scope_on_set,
+ typed_row_reference,
+ typed_row_reference_from_record
+)
+
+type_int = int
class Table(_Table):
- def __init__(self, *args, **kwargs):
- super(Table, self).__init__(*args, **kwargs)
+ def __init__(self, db, tablename, *fields, **kwargs):
+ _primary_keys, _notnulls = list(kwargs.get('primarykey', [])), {}
+ _notnulls = {
+ field.name: field.notnull
+ for field in fields if hasattr(field, 'notnull')
+ }
+ super(Table, self).__init__(db, tablename, *fields, **kwargs)
+ self._before_save = []
+ self._after_save = []
+ self._before_destroy = []
+ self._after_destroy = []
+ self._before_commit = []
+ self._before_commit_insert = []
+ self._before_commit_update = []
+ self._before_commit_delete = []
+ self._before_commit_save = []
+ self._before_commit_destroy = []
+ self._after_commit = []
+ self._after_commit_insert = []
+ self._after_commit_update = []
+ self._after_commit_delete = []
+ self._after_commit_save = []
+ self._after_commit_destroy = []
self._unique_fields_validation_ = {}
+ self._primary_keys = _primary_keys
+ #: avoid pyDAL mess in ops and migrations
+ if len(self._primary_keys) == 1 and getattr(self, '_primarykey', None):
+ del self._primarykey
+ for key in self._primary_keys:
+ self[key].notnull = _notnulls[key]
+ if not hasattr(self, '_id'):
+ self._id = None
+
+ @cachedprop
+ def _has_commit_insert_callbacks(self):
+ return any([
+ self._before_commit,
+ self._after_commit,
+ self._before_commit_insert,
+ self._after_commit_insert
+ ])
+
+ @cachedprop
+ def _has_commit_update_callbacks(self):
+ return any([
+ self._before_commit,
+ self._after_commit,
+ self._before_commit_update,
+ self._after_commit_update
+ ])
+
+ @cachedprop
+ def _has_commit_delete_callbacks(self):
+ return any([
+ self._before_commit,
+ self._after_commit,
+ self._before_commit_delete,
+ self._after_commit_delete
+ ])
+
+ @cachedprop
+ def _has_commit_save_callbacks(self):
+ return any([
+ self._before_commit,
+ self._after_commit,
+ self._before_commit_save,
+ self._after_commit_save
+ ])
+
+ @cachedprop
+ def _has_commit_destroy_callbacks(self):
+ return any([
+ self._before_commit,
+ self._after_commit,
+ self._before_commit_destroy,
+ self._after_commit_destroy
+ ])
def _create_references(self):
self._referenced_by = []
self._referenced_by_list = []
self._references = []
+ def _fields_and_values_for_save(self, row, fieldset, op_method):
+ fields = {key: row[key] for key in fieldset}
+ return op_method(fields)
+
+ def insert(self, skip_callbacks=False, **fields):
+ row = self._fields_and_values_for_insert(fields)
+ if not skip_callbacks and any(f(row) for f in self._before_insert):
+ return 0
+ ret = self._db._adapter.insert(self, row.op_values())
+ if not skip_callbacks:
+ if self._has_commit_insert_callbacks:
+ txn = self._db._adapter.top_transaction()
+ if txn:
+ txn._add_op(TransactionOp(
+ TransactionOps.insert,
+ self,
+ TransactionOpContext(
+ values=row,
+ ret=ret
+ )
+ ))
+ if ret:
+ for f in self._after_insert:
+ f(row, ret)
+ return ret
+
+ def validate_and_insert(self, skip_callbacks=False, **fields):
+ response, new_fields = self._validate_fields(fields)
+ if not response.errors:
+ response.id = self.insert(skip_callbacks=skip_callbacks, **new_fields)
+ return response
+
+ def _insert_from_save(self, row, skip_callbacks=False):
+ if not skip_callbacks and any(f(row) for f in self._before_save):
+ return row
+ fields = self._fields_and_values_for_save(
+ row, self._model_._fieldset_initable, self._fields_and_values_for_insert
+ )
+ ret = self.insert(skip_callbacks=skip_callbacks, **fields)
+ if ret:
+ self._model_._set_row_persistence(row, ret)
+ if not skip_callbacks:
+ if self._has_commit_save_callbacks:
+ txn = self._db._adapter.top_transaction()
+ if txn:
+ txn._add_op(TransactionOp(
+ TransactionOps.save,
+ self,
+ TransactionOpContext(
+ values=fields,
+ ret=ret,
+ row=row.clone_changed(),
+ changes=row.changes
+ )
+ ))
+ if row._concrete:
+ for f in self._after_save:
+ f(row)
+ return row
+
class Field(_Field):
_internal_types = {
@@ -105,6 +258,7 @@ def __init__(self, type='string', *args, **kwargs):
if 'length' in kwargs:
kwargs['length'] = int(kwargs['length'])
#: store args and kwargs for `_make_field`
+ self._ormkw = kwargs.pop('_kw', {})
self._args = args
self._kwargs = kwargs
#: increase creation counter (used to keep order of fields)
@@ -149,7 +303,7 @@ def _parse_validation(self):
self._custom_requires
#: `_make_field` will be called by `Model` class or `Form` class
- # it will make intenral Field class compatible with the pyDAL's one
+ # it will make internal Field class compatible with the pyDAL's one
def _make_field(self, name, model=None):
if self._obj_created_:
return self
@@ -157,6 +311,11 @@ def _make_field(self, name, model=None):
self.modelname = model.__class__.__name__
#: convert field type to pyDAL ones if needed
ftype = self._pydal_types.get(self._type, self._type)
+ if ftype.startswith("geo") and model:
+ geometry_type = self._ormkw["geometry_type"]
+ srid = self._ormkw["srid"] or getattr(model.db._adapter, "srid", 4326)
+ dimension = self._ormkw["dimension"] or 2
+ ftype = f"{ftype}({geometry_type},{srid},{dimension})"
#: create pyDAL's Field instance
super(Field, self).__init__(name, ftype, *self._args, **self._kwargs)
#: add automatic validation (if requested)
@@ -251,6 +410,36 @@ def int_list(cls, *args, **kwargs):
def string_list(cls, *args, **kwargs):
return cls('list:string', *args, **kwargs)
+ @classmethod
+ def geography(
+ cls,
+ geometry_type: str = 'GEOMETRY',
+ srid: Optional[type_int] = None,
+ dimension: Optional[type_int] = None,
+ **kwargs
+ ):
+ kwargs['_kw'] = {
+ "geometry_type": geometry_type,
+ "srid": srid,
+ "dimension": dimension
+ }
+ return cls("geography", **kwargs)
+
+ @classmethod
+ def geometry(
+ cls,
+ geometry_type: str = 'GEOMETRY',
+ srid: Optional[type_int] = None,
+ dimension: Optional[type_int] = None,
+ **kwargs
+ ):
+ kwargs['_kw'] = {
+ "geometry_type": geometry_type,
+ "srid": srid,
+ "dimension": dimension
+ }
+ return cls("geometry", **kwargs)
+
def cast(self, value, **kwargs):
return Expression(
self.db, self._dialect.cast, self,
@@ -308,16 +497,29 @@ def _parse_paginate(self, pagination):
def _join_set_builder(self, obj, jdata, auto_select_tables):
return JoinedSet._from_set(
- obj, jdata=jdata, auto_select_tables=auto_select_tables)
+ obj, jdata=jdata, auto_select_tables=auto_select_tables
+ )
def _left_join_set_builder(self, jdata):
return JoinedSet._from_set(
- self, ljdata=jdata, auto_select_tables=[self._model_.table])
+ self, ljdata=jdata, auto_select_tables=[self._model_.table]
+ )
def _run_select_(self, *fields, **options):
- return super(Set, self).select(*fields, **options)
+ tablemap = self.db._adapter.tables(
+ self.query,
+ options.get('join', None),
+ options.get('left', None),
+ options.get('orderby', None),
+ options.get('groupby', None)
+ )
+ fields, concrete_tables = self.db._adapter._expand_all_with_concrete_tables(
+ fields, tablemap
+ )
+ options['_concrete_tables'] = concrete_tables
+ return self.db._adapter.select(self.query, fields, options)
- def _get_table_from_query(self):
+ def _get_table_from_query(self) -> Table:
if self._model_:
return self._model_.table
return self.db._adapter.get_table(self.query)
@@ -325,7 +527,9 @@ def _get_table_from_query(self):
def select(self, *fields, **options):
obj = self
pagination, including = (
- options.pop('paginate', None), options.pop('including', None))
+ options.pop('paginate', None),
+ options.pop('including', None)
+ )
if pagination:
options['limitby'] = self._parse_paginate(pagination)
if including and self._model_ is not None:
@@ -333,18 +537,68 @@ def select(self, *fields, **options):
obj = self._left_join_set_builder(jdata)
return obj._run_select_(*fields, **options)
- def update(self, **update_fields):
+ def iterselect(self, *fields, **options):
+ pagination = options.pop('paginate', None)
+ if pagination:
+ options['limitby'] = self._parse_paginate(pagination)
+ tablemap = self.db._adapter.tables(
+ self.query,
+ options.get('join', None),
+ options.get('left', None),
+ options.get('orderby', None),
+ options.get('groupby', None)
+ )
+ fields, concrete_tables = self.db._adapter._expand_all_with_concrete_tables(
+ fields, tablemap
+ )
+ options['_concrete_tables'] = concrete_tables
+ return self.db._adapter.iterselect(self.query, fields, options)
+
+ def update(self, skip_callbacks=False, **update_fields):
table = self._get_table_from_query()
row = table._fields_and_values_for_update(update_fields)
if not row._values:
raise ValueError("No fields to update")
- if any(f(self, row) for f in table._before_update):
+ if not skip_callbacks and any(f(self, row) for f in table._before_update):
return 0
ret = self.db._adapter.update(table, self.query, row.op_values())
- ret and [f(self, row) for f in table._after_update]
+ if not skip_callbacks:
+ if table._has_commit_update_callbacks:
+ txn = self._db._adapter.top_transaction()
+ if txn:
+ txn._add_op(TransactionOp(
+ TransactionOps.update,
+ table,
+ TransactionOpContext(
+ values=row,
+ dbset=self,
+ ret=ret
+ )
+ ))
+ ret and [f(self, row) for f in table._after_update]
+ return ret
+
+ def delete(self, skip_callbacks=False):
+ table = self._get_table_from_query()
+ if not skip_callbacks and any(f(self) for f in table._before_delete):
+ return 0
+ ret = self.db._adapter.delete(table, self.query)
+ if not skip_callbacks:
+ if table._has_commit_delete_callbacks:
+ txn = self._db._adapter.top_transaction()
+ if txn:
+ txn._add_op(TransactionOp(
+ TransactionOps.delete,
+ table,
+ TransactionOpContext(
+ dbset=self,
+ ret=ret
+ )
+ ))
+ ret and [f(self) for f in table._after_delete]
return ret
- def validate_and_update(self, **update_fields):
+ def validate_and_update(self, skip_callbacks=False, **update_fields):
table = self._get_table_from_query()
current._dbvalidation_record_id_ = None
if table._unique_fields_validation_ and self.count() == 1:
@@ -370,15 +624,96 @@ def validate_and_update(self, **update_fields):
row = table._fields_and_values_for_update(new_fields)
if not row._values:
raise ValueError("No fields to update")
- if any(f(self, row) for f in table._before_update):
+ if not skip_callbacks and any(f(self, row) for f in table._before_update):
ret = 0
else:
ret = self.db._adapter.update(
- table, self.query, row.op_values())
- ret and [f(self, row) for f in table._after_update]
+ table, self.query, row.op_values()
+ )
+ if not skip_callbacks and ret:
+ for f in table._after_update:
+ f(self, row)
response.updated = ret
return response
+ def _update_from_save(self, model, row, skip_callbacks=False):
+ table: Table = model.table
+ if not skip_callbacks and any(f(row) for f in table._before_save):
+ return False
+ fields = table._fields_and_values_for_save(
+ row, model._fieldset_editable, table._fields_and_values_for_update
+ )
+ if not skip_callbacks and any(f(self, fields) for f in table._before_update):
+ return False
+ ret = self.db._adapter.update(table, self.query, fields.op_values())
+ if not skip_callbacks:
+ if table._has_commit_update_callbacks or table._has_commit_save_callbacks:
+ txn = self._db._adapter.top_transaction()
+ if txn and table._has_commit_update_callbacks:
+ txn._add_op(TransactionOp(
+ TransactionOps.update,
+ table,
+ TransactionOpContext(
+ values=fields,
+ dbset=self,
+ ret=ret
+ )
+ ))
+ if txn and table._has_commit_save_callbacks:
+ txn._add_op(TransactionOp(
+ TransactionOps.save,
+ table,
+ TransactionOpContext(
+ values=fields,
+ dbset=self,
+ ret=ret,
+ row=row.clone_changed(),
+ changes=row.changes
+ )
+ ))
+ ret and [f(self, fields) for f in table._after_update]
+ ret and [f(row) for f in table._after_save]
+ return bool(ret)
+
+ def _delete_from_destroy(self, model, row, skip_callbacks=False):
+ table: Table = model.table
+ if not skip_callbacks and any(f(row) for f in table._before_destroy):
+ return False
+ if not skip_callbacks and any(f(self) for f in table._before_delete):
+ return 0
+ ret = self.db._adapter.delete(table, self.query)
+ if ret:
+ model._unset_row_persistence(row)
+ if not skip_callbacks:
+ if (
+ table._has_commit_delete_callbacks or
+ table._has_commit_destroy_callbacks
+ ):
+ txn = self._db._adapter.top_transaction()
+ if txn and table._has_commit_delete_callbacks:
+ txn._add_op(TransactionOp(
+ TransactionOps.delete,
+ table,
+ TransactionOpContext(
+ dbset=self,
+ ret=ret
+ )
+ ))
+ if txn and table._has_commit_destroy_callbacks:
+ txn._add_op(TransactionOp(
+ TransactionOps.destroy,
+ table,
+ TransactionOpContext(
+ dbset=self,
+ ret=ret,
+ row=row.clone_changed(),
+ changes=row.changes
+ )
+ ))
+ ret and [f(self) for f in table._after_delete]
+ ret and [f(row) for f in table._after_destroy]
+ return bool(ret)
+
def join(self, *args):
rv = self
if self._model_ is not None:
@@ -406,27 +741,24 @@ def _parse_rjoin(self, arg):
#: match has_many
rel = self._model_._hasmany_ref_.get(arg)
if rel:
- if isinstance(rel, dict) and rel.get('via'):
+ if rel.via:
r = RelationBuilder(rel, self._model_._instance_()).via()
return r[0], r[1]._table, 'many'
- else:
- r = RelationBuilder(rel, self._model_._instance_())
- return r.many(), rel.table, 'many'
+ r = RelationBuilder(rel, self._model_._instance_())
+ return r.many(), rel.table, 'many'
#: match belongs_to and refers_to
- rel = self._model_._belongs_ref_.get(arg)
+ rel = self._model_._belongs_fks_.get(arg)
if rel:
- r = RelationBuilder(
- (rel, arg), self._model_._instance_()
- ).belongs_query()
- return r, self._model_.db[rel], 'belongs'
+ r = RelationBuilder(rel, self._model_._instance_()).belongs_query()
+ return r, self._model_.db[rel.model], 'belongs'
#: match has_one
rel = self._model_._hasone_ref_.get(arg)
if rel:
r = RelationBuilder(rel, self._model_._instance_())
return r.many(), rel.table, 'one'
raise RuntimeError(
- 'Unable to find %s relation of %s model' %
- (arg, self._model_.__name__))
+ f'Unable to find {arg} relation of {self._model_.__name__} model'
+ )
def _parse_left_rjoins(self, args):
if not isinstance(args, (list, tuple)):
@@ -481,8 +813,12 @@ def _model_(self):
return self._relation_.ref.model_instance
@cachedprop
- def _field_(self):
- return self._relation_.ref.field_instance
+ def _fields_(self):
+ pks = self._relation_.model.primary_keys or ["id"]
+ return [
+ (relation_field.name, pks[idx])
+ for idx, relation_field in enumerate(self._relation_.ref.fields_instances)
+ ]
@cachedprop
def _scopes_(self):
@@ -511,14 +847,14 @@ def _last_resultset(self, refresh=False):
def _filter_reload(self, kwargs):
return kwargs.pop('reload', False)
- def create(self, **kwargs):
- attributes = self._get_fields_from_scopes(
- self._scopes_, self._model_.tablename)
- attributes.update(**kwargs)
- attributes[self._field_.name] = self._row_.id
- return self._model_.create(
- **attributes
+ def create(self, skip_callbacks=False, **kwargs):
+ attrs = self._get_fields_from_scopes(
+ self._scopes_, self._model_.tablename
)
+ attrs.update(**kwargs)
+ for ref, local in self._fields_:
+ attrs[ref] = self._row_[local]
+ return self._model_.create(skip_callbacks=skip_callbacks, **attrs)
@staticmethod
def _get_fields_from_scopes(scopes, table_name):
@@ -533,8 +869,10 @@ def _get_fields_from_scopes(scopes, table_name):
components.append(component.second)
components.append(component.first)
else:
- if isinstance(component, Field) and \
- component._tablename == table_name:
+ if (
+ isinstance(component, Field) and
+ component._tablename == table_name
+ ):
current_kv.append(component)
else:
if current_kv:
@@ -569,21 +907,47 @@ def __call__(self, *args, **kwargs):
return self._last_resultset(refresh)
return self.select(*args, **kwargs)
- def add(self, obj):
- attributes = self._get_fields_from_scopes(
- self._scopes_, self._model_.tablename)
- attributes[self._field_.name] = self._row_.id
- return self.db(
- self.db[self._field_._tablename].id == obj.id
- ).validate_and_update(**attributes)
-
- def remove(self, obj):
- if self.db[self._field_._tablename][self._field_.name]._isrefers:
- return self.db(
- self._field_._table.id == obj.id).validate_and_update(
- **{self._field_.name: None}
+ def add(self, obj, skip_callbacks=False):
+ attrs = self._get_fields_from_scopes(
+ self._scopes_, self._model_.tablename
+ )
+ rev_attrs = {}
+ for ref, local in self._fields_:
+ attrs[ref] = self._row_[local]
+ rev_attrs[local] = attrs[ref]
+ rv = self.db(self._model_._query_row(obj)).validate_and_update(
+ skip_callbacks=skip_callbacks, **attrs
+ )
+ if rv:
+ for key, val in attrs.items():
+ obj[key] = val
+ if self._relation_.ref.reverse not in obj._compound_rel_mappers:
+ obj[self._relation_.ref.reverse] = typed_row_reference(
+ rev_attrs if len(rev_attrs) > 1 else rev_attrs[self._fields_[0][1]],
+ self._relation_.model.table
+ )
+ return rv
+
+ def remove(self, obj, skip_callbacks=False):
+ attrs, is_delete = {ref: None for ref, _ in self._fields_}, False
+ if self._model_._belongs_fks_[self._relation_.ref.reverse].is_refers:
+ rv = self.db(self._model_._query_row(obj)).validate_and_update(
+ skip_callbacks=skip_callbacks,
+ **attrs
+ )
+ else:
+ is_delete = True
+ rv = self.db(self._model_._query_row(obj)).delete(
+ skip_callbacks=skip_callbacks
)
- return self.db(self._field_._table.id == obj.id).delete()
+ if rv:
+ for key, val in attrs.items():
+ obj[key] = val
+ if self._relation_.ref.reverse not in obj._compound_rel_mappers:
+ obj[self._relation_.ref.reverse] = None
+ if is_delete:
+ obj._concrete = False
+ return rv
class HasManyViaSet(RelationSet):
@@ -592,11 +956,14 @@ class HasManyViaSet(RelationSet):
@cachedprop
def _viadata(self):
- query, rfield, model_name, rid, via, viadata = \
- super(HasManyViaSet, self)._get_query_()
+ query, rfield, model_name, rid, via, viadata = super()._get_query_()
return sdict(
- query=query, rfield=rfield, model_name=model_name, rid=rid,
- via=via, data=viadata
+ query=query,
+ rfield=rfield,
+ model_name=model_name,
+ rid=rid,
+ via=via,
+ data=viadata
)
def _get_query_(self):
@@ -619,9 +986,12 @@ def __call__(self, *args, **kwargs):
def _get_relation_fields(self):
viadata = self._viadata.data
- self_field = self._model_._hasmany_ref_[viadata.via].field
- rel_field = viadata.field or viadata.name[:-1]
- return self_field, rel_field
+ manyref = self._model_._hasmany_ref_[viadata.via]
+ fkdata = manyref.model_instance._belongs_fks_
+ jref = fkdata[viadata.field or viadata.name[:-1]]
+ fields_src = manyref.fields
+ fields_dst = list(jref.coupled_fields)
+ return fields_src, fields_dst
def _fields_from_scopes(self):
viadata = self._viadata.data
@@ -635,70 +1005,87 @@ def _fields_from_scopes(self):
def create(self, **kwargs):
raise RuntimeError('Cannot create third objects for many relations')
- def add(self, obj, **kwargs):
+ def add(self, obj, skip_callbacks=False, **kwargs):
# works on join tables only!
if self._viadata.via is None:
raise RuntimeError(self._via_error % 'add')
nrow = self._fields_from_scopes()
nrow.update(**kwargs)
#: get belongs references
- self_field, rel_field = self._get_relation_fields()
- nrow[self_field] = self._viadata.rid
- nrow[rel_field] = obj.id
+ self_fields, rel_fields = self._get_relation_fields()
+ for idx, self_field in enumerate(self_fields):
+ nrow[self_field] = self._viadata.rid[idx]
+ for local_field, foreign_field in rel_fields:
+ nrow[local_field] = obj[foreign_field]
#: validate and insert
- return self.db[self._viadata.via]._model_.create(nrow)
+ return self.db[self._viadata.via]._model_.create(
+ nrow, skip_callbacks=skip_callbacks
+ )
- def remove(self, obj):
+ def remove(self, obj, skip_callbacks=False):
# works on join tables only!
if self._viadata.via is None:
raise RuntimeError(self._via_error % 'remove')
#: get belongs references
- self_field, rel_field = self._get_relation_fields()
+ self_fields, rel_fields = self._get_relation_fields()
#: delete
- return self.db(
- (self.db[self._viadata.via][self_field] == self._viadata.rid) &
- (self.db[self._viadata.via][rel_field] == obj.id)).delete()
+ query = reduce(
+ operator.and_, [
+ self.db[self._viadata.via][field] == self._viadata.rid[idx]
+ for idx, field in enumerate(self_fields)
+ ] + [
+ self.db[self._viadata.via][local_field] == obj[foreign_field]
+ for local_field, foreign_field in rel_fields
+ ]
+ )
+ return self.db(query).delete(skip_callbacks=skip_callbacks)
class JoinedSet(Set):
@classmethod
def _from_set(cls, obj, jdata=[], ljdata=[], auto_select_tables=[]):
- rv = cls(
- obj.db, obj.query, obj.query.ignore_common_filters, obj._model_)
+ rv = cls(obj.db, obj.query, obj.query.ignore_common_filters, obj._model_)
rv._stable_ = obj._model_.tablename
rv._jdata_ = list(jdata)
rv._ljdata_ = list(ljdata)
rv._auto_select_tables_ = list(auto_select_tables)
+ rv._pks_ = obj._model_._instance_()._fieldset_pk
return rv
def _clone(self, ignore_common_filters=None, model=None, **changes):
- rv = super(JoinedSet, self)._clone(
- ignore_common_filters, model, **changes)
+ rv = super()._clone(ignore_common_filters, model, **changes)
rv._stable_ = self._stable_
rv._jdata_ = self._jdata_
rv._ljdata_ = self._ljdata_
rv._auto_select_tables_ = self._auto_select_tables_
+ rv._pks_ = self._pks_
return rv
def _join_set_builder(self, obj, jdata, auto_select_tables):
return JoinedSet._from_set(
obj, jdata=self._jdata_ + jdata, ljdata=self._ljdata_,
- auto_select_tables=self._auto_select_tables_ + auto_select_tables)
+ auto_select_tables=self._auto_select_tables_ + auto_select_tables
+ )
def _left_join_set_builder(self, jdata):
return JoinedSet._from_set(
self, jdata=self._jdata_, ljdata=self._ljdata_ + jdata,
- auto_select_tables=self._auto_select_tables_)
+ auto_select_tables=self._auto_select_tables_
+ )
- def _iterselect_rows(self, *fields, **attributes):
+ def _iterselect_rows(self, *fields, **options):
tablemap = self.db._adapter.tables(
- self.query, attributes.get('join', None),
- attributes.get('left', None), attributes.get('orderby', None),
- attributes.get('groupby', None))
- fields = self.db._adapter.expand_all(fields, tablemap)
- colnames, sql = self.db._adapter._select_wcols(
- self.query, fields, **attributes)
- return JoinIterRows(self.db, sql, fields, colnames)
+ self.query,
+ options.get('join', None),
+ options.get('left', None),
+ options.get('orderby', None),
+ options.get('groupby', None)
+ )
+ fields, concrete_tables = self.db._adapter._expand_all_with_concrete_tables(
+ fields, tablemap
+ )
+ colnames, sql = self.db._adapter._select_wcols(self.query, fields, **options)
+ return JoinIterRows(self.db, sql, fields, concrete_tables, colnames)
def _split_joins(self, joins):
rv = {'belongs': [], 'one': [], 'many': []}
@@ -710,10 +1097,19 @@ def _build_records_from_joined(self, rowmap, inclusions, colnames):
for rid, many_data in inclusions.items():
for jname, included in many_data.items():
rowmap[rid][jname]._cached_resultset = Rows(
- self.db, list(included.values()), [])
+ self.db, list(included.values()), []
+ )
return JoinRows(
self.db, list(rowmap.values()), colnames,
- _jdata=self._jdata_ + self._ljdata_)
+ _jdata=self._jdata_ + self._ljdata_
+ )
+
+ def _select_rowpks_extractor(self, row):
+ if not set(row.keys()).issuperset(self._pks_):
+ return None
+ if len(self._pks_) > 1:
+ return tuple(row[pk] for pk in self._pks_)
+ return row[tuple(self._pks_)[0]]
def _run_select_(self, *fields, **options):
#: build parsers
@@ -733,108 +1129,144 @@ def _run_select_(self, *fields, **options):
#: use iterselect for performance
rows = self._iterselect_rows(*fields, **options)
#: rebuild rowset using nested objects
+ plainrows = []
rowmap = OrderedDict()
inclusions = defaultdict(
lambda: {
- jname: OrderedDict() for jname, jtable in (many_j + many_l)})
+ jname: OrderedDict() for jname, _ in (many_j + many_l)
+ }
+ )
for row in rows:
- rid = row[self._stable_].id
+ if self._stable_ not in row:
+ plainrows.append(row)
+ continue
+ rid = self._select_rowpks_extractor(row[self._stable_])
+ if rid is None:
+ plainrows.append(row)
+ continue
rowmap[rid] = rowmap.get(rid, row[self._stable_])
for parser in parsers:
parser(rowmap, inclusions, row, rid)
+ if not rowmap and plainrows:
+ return Rows(self.db, plainrows, rows.colnames)
return self._build_records_from_joined(
- rowmap, inclusions, rows.colnames)
+ rowmap, inclusions, rows.colnames
+ )
def _build_jparsers(self, belongs, one, many):
rv = []
for jname, jtable in belongs:
- rv.append(self._jbelong_parser(jname, jtable, self.db))
+ rv.append(self._jbelong_parser(self.db, jname, jtable))
for jname, jtable in one:
- rv.append(self._jone_parser(jname, jtable))
+ rv.append(self._jone_parser(self.db, jname, jtable))
for jname, jtable in many:
- rv.append(self._jmany_parser(jname, jtable))
+ rv.append(self._jmany_parser(self.db, jname, jtable))
return rv
def _build_lparsers(self, belongs, one, many):
rv = []
for jname, jtable in belongs:
- rv.append(self._lbelong_parser(jname, jtable, self.db))
+ rv.append(self._lbelong_parser(self.db, jname, jtable))
for jname, jtable in one:
- rv.append(self._lone_parser(jname, jtable))
+ rv.append(self._lone_parser(self.db, jname, jtable))
for jname, jtable in many:
- rv.append(self._lmany_parser(jname, jtable))
+ rv.append(self._lmany_parser(self.db, jname, jtable))
return rv
@staticmethod
- def _jbelong_parser(fieldname, tablename, db):
+ def _jbelong_parser(db, fieldname, tablename):
+ rmodel = db[tablename]._model_
+
def parser(rowmap, inclusions, row, rid):
- rowmap[rid][fieldname] = JoinedIDReference._from_record(
- row[tablename], db[tablename])
+ rowmap[rid][fieldname] = typed_row_reference_from_record(
+ row[tablename], rmodel
+ )
return parser
@staticmethod
- def _jone_parser(fieldname, tablename):
+ def _jone_parser(db, fieldname, tablename):
def parser(rowmap, inclusions, row, rid):
rowmap[rid][fieldname]._cached_resultset = row[tablename]
return parser
@staticmethod
- def _jmany_parser(fieldname, tablename):
+ def _jmany_parser(db, fieldname, tablename):
+ rmodel = db[tablename]._model_
+ pks = rmodel.primary_keys or ["id"]
+ ext = lambda row: tuple(row[pk] for pk in pks) if len(pks) > 1 else row[pks[0]]
+
def parser(rowmap, inclusions, row, rid):
- inclusions[rid][fieldname][row[tablename].id] = \
+ inclusions[rid][fieldname][ext(row[tablename])] = \
inclusions[rid][fieldname].get(
- row[tablename].id, row[tablename])
+ ext(row[tablename]), row[tablename]
+ )
return parser
@staticmethod
- def _lbelong_parser(fieldname, tablename, db):
+ def _lbelong_parser(db, fieldname, tablename):
+ rmodel = db[tablename]._model_
+ pks = rmodel.primary_keys or ["id"]
+ check = lambda row: all(row[pk] for pk in pks)
+
def parser(rowmap, inclusions, row, rid):
- if not row[tablename].id:
+ if not check(row[tablename]):
return
- rowmap[rid][fieldname] = JoinedIDReference._from_record(
- row[tablename], db[tablename])
+ rowmap[rid][fieldname] = typed_row_reference_from_record(
+ row[tablename], rmodel
+ )
return parser
@staticmethod
- def _lone_parser(fieldname, tablename):
+ def _lone_parser(db, fieldname, tablename):
+ rmodel = db[tablename]._model_
+ pks = rmodel.primary_keys or ["id"]
+ check = lambda row: all(row[pk] for pk in pks)
+
def parser(rowmap, inclusions, row, rid):
- if not row[tablename].id:
+ if not check(row[tablename]):
return
rowmap[rid][fieldname]._cached_resultset = row[tablename]
return parser
@staticmethod
- def _lmany_parser(fieldname, tablename):
+ def _lmany_parser(db, fieldname, tablename):
+ rmodel = db[tablename]._model_
+ pks = rmodel.primary_keys or ["id"]
+ ext = lambda row: tuple(row[pk] for pk in pks) if len(pks) > 1 else row[pks[0]]
+ check = lambda row: all(row[pk] for pk in pks)
+
def parser(rowmap, inclusions, row, rid):
- if not row[tablename].id:
+ if not check(row[tablename]):
return
- inclusions[rid][fieldname][row[tablename].id] = \
+ inclusions[rid][fieldname][ext(row[tablename])] = \
inclusions[rid][fieldname].get(
- row[tablename].id, row[tablename])
+ ext(row[tablename]), row[tablename]
+ )
return parser
class Row(_Row):
_as_dict_types_ = tuple(
[type(None)] + [int, float, bool, list, dict, str] +
- [datetime.datetime, datetime.date, datetime.time])
+ [datetime.datetime, datetime.date, datetime.time]
+ )
- def as_dict(self, datetime_to_str=False, custom_types=None):
+ def as_dict(self, datetime_to_str=False, custom_types=None, geo_coordinates=True):
rv = {}
for key, val in self.items():
if isinstance(val, Row):
val = val.as_dict()
- elif isinstance(val, _IDReference):
- val = int(val)
elif isinstance(val, decimal.Decimal):
val = float(val)
+ elif isinstance(val, GeoFieldWrapper) and geo_coordinates:
+ val = val.__json__()
elif not isinstance(val, self._as_dict_types_):
continue
rv[key] = val
return rv
def __getstate__(self):
- return self.as_dict()
+ return self.as_dict(geo_coordinates=False)
def __json__(self):
return self.as_dict()
@@ -843,7 +1275,7 @@ def __xml__(self, key=None, quote=True):
return xml_encode(self.as_dict(), key or 'row', quote)
def __str__(self):
- return ''.format(self.as_dict())
+ return ''.format(self.as_dict(geo_coordinates=False))
def __repr__(self):
return str(self)
@@ -941,19 +1373,20 @@ def __str__(self):
return str(self.records)
-class JoinIterRows(_IterRows):
- def __init__(self, db, sql, fields, colnames):
+class IterRows(_IterRows):
+ def __init__(self, db, sql, fields, concrete_tables, colnames):
self.db = db
self.fields = fields
+ self.concrete_tables = concrete_tables
self.colnames = colnames
- self.fdata, self.tables = \
- self.db._adapter._parse_expand_colnames(fields)
+ self.fdata, self.tables = self.db._adapter._parse_expand_colnames(fields)
self.cursor = self.db._adapter.cursor
self.db._adapter.execute(sql)
self.db._adapter.lock_cursor(self.cursor)
self._head = None
self.last_item = None
self.last_item_id = None
+ self.compact = True
self.blob_decode = True
self.cacheable = False
self.sql = sql
@@ -962,11 +1395,24 @@ def __next__(self):
db_row = self.cursor.fetchone()
if db_row is None:
raise StopIteration
- return self.db._adapter._parse(
- db_row, self.fdata, self.tables, self.fields, self.colnames,
- self.blob_decode)
+ row = self.db._adapter._parse(
+ db_row,
+ self.fdata,
+ self.tables,
+ self.concrete_tables,
+ self.fields,
+ self.colnames,
+ self.blob_decode
+ )
+ if self.compact:
+ keys = list(row.keys())
+ if len(keys) == 1 and keys[0] != '_extra':
+ row = row[keys[0]]
+ return row
def __iter__(self):
+ if self._head:
+ yield self._head
try:
row = next(self)
while row is not None:
@@ -997,3 +1443,49 @@ def as_list(
else:
items = [item for item in self]
return items
+
+
+class JoinIterRows(IterRows):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.compact = False
+
+
+class TransactionOps(str, Enum):
+ insert = "insert"
+ update = "update"
+ delete = "delete"
+ save = "save"
+ destroy = "destroy"
+
+
+class TransactionOpContext:
+ __slots__ = ["values", "dbset", "return_value", "row", "changes"]
+
+ def __init__(
+ self,
+ values: Any = None,
+ dbset: Optional[Set] = None,
+ ret: Any = None,
+ row: Optional[Row] = None,
+ changes: Optional[sdict] = None
+ ):
+ self.values = values
+ self.dbset = dbset
+ self.return_value = ret
+ self.row = row
+ self.changes = changes
+
+
+class TransactionOp:
+ __slots__ = ["op_type", "table", "context"]
+
+ def __init__(
+ self,
+ op_type: TransactionOps,
+ table: Table,
+ context: TransactionOpContext
+ ):
+ self.op_type = op_type
+ self.table = table
+ self.context = context
diff --git a/emmett/orm/transactions.py b/emmett/orm/transactions.py
index 808e26a0..f253edb1 100644
--- a/emmett/orm/transactions.py
+++ b/emmett/orm/transactions.py
@@ -14,6 +14,7 @@
"""
import uuid
+
from functools import wraps
@@ -45,6 +46,13 @@ class _transaction(callable_context_manager):
def __init__(self, adapter, lock_type=None):
self.adapter = adapter
self._lock_type = lock_type
+ self._ops = []
+
+ def _add_op(self, op):
+ self._ops.append(op)
+
+ def _add_ops(self, ops):
+ self._ops.extend(ops)
def _begin(self):
if self._lock_type:
@@ -53,7 +61,18 @@ def _begin(self):
self.adapter.begin()
def commit(self, begin=True):
+ for op in self._ops:
+ for callback in op.table._before_commit:
+ callback(op.op_type, op.context)
+ for callback in getattr(op.table, f"_before_commit_{op.op_type}"):
+ callback(op.context)
self.adapter.commit()
+ for op in self._ops:
+ for callback in op.table._after_commit:
+ callback(op.op_type, op.context)
+ for callback in getattr(op.table, f"_after_commit_{op.op_type}"):
+ callback(op.context)
+ self._ops.clear()
if begin:
self._begin()
@@ -87,6 +106,14 @@ def __init__(self, adapter, sid=None):
self.adapter = adapter
self.sid = sid or 's' + uuid.uuid4().hex
self.quoted_sid = self.adapter.dialect.quote(self.sid)
+ self._ops = []
+ self._parent = None
+
+ def _add_op(self, op):
+ self._ops.append(op)
+
+ def _add_ops(self, ops):
+ self._ops.extend(ops)
def _begin(self):
self.adapter.execute('SAVEPOINT %s;' % self.quoted_sid)
@@ -100,6 +127,7 @@ def rollback(self):
self.adapter.execute('ROLLBACK TO SAVEPOINT %s;' % self.quoted_sid)
def __enter__(self):
+ self._parent = self.adapter.top_transaction()
self._begin()
self.adapter.push_transaction(self)
return self
@@ -111,6 +139,8 @@ def __exit__(self, exc_type, exc_val, exc_tb):
else:
try:
self.commit(begin=False)
+ if self._parent:
+ self._parent._add_ops(self._ops)
except Exception:
self.rollback()
raise
diff --git a/emmett/templating/templater.py b/emmett/templating/templater.py
index 09f2eb44..cf088534 100644
--- a/emmett/templating/templater.py
+++ b/emmett/templating/templater.py
@@ -46,16 +46,18 @@ def register_namespace(self, namespace: str, path: Optional[str] = None):
path = path or self.path
self._namespaces[namespace] = path
- def _get_namespace_path_elements(self, file_name: str) -> Tuple[str, str]:
+ def _get_namespace_path_elements(
+ self, file_name: str, path: Optional[str]
+ ) -> Tuple[str, str]:
if ":" in file_name:
namespace, file_name = file_name.split(":")
path = self._namespaces.get(namespace, self.path)
else:
- path = self.path
+ path = path or self.path
return path, file_name
- def _preload(self, file_name: str):
- path, file_name = self._get_namespace_path_elements(file_name)
+ def _preload(self, file_name: str, path: Optional[str] = None):
+ path, file_name = self._get_namespace_path_elements(file_name, path)
file_extension = os.path.splitext(file_name)[1]
return reduce(
lambda args, loader: loader(args[0], args[1]),
@@ -63,5 +65,5 @@ def _preload(self, file_name: str):
(path, file_name)
)
- def _no_preload(self, file_name):
- return self._get_namespace_path_elements(file_name)
+ def _no_preload(self, file_name: str, path: Optional[str] = None):
+ return self._get_namespace_path_elements(file_name, path)
diff --git a/emmett/tools/auth/forms.py b/emmett/tools/auth/forms.py
index be62298d..5a161b5d 100644
--- a/emmett/tools/auth/forms.py
+++ b/emmett/tools/auth/forms.py
@@ -148,8 +148,8 @@ def profile_form(auth, fields, **kwargs):
'keepvalues': True}
opts.update(**kwargs)
return ModelForm(
- auth.models['user'].table,
- record=auth.user,
+ auth.models['user'],
+ record_id=auth.user.id,
fields=fields,
upload=auth.ext.exposer.url('download'),
**opts
diff --git a/emmett/tools/auth/models.py b/emmett/tools/auth/models.py
index de0f0663..fe665f93 100644
--- a/emmett/tools/auth/models.py
+++ b/emmett/tools/auth/models.py
@@ -70,6 +70,7 @@ def _define_(self):
self.__super_method('define_computations')()
self.__super_method('define_callbacks')()
self.__super_method('define_scopes')()
+ self.__super_method('define_query_helpers')()
self.__super_method('define_form_utils')()
self.__define_authform_utils()
self.setup()
diff --git a/emmett/validators/__init__.py b/emmett/validators/__init__.py
index 6ca319f8..2e6eca14 100644
--- a/emmett/validators/__init__.py
+++ b/emmett/validators/__init__.py
@@ -127,15 +127,20 @@ def parse_is_list(self, data, message=None):
) if validator else None
def parse_reference(self, field):
- ref_table = None
- multiple = None
+ ref_table, ref_field, multiple = None, None, None
if field.type.startswith('reference'):
multiple = False
elif field.type.startswith('list:reference'):
multiple = True
if multiple is not None:
ref_table = field.type.split(' ')[1]
- return ref_table, multiple
+ model = field.table._model_
+ #: can't support (yet?) multi pks
+ if model._belongs_ref_[field.name].compound:
+ ref_table = None
+ else:
+ ref_field = model._belongs_ref_[field.name].fk
+ return ref_table, ref_field, multiple
def __call__(self, field, data):
validators = []
@@ -195,7 +200,7 @@ def __call__(self, field, data):
#: allows {'in': {'dbset': lambda db: db.where(query)}}
_dbset = _in.get('dbset')
if callable(_dbset):
- ref_table, multiple = self.parse_reference(field)
+ ref_table, ref_field, multiple = self.parse_reference(field)
if ref_table:
opt_keys = [key for key in list(_in) if key != 'dbset']
for key in opt_keys:
@@ -204,6 +209,7 @@ def __call__(self, field, data):
inDB(
field.db,
ref_table,
+ ref_field,
dbset=_dbset,
multiple=multiple,
message=message,
@@ -289,11 +295,17 @@ def __call__(self, field, data):
validators.append(Not(self(field, data['not']), message=message))
#: insert presence/empty validation if needed
if presence:
- ref_table, multiple = self.parse_reference(field)
+ ref_table, ref_field, multiple = self.parse_reference(field)
if ref_table:
if not _dbset:
validators.append(
- inDB(field.db, ref_table, multiple=multiple, message=message)
+ inDB(
+ field.db,
+ ref_table,
+ ref_field,
+ multiple=multiple,
+ message=message
+ )
)
else:
validators.insert(0, isntEmpty(message=message))
diff --git a/emmett/validators/consist.py b/emmett/validators/consist.py
index 56c64eec..ea98432d 100644
--- a/emmett/validators/consist.py
+++ b/emmett/validators/consist.py
@@ -96,6 +96,8 @@ def __call__(self, value):
return super().__call__(value.lower() if value else value)
def check(self, value):
+ if isinstance(value, time):
+ return value, None
val = self.rule.match(value)
try:
(h, m, s) = (int(val.group('h')), 0, 0)
diff --git a/emmett/validators/process.py b/emmett/validators/process.py
index a8d0b55b..ab6663cd 100644
--- a/emmett/validators/process.py
+++ b/emmett/validators/process.py
@@ -135,4 +135,7 @@ def __init__(
self.salt = salt
def __call__(self, value):
- return LazyCrypt(self, value), None
+ crypt = LazyCrypt(self, value)
+ if isinstance(value, LazyCrypt) and value == crypt:
+ return value, None
+ return crypt, None
diff --git a/pyproject.toml b/pyproject.toml
index 4295c221..76142541 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[tool.poetry]
name = "Emmett"
-version = "2.3.2"
+version = "2.4.0"
description = "The web framework for inventors"
authors = ["Giovanni Barillari "]
license = "BSD-3-Clause"
@@ -22,19 +22,18 @@ classifiers = [
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
+ "Programming Language :: Python :: 3.10",
"Topic :: Internet :: WWW/HTTP :: Dynamic Content",
"Topic :: Software Development :: Libraries :: Python Modules"
]
packages = [
- {include = "emmett"},
+ {include = "emmett/**/*.*", format = "sdist" },
{include = "tests", format = "sdist"}
]
include = [
"CHANGES.md",
"LICENSE",
- "emmett/orm/migrations/migration.tmpl",
- "emmett/assets/**/*",
"docs/**/*"
]
@@ -44,29 +43,30 @@ emmett = "emmett.cli:main"
[tool.poetry.dependencies]
python = "^3.7"
click = ">=6.0"
-h11 = "~0.10.0"
+h11 = "~0.12.0"
h2 = ">= 3.2.0, < 4.1.0"
pendulum = "~2.1.2"
pyaes = "~1.6.1"
pyDAL = "17.3"
python-rapidjson = "^1.0"
pyyaml = "^5.4"
-renoir = "^1.3"
+renoir = "^1.5"
severus = "^1.1"
-uvicorn = "0.14.0"
-websockets = "^9.1"
+uvicorn = "0.16.0"
+websockets = "^10.0"
-httptools = { version = "~0.2.0", markers = "sys_platform != 'win32'" }
-uvloop = { version = "~0.15.3", markers = "sys_platform != 'win32'" }
+httptools = { version = "~0.3.0", markers = "sys_platform != 'win32'" }
+uvloop = { version = "~0.16.0", markers = "sys_platform != 'win32'" }
-orjson = { version = "~3.5.1", optional = true }
-emmett-crypto = { version = "^0.1.0", optional = true }
+orjson = { version = "~3.6.5", optional = true }
+emmett-crypto = { version = "^0.2.0", optional = true }
[tool.poetry.dev-dependencies]
ipaddress = "^1.0"
pylint = "^2.4.4"
-pytest = "^5.3"
-pytest-asyncio = "^0.10"
+pytest = "^6.2"
+pytest-asyncio = "^0.15"
+psycopg2-binary = "^2.9.3"
[tool.poetry.extras]
orjson = ["orjson"]
diff --git a/requirements.txt b/requirements.txt
deleted file mode 100644
index 32973a63..00000000
--- a/requirements.txt
+++ /dev/null
@@ -1,14 +0,0 @@
-click==8.0.1
-h11==0.10.0
-h2==4.0.0
-httptools==0.2.0; sys_platform != "win32"
-pendulum==2.1.2
-pyaes==1.6.1
-pydal==17.3
-python-rapidjson==1.4
-pyyaml==5.4.1
-renoir==1.4.0
-severus==1.1.1
-uvicorn==0.14.0
-uvloop==0.15.3; sys_platform != "win32"
-websockets==9.1
diff --git a/tests/test_auth.py b/tests/test_auth.py
index 346bfa4a..b4debf9b 100644
--- a/tests/test_auth.py
+++ b/tests/test_auth.py
@@ -29,6 +29,7 @@ class Thing(Model):
def app():
rv = App(__name__)
rv.config.mailer.sender = 'nina@massivedynamics.com'
+ rv.config.mailer.suppress = True
rv.config.auth.single_template = True
rv.config.auth.hmac_key = "foobar"
rv.pipeline = [SessionManager.cookies('foobar')]
diff --git a/tests/test_migrations.py b/tests/test_migrations.py
index 130a9a29..9b774177 100644
--- a/tests/test_migrations.py
+++ b/tests/test_migrations.py
@@ -6,10 +6,12 @@
Test Emmett migrations engine
"""
+import uuid
+
import pytest
from emmett import App
-from emmett.orm import Database, Model, Field
+from emmett.orm import Database, Model, Field, belongs_to, refers_to
from emmett.orm.migrations.engine import MetaEngine, Engine
from emmett.orm.migrations.generation import MetaData, Comparator
@@ -196,6 +198,7 @@ def test_step_four_alter_table(app):
class StepFiveThing(Model):
name = Field()
+ code = Field(unique=True)
value = Field.int()
created_at = Field.datetime()
@@ -217,6 +220,7 @@ class StepFiveThingEdit(StepFiveThing):
_step_five_sql_before = [
+ 'CREATE UNIQUE INDEX "step_five_things_widx__code_unique" ON "step_five_things" ("code");',
'CREATE INDEX "step_five_things_widx__name" ON "step_five_things" ("name");',
'CREATE INDEX "step_five_things_widx__name_value" ON "step_five_things" ("name","value");'
]
@@ -241,3 +245,88 @@ def test_step_five_indexes(app):
for op in ops2.ops:
sql = _make_sql(db, op)
assert sql in _step_five_sql_after
+
+
+class StepSixThing(Model):
+ id = Field()
+ name = Field()
+
+ default_values = {"id": lambda: uuid.uuid4()}
+
+
+class StepSixRelate(Model):
+ belongs_to('step_six_thing')
+ name = Field()
+
+
+_step_six_sql_t1 = """CREATE TABLE "step_six_things"(
+ "id" CHAR(512) PRIMARY KEY,
+ "name" CHAR(512)
+);"""
+_step_six_sql_t2 = """CREATE TABLE "step_six_relates"(
+ "id" INTEGER PRIMARY KEY AUTOINCREMENT,
+ "name" CHAR(512),
+ "step_six_thing" CHAR(512)
+);"""
+_step_six_sql_fk = "".join([
+ 'ALTER TABLE "step_six_relates" ADD CONSTRAINT ',
+ '"step_six_relates_ecnt__fk__stepsixthings_stepsixthing" FOREIGN KEY ',
+ '("step_six_thing") REFERENCES "step_six_things"("id") ON DELETE CASCADE;'
+])
+
+
+def test_step_six_id_types(app):
+ db = Database(app, auto_migrate=False)
+ db.define_models(StepSixThing, StepSixRelate)
+ ops = _make_ops(db)
+
+ assert _make_sql(db, ops.ops[0]) == _step_six_sql_t1
+ assert _make_sql(db, ops.ops[1]) == _step_six_sql_t2
+ assert _make_sql(db, ops.ops[2]) == _step_six_sql_fk
+
+
+class StepSevenThing(Model):
+ primary_keys = ['foo', 'bar']
+
+ foo = Field()
+ bar = Field()
+
+
+class StepSevenRelate(Model):
+ refers_to({'foo': 'StepSevenThing.foo'}, {'bar': 'StepSevenThing.bar'})
+ name = Field()
+
+ foreign_keys = {
+ "test": {
+ "fields": ["foo", "bar"],
+ "foreign_fields": ["foo", "bar"]
+ }
+ }
+
+
+_step_seven_sql_t1 = """CREATE TABLE "step_seven_things"(
+ "foo" CHAR(512),
+ "bar" CHAR(512),
+ PRIMARY KEY("foo", "bar")
+);"""
+_step_seven_sql_t2 = """CREATE TABLE "step_seven_relates"(
+ "id" INTEGER PRIMARY KEY AUTOINCREMENT,
+ "name" CHAR(512),
+ "foo" CHAR(512),
+ "bar" CHAR(512)
+);"""
+_step_seven_sql_fk = "".join([
+ 'ALTER TABLE "step_seven_relates" ADD CONSTRAINT ',
+ '"step_seven_relates_ecnt__fk__stepseventhings_foo_bar" FOREIGN KEY ("foo","bar") ',
+ 'REFERENCES "step_seven_things"("foo","bar") ON DELETE CASCADE;',
+])
+
+
+def test_step_seven_composed_pks(app):
+ db = Database(app, auto_migrate=False)
+ db.define_models(StepSevenThing, StepSevenRelate)
+ ops = _make_ops(db)
+
+ assert _make_sql(db, ops.ops[0]) == _step_seven_sql_t1
+ assert _make_sql(db, ops.ops[1]) == _step_seven_sql_t2
+ assert _make_sql(db, ops.ops[2]) == _step_seven_sql_fk
diff --git a/tests/test_orm.py b/tests/test_orm.py
index d7d6ff00..c31b3691 100644
--- a/tests/test_orm.py
+++ b/tests/test_orm.py
@@ -9,22 +9,52 @@
import pytest
from datetime import datetime, timedelta
+from uuid import uuid4
+
from pydal.objects import Table
from pydal import Field as _Field
-from emmett import App, sdict
+from emmett import App, sdict, now
from emmett.orm import (
Database, Field, Model,
compute,
before_insert, after_insert,
before_update, after_update,
before_delete, after_delete,
+ before_save, after_save,
+ before_destroy, after_destroy,
+ before_commit, after_commit,
rowattr, rowmethod,
has_one, has_many, belongs_to,
scope
)
+from emmett.orm.migrations.utils import generate_runtime_migration
+from emmett.orm.objects import TransactionOps
+from emmett.orm.errors import MissingFieldsForCompute
from emmett.validators import isntEmpty, hasLength
+CALLBACK_OPS = {
+ "before_insert": [],
+ "before_update": [],
+ "before_delete": [],
+ "before_save":[],
+ "before_destroy": [],
+ "after_insert": [],
+ "after_update": [],
+ "after_delete": [],
+ "after_save":[],
+ "after_destroy": []
+}
+COMMIT_CALLBACKS = {
+ "all": [],
+ "insert": [],
+ "update": [],
+ "delete": [],
+ "save": [],
+ "destroy": []
+}
+
+
def _represent_f(value):
return value
@@ -33,32 +63,19 @@ def _widget_f(field, value):
return value
-def _call_bi(fields):
- return fields[:-1]
-
-
-def _call_ai(fields, id):
- return fields[:-1], id + 1
-
-
-def _call_u(set, fields):
- return set, fields[:-1]
-
-
-def _call_d(set):
- return set
-
-
class Stuff(Model):
a = Field.string()
b = Field()
price = Field.float()
quantity = Field.int()
total = Field.float()
+ total_watch = Field.float()
invisible = Field()
validation = {
- "a": {'presence': True}
+ "a": {'presence': True},
+ "total": {"allow": "empty"},
+ "total_watch": {"allow": "empty"}
}
fields_rw = {
@@ -85,36 +102,37 @@ class Stuff(Model):
"a": _widget_f
}
- # def setup(self):
- # self.table.b.requires = notInDb(self.db, self.table.b)
-
@compute('total')
def eval_total(self, row):
return row.price * row.quantity
+ @compute('total_watch', watch=['price', 'quantity'])
+ def eval_total_watch(self, row):
+ return row.price * row.quantity
+
@before_insert
def bi(self, fields):
- return _call_bi(fields)
+ CALLBACK_OPS['before_insert'].append(fields)
@after_insert
def ai(self, fields, id):
- return _call_ai(fields, id)
+ CALLBACK_OPS['after_insert'].append((fields, id))
@before_update
def bu(self, set, fields):
- return _call_u(set, fields)
+ CALLBACK_OPS['before_update'].append((set, fields))
@after_update
def au(self, set, fields):
- return _call_u(set, fields)
+ CALLBACK_OPS['after_update'].append((set, fields))
@before_delete
def bd(self, set):
- return _call_d(set)
+ CALLBACK_OPS['before_delete'].append(set)
@after_delete
def ad(self, set):
- return _call_d(set)
+ CALLBACK_OPS['after_delete'].append(set)
@rowattr('totalv')
def eval_total_v(self, row):
@@ -288,20 +306,135 @@ def filter_status(self, *statuses):
return self.status.belongs(*[self.STATUS[v] for v in statuses])
+class Product(Model):
+ name = Field.string()
+ price = Field.float(default=0.0)
+
+
+class Cart(Model):
+ has_many({"elements": "CartElement"})
+
+ updated_at = Field.datetime(default=now, update=now)
+ total_denorm = Field.float(default=0.0)
+ revision = Field.string(default=lambda: uuid4().hex, update=lambda: uuid4().hex)
+
+ def _sum_elements(self, row):
+ summable = (CartElement.quantity.cast("float") * Product.price).sum()
+ sum = row.elements.join("product").select(summable).first()
+ return sum[summable] or 0.0
+
+ @before_save
+ def _rebuild_total(self, row):
+ row.total_denorm = self._sum_elements(row)
+
+ @rowattr("total")
+ def _compute_total(self, row):
+ return self._sum_elements(row)
+
+
+class CartElement(Model):
+ belongs_to("product", "cart")
+
+ updated_at = Field.datetime(default=now, update=now)
+ quantity = Field.int(default=1)
+ price_denorm = Field.float(default=0.0)
+
+ @before_save
+ def _rebuild_price(self, row):
+ row.price_denorm = row.quantity * row.product.price
+
+ @after_save
+ def _refresh_cart_after_update(self, row):
+ row.cart.save()
+
+ @before_destroy
+ def _undo_quantity_on_removal(self, row):
+ row.quantity = 0
+ row.price_denorm = 0
+
+ @after_destroy
+ def _refresh_cart_after_removal(self, row):
+ row.cart.save()
+
+ @rowattr("price")
+ def _compute_price(self, row):
+ return row.quantity * row.product.price
+
+
+class CustomPKType(Model):
+ id = Field.string()
+
+
+class CustomPKName(Model):
+ primary_keys = ["name"]
+ name = Field.string()
+
+
+class CustomPKMulti(Model):
+ primary_keys = ["first_name", "last_name"]
+ first_name = Field.string()
+ last_name = Field.string()
+
+
+class CommitWatcher(Model):
+ foo = Field.string()
+ created_at = Field.datetime(default=now)
+ updated_at = Field.datetime(default=now, update=now)
+
+ @before_commit
+ def _commit_watch_before(self, op_type, ctx):
+ COMMIT_CALLBACKS["all"].append(("before", op_type, ctx))
+
+ @after_commit
+ def _commit_watch_after(self, op_type, ctx):
+ COMMIT_CALLBACKS["all"].append(("after", op_type, ctx))
+
+ @before_commit.operation(TransactionOps.save)
+ def _commit_watch_before_save(self, ctx):
+ COMMIT_CALLBACKS["save"].append(("before", ctx))
+
+ @after_commit.operation(TransactionOps.save)
+ def _commit_watch_after_save(self, ctx):
+ COMMIT_CALLBACKS["save"].append(("after", ctx))
+
+ @before_commit.operation(TransactionOps.destroy)
+ def _commit_watch_before_destroy(self, ctx):
+ COMMIT_CALLBACKS["destroy"].append(("before", ctx))
+
+ @after_commit.operation(TransactionOps.destroy)
+ def _commit_watch_after_destroy(self, ctx):
+ COMMIT_CALLBACKS["destroy"].append(("after", ctx))
+
+
@pytest.fixture(scope='module')
-def db():
+def _db():
app = App(__name__)
db = Database(
app, config=sdict(
- uri='sqlite://dal.db', auto_connect=True, auto_migrate=True))
- db.define_models([
- Stuff, Person, Thing, Feature, Price, Doctor, Patient, Appointment,
- User, Organization, Membership, House, Mouse, NeedSplit, Zoo, Animal,
- Elephant, Dog, Subscription
- ])
+ uri=f'sqlite://{uuid4().hex}.db',
+ auto_connect=True
+ )
+ )
+ db.define_models(
+ Stuff, Person, Thing, Feature, Price, Dog, Subscription,
+ Doctor, Patient, Appointment,
+ User, Organization, Membership,
+ House, Mouse, NeedSplit, Zoo, Animal, Elephant,
+ Product, Cart, CartElement,
+ CustomPKType, CustomPKName, CustomPKMulti,
+ CommitWatcher
+ )
return db
+@pytest.fixture(scope='function')
+def db(_db):
+ migration = generate_runtime_migration(_db)
+ migration.up()
+ yield _db
+ migration.down()
+
+
def test_db_instance(db):
assert isinstance(db, Database)
@@ -354,40 +487,434 @@ def test_widgets(db):
def test_computations(db):
+ #: no watch
row = sdict(price=12.95, quantity=3)
rv = db.Stuff.total.compute(row)
assert rv == 12.95 * 3
+ #: watch fulfill
+ row = sdict(price=12.95, quantity=3)
+ rv = db.Stuff.total_watch.compute(row)
+ assert rv == 12.95 * 3
+ #: watch missing field
+ row = sdict(price=12.95)
+ with pytest.raises(MissingFieldsForCompute):
+ db.Stuff.total_watch.compute(row)
+ #: update flow
+ res = Stuff.create(a="foo", price=12.95, quantity=1)
+ row = Stuff.get(res.id)
+ with pytest.raises(MissingFieldsForCompute):
+ row.update_record(quantity=2)
+ row.update_record(price=row.price, quantity=2)
+ assert row.total == row.price * 2
+ assert row.total_watch == row.price * 2
def test_callbacks(db):
- fields = ["a", "b", "c"]
+ fields = {"a": 1, "b": 2, "c": 3}
id = 12
- rv = db.Stuff._before_insert[-1](fields)
- assert rv == fields[:-1]
- rv = db.Stuff._after_insert[-1](fields, id)
- assert rv[0] == fields[:-1] and rv[1] == id + 1
+ db.Stuff._before_insert[-1](fields)
+ assert CALLBACK_OPS["before_insert"][-1] == fields
+ db.Stuff._after_insert[-1](fields, id)
+ res = CALLBACK_OPS["after_insert"][-1]
+ assert res[0] == fields and res[1] == id
set = {"a": "b"}
- rv = db.Stuff._before_update[-1](set, fields)
- assert rv[0] == set and rv[1] == fields[:-1]
- rv = db.Stuff._after_update[-1](set, fields)
- assert rv[0] == set and rv[1] == fields[:-1]
- rv = db.Stuff._before_delete[-1](set)
- assert rv == set
- rv = db.Stuff._after_delete[-1](set)
- assert rv == set
+ db.Stuff._before_update[-1](set, fields)
+ res = CALLBACK_OPS["before_update"][-1]
+ assert res[0] == set and res[1] == fields
+ db.Stuff._after_update[-1](set, fields)
+ res = CALLBACK_OPS["after_update"][-1]
+ assert res[0] == set and res[1] == fields
+ db.Stuff._before_delete[-1](set)
+ res = CALLBACK_OPS["before_delete"][-1]
+ assert res == set
+ db.Stuff._after_delete[-1](set)
+ res = CALLBACK_OPS["after_delete"][-1]
+ assert res == set
+
+
+def test_save(db):
+ p1 = db.Product.insert(name="foo", price=2.99)
+ p2 = db.Product.insert(name="bar", price=7.49)
+ cart = db.Cart.insert()
+ assert cart.total == 0
+ assert cart.total_denorm == 0
+
+ cart_rev = cart.revision
+ item = CartElement.new(cart=cart, product=p1)
+ item.save()
+ assert item.price == p1.price
+ assert item.price_denorm == p1.price
+ cart = Cart.get(cart.id)
+ assert cart.total == p1.price
+ assert cart.total_denorm == p1.price
+ assert cart.revision != cart_rev
+
+ cart_rev = cart.revision
+ item = CartElement.new(cart=cart, product=p2, quantity=3)
+ item.save()
+ assert item.price == p2.price * 3
+ assert item.price_denorm == p2.price * 3
+ cart = Cart.get(cart.id)
+ assert cart.total == p1.price + p2.price * 3
+ assert cart.total_denorm == p1.price + p2.price * 3
+ assert cart.revision != cart_rev
+
+
+def test_destroy(db):
+ p1 = db.Product.insert(name="foo", price=2.99)
+ p2 = db.Product.insert(name="bar", price=7.49)
+ cart = db.Cart.insert()
+
+ item = CartElement.new(cart=cart, product=p1)
+ item.save()
+ item = CartElement.new(cart=cart, product=p2, quantity=3)
+ item.save()
+
+ cart = Cart.get(cart.id)
+ cart_rev = cart.revision
+
+ item.destroy()
+ assert not item.id
+ assert not item.price
+ assert not item.price_denorm
+ cart = Cart.get(cart.id)
+ assert cart.total == p1.price
+ assert cart.total_denorm == p1.price
+ assert cart.revision != cart_rev
+
+
+def test_commit_callbacks(db):
+ #: insert
+ row = db.CommitWatcher.insert(foo="test1")
+ assert not COMMIT_CALLBACKS["all"]
+ db.commit()
+
+ assert len(COMMIT_CALLBACKS["all"]) == 2
+
+ before, after = COMMIT_CALLBACKS["all"]
+
+ order, op_type, ctx = before
+ assert order == "before"
+ assert op_type == TransactionOps.insert
+ assert ctx.values.foo == "test1"
+ assert ctx.return_value == row.id
+
+ order, op_type, ctx = after
+ assert order == "after"
+ assert op_type == TransactionOps.insert
+ assert ctx.values.foo == "test1"
+ assert ctx.return_value == row.id
+
+ COMMIT_CALLBACKS["all"].clear()
+
+ #: update
+ row.update_record(foo="test1a")
+ assert not COMMIT_CALLBACKS["all"]
+ db.commit()
+
+ assert len(COMMIT_CALLBACKS["all"]) == 2
+
+ before, after = COMMIT_CALLBACKS["all"]
+
+ order, op_type, ctx = before
+ assert order == "before"
+ assert op_type == TransactionOps.update
+ assert ctx.dbset
+ assert ctx.values.foo == "test1a"
+ assert ctx.return_value == 1
+
+ order, op_type, ctx = after
+ assert order == "after"
+ assert op_type == TransactionOps.update
+ assert ctx.dbset
+ assert ctx.values.foo == "test1a"
+ assert ctx.return_value == 1
+
+ COMMIT_CALLBACKS["all"].clear()
+
+ #: delete
+ row.delete_record()
+ assert not COMMIT_CALLBACKS["all"]
+ db.commit()
+
+ assert len(COMMIT_CALLBACKS["all"]) == 2
+
+ before, after = COMMIT_CALLBACKS["all"]
+
+ order, op_type, ctx = before
+ assert order == "before"
+ assert op_type == TransactionOps.delete
+ assert ctx.dbset
+ assert ctx.return_value == 1
+
+ order, op_type, ctx = after
+ assert order == "after"
+ assert op_type == TransactionOps.delete
+ assert ctx.dbset
+ assert ctx.return_value == 1
+
+ COMMIT_CALLBACKS["all"].clear()
+
+ #: save:insert
+ row = CommitWatcher.new(foo="test2")
+ row.save()
+ assert not COMMIT_CALLBACKS["all"]
+ assert not COMMIT_CALLBACKS["save"]
+ db.commit()
+
+ assert len(COMMIT_CALLBACKS["all"]) == 4
+ assert len(COMMIT_CALLBACKS["save"]) == 2
+
+ before_ins, before_save, after_ins, after_save = COMMIT_CALLBACKS["all"]
+
+ order, op_type, ctx = before_ins
+ assert order == "before"
+ assert op_type == TransactionOps.insert
+ assert ctx.values.foo == "test2"
+ assert ctx.return_value == row.id
+
+ order, op_type, ctx = after_ins
+ assert order == "after"
+ assert op_type == TransactionOps.insert
+ assert ctx.values.foo == "test2"
+ assert ctx.return_value == row.id
+
+ order, op_type, ctx = before_save
+ assert order == "before"
+ assert op_type == TransactionOps.save
+ assert ctx.values.foo == "test2"
+ assert ctx.return_value == row.id
+ assert ctx.row.id == row.id
+ assert "id" in ctx.changes
+
+ order, op_type, ctx = after_save
+ assert order == "after"
+ assert op_type == TransactionOps.save
+ assert ctx.values.foo == "test2"
+ assert ctx.return_value == row.id
+ assert ctx.row.id == row.id
+ assert "id" in ctx.changes
+
+ before_save, after_save = COMMIT_CALLBACKS["save"]
+
+ order, ctx = before_save
+ assert order == "before"
+ assert ctx.values.foo == "test2"
+ assert ctx.return_value == row.id
+ assert ctx.row.id == row.id
+ assert "id" in ctx.changes
+
+ order, ctx = after_save
+ assert order == "after"
+ assert ctx.values.foo == "test2"
+ assert ctx.return_value == row.id
+ assert ctx.row.id == row.id
+ assert "id" in ctx.changes
+
+ COMMIT_CALLBACKS["all"].clear()
+ COMMIT_CALLBACKS["save"].clear()
+
+ #: save:update
+ row.foo = "test2a"
+ row.save()
+ assert not COMMIT_CALLBACKS["all"]
+ assert not COMMIT_CALLBACKS["save"]
+ db.commit()
+
+ assert len(COMMIT_CALLBACKS["all"]) == 4
+ assert len(COMMIT_CALLBACKS["save"]) == 2
+
+ before_upd, before_save, after_upd, after_save = COMMIT_CALLBACKS["all"]
+
+ order, op_type, ctx = before_upd
+ assert order == "before"
+ assert op_type == TransactionOps.update
+ assert ctx.dbset
+ assert ctx.values.foo == "test2a"
+ assert ctx.return_value == 1
+
+ order, op_type, ctx = after_upd
+ assert order == "after"
+ assert op_type == TransactionOps.update
+ assert ctx.dbset
+ assert ctx.values.foo == "test2a"
+ assert ctx.return_value == 1
+
+ order, op_type, ctx = before_save
+ assert order == "before"
+ assert op_type == TransactionOps.save
+ assert ctx.dbset
+ assert ctx.values.foo == "test2a"
+ assert ctx.return_value == 1
+ assert ctx.row.id == row.id
+ assert set(ctx.changes.keys()).issubset({"foo", "updated_at"})
+
+ order, op_type, ctx = after_save
+ assert order == "after"
+ assert op_type == TransactionOps.save
+ assert ctx.dbset
+ assert ctx.values.foo == "test2a"
+ assert ctx.return_value == 1
+ assert ctx.row.id == row.id
+ assert set(ctx.changes.keys()).issubset({"foo", "updated_at"})
+
+ before_save, after_save = COMMIT_CALLBACKS["save"]
+
+ order, ctx = before_save
+ assert order == "before"
+ assert ctx.dbset
+ assert ctx.values.foo == "test2a"
+ assert ctx.return_value == 1
+ assert ctx.row.id == row.id
+ assert set(ctx.changes.keys()).issubset({"foo", "updated_at"})
+
+ order, ctx = after_save
+ assert order == "after"
+ assert ctx.dbset
+ assert ctx.values.foo == "test2a"
+ assert ctx.return_value == 1
+ assert ctx.row.id == row.id
+ assert set(ctx.changes.keys()).issubset({"foo", "updated_at"})
+
+ COMMIT_CALLBACKS["all"].clear()
+ COMMIT_CALLBACKS["save"].clear()
+
+ #: destroy
+ row.destroy()
+ assert not COMMIT_CALLBACKS["all"]
+ assert not COMMIT_CALLBACKS["destroy"]
+ db.commit()
+
+ assert len(COMMIT_CALLBACKS["all"]) == 4
+ assert len(COMMIT_CALLBACKS["destroy"]) == 2
+
+ before_del, before_destroy, after_del, after_destroy = COMMIT_CALLBACKS["all"]
+
+ order, op_type, ctx = before_del
+ assert order == "before"
+ assert op_type == TransactionOps.delete
+ assert ctx.dbset
+ assert ctx.return_value == 1
+
+ order, op_type, ctx = after_del
+ assert order == "after"
+ assert op_type == TransactionOps.delete
+ assert ctx.dbset
+ assert ctx.return_value == 1
+
+ order, op_type, ctx = before_destroy
+ assert order == "before"
+ assert op_type == TransactionOps.destroy
+ assert ctx.dbset
+ assert ctx.return_value == 1
+ assert ctx.row.id == row.id
+
+ order, op_type, ctx = after_destroy
+ assert order == "after"
+ assert op_type == TransactionOps.destroy
+ assert ctx.dbset
+ assert ctx.return_value == 1
+ assert ctx.row.id == row.id
+
+ before_destroy, after_destroy = COMMIT_CALLBACKS["destroy"]
+
+ order, ctx = before_destroy
+ assert order == "before"
+ assert ctx.dbset
+ assert ctx.return_value == 1
+ assert ctx.row.id == row.id
+
+ order, ctx = after_destroy
+ assert order == "after"
+ assert ctx.dbset
+ assert ctx.return_value == 1
+ assert ctx.row.id == row.id
+
+ COMMIT_CALLBACKS["all"].clear()
+ COMMIT_CALLBACKS["destroy"].clear()
+
+
+def test_callbacks_skip(db):
+ for stack in CALLBACK_OPS.values():
+ stack.clear()
+ for stack in COMMIT_CALLBACKS.values():
+ stack.clear()
+
+ #: insert
+ row = db.CommitWatcher.insert(foo="test1", skip_callbacks=True)
+ db.commit()
+ assert not CALLBACK_OPS["before_insert"]
+ assert not CALLBACK_OPS["after_insert"]
+ assert not COMMIT_CALLBACKS["all"]
+ assert not COMMIT_CALLBACKS["insert"]
+
+ #: update
+ row.update_record(foo="test1a", skip_callbacks=True)
+ db.commit()
+ assert not CALLBACK_OPS["before_update"]
+ assert not CALLBACK_OPS["after_update"]
+ assert not COMMIT_CALLBACKS["all"]
+ assert not COMMIT_CALLBACKS["update"]
+
+ #: delete
+ row.delete_record(skip_callbacks=True)
+ db.commit()
+ assert not CALLBACK_OPS["before_delete"]
+ assert not CALLBACK_OPS["after_delete"]
+ assert not COMMIT_CALLBACKS["all"]
+ assert not COMMIT_CALLBACKS["delete"]
+
+ #: save:insert
+ row = CommitWatcher.new(foo="test2")
+ row.save(skip_callbacks=True)
+ db.commit()
+ assert not CALLBACK_OPS["before_save"]
+ assert not CALLBACK_OPS["after_save"]
+ assert not CALLBACK_OPS["before_insert"]
+ assert not CALLBACK_OPS["after_insert"]
+ assert not COMMIT_CALLBACKS["all"]
+ assert not COMMIT_CALLBACKS["save"]
+ assert not COMMIT_CALLBACKS["insert"]
+
+ #: save:update
+ row.foo = "test2a"
+ row.save(skip_callbacks=True)
+ db.commit()
+ assert not CALLBACK_OPS["before_save"]
+ assert not CALLBACK_OPS["after_save"]
+ assert not CALLBACK_OPS["before_update"]
+ assert not CALLBACK_OPS["after_update"]
+ assert not COMMIT_CALLBACKS["all"]
+ assert not COMMIT_CALLBACKS["save"]
+ assert not COMMIT_CALLBACKS["update"]
+
+ #: destroy
+ row.destroy(skip_callbacks=True)
+ db.commit()
+ assert not CALLBACK_OPS["before_destroy"]
+ assert not CALLBACK_OPS["after_destroy"]
+ assert not CALLBACK_OPS["before_delete"]
+ assert not CALLBACK_OPS["after_delete"]
+ assert not COMMIT_CALLBACKS["all"]
+ assert not COMMIT_CALLBACKS["destroy"]
+ assert not COMMIT_CALLBACKS["delete"]
def test_rowattrs(db):
db.Stuff._before_insert = []
db.Stuff._after_insert = []
- db.Stuff.insert(a="foo", b="bar", price=12.95, quantity=3)
+ res = db.Stuff.insert(a="foo", b="bar", price=12.95, quantity=3)
db.commit()
- row = db(db.Stuff).select().first()
+ row = Stuff.get(res)
assert row.totalv == 12.95 * 3
def test_rowmethods(db):
- row = db(db.Stuff).select().first()
+ db.Stuff._before_insert = []
+ db.Stuff._after_insert = []
+ res = db.Stuff.insert(a="foo", b="bar", price=12.95, quantity=3)
+ db.commit()
+ row = Stuff.get(res)
assert row.totalm() == 12.95 * 3
@@ -398,25 +925,40 @@ def test_modelmethods(db):
def test_relations(db):
- p = db.Person.insert(name="Giovanni", age=25)
- t = db.Thing.insert(name="apple", color="red", person=p)
- f = db.Feature.insert(name="tasty", thing=t)
+ p1 = db.Person.insert(name="Giovanni", age=25)
+ p2 = db.Person.insert(name="Giorgio", age=30)
+ t1 = db.Thing.insert(name="apple", color="red", person=p1)
+ t2 = db.Thing.insert(name="apple", color="green", person=p1)
+ f = db.Feature.insert(name="tasty", thing=t1)
db.Price.insert(value=5, feature=f)
- p = db.Person(name="Giovanni")
#: belongs, has_one, has_many
- t = p.things()
+ p1 = db.Person(name="Giovanni")
+ p2 = db.Person(name="Giorgio")
+ assert p1.things.count() == 2
+ assert p2.things.count() == 0
+ t = p1.things.where(lambda t: t.color == "red").select()
assert len(t) == 1
- assert t[0].name == "apple" and t[0].color == "red" and \
- t[0].person.id == p.id
- f = p.things()[0].features()
+ assert t[0].name == "apple" and t[0].color == "red" and t[0].person.id == p1.id
+ f = p1.things()[0].features()
assert len(f) == 1
assert f[0].name == "tasty" and f[0].thing.id == t[0].id and \
- f[0].thing.person.id == p.id
- m = p.things()[0].features()[0].price()
- assert m.value == 5 and m.feature.id == f[0].id and \
- m.feature.thing.id == t[0].id and m.feature.thing.person.id == p.id
+ f[0].thing.person.id == p1.id
+ m = p1.things()[0].features()[0].price()
+ assert (
+ m.value == 5 and m.feature.id == f[0].id and
+ m.feature.thing.id == t[0].id and m.feature.thing.person.id == p1.id
+ )
+ p2.things.add(t2)
+ assert p1.things.count() == 1
+ assert p2.things.count() == 1
+ t = p2.things()
+ assert len(t) == 1
+ assert t[0].name == "apple" and t[0].color == "green" and t[0].person.id == p2.id
+ p2.things.remove(t2)
+ assert p2.things.count() == 0
+ assert db(db.things).count() == 1
#: has_many via as shortcut
- assert len(p.features()) == 1
+ assert len(p1.features()) == 1
#: has_many via with join tables logic
doctor = db.Doctor.insert(name="cox")
patient = db.Patient.insert(name="mario")
@@ -437,9 +979,14 @@ def test_relations(db):
assert jim.organizations().first().id == org
assert joe.memberships().first().role == 'admin'
assert jim.memberships().first().role == 'manager'
+ org.users.remove(joe)
+ org.users.remove(jim)
+ assert len(org.users(reload=True)) == 0
+ assert len(joe.organizations(reload=True)) == 0
+ assert len(jim.organizations(reload=True)) == 0
#: has_many with specified feld
- db.Dog.insert(name='pongo', owner=p)
- assert len(p.pets()) == 1 and p.pets().first().name == 'pongo'
+ db.Dog.insert(name='pongo', owner=p1)
+ assert len(p1.pets()) == 1 and p1.pets().first().name == 'pongo'
#: has_many via with specified field
zoo = db.Zoo.insert(name='magic zoo')
mouse = db.Mouse.insert(name='jerry')
@@ -504,32 +1051,127 @@ def test_relations_scopes(db):
assert org.admins.count() == 1
assert org.admins2.count() == 1
assert org.admins3.count() == 1
+ org.users.remove(gus)
+ org.users.remove(frank)
+ assert org.admins.count() == 0
+ assert org.admins2.count() == 0
+ assert org.admins3.count() == 0
org2 = db.Organization.insert(name="Laundry", is_cover=True)
org2.users.add(gus, role="admin")
- assert len(gus.cover_orgs()) == 1
+ assert gus.cover_orgs.count() == 1
assert gus.cover_orgs().first().id == org2
+ org2.users.remove(gus)
+ assert gus.cover_orgs.count() == 0
org.delete_record()
org2.delete_record()
#: creation/addition
org = db.Organization.insert(name="Los pollos hermanos")
org.admins.add(gus)
assert org.admins.count() == 1
+ org.admins.remove(gus)
+ assert org.admins.count() == 0
org.delete_record()
org = db.Organization.insert(name="Los pollos hermanos")
org.admins2.add(gus)
assert org.admins2.count() == 1
+ org.admins2.remove(gus)
+ assert org.admins2.count() == 0
org.delete_record()
org = db.Organization.insert(name="Los pollos hermanos")
org.admins3.add(gus)
assert org.admins3.count() == 1
+ org.admins3.remove(gus)
+ assert org.admins3.count() == 0
org.delete_record()
gus = User.get(name="Gus Fring")
org2 = db.Organization.insert(name="Laundry", is_cover=True)
gus.cover_orgs.add(org2)
- assert len(gus.cover_orgs()) == 1
+ assert gus.cover_orgs.count() == 1
assert gus.cover_orgs().first().id == org2
+ gus.cover_orgs.remove(org2)
+ assert gus.cover_orgs.count() == 0
def test_model_where(db):
assert Subscription.where(lambda s: s.status == 1).query == \
db(db.Subscription.status == 1).query
+
+
+def test_model_first(db):
+ p = db.Person.insert(name="Walter", age=50)
+ db.Subscription.insert(
+ name="a",
+ expires_at=datetime.now() + timedelta(hours=20),
+ person=p,
+ status=1
+ )
+ db.Subscription.insert(
+ name="b",
+ expires_at=datetime.now() + timedelta(hours=20),
+ person=p,
+ status=1
+ )
+ db.CustomPKType.insert(id="a")
+ db.CustomPKType.insert(id="b")
+ db.CustomPKName.insert(name="a")
+ db.CustomPKName.insert(name="b")
+ db.CustomPKMulti.insert(first_name="foo", last_name="bar")
+ db.CustomPKMulti.insert(first_name="foo", last_name="baz")
+ db.CustomPKMulti.insert(first_name="bar", last_name="baz")
+
+ assert Subscription.first().id == Subscription.all().select(
+ orderby=Subscription.id,
+ limitby=(0, 1)
+ ).first().id
+ assert CustomPKType.first().id == CustomPKType.all().select(
+ orderby=CustomPKType.id,
+ limitby=(0, 1)
+ ).first().id
+ assert CustomPKName.first().name == CustomPKName.all().select(
+ orderby=CustomPKName.name,
+ limitby=(0, 1)
+ ).first().name
+ assert CustomPKMulti.first() == CustomPKMulti.all().select(
+ orderby=CustomPKMulti.first_name|CustomPKMulti.last_name,
+ limitby=(0, 1)
+ ).first()
+
+
+def test_model_last(db):
+ p = db.Person.insert(name="Walter", age=50)
+ db.Subscription.insert(
+ name="a",
+ expires_at=datetime.now() + timedelta(hours=20),
+ person=p,
+ status=1
+ )
+ db.Subscription.insert(
+ name="b",
+ expires_at=datetime.now() + timedelta(hours=20),
+ person=p,
+ status=1
+ )
+ db.CustomPKType.insert(id="a")
+ db.CustomPKType.insert(id="b")
+ db.CustomPKName.insert(name="a")
+ db.CustomPKName.insert(name="b")
+ db.CustomPKMulti.insert(first_name="foo", last_name="bar")
+ db.CustomPKMulti.insert(first_name="foo", last_name="baz")
+ db.CustomPKMulti.insert(first_name="bar", last_name="baz")
+
+ assert Subscription.last().id == Subscription.all().select(
+ orderby=~Subscription.id,
+ limitby=(0, 1)
+ ).first().id
+ assert CustomPKType.last().id == CustomPKType.all().select(
+ orderby=~CustomPKType.id,
+ limitby=(0, 1)
+ ).first().id
+ assert CustomPKName.last().name == CustomPKName.all().select(
+ orderby=~CustomPKName.name,
+ limitby=(0, 1)
+ ).first().name
+ assert CustomPKMulti.last() == CustomPKMulti.all().select(
+ orderby=~CustomPKMulti.first_name|~CustomPKMulti.last_name,
+ limitby=(0, 1)
+ ).first()
diff --git a/tests/test_orm_gis.py b/tests/test_orm_gis.py
new file mode 100644
index 00000000..0da40b86
--- /dev/null
+++ b/tests/test_orm_gis.py
@@ -0,0 +1,290 @@
+# -*- coding: utf-8 -*-
+"""
+ tests.orm_gis
+ -------------
+
+ Test ORM GIS features
+"""
+
+import os
+import pytest
+
+from emmett import App, sdict
+from emmett.orm import Database, Model, Field, geo
+from emmett.orm.migrations.utils import generate_runtime_migration
+
+require_postgres = pytest.mark.skipif(
+ not os.environ.get("POSTGRES_URI"), reason="No postgres database"
+)
+
+
+class Geography(Model):
+ name = Field.string()
+ geo = Field.geography()
+ point = Field.geography("POINT")
+ line = Field.geography("LINESTRING")
+ polygon = Field.geography("POLYGON")
+ multipoint = Field.geography("MULTIPOINT")
+ multiline = Field.geography("MULTILINESTRING")
+ multipolygon = Field.geography("MULTIPOLYGON")
+
+
+class Geometry(Model):
+ name = Field.string()
+ geo = Field.geometry()
+ point = Field.geometry("POINT")
+ line = Field.geometry("LINESTRING")
+ polygon = Field.geometry("POLYGON")
+ multipoint = Field.geometry("MULTIPOINT")
+ multiline = Field.geometry("MULTILINESTRING")
+ multipolygon = Field.geometry("MULTIPOLYGON")
+
+
+@pytest.fixture(scope='module')
+def _db():
+ app = App(__name__)
+ db = Database(
+ app,
+ config=sdict(
+ uri=f"postgres://{os.environ.get('POSTGRES_URI')}"
+ )
+ )
+ db.define_models(
+ Geography,
+ Geometry
+ )
+ return db
+
+
+@pytest.fixture(scope='function')
+def db(_db):
+ migration = generate_runtime_migration(_db)
+ with _db.connection():
+ migration.up()
+ yield _db
+ migration.down()
+
+
+@require_postgres
+def test_field_types(_db):
+ assert Geography.geo.type == "geography(GEOMETRY,4326,2)"
+ assert Geography.point.type == "geography(POINT,4326,2)"
+ assert Geography.line.type == "geography(LINESTRING,4326,2)"
+ assert Geography.polygon.type == "geography(POLYGON,4326,2)"
+ assert Geography.multipoint.type == "geography(MULTIPOINT,4326,2)"
+ assert Geography.multiline.type == "geography(MULTILINESTRING,4326,2)"
+ assert Geography.multipolygon.type == "geography(MULTIPOLYGON,4326,2)"
+ assert Geometry.geo.type == "geometry(GEOMETRY,4326,2)"
+ assert Geometry.point.type == "geometry(POINT,4326,2)"
+ assert Geometry.line.type == "geometry(LINESTRING,4326,2)"
+ assert Geometry.polygon.type == "geometry(POLYGON,4326,2)"
+ assert Geometry.multipoint.type == "geometry(MULTIPOINT,4326,2)"
+ assert Geometry.multiline.type == "geometry(MULTILINESTRING,4326,2)"
+ assert Geometry.multipolygon.type == "geometry(MULTIPOLYGON,4326,2)"
+
+
+@require_postgres
+def test_gis_insert(db):
+ for model in [Geometry, Geography]:
+ row = model.new(
+ point=geo.Point(1, 1),
+ line=geo.Line((0, 0), (20, 80), (80, 80)),
+ polygon=geo.Polygon((0, 0), (150, 0), (150, 10), (0, 10), (0, 0)),
+ multipoint=geo.MultiPoint((1, 1), (2, 2)),
+ multiline=geo.MultiLine(((1, 1), (2, 2), (3, 3)), ((1, 1), (4, 4), (5, 5))),
+ multipolygon=geo.MultiPolygon(
+ (
+ ((0, 0), (20, 0), (20, 20), (0, 0)),
+ ((0, 0), (30, 0), (30, 30), (0, 0))
+ ),
+ (
+ ((1, 1), (21, 1), (21, 21), (1, 1)),
+ ((1, 1), (31, 1), (31, 31), (1, 1))
+ )
+ )
+ )
+ row.save()
+
+ assert row.point == "POINT(1.000000 1.000000)"
+ assert row.point.geometry == "POINT"
+ assert row.point.coordinates == (1, 1)
+ assert not row.point.groups
+
+ assert row.line == "LINESTRING({})".format(
+ ",".join([
+ " ".join(f"{v}.000000" for v in tup)
+ for tup in [
+ (0, 0),
+ (20, 80),
+ (80, 80)
+ ]
+ ])
+ )
+ assert row.line.geometry == "LINESTRING"
+ assert row.line.coordinates == ((0, 0), (20, 80), (80, 80))
+ assert not row.line.groups
+
+ assert row.polygon == "POLYGON(({}))".format(
+ ",".join([
+ " ".join(f"{v}.000000" for v in tup)
+ for tup in [
+ (0, 0),
+ (150, 0),
+ (150, 10),
+ (0, 10),
+ (0, 0)
+ ]
+ ])
+ )
+ assert row.polygon.geometry == "POLYGON"
+ assert row.polygon.coordinates == (
+ ((0, 0), (150, 0), (150, 10), (0, 10), (0, 0)),
+ )
+ assert not row.polygon.groups
+
+ assert row.multipoint == "MULTIPOINT((1.000000 1.000000),(2.000000 2.000000))"
+ assert row.multipoint.geometry == "MULTIPOINT"
+ assert row.multipoint.coordinates == ((1, 1), (2, 2))
+ assert len(row.multipoint.groups) == 2
+ assert row.multipoint.groups[0] == geo.Point(1, 1)
+ assert row.multipoint.groups[1] == geo.Point(2, 2)
+
+ assert row.multiline == "MULTILINESTRING({})".format(
+ ",".join([
+ "({})".format(
+ ",".join([
+ " ".join(f"{v}.000000" for v in tup)
+ for tup in group
+ ])
+ ) for group in [
+ ((1, 1), (2, 2), (3, 3)),
+ ((1, 1), (4, 4), (5, 5))
+ ]
+ ])
+ )
+ assert row.multiline.geometry == "MULTILINESTRING"
+ assert row.multiline.coordinates == (
+ ((1, 1), (2, 2), (3, 3)),
+ ((1, 1), (4, 4), (5, 5))
+ )
+ assert len(row.multiline.groups) == 2
+ assert row.multiline.groups[0] == geo.Line((1, 1), (2, 2), (3, 3))
+ assert row.multiline.groups[1] == geo.Line((1, 1), (4, 4), (5, 5))
+
+ assert row.multipolygon == "MULTIPOLYGON({})".format(
+ ",".join([
+ "({})".format(
+ ",".join([
+ "({})".format(
+ ",".join([
+ " ".join(f"{v}.000000" for v in tup)
+ for tup in group
+ ])
+ ) for group in polygon
+ ])
+ ) for polygon in [
+ (
+ ((0, 0), (20, 0), (20, 20), (0, 0)),
+ ((0, 0), (30, 0), (30, 30), (0, 0))
+ ),
+ (
+ ((1, 1), (21, 1), (21, 21), (1, 1)),
+ ((1, 1), (31, 1), (31, 31), (1, 1))
+ )
+ ]
+ ])
+ )
+ assert row.multipolygon.geometry == "MULTIPOLYGON"
+ assert row.multipolygon.coordinates == (
+ (
+ ((0, 0), (20, 0), (20, 20), (0, 0)),
+ ((0, 0), (30, 0), (30, 30), (0, 0))
+ ),
+ (
+ ((1, 1), (21, 1), (21, 21), (1, 1)),
+ ((1, 1), (31, 1), (31, 31), (1, 1))
+ )
+ )
+ assert len(row.multipolygon.groups) == 2
+ assert row.multipolygon.groups[0] == geo.Polygon(
+ ((0, 0), (20, 0), (20, 20), (0, 0)),
+ ((0, 0), (30, 0), (30, 30), (0, 0))
+ )
+ assert row.multipolygon.groups[1] == geo.Polygon(
+ ((1, 1), (21, 1), (21, 21), (1, 1)),
+ ((1, 1), (31, 1), (31, 31), (1, 1))
+ )
+
+
+@require_postgres
+def test_gis_select(db):
+ for model in [Geometry, Geography]:
+ row = model.new(
+ point=geo.Point(1, 1),
+ line=geo.Line((0, 0), (20, 80), (80, 80)),
+ polygon=geo.Polygon((0, 0), (150, 0), (150, 10), (0, 10), (0, 0)),
+ multipoint=geo.MultiPoint((1, 1), (2, 2)),
+ multiline=geo.MultiLine(((1, 1), (2, 2), (3, 3)), ((1, 1), (4, 4), (5, 5))),
+ multipolygon=geo.MultiPolygon(
+ (
+ ((0, 0), (20, 0), (20, 20), (0, 0)),
+ ((0, 0), (30, 0), (30, 30), (0, 0))
+ ),
+ (
+ ((1, 1), (21, 1), (21, 21), (1, 1)),
+ ((1, 1), (31, 1), (31, 31), (1, 1))
+ )
+ )
+ )
+ row.save()
+ row = model.get(row.id)
+
+ assert row.point.geometry == "POINT"
+ assert row.point.coordinates == (1, 1)
+ assert not row.point.groups
+
+ assert row.line.geometry == "LINESTRING"
+ assert row.line.coordinates == ((0, 0), (20, 80), (80, 80))
+ assert not row.line.groups
+
+ assert row.polygon.geometry == "POLYGON"
+ assert row.polygon.coordinates == (
+ ((0, 0), (150, 0), (150, 10), (0, 10), (0, 0)),
+ )
+ assert not row.polygon.groups
+
+ assert row.multipoint.geometry == "MULTIPOINT"
+ assert row.multipoint.coordinates == ((1, 1), (2, 2))
+ assert len(row.multipoint.groups) == 2
+ assert row.multipoint.groups[0] == geo.Point(1, 1)
+ assert row.multipoint.groups[1] == geo.Point(2, 2)
+
+ assert row.multiline.geometry == "MULTILINESTRING"
+ assert row.multiline.coordinates == (
+ ((1, 1), (2, 2), (3, 3)),
+ ((1, 1), (4, 4), (5, 5))
+ )
+ assert len(row.multiline.groups) == 2
+ assert row.multiline.groups[0] == geo.Line((1, 1), (2, 2), (3, 3))
+ assert row.multiline.groups[1] == geo.Line((1, 1), (4, 4), (5, 5))
+
+ assert row.multipolygon.geometry == "MULTIPOLYGON"
+ assert row.multipolygon.coordinates == (
+ (
+ ((0, 0), (20, 0), (20, 20), (0, 0)),
+ ((0, 0), (30, 0), (30, 30), (0, 0))
+ ),
+ (
+ ((1, 1), (21, 1), (21, 21), (1, 1)),
+ ((1, 1), (31, 1), (31, 31), (1, 1))
+ )
+ )
+ assert len(row.multipolygon.groups) == 2
+ assert row.multipolygon.groups[0] == geo.Polygon(
+ ((0, 0), (20, 0), (20, 20), (0, 0)),
+ ((0, 0), (30, 0), (30, 30), (0, 0))
+ )
+ assert row.multipolygon.groups[1] == geo.Polygon(
+ ((1, 1), (21, 1), (21, 21), (1, 1)),
+ ((1, 1), (31, 1), (31, 31), (1, 1))
+ )
diff --git a/tests/test_orm_pks.py b/tests/test_orm_pks.py
new file mode 100644
index 00000000..9923f227
--- /dev/null
+++ b/tests/test_orm_pks.py
@@ -0,0 +1,747 @@
+# -*- coding: utf-8 -*-
+"""
+ tests.orm_pks
+ -------------
+
+ Test ORM primary keys hendling
+"""
+
+import os
+import pytest
+
+from uuid import uuid4
+
+from emmett import App, sdict
+from emmett.orm import Database, Model, Field, belongs_to, has_many
+from emmett.orm.errors import SaveException
+from emmett.orm.migrations.utils import generate_runtime_migration
+
+require_postgres = pytest.mark.skipif(
+ not os.environ.get("POSTGRES_URI"), reason="No postgres database"
+)
+
+
+class Standard(Model):
+ foo = Field.string()
+ bar = Field.string()
+
+
+class CustomType(Model):
+ id = Field.string()
+ foo = Field.string()
+ bar = Field.string()
+
+
+class CustomName(Model):
+ primary_keys = ["foo"]
+ foo = Field.string()
+ bar = Field.string()
+
+
+class CustomMulti(Model):
+ primary_keys = ["foo", "bar"]
+ foo = Field.string()
+ bar = Field.string()
+ baz = Field.string()
+
+
+class SourceCustom(Model):
+ has_many("dest_custom_customs", "dest_custom_multis")
+
+ id = Field.string(default=lambda: uuid4().hex)
+ foo = Field.string()
+
+
+class SourceMulti(Model):
+ primary_keys = ["foo", "bar"]
+
+ has_many("dest_multi_customs", "dest_multi_multis")
+
+ foo = Field.string(default=lambda: uuid4().hex)
+ bar = Field.string(default=lambda: uuid4().hex)
+ baz = Field.string()
+
+
+class DestCustomCustom(Model):
+ belongs_to("source_custom")
+
+ id = Field.string(default=lambda: uuid4().hex)
+ foo = Field.string()
+
+
+class DestCustomMulti(Model):
+ primary_keys = ["foo", "bar"]
+
+ belongs_to("source_custom")
+
+ foo = Field.string(default=lambda: uuid4().hex)
+ bar = Field.string(default=lambda: uuid4().hex)
+ baz = Field.string()
+
+
+class DestMultiCustom(Model):
+ belongs_to("source_multi")
+
+ id = Field.string(default=lambda: uuid4().hex)
+ foo = Field.string()
+
+
+class DestMultiMulti(Model):
+ primary_keys = ["foo", "bar"]
+
+ belongs_to("source_multi")
+
+ foo = Field.string(default=lambda: uuid4().hex)
+ bar = Field.string(default=lambda: uuid4().hex)
+ baz = Field.string()
+
+
+class DoctorCustom(Model):
+ has_many(
+ {"appointments": "AppointmentCustom"},
+ {"patients": {"via": "appointments.patient_custom"}},
+ {"symptoms_to_treat": {"via": "patients.symptoms"}}
+ )
+
+ id = Field.string(default=lambda: uuid4().hex)
+ name = Field.string()
+
+
+class DoctorMulti(Model):
+ primary_keys = ["foo", "bar"]
+
+ has_many(
+ {"appointments": "AppointmentMulti"},
+ {"patients": {"via": "appointments.patient_multi"}},
+ {"symptoms_to_treat": {"via": "patients.symptoms"}}
+ )
+
+ foo = Field.string(default=lambda: uuid4().hex)
+ bar = Field.string(default=lambda: uuid4().hex)
+ name = Field.string()
+
+
+class PatientCustom(Model):
+ primary_keys = ["code"]
+
+ has_many(
+ {"appointments": "AppointmentCustom"},
+ {"symptoms": "SymptomCustom.patient"},
+ {"doctors": {"via": "appointments.doctor_custom"}}
+ )
+
+ code = Field.string(default=lambda: uuid4().hex)
+ name = Field.string()
+
+
+class PatientMulti(Model):
+ primary_keys = ["foo", "bar"]
+
+ has_many(
+ {"appointments": "AppointmentMulti"},
+ {"symptoms": "SymptomMulti.patient"},
+ {"doctors": {"via": "appointments.doctor_multi"}}
+ )
+
+ foo = Field.string(default=lambda: uuid4().hex)
+ bar = Field.string(default=lambda: uuid4().hex)
+ name = Field.string()
+
+
+class SymptomCustom(Model):
+ belongs_to({"patient": "PatientCustom"})
+
+ id = Field.string(default=lambda: uuid4().hex)
+ name = Field.string()
+
+
+class SymptomMulti(Model):
+ primary_keys = ["foo", "bar"]
+
+ belongs_to({"patient": "PatientMulti"})
+
+ foo = Field.string(default=lambda: uuid4().hex)
+ bar = Field.string(default=lambda: uuid4().hex)
+ name = Field.string()
+
+
+class AppointmentCustom(Model):
+ primary_keys = ["code"]
+
+ belongs_to("patient_custom", "doctor_custom")
+
+ code = Field.string(default=lambda: uuid4().hex)
+ name = Field.string()
+
+
+class AppointmentMulti(Model):
+ primary_keys = ["foo", "bar"]
+
+ belongs_to("patient_multi", "doctor_multi")
+
+ foo = Field.string(default=lambda: uuid4().hex)
+ bar = Field.string(default=lambda: uuid4().hex)
+ name = Field.string()
+
+
+@pytest.fixture(scope='module')
+def _db():
+ app = App(__name__)
+ db = Database(
+ app,
+ config=sdict(
+ uri=f'sqlite://{uuid4().hex}.db',
+ auto_connect=True
+ )
+ )
+ db.define_models(
+ Standard,
+ CustomType,
+ CustomName,
+ CustomMulti
+ )
+ return db
+
+
+@pytest.fixture(scope='module')
+def _pgs():
+ app = App(__name__)
+ db = Database(
+ app,
+ config=sdict(
+ uri=f"postgres://{os.environ.get('POSTGRES_URI')}",
+ auto_connect=True
+ )
+ )
+ db.define_models(
+ SourceCustom,
+ SourceMulti,
+ DestCustomCustom,
+ DestCustomMulti,
+ DestMultiCustom,
+ DestMultiMulti,
+ DoctorCustom,
+ PatientCustom,
+ AppointmentCustom,
+ DoctorMulti,
+ PatientMulti,
+ AppointmentMulti,
+ SymptomCustom,
+ SymptomMulti
+ )
+ return db
+
+
+@pytest.fixture(scope='function')
+def db(_db):
+ migration = generate_runtime_migration(_db)
+ migration.up()
+ yield _db
+ migration.down()
+
+
+@pytest.fixture(scope='function')
+def pgs(_pgs):
+ migration = generate_runtime_migration(_pgs)
+ migration.up()
+ yield _pgs
+ migration.down()
+
+
+def test_insert(db):
+ res = db.Standard.insert(foo="test1", bar="test2")
+ assert isinstance(res, int)
+ assert res.id
+ assert res.foo == "test1"
+ assert res.bar == "test2"
+
+ res = db.CustomType.insert(id="test1", foo="test2", bar="test3")
+ assert isinstance(res, str)
+ assert res.id == "test1"
+ assert res.foo == "test2"
+ assert res.bar == "test3"
+
+ res = db.CustomName.insert(foo="test1", bar="test2")
+ assert isinstance(res, str)
+ assert not res.id
+ assert res.foo == "test1"
+ assert res.bar == "test2"
+
+ res = db.CustomMulti.insert(foo="test1", bar="test2", baz="test3")
+ assert isinstance(res, tuple)
+ assert not res.id
+ assert res.foo == "test1"
+ assert res.bar == "test2"
+ assert res.baz == "test3"
+
+
+def test_save_insert(db):
+ row = Standard.new(foo="test1", bar="test2")
+ done = row.save()
+ assert done
+ assert row._concrete
+ assert row.id
+ assert type(row.id) == int
+
+ row = CustomType.new(id="test1", foo="test2", bar="test3")
+ done = row.save()
+ assert done
+ assert row._concrete
+ assert row.id == "test1"
+
+ row = CustomName.new(foo="test1", bar="test2")
+ done = row.save()
+ assert done
+ assert row._concrete
+ assert "id" not in row
+ assert row.foo == "test1"
+
+ row = CustomMulti.new(foo="test1", bar="test2", baz="test3")
+ done = row.save()
+ assert done
+ assert row._concrete
+ assert "id" not in row
+ assert row.foo == "test1"
+ assert row.bar == "test2"
+ assert row.baz == "test3"
+
+
+def test_save_update(db):
+ row = Standard.new(foo="test1", bar="test2")
+ row.save()
+ row.bar = "test2a"
+ done = row.save()
+ assert done
+ assert row._concrete
+ assert row.bar == "test2a"
+ row.id = 123
+ done = row.save()
+ assert not done
+ with pytest.raises(SaveException):
+ row.save(raise_on_error=True)
+
+ row = CustomType.new(id="test1", foo="test2", bar="test3")
+ row.save()
+ row.bar = "test2a"
+ done = row.save()
+ assert done
+ assert row._concrete
+ assert row.bar == "test2a"
+ row.id = "test1a"
+ done = row.save()
+ assert not done
+ with pytest.raises(SaveException):
+ row.save(raise_on_error=True)
+
+ row = CustomName.new(foo="test1", bar="test2")
+ row.save()
+ row.bar = "test2a"
+ done = row.save()
+ assert done
+ assert row._concrete
+ assert row.bar == "test2a"
+ row.foo = "test1a"
+ done = row.save()
+ assert not done
+ with pytest.raises(SaveException):
+ row.save(raise_on_error=True)
+
+ row = CustomMulti.new(foo="test1", bar="test2", baz="test3")
+ row.save()
+ row.baz = "test3a"
+ done = row.save()
+ assert done
+ assert row._concrete
+ assert row.baz == "test3a"
+ row.foo = "test1a"
+ done = row.save()
+ assert not done
+ with pytest.raises(SaveException):
+ row.save(raise_on_error=True)
+
+
+def test_destroy_delete(db):
+ row = Standard.new(foo="test1", bar="test2")
+ row.save()
+ done = row.destroy()
+ assert done
+ assert not row._concrete
+ assert row.id is None
+ assert row.foo == "test1"
+
+ row = CustomType.new(id="test1", foo="test2", bar="test3")
+ row.save()
+ done = row.destroy()
+ assert done
+ assert not row._concrete
+ assert row.id is None
+ assert row.foo == "test2"
+
+ row = CustomName.new(foo="test1", bar="test2")
+ row.save()
+ done = row.destroy()
+ assert done
+ assert not row._concrete
+ assert row.foo is None
+ assert row.bar == "test2"
+
+ row = CustomMulti.new(foo="test1", bar="test2", baz="test3")
+ row.save()
+ done = row.destroy()
+ assert done
+ assert not row._concrete
+ assert row.foo is None
+ assert row.bar is None
+ assert row.baz == "test3"
+
+
+@require_postgres
+def test_relations(pgs):
+ sc1 = SourceCustom.new(foo="test1")
+ sc1.save()
+ sc2 = SourceCustom.new(foo="test2")
+ sc2.save()
+ sm1 = SourceMulti.new(baz="test1")
+ sm1.save()
+ sm2 = SourceMulti.new(baz="test2")
+ sm2.save()
+
+ #: create
+ dcc1 = sc1.dest_custom_customs.create(foo="test")
+ assert isinstance(dcc1.id, str)
+ row = sc1.dest_custom_customs().first()
+ assert row.foo == "test"
+ rc = DestCustomCustom.get(row.id)
+ assert rc.foo == row.foo
+ assert isinstance(rc.source_custom, str)
+
+ dcm1 = sc1.dest_custom_multis.create(baz="test")
+ assert isinstance(dcm1.id, tuple)
+ row = sc1.dest_custom_multis().first()
+ assert row.baz == "test"
+ rc = DestCustomMulti.get(row.foo, row.bar)
+ assert rc.foo == row.foo
+ assert rc.bar == row.bar
+ assert isinstance(rc.source_custom, str)
+ rc = DestCustomMulti.get(foo=row.foo, bar=row.bar)
+ assert rc.foo == row.foo
+ assert rc.bar == row.bar
+ assert isinstance(rc.source_custom, str)
+ rc = DestCustomMulti.get((row.foo, row.bar))
+ assert rc.foo == row.foo
+ assert rc.bar == row.bar
+ assert isinstance(rc.source_custom, str)
+ rc = DestCustomMulti.get({"foo": row.foo, "bar": row.bar})
+ assert rc.foo == row.foo
+ assert rc.bar == row.bar
+ assert isinstance(rc.source_custom, str)
+
+ dmc1 = sm1.dest_multi_customs.create(foo="test")
+ assert isinstance(dmc1.id, str)
+ row = sm1.dest_multi_customs().first()
+ assert row.foo == "test"
+ rc = DestMultiCustom.get(row.id)
+ assert rc.foo == row.foo
+ assert isinstance(rc.source_multi, tuple)
+ assert rc.source_multi.foo == rc.source_multi_foo
+ assert rc.source_multi.bar == rc.source_multi_bar
+ assert rc.source_multi.baz == "test1"
+
+ dmm1 = sm1.dest_multi_multis.create(baz="test")
+ assert isinstance(dmm1.id, tuple)
+ row = sm1.dest_multi_multis().first()
+ assert row.baz == "test"
+ rc = DestMultiMulti.get(row.foo, row.bar)
+ assert rc.foo == row.foo
+ assert rc.bar == row.bar
+ assert isinstance(rc.source_multi, tuple)
+ assert rc.source_multi.foo == rc.source_multi_foo
+ assert rc.source_multi.bar == rc.source_multi_bar
+ assert rc.source_multi.baz == "test1"
+ rc = DestMultiMulti.get(foo=row.foo, bar=row.bar)
+ assert rc.foo == row.foo
+ assert rc.bar == row.bar
+ assert isinstance(rc.source_multi, tuple)
+ assert rc.source_multi.foo == rc.source_multi_foo
+ assert rc.source_multi.bar == rc.source_multi_bar
+ assert rc.source_multi.baz == "test1"
+ rc = DestMultiMulti.get((row.foo, row.bar))
+ assert rc.foo == row.foo
+ assert rc.bar == row.bar
+ assert isinstance(rc.source_multi, tuple)
+ assert rc.source_multi.foo == rc.source_multi_foo
+ assert rc.source_multi.bar == rc.source_multi_bar
+ assert rc.source_multi.baz == "test1"
+ rc = DestMultiMulti.get({"foo": row.foo, "bar": row.bar})
+ assert rc.foo == row.foo
+ assert rc.bar == row.bar
+ assert isinstance(rc.source_multi, tuple)
+ assert rc.source_multi.foo == rc.source_multi_foo
+ assert rc.source_multi.bar == rc.source_multi_bar
+ assert rc.source_multi.baz == "test1"
+
+ #: add, remove
+ dcc1 = DestCustomCustom.first()
+ sc2.dest_custom_customs.add(dcc1)
+ assert sc1.dest_custom_customs.count() == 0
+ assert sc2.dest_custom_customs.count() == 1
+ assert dcc1.source_custom.id == sc2.id
+ sc2.dest_custom_customs.remove(dcc1)
+ assert sc1.dest_custom_customs.count() == 0
+ assert sc2.dest_custom_customs.count() == 0
+ assert dcc1.source_custom is None
+ assert not dcc1.is_valid
+
+ dcm1 = DestCustomMulti.first()
+ sc2.dest_custom_multis.add(dcm1)
+ assert sc1.dest_custom_multis.count() == 0
+ assert sc2.dest_custom_multis.count() == 1
+ assert dcm1.source_custom.id == sc2.id
+ sc2.dest_custom_multis.remove(dcm1)
+ assert sc1.dest_custom_multis.count() == 0
+ assert sc2.dest_custom_multis.count() == 0
+ assert dcm1.source_custom is None
+ assert not dcm1.is_valid
+
+ dmc1 = DestMultiCustom.first()
+ sm2.dest_multi_customs.add(dmc1)
+ assert sm1.dest_multi_customs.count() == 0
+ assert sm2.dest_multi_customs.count() == 1
+ assert dmc1.source_multi.foo == sm2.foo
+ assert dmc1.source_multi.bar == sm2.bar
+ sm2.dest_multi_customs.remove(dmc1)
+ assert sm1.dest_multi_customs.count() == 0
+ assert sm2.dest_multi_customs.count() == 0
+ assert dmc1.source_multi is None
+ assert not dmc1.is_valid
+
+ dmm1 = DestMultiMulti.first()
+ sm2.dest_multi_multis.add(dmm1)
+ assert sm1.dest_multi_multis.count() == 0
+ assert sm2.dest_multi_multis.count() == 1
+ assert dmm1.source_multi.foo == sm2.foo
+ assert dmm1.source_multi.bar == sm2.bar
+ sm2.dest_multi_multis.remove(dmm1)
+ assert sm1.dest_multi_multis.count() == 0
+ assert sm2.dest_multi_multis.count() == 0
+ assert dmm1.source_multi is None
+ assert not dmm1.is_valid
+
+
+@require_postgres
+def test_via_relations(pgs):
+ doc1 = DoctorCustom.new(name="test1")
+ doc1.save()
+ doc2 = DoctorCustom.new(name="test2")
+ doc2.save()
+ pat1 = PatientCustom.new(name="test1")
+ pat1.save()
+ pat1.symptoms.create(name="test1a")
+ pat2 = PatientCustom.new(name="test2")
+ pat2.save()
+ pat2.symptoms.create(name="test2a")
+ pat2.symptoms.create(name="test2b")
+ doc3 = DoctorMulti.new(name="test1")
+ doc3.save()
+ doc4 = DoctorMulti.new(name="test2")
+ doc4.save()
+ pat3 = PatientMulti.new(name="test1")
+ pat3.save()
+ pat3.symptoms.create(name="test3a")
+ pat3.symptoms.create(name="test3b")
+ pat4 = PatientMulti.new(name="test2")
+ pat4.save()
+ pat4.symptoms.create(name="test4a")
+
+ #: add, remove
+ doc1.patients.add(pat1, name="test1")
+ doc2.patients.add(pat2, name="test2")
+ assert doc1.patients.count() == 1
+ assert doc1.patients.count() == 1
+ assert doc1.symptoms_to_treat.count() == 1
+ assert doc2.symptoms_to_treat.count() == 2
+ doc1.patients.add(pat2, name="test2")
+ assert doc1.patients.count() == 2
+ assert doc2.patients.count() == 1
+ assert doc1.symptoms_to_treat.count() == 3
+ assert doc2.symptoms_to_treat.count() == 2
+ doc2.patients.remove(pat2)
+ assert doc1.patients.count() == 2
+ assert doc2.patients.count() == 0
+ assert doc1.symptoms_to_treat.count() == 3
+ assert doc2.symptoms_to_treat.count() == 0
+
+ doc3.patients.add(pat3, name="test1")
+ doc4.patients.add(pat4, name="test2")
+ assert doc3.patients.count() == 1
+ assert doc3.patients.count() == 1
+ assert doc3.symptoms_to_treat.count() == 2
+ assert doc4.symptoms_to_treat.count() == 1
+ doc3.patients.add(pat4, name="test2")
+ assert doc3.patients.count() == 2
+ assert doc4.patients.count() == 1
+ assert doc3.symptoms_to_treat.count() == 3
+ assert doc4.symptoms_to_treat.count() == 1
+ doc4.patients.remove(pat4)
+ assert doc3.patients.count() == 2
+ assert doc4.patients.count() == 0
+ assert doc3.symptoms_to_treat.count() == 3
+ assert doc4.symptoms_to_treat.count() == 0
+
+
+@require_postgres
+def test_relations_set(pgs):
+ doc1 = DoctorCustom.new(name="test1")
+ doc1.save()
+ doc2 = DoctorCustom.new(name="test2")
+ doc2.save()
+ pat1 = PatientCustom.new(name="test1")
+ pat1.save()
+ pat1.symptoms.create(name="test1a")
+ pat2 = PatientCustom.new(name="test2")
+ pat2.save()
+ pat2.symptoms.create(name="test2a")
+ pat2.symptoms.create(name="test2b")
+ doc3 = DoctorMulti.new(name="test1")
+ doc3.save()
+ doc4 = DoctorMulti.new(name="test2")
+ doc4.save()
+ pat3 = PatientMulti.new(name="test1")
+ pat3.save()
+ pat3.symptoms.create(name="test3a")
+ pat3.symptoms.create(name="test3b")
+ pat4 = PatientMulti.new(name="test2")
+ pat4.save()
+ pat4.symptoms.create(name="test4a")
+
+ doc1.patients.add(pat1, name="test1")
+
+ djoin = DoctorCustom.all().join("appointments").select()
+ assert len(djoin) == 1
+ assert djoin[0].id == doc1.id
+ assert len(djoin[0].appointments()) == 1
+
+ djoin = DoctorCustom.all().join("patients").select()
+ assert len(djoin) == 1
+ assert djoin[0].id == doc1.id
+ assert len(djoin[0].patients()) == 1
+ assert djoin[0].patients()[0].code == pat1.code
+
+ pjoin = PatientCustom.all().join("appointments").select()
+ assert len(pjoin) == 1
+ assert pjoin[0].code == pat1.code
+ assert len(pjoin[0].appointments()) == 1
+
+ pjoin = PatientCustom.all().join("doctors").select()
+ assert len(pjoin) == 1
+ assert pjoin[0].code == pat1.code
+ assert len(pjoin[0].doctors()) == 1
+ assert pjoin[0].doctors()[0].id == doc1.id
+
+ ajoin = AppointmentCustom.all().join("doctor_custom", "patient_custom").select()
+ assert len(ajoin) == 1
+ assert ajoin[0].doctor_custom.id == doc1.id
+ assert ajoin[0].patient_custom.code == pat1.code
+
+ djoin = DoctorCustom.all().select(including=["appointments"])
+ assert len(djoin) == 2
+ assert len(djoin[0].appointments()) == 1
+ assert len(djoin[1].appointments()) == 0
+
+ djoin = DoctorCustom.all().join("appointments").select(including=["patients"])
+ assert len(djoin) == 1
+ assert djoin[0].id == doc1.id
+ assert len(djoin[0].patients()) == 1
+
+ pjoin = PatientCustom.all().select(including=["appointments"])
+ assert len(pjoin) == 2
+ assert len(pjoin[0].appointments()) == 1
+ assert len(pjoin[1].appointments()) == 0
+
+ pjoin = PatientCustom.all().join("appointments").select(including=["doctors"])
+ assert len(pjoin) == 1
+ assert pjoin[0].code == pat1.code
+ assert len(pjoin[0].doctors()) == 1
+
+ doc3.patients.add(pat3, name="test1")
+
+ djoin = DoctorMulti.all().join("appointments").select()
+ assert len(djoin) == 1
+ assert djoin[0].foo == doc3.foo
+ assert djoin[0].bar == doc3.bar
+ assert len(djoin[0].appointments()) == 1
+
+ djoin = DoctorMulti.all().join("patients").select()
+ assert len(djoin) == 1
+ assert djoin[0].foo == doc3.foo
+ assert djoin[0].bar == doc3.bar
+ assert len(djoin[0].patients()) == 1
+ assert djoin[0].patients()[0].foo == pat3.foo
+ assert djoin[0].patients()[0].bar == pat3.bar
+
+ pjoin = PatientMulti.all().join("appointments").select()
+ assert len(pjoin) == 1
+ assert pjoin[0].foo == pat3.foo
+ assert pjoin[0].bar == pat3.bar
+ assert len(pjoin[0].appointments()) == 1
+
+ pjoin = PatientMulti.all().join("doctors").select()
+ assert len(pjoin) == 1
+ assert pjoin[0].foo == pat3.foo
+ assert pjoin[0].bar == pat3.bar
+ assert len(pjoin[0].doctors()) == 1
+ assert pjoin[0].doctors()[0].foo == doc3.foo
+ assert pjoin[0].doctors()[0].bar == doc3.bar
+
+ ajoin = AppointmentMulti.all().join("doctor_multi", "patient_multi").select()
+ assert len(ajoin) == 1
+ assert ajoin[0].doctor_multi.foo == doc3.foo
+ assert ajoin[0].doctor_multi.bar == doc3.bar
+ assert ajoin[0].patient_multi.foo == pat3.foo
+ assert ajoin[0].patient_multi.bar == pat3.bar
+
+ djoin = DoctorMulti.all().select(including=["appointments"])
+ assert len(djoin) == 2
+ assert len(djoin[0].appointments()) == 1
+ assert len(djoin[1].appointments()) == 0
+
+ djoin = DoctorMulti.all().join("appointments").select(including=["patients"])
+ assert len(djoin) == 1
+ assert djoin[0].foo == doc3.foo
+ assert djoin[0].bar == doc3.bar
+ assert len(djoin[0].patients()) == 1
+
+ pjoin = PatientMulti.all().select(including=["appointments"])
+ assert len(pjoin) == 2
+ assert len(pjoin[0].appointments()) == 1
+ assert len(pjoin[1].appointments()) == 0
+
+ pjoin = PatientMulti.all().join("appointments").select(including=["doctors"])
+ assert len(pjoin) == 1
+ assert pjoin[0].foo == pat3.foo
+ assert pjoin[0].bar == pat3.bar
+ assert len(pjoin[0].doctors()) == 1
+
+
+@require_postgres
+def test_row(pgs):
+ sm1 = SourceMulti.new(baz="test1")
+ sm1.save()
+ sm2 = SourceMulti.new(baz="test2")
+ sm2.save()
+
+ dmm1 = DestMultiMulti.new(source_multi=sm1, baz="test")
+ dmm1.save()
+ assert sm1.dest_multi_multis.count() == 1
+ assert sm2.dest_multi_multis.count() == 0
+
+ dmm1.source_multi = sm2
+ assert set(dmm1._changes.keys()).issubset(
+ {"source_multi", "source_multi_foo", "source_multi_bar"}
+ )
+ dmm1.save()
+ assert sm1.dest_multi_multis.count() == 0
+ assert sm2.dest_multi_multis.count() == 1
+
+ DestMultiMulti.create(source_multi=sm1, baz="test")
+ assert sm1.dest_multi_multis.count() == 1
diff --git a/tests/test_orm_row.py b/tests/test_orm_row.py
new file mode 100644
index 00000000..e6579e97
--- /dev/null
+++ b/tests/test_orm_row.py
@@ -0,0 +1,271 @@
+# -*- coding: utf-8 -*-
+"""
+ tests.orm_row
+ -------------
+
+ Test ORM row objects
+"""
+
+import pytest
+
+from uuid import uuid4
+
+from emmett import App, sdict, now
+from emmett.orm import Database, Model, Field, belongs_to, has_many, rowmethod
+from emmett.orm.errors import ValidationError
+from emmett.orm.migrations.utils import generate_runtime_migration
+from emmett.orm.objects import Row
+
+
+class One(Model):
+ has_many("twos")
+
+ foo = Field.string(notnull=True)
+ bar = Field.string()
+
+
+class Two(Model):
+ belongs_to("one")
+
+ foo = Field.string()
+ bar = Field.string()
+
+
+class Override(Model):
+ foo = Field.string()
+ deleted_at = Field.datetime()
+
+ validation = {"deleted_at": {"allow": "empty"}}
+
+ @rowmethod("destroy")
+ def _row_destroy(self, row):
+ row.deleted_at = now()
+ row.save()
+
+ @rowmethod("force_destroy")
+ def _row_force_destroy(self, row):
+ self.super_rowmethod("destroy")(row)
+
+
+@pytest.fixture(scope='module')
+def _db():
+ app = App(__name__)
+ db = Database(
+ app,
+ config=sdict(
+ uri=f'sqlite://{uuid4().hex}.db',
+ auto_connect=True
+ )
+ )
+ db.define_models(One, Two, Override)
+ return db
+
+@pytest.fixture(scope='function')
+def db(_db):
+ migration = generate_runtime_migration(_db)
+ migration.up()
+ yield _db
+ migration.down()
+
+
+def test_rowclass(db):
+ ret = db.One.insert(foo="test1", bar="test2")
+ db.Two.insert(one=ret, foo="test1", bar="test2")
+
+ ret._allocate_()
+ assert type(ret._refrecord) == One._instance_()._rowclass_
+
+ row = One.get(ret.id)
+ assert type(row) == One._instance_()._rowclass_
+
+ row = One.first()
+ assert type(row) == One._instance_()._rowclass_
+
+ row = db(db.One).select().first()
+ assert type(row) == One._instance_()._rowclass_
+
+ row = db(db.One).select(db.One.ALL).first()
+ assert type(row) == One._instance_()._rowclass_
+
+ row = One.all().select().first()
+ assert type(row) == One._instance_()._rowclass_
+
+ row = One.where(lambda m: m.id != None).select().first()
+ assert type(row) == One._instance_()._rowclass_
+
+ row = db(db.One).select().first()
+ assert type(row) == One._instance_()._rowclass_
+
+ row = db(db.One).select(db.One.ALL).first()
+ assert type(row) == One._instance_()._rowclass_
+
+ row = One.all().select(One.bar).first()
+ assert type(row) == Row
+
+ row = db(db.One).select(One.bar).first()
+ assert type(row) == Row
+
+ row = One.all().join("twos").select().first()
+ assert type(row) == One._instance_()._rowclass_
+ assert type(row.twos().first()) == Two._instance_()._rowclass_
+
+ row = One.all().join("twos").select(One.table.ALL, Two.table.ALL).first()
+ assert type(row) == One._instance_()._rowclass_
+ assert type(row.twos().first()) == Two._instance_()._rowclass_
+
+ # row = One.all().join("twos").select(One.table.ALL, Two.foo).first()
+ # assert type(row) == Row
+ # assert type(row.ones) == One._instance_()._rowclass_
+ # assert type(row.twos) == Row
+
+ row = One.all().join("twos").select(One.foo, Two.foo).first()
+ assert type(row) == Row
+ assert type(row.ones) == Row
+ assert type(row.twos) == Row
+
+ row = db(Two.one == One.id).select().first()
+ assert type(row) == Row
+ assert type(row.ones) == One._instance_()._rowclass_
+ assert type(row.twos) == Two._instance_()._rowclass_
+
+ row = db(Two.one == One.id).select(One.table.ALL, Two.foo).first()
+ assert type(row) == Row
+ assert type(row.ones) == One._instance_()._rowclass_
+ assert type(row.twos) == Row
+
+ row = db(Two.one == One.id).select(One.foo, Two.foo).first()
+ assert type(row) == Row
+ assert type(row.ones) == Row
+ assert type(row.twos) == Row
+
+ for row in db(Two.one == One.id).iterselect():
+ assert type(row) == Row
+ assert type(row.ones) == One._instance_()._rowclass_
+ assert type(row.twos) == Two._instance_()._rowclass_
+
+ for row in db(Two.one == One.id).iterselect(One.table.ALL, Two.foo):
+ assert type(row) == Row
+ assert type(row.ones) == One._instance_()._rowclass_
+ assert type(row.twos) == Row
+
+ for row in db(Two.one == One.id).iterselect(One.foo, Two.foo):
+ assert type(row) == Row
+ assert type(row.ones) == Row
+ assert type(row.twos) == Row
+
+
+def test_concrete(db):
+ row = One.new(foo="test")
+ assert not row._concrete
+
+ row.save()
+ assert row._concrete
+
+ row = One.get(row.id)
+ assert row._concrete
+
+
+def test_changes(db):
+ row = One.new(foo="test1")
+ assert not row.has_changed
+
+ row.bar = "test2"
+ assert row.has_changed
+ assert row.has_changed_value("bar")
+ assert row.get_value_change("bar") == (None, "test2")
+
+ row.bar = "test2a"
+ assert row.has_changed
+ assert row.has_changed_value("bar")
+ assert row.get_value_change("bar") == (None, "test2a")
+
+ row.bar = None
+ assert not row.has_changed
+ assert not row.has_changed_value("bar")
+ assert row.get_value_change("bar") is None
+
+ row.update(bar="test2b")
+ assert row.has_changed
+ assert row.has_changed_value("bar")
+ assert row.get_value_change("bar") == (None, "test2b")
+
+ row.update({"bar": "test2c"})
+ assert row.has_changed
+ assert row.has_changed_value("bar")
+ assert row.get_value_change("bar") == (None, "test2c")
+
+ row.update(bar=None)
+ assert not row.has_changed
+ assert not row.has_changed_value("bar")
+ assert row.get_value_change("bar") is None
+
+ row.bar = "test2"
+ row.save()
+ assert not row.has_changed
+
+
+def test_validation_methods(db):
+ row = One.new()
+ assert not row.is_valid
+ assert set(row.validation_errors.keys()).issubset({"foo"})
+ assert not row.save()
+ with pytest.raises(ValidationError):
+ row.save(raise_on_error=True)
+
+ row.foo = "test"
+ assert row.is_valid
+ assert not row.validation_errors
+ assert row.save()
+
+
+def test_clone_methods(db):
+ row = One.new(foo="test1")
+ row.bar = "test2"
+
+ row2 = row.clone()
+ row3 = row.clone_changed()
+
+ assert not row2._concrete
+ assert not row2.has_changed
+ assert not row2.bar
+ assert not row3._concrete
+ assert not row3.has_changed
+ assert row3.bar == "test2"
+
+ row.save()
+ row.foo = "test1a"
+ row2 = row.clone()
+ row3 = row.clone_changed()
+
+ assert row2._concrete
+ assert not row2.has_changed
+ assert row2.foo == "test1"
+ assert row3._concrete
+ assert not row3.has_changed
+ assert row3.foo == "test1a"
+
+
+def test_refresh(db):
+ row = One.new(foo="test1")
+ assert not row.refresh()
+
+ row.save()
+ assert row.refresh()
+
+ row.foo = "test2"
+ assert row.refresh()
+ assert row.foo == "test1"
+
+
+def test_methods_override(db):
+ row = Override.new(foo="test")
+ row.save()
+ assert row.id
+ assert not row.deleted_at
+
+ row.destroy()
+ assert row.id
+ assert row.deleted_at
+
+ row.force_destroy()
+ assert not row.id
diff --git a/tests/test_session.py b/tests/test_session.py
index 2616a49b..5a4766aa 100644
--- a/tests/test_session.py
+++ b/tests/test_session.py
@@ -31,13 +31,15 @@ def ctx():
current._close_(token)
+@pytest.mark.parametrize("encryption_mode", ["legacy", "modern"])
@pytest.mark.asyncio
-async def test_session_cookie(ctx):
+async def test_session_cookie(ctx, encryption_mode):
session_cookie = SessionManager.cookies(
key='sid',
secure=True,
domain='localhost',
- cookie_name='foo_session'
+ cookie_name='foo_session',
+ encryption_mode=encryption_mode
)
assert session_cookie.key == 'sid'
assert session_cookie.secure is True