Skip to content

Commit

Permalink
Added tests on issue-42 (#45)
Browse files Browse the repository at this point in the history
Fixed issue #42
  • Loading branch information
M1ha-Shvn authored Oct 19, 2022
1 parent b2cb098 commit fc362e8
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/django_clickhouse/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def _update_returning_param(self, returning):
if returning is None:
returning = pk_name
elif isinstance(returning, str):
returning = [pk_name, returning]
returning = [pk_name, returning] if returning != '*' else '*'
else:
returning = list(returning) + [pk_name]

Expand Down
50 changes: 50 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,22 @@ def test_pg_bulk_create(self):
self.assertSetEqual({('insert', "%s.%d" % (self.db_alias, instance.pk)) for instance in items},
set(self.storage.get_operations(self.clickhouse_model.get_import_key(), 10)))

def test_pg_bulk_create_returning(self):
now_dt = now()
res = self.django_model.objects.pg_bulk_create([
{'value': i, 'created': now_dt, 'created_date': now_dt.date()}
for i in range(5)
], returning='*')

self.assertEqual(5, len(res))
for i, instance in enumerate(res):
self.assertEqual(instance.created, now_dt)
self.assertEqual(instance.created_date, now_dt.date())
self.assertEqual(i, instance.value)

self.assertSetEqual({('insert', "%s.%d" % (self.db_alias, instance.pk)) for instance in res},
set(self.storage.get_operations(self.clickhouse_model.get_import_key(), 10)))

def test_pg_bulk_update(self):
items = list(self.django_model.objects.filter(pk__in={1, 2}))

Expand All @@ -115,6 +131,21 @@ def test_pg_bulk_update(self):
self.assertSetEqual({('update', "%s.%d" % (self.db_alias, instance.pk)) for instance in items},
set(self.storage.get_operations(self.clickhouse_model.get_import_key(), 10)))

def test_pg_bulk_update_returning(self):
items = list(self.django_model.objects.filter(pk__in={1, 2}))

res = self.django_model.objects.pg_bulk_update([
{'id': instance.pk, 'value': instance.pk * 10}
for instance in items
], returning='*')

self.assertEqual(2, len(res))
for instance in res:
self.assertEqual(instance.value, instance.pk * 10)

self.assertSetEqual({('update', "%s.%d" % (self.db_alias, instance.pk)) for instance in items},
set(self.storage.get_operations(self.clickhouse_model.get_import_key(), 10)))

def test_pg_bulk_update_or_create(self):
items = list(self.django_model.objects.filter(pk__in={1, 2}))

Expand All @@ -135,6 +166,25 @@ def test_pg_bulk_update_or_create(self):
self.assertSetEqual({('update', "%s.%d" % (self.db_alias, instance.pk)) for instance in items},
set(self.storage.get_operations(self.clickhouse_model.get_import_key(), 10)))

def test_pg_bulk_update_or_create_returning(self):
items = list(self.django_model.objects.filter(pk__in={1, 2}))

data = [{
'id': instance.pk,
'value': instance.pk * 10,
'created_date': instance.created_date,
'created': instance.created
} for instance in items] + [{'id': 11, 'value': 110, 'created_date': datetime.date.today(), 'created': now()}]

res = self.django_model.objects.pg_bulk_update_or_create(data, returning='*')

self.assertEqual(3, len(res))
for instance in res:
self.assertEqual(instance.value, instance.pk * 10)

self.assertSetEqual({('update', "%s.%d" % (self.db_alias, instance.pk)) for instance in res},
set(self.storage.get_operations(self.clickhouse_model.get_import_key(), 10)))

def test_get_or_create(self):
instance, created = self.django_model.objects. \
get_or_create(pk=100, defaults={'created_date': datetime.date.today(), 'created': datetime.datetime.now(),
Expand Down

0 comments on commit fc362e8

Please sign in to comment.