Skip to content

Commit

Permalink
fix: Fix typing issues
Browse files Browse the repository at this point in the history
  • Loading branch information
bellini666 committed Jul 18, 2024
1 parent 82f92d6 commit d0fcf30
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 7 deletions.
6 changes: 4 additions & 2 deletions strawberry_django/integrations/guardian.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import contextlib
import dataclasses
import weakref
from typing import Optional, Union, cast
from typing import Optional, Type, Union, cast

from django.contrib.auth import get_user_model
from django.contrib.auth.models import Group
Expand All @@ -24,7 +24,9 @@ class ObjectPermissionModels:
group: GroupObjectPermissionBase


def get_object_permission_models(model: models.Model):
def get_object_permission_models(
model: Union[models.Model, Type[models.Model]],
) -> ObjectPermissionModels:
return ObjectPermissionModels(
user=cast(UserObjectPermissionBase, get_user_obj_perms_model(model)),
group=cast(GroupObjectPermissionBase, get_group_obj_perms_model(model)),
Expand Down
23 changes: 18 additions & 5 deletions strawberry_django/mutations/resolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
Callable,
Iterable,
List,
Type,
TypeVar,
cast,
overload,
Expand Down Expand Up @@ -262,7 +263,10 @@ def prepare_create_update(
(ParsedObject, str),
):
value, value_data = _parse_data( # noqa: PLW2901
info, field.related_model, value, key_attr=key_attr
info,
cast(Type[Model], field.related_model),
value,
key_attr=key_attr,
)
if value is None and not value_data:
value = None # noqa: PLW2901
Expand Down Expand Up @@ -508,7 +512,7 @@ def update_field(info: Info, instance: Model, field: models.Field, value: Any):
and isinstance(field, models.ForeignObject)
and not isinstance(value, Model)
):
value, data = _parse_pk(value, field.related_model)
value, data = _parse_pk(value, cast(Type[Model], field.related_model))

field.save_form_data(instance, value)
# If data was passed to the foreign key, update it recursively
Expand Down Expand Up @@ -574,7 +578,9 @@ def update_m2m(
existing = set(manager.all())
need_remove_cache = need_remove_cache or bool(values)
for v in values:
obj, data = _parse_data(info, manager.model, v, key_attr=key_attr)
obj, data = _parse_data(
info, cast(Type[Model], manager.model), v, key_attr=key_attr
)
if obj:
data.pop(key_attr, None)
through_defaults = data.pop("through_defaults", {})
Expand Down Expand Up @@ -632,7 +638,12 @@ def update_m2m(
else:
need_remove_cache = need_remove_cache or bool(value.add)
for v in value.add or []:
obj, data = _parse_data(info, manager.model, v, key_attr=key_attr)
obj, data = _parse_data(
info,
cast(Type[Model], manager.model),
v,
key_attr=key_attr,
)
if obj and data:
data.pop(key_attr, None)
if full_clean:
Expand All @@ -653,7 +664,9 @@ def update_m2m(

need_remove_cache = need_remove_cache or bool(value.remove)
for v in value.remove or []:
obj, data = _parse_data(info, manager.model, v, key_attr=key_attr)
obj, data = _parse_data(
info, cast(Type[Model], manager.model), v, key_attr=key_attr
)
data.pop(key_attr, None)
assert not data
to_remove.append(obj)
Expand Down

0 comments on commit d0fcf30

Please sign in to comment.