Skip to content

Commit

Permalink
Merge pull request #91 from M1ha-Shvn/set-functions-supported-fields-…
Browse files Browse the repository at this point in the history
…as-classes

Ability to add `supported_field_classes` attribute as python path or class object to set function
  • Loading branch information
M1ha-Shvn authored May 4, 2022
2 parents 84b7c47 + b8f4f95 commit 3bc6dff
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 15 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,9 @@ which contains alias in `names` attribute.
### Custom set function
You can define your own set function, creating `AbstractSetFunction` subclass and implementing:
* `names` attribute
* `supported_field_classes` attribute
* `supported_field_classes` attribute
It can contain class name, class or full python import path of a class.
If a class or path is given, any child class would be accepted too.
* One of:
- `def get_sql_value(self, field, val, connection, val_as_param=True, with_table=False, for_update=True, **kwargs)` method
This method defines new value to set for parameter. It is called from `get_sql(...)` method by default.
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

setup(
name='django-pg-bulk-update',
version='3.6.0',
version='3.7.0',
packages=['django_pg_bulk_update'],
package_dir={'': 'src'},
url='https://github.com/M1hacka/django-pg-bulk-update',
Expand Down
64 changes: 51 additions & 13 deletions src/django_pg_bulk_update/set_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from .compatibility import get_postgres_version, jsonb_available, Postgres94MergeJSONBMigration, hstore_serialize, \
hstore_available, import_pg_field_or_dummy, tz_utc, django_expressions_available
from .types import TDatabase, AbstractFieldFormatter
from .utils import get_subclasses, format_field_value
from .utils import get_subclasses, format_field_value, lazy_import

# When doing increment operations, we need to replace NULL values with something
# This dictionary contains field defaults by it's class name.
Expand Down Expand Up @@ -122,7 +122,10 @@ class AbstractSetFunction(AbstractFieldFormatter):
names = set()

# If set function supports any field class, this should be None.
# Otherwise a set of class names supported
# Otherwise, it must be a set, containing:
# - Field classes (any child classes are accepted too)
# - Field class python import paths (any child classes are accepted too)
# - Filed class names
supported_field_classes = None

# If set functions doesn't need value from input, set this to False.
Expand Down Expand Up @@ -203,7 +206,10 @@ def field_is_supported(self, field): # type: (Field) -> bool
if self.supported_field_classes is None:
return True
else:
return field.__class__.__name__ in self.supported_field_classes
return any(
field.__class__.__name__ == field_cls or lazy_import(field_cls) == field.__class__
for field_cls in self.supported_field_classes
)

def _parse_null_default(self, field, connection, **kwargs):
"""
Expand Down Expand Up @@ -474,10 +480,22 @@ def get_sql_value(self, field, val, connection, val_as_param=True, with_table=Fa
class PlusSetFunction(AbstractSetFunction):
names = {'+', 'incr'}

supported_field_classes = {'IntegerField', 'FloatField', 'AutoField', 'BigAutoField', 'BigIntegerField',
'SmallIntegerField', 'PositiveIntegerField', 'PositiveSmallIntegerField', 'DecimalField',
'IntegerRangeField', 'BigIntegerRangeField', 'FloatRangeField', 'DateTimeRangeField',
'DateRangeField'}
supported_field_classes = {
'django.db.models.IntegerField',
'django.db.models.FloatField',
'django.db.models.AutoField',
'django.db.models.BigAutoField',
'django.db.models.BigIntegerField',
'django.db.models.SmallIntegerField',
'django.db.models.PositiveIntegerField',
'django.db.models.PositiveSmallIntegerField',
'django.db.models.DecimalField',
'django.db.models.IntegerRangeField',
'django.db.models.BigIntegerRangeField',
'django.db.models.FloatRangeField',
'django.db.models.DateTimeRangeField',
'django.db.models.DateRangeField'
}

def get_sql_value(self, field, val, connection, val_as_param=True, with_table=False, for_update=True, **kwargs):
null_default, null_default_params = self._parse_null_default(field, connection, **kwargs)
Expand All @@ -498,9 +516,22 @@ def get_sql_value(self, field, val, connection, val_as_param=True, with_table=Fa
class ConcatSetFunction(AbstractSetFunction):
names = {'||', 'concat'}

supported_field_classes = {'CharField', 'TextField', 'EmailField', 'FilePathField', 'SlugField', 'HStoreField',
'URLField', 'BinaryField', 'JSONField', 'ArrayField', 'CITextField', 'CICharField',
'CIEmailField'}
supported_field_classes = {
'django.db.models.CharField',
'django.db.models.TextField',
'django.db.models.EmailField',
'django.db.models.FilePathField',
'django.db.models.SlugField',
'django.contrib.postgres.fields.HStoreField',
'django.db.models.URLField',
'django.db.models.BinaryField',
'django.db.models.JSONField',
'django.contrib.postgres.fields.JSONField',
'django.contrib.postgres.fields.ArrayField',
'django.db.models.CITextField',
'django.db.models.CICharField',
'django.db.models.CIEmailField'
}

def get_sql_value(self, field, val, connection, val_as_param=True, with_table=False, for_update=True, **kwargs):
null_default, null_default_params = self._parse_null_default(field, connection, **kwargs)
Expand Down Expand Up @@ -531,7 +562,9 @@ def get_sql_value(self, field, val, connection, val_as_param=True, with_table=Fa
class UnionSetFunction(AbstractSetFunction):
names = {'union'}

supported_field_classes = {'ArrayField'}
supported_field_classes = {
'django.contrib.postgres.fields.ArrayField'
}

def get_sql_value(self, field, val, connection, val_as_param=True, with_table=False, for_update=True, **kwargs):
if for_update:
Expand All @@ -548,7 +581,9 @@ def get_sql_value(self, field, val, connection, val_as_param=True, with_table=Fa
class ArrayRemoveSetFunction(AbstractSetFunction):
names = {'array_remove'}

supported_field_classes = {'ArrayField'}
supported_field_classes = {
'django.contrib.postgres.fields.ArrayField'
}

def format_field_value(self, field, val, connection, cast_type=False, **kwargs):
# Support for django 1.8
Expand Down Expand Up @@ -581,7 +616,10 @@ def get_sql_value(self, field, val, connection, val_as_param=True, with_table=Fa

class NowSetFunction(AbstractSetFunction):
names = {'now', 'NOW'}
supported_field_classes = {'DateField', 'DateTimeField'}
supported_field_classes = {
'django.db.models.DateField',
'django.db.models.DateTimeField'
}
needs_value = False

def __init__(self, if_null=False): # type: (bool) -> None
Expand Down
19 changes: 19 additions & 0 deletions src/django_pg_bulk_update/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from django.core.exceptions import FieldError
from django.db.models import Field
from django.db.models.sql.subqueries import UpdateQuery
from importlib import import_module

from .compatibility import hstore_serialize, hstore_available, get_field_db_type, import_pg_field_or_dummy
from .types import TDatabase
Expand Down Expand Up @@ -185,3 +186,21 @@ def is_auto_set_field(field): # type: (Field) -> bool
:return: Boolean
"""
return getattr(field, 'auto_now', False) or getattr(field, 'auto_now_add', False)


def lazy_import(class_obj): # (str) -> Optional[Any]
"""
Import field class by dot separated import path
:param class_obj: Already imported class or dot separated import path
:return: Class or None if class has not been imported
"""
if not isinstance(class_obj, str):
return class_obj

module_path, cls = class_obj.rsplit('.', 1)
try:
module = import_module(module_path)
except ImportError:
return None

return getattr(module, cls, None)
22 changes: 22 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from django.test import SimpleTestCase

from django_pg_bulk_update.types import FieldDescriptor
from django_pg_bulk_update.utils import lazy_import


class ImportFieldClassTest(SimpleTestCase):
def test_class(self):
cls = lazy_import(FieldDescriptor)
self.assertEqual(FieldDescriptor, cls)

def test_invalid_module(self):
cls = lazy_import('django_pg_bulk_update.invalid_module.InvalidClass')
self.assertIsNone(cls)

def test_invalid_class(self):
cls = lazy_import('django_pg_bulk_update.types.InvalidClass')
self.assertIsNone(cls)

def test_valid(self):
cls = lazy_import('django_pg_bulk_update.types.FieldDescriptor')
self.assertEqual(FieldDescriptor, cls)

0 comments on commit 3bc6dff

Please sign in to comment.