Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: async default value #1498

Merged
merged 3 commits into from
Apr 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Added
- Enhancement for FastAPI lifespan support (#1371)
- Add __eq__ method to Q to more easily test dynamically-built queries (#1506)
- Added PlainToTsQuery function for postgres (#1347)
- Allow field's default keyword to be async function (#1498)

Fixed
^^^^^
Expand Down
21 changes: 21 additions & 0 deletions tests/test_callable_default.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from tests import testmodels
from tortoise.contrib import test


class TestCallableDefault(test.TestCase):
async def test_default_create(self):
model = await testmodels.CallableDefault.create()
self.assertEqual(model.callable_default, "callable_default")
self.assertEqual(model.async_default, "async_callable_default")

async def test_default_by_save(self):
saved_model = testmodels.CallableDefault()
await saved_model.save()
self.assertEqual(saved_model.callable_default, "callable_default")
self.assertEqual(saved_model.async_default, "async_callable_default")

async def test_async_default_change(self):
default_change = testmodels.CallableDefault()
default_change.async_default = "changed"
await default_change.save()
self.assertEqual(default_change.async_default, "changed")
14 changes: 14 additions & 0 deletions tests/testmodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -892,3 +892,17 @@ class PydanticMeta:
alias_generator=camelize_var,
populate_by_name=True,
)


def callable_default() -> str:
return "callable_default"


async def async_callable_default() -> str:
return "async_callable_default"


class CallableDefault(Model):
id = fields.IntField(pk=True)
callable_default = fields.CharField(max_length=32, default=callable_default)
async_default = fields.CharField(max_length=32, default=async_callable_default)
23 changes: 21 additions & 2 deletions tortoise/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,15 +667,25 @@ def __init__(self, **kwargs: Any) -> None:
self._partial = False
self._saved_in_db = False
self._custom_generated_pk = False
self._await_when_save: Dict[str, Callable[[], Awaitable[Any]]] = {}

# Assign defaults for missing fields
for key in meta.fields.difference(self._set_kwargs(kwargs)):
field_object = meta.fields_map[key]
if callable(field_object.default):
setattr(self, key, field_object.default())
field_default = field_object.default
if inspect.iscoroutinefunction(field_default):
self._await_when_save[key] = field_default
elif callable(field_default):
setattr(self, key, field_default())
else:
setattr(self, key, deepcopy(field_object.default))

def __setattr__(self, key, value):
# set field value override async default function
if hasattr(self, "_await_when_save"):
self._await_when_save.pop(key, None)
super().__setattr__(key, value)

def _set_kwargs(self, kwargs: dict) -> Set[str]:
meta = self._meta

Expand Down Expand Up @@ -719,6 +729,7 @@ def _init_from_db(cls: Type[MODEL], **kwargs: Any) -> MODEL:
self._partial = False
self._saved_in_db = True
self._custom_generated_pk = self._meta.db_pk_column not in self._meta.generated_db_fields
self._await_when_save = {}

meta = self._meta

Expand Down Expand Up @@ -845,6 +856,13 @@ def register_listener(cls, signal: Signals, listener: Callable):
if listener not in cls_listeners:
cls_listeners.append(listener)

async def _set_async_default_field(self) -> None:
"""retrieve value from field's async default value"""
if hasattr(self, "_await_when_save"):
for k, v in self._await_when_save.copy().items():
setattr(self, k, await v())
self._await_when_save = {}

async def _pre_delete(
self,
using_db: Optional[BaseDBAsyncClient] = None,
Expand Down Expand Up @@ -921,6 +939,7 @@ async def save(
:raises IncompleteInstanceError: If the model is partial and the fields are not available for persistence.
:raises IntegrityError: If the model can't be created or updated (specifically if force_create or force_update has been set)
"""
await self._set_async_default_field()
db = using_db or self._choose_db(True)
executor = db.executor_class(model=self.__class__, db=db)
if self._partial:
Expand Down
Loading