diff --git a/docs/release-notes.md b/docs/release-notes.md index 08df9ae9..fdcc6eff 100644 --- a/docs/release-notes.md +++ b/docs/release-notes.md @@ -15,6 +15,7 @@ hide: ### Fixed - `hash_names` with unique_together incorrectly fed the `uc` prefix into the hasher. This makes the names unidentificatable. +- `BooleanField` logic for automic migrations works by using server_default. ## 0.28.0 diff --git a/edgy/core/db/fields/core.py b/edgy/core/db/fields/core.py index 9884c80d..ee065fa7 100644 --- a/edgy/core/db/fields/core.py +++ b/edgy/core/db/fields/core.py @@ -260,19 +260,25 @@ class BooleanField(FieldFactory, bool_type): def __new__( # type: ignore cls, *, - default: Optional[bool] = False, + default: Union[None, bool, Callable[[], bool]] = False, **kwargs: Any, ) -> BaseFieldType: - kwargs = { - **kwargs, - **{key: value for key, value in locals().items() if key not in CLASS_DEFAULTS}, - } + if default is not None: + kwargs["default"] = default return super().__new__(cls, **kwargs) @classmethod def get_column_type(cls, kwargs: dict[str, Any]) -> Any: return sqlalchemy.Boolean() + @classmethod + def validate(cls, kwargs: dict[str, Any]) -> None: + super().validate(kwargs) + + default = kwargs.get("default") + if default is not None and isinstance(default, bool): + kwargs.setdefault("server_default", sqlalchemy.text("true" if default else "false")) + class DateTimeField(_AutoNowMixin, datetime.datetime): """Representation of a datetime field""" diff --git a/tests/cli/main.py b/tests/cli/main.py index 4e0b7e0c..cb0990ee 100644 --- a/tests/cli/main.py +++ b/tests/cli/main.py @@ -36,6 +36,8 @@ class User(edgy.StrictModel): # simple default active = edgy.fields.BooleanField(server_default=sqlalchemy.text("true"), default=False) profile = edgy.fields.ForeignKey("Profile", null=False, default=complex_default) + # auto server defaults + is_staff = edgy.fields.BooleanField() class Meta: registry = models diff --git a/tests/cli/test_nullable_fields.py b/tests/cli/test_nullable_fields.py index 4718dc3a..45f1406d 100644 --- a/tests/cli/test_nullable_fields.py +++ b/tests/cli/test_nullable_fields.py @@ -201,6 +201,7 @@ async def main(): async with main.models: user = await main.User.query.get(name="edgy") assert user.active + assert not user.is_staff assert user.content_type.name == "User" assert user.profile == await main.Profile.query.get(name="edgy") assert user.profile.content_type.name == "Profile" @@ -211,6 +212,7 @@ async def main(): async with main.models: user = await main.User.query.get(name="edgy") assert user.active + assert not user.is_staff assert user.content_type.name == "User" assert user.profile == await main.Profile.query.get(name="edgy")