Skip to content

Commit

Permalink
test fixes and CR changes for #43 (#46)
Browse files Browse the repository at this point in the history
* test fixes and CR changes

* fix tests
  • Loading branch information
JBKahn authored Sep 29, 2016
1 parent 39076e9 commit 54edb29
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 20 deletions.
20 changes: 12 additions & 8 deletions django_sharding_library/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ class ShardedRouter(object):
read or write from.
"""

def get_shard_for_instance(self, instance):
return instance._state.db or instance.get_shard()

def get_shard_for_id_field(self, model, sharded_by_field_id):
try:
return model.get_shard_from_id(sharded_by_field_id)
Expand All @@ -31,9 +34,10 @@ def get_shard_for_postgres_pk_field(self, model, pk_value):
shard_id_to_find = int(bin(pk_value)[-23:-10], 2) # We know where the shard id is stored in the PK's bits.

# We can check the shard id from the PK against the shard ID in the databases config
for alias in settings.DATABASES.keys():
if settings.DATABASES[alias]["SHARD_GROUP"] == group and settings.DATABASES[alias]["SHARD_ID"] == shard_id_to_find:
for alias, db_settings in settings.DATABASES.items():
if db_settings["SHARD_GROUP"] == group and db_settings["SHARD_ID"] == shard_id_to_find:
return alias

return None # Return None if we could not determine the shard so we can fall through to the next shard grab attempt

def get_read_db_routing_strategy(self, shard_group):
Expand Down Expand Up @@ -64,12 +68,12 @@ def _get_shard(self, model, **hints):
)
if sharded_by_field_id:
shard = self.get_shard_for_id_field(model, sharded_by_field_id)
if shard is None and isinstance(getattr(model._meta, 'pk'), PostgresShardGeneratedIDField) and \
(hints.get('exact_lookups', {}).get('pk') is not None or hints.get('exact_lookups', {}).get('id') is not None):
return self.get_shard_for_postgres_pk_field(
model,
hints.get('exact_lookups', {}).get('pk') or hints.get('exact_lookups', {}).get('id')
)

is_pk_postgres_generated_id_field = isinstance(getattr(model._meta, 'pk'), PostgresShardGeneratedIDField)
lookup_pk = hints.get('exact_lookups', {}).get('pk') or hints.get('exact_lookups', {}).get('id')

if shard is None and is_pk_postgres_generated_id_field and lookup_pk is not None:
return self.get_shard_for_postgres_pk_field(model, lookup_pk)

return shard

Expand Down
6 changes: 3 additions & 3 deletions runtests.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,13 @@
'name': 'app_shard_003',
'shard_group': 'postgres',
'environment_variable': 'SHARD_003_DATABASE_URL',
'default_database_url': 'postgres://postgres:@localhost/shard_003' if TRAVISCI else 'sqlite://testing128'
'default_database_url': 'sqlite://testing125'
},
{
'name': 'app_shard_004',
'shard_group': 'postgres',
'environment_variable': 'SHARD_004_DATABASE_URL',
'default_database_url': 'postgres://postgres:@localhost/shard_004' if TRAVISCI else 'sqlite://testing129'
'default_database_url': 'sqlite://testing125'
},
]
})
Expand Down Expand Up @@ -94,7 +94,7 @@ def run_tests(*test_args):
test_runner = TestRunner()

location = (os.environ.get('TRAVIS') and "postgres and mysql") or "sqlite"
print("I am running tests on {}".format(location))
print("I am running tests on {}".format(location)) # noqa

failures = test_runner.run_tests(test_args, interactive=False)

Expand Down
7 changes: 7 additions & 0 deletions tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,17 @@ def get_shard_key(self):
return self.test.user_pk


@model_config(database='app_shard_001')
class PostgresCustomIDModelBackupField(TableStrategyModel):
pass


@model_config(shard_group="postgres", sharded_by_field="user_pk")
class PostgresCustomIDModel(models.Model):
if settings.DATABASES['default']['ENGINE'] in Backends.POSTGRES:
id = PostgresShardGeneratedIDField(primary_key=True)
else:
id = TableShardedIDField(primary_key=True, source_table=PostgresCustomIDModelBackupField)
random_string = models.CharField(max_length=120)
user_pk = models.PositiveIntegerField()

Expand Down
7 changes: 2 additions & 5 deletions tests/test_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,10 +180,6 @@ def test_pre_save_calls_save_shard(self):

class PostgresShardIdFieldTestCase(TestCase):

def setUp(self):
from django.contrib.auth import get_user_model
self.user = PostgresShardUser.objects.create_user(username='username', password='pwassword', email='[email protected]')

@unittest.skipIf(settings.DATABASES['default']['ENGINE'] not in Backends.POSTGRES, "Not a postgres backend")
def test_check_shard_id_function(self):
cursor = connections['default'].cursor()
Expand All @@ -201,7 +197,8 @@ def test_check_shard_id_function(self):

@unittest.skipIf(settings.DATABASES['default']['ENGINE'] not in Backends.POSTGRES, "Not a postgres backend")
def test_check_shard_id_returns_with_model_save(self):
created_model = PostgresCustomIDModel.objects.create(random_string='Test String', user_pk=self.user.id)
user = PostgresShardUser.objects.create_user(username='username', password='pwassword', email='[email protected]')
created_model = PostgresCustomIDModel.objects.create(random_string='Test String', user_pk=user.id)
self.assertTrue(getattr(created_model, 'id'))

# Same as above, lets create an id that would have been made 10 seconds ago and make sure the one that was
Expand Down
4 changes: 2 additions & 2 deletions tests/test_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def test_router_gets_hints_correctly_with_positional_arguments_like_Q_in_filter(

with patch.object(ShardedRouter, 'db_for_write', wraps=self.sut.db_for_write) as write_route_function:
with patch.object(ShardedRouter, 'db_for_read', wraps=self.sut.db_for_read) as read_route_function:
list(TestModel.objects.filter(Q(test_string="test") | Q(test_string__isnull=True), user_pk=self.user.pk))
list(TestModel.objects.filter(Q(random_string="test") | Q(random_string__isnull=True), user_pk=self.user.pk))

self.assertEqual(
[call(TestModel, **lookups_to_find), call(get_user_model())],
Expand All @@ -267,7 +267,7 @@ def test_router_gets_hints_correctly_with_positional_arguments_like_Q_in_get(sel

with patch.object(ShardedRouter, 'db_for_write', wraps=self.sut.db_for_write) as write_route_function:
with patch.object(ShardedRouter, 'db_for_read', wraps=self.sut.db_for_read) as read_route_function:
list(TestModel.objects.get(Q(test_string="test") | Q(test_string__isnull=True), user_pk=self.user.pk))
TestModel.objects.get(Q(random_string="test") | Q(random_string__isnull=True), user_pk=self.user.pk)

self.assertEqual(
[call(TestModel, **lookups_to_find), call(get_user_model())],
Expand Down
10 changes: 8 additions & 2 deletions tests/test_travis.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,21 @@ class TravisTestCase(TransactionTestCase):
def test_travis_uses_non_sqlite_databases(self):
TRAVISCI = os.environ.get('TRAVIS')

self.assertEqual(len(settings.DATABASES), 7)

if TRAVISCI:
self.assertIn(settings.DATABASES['default']['ENGINE'], Backends.POSTGRES)
self.assertIn(settings.DATABASES['app_shard_001']['ENGINE'], Backends.MYSQL)
self.assertIn(settings.DATABASES['app_shard_002']['ENGINE'], Backends.POSTGRES)
self.assertIn(settings.DATABASES['app_shard_001_replica_001']['ENGINE'], Backends.POSTGRES)
self.assertIn(settings.DATABASES['app_shard_001_replica_002']['ENGINE'], Backends.POSTGRES)
self.assertIn(settings.DATABASES['app_shard_002']['ENGINE'], Backends.POSTGRES)
self.assertIn(settings.DATABASES['app_shard_003']['ENGINE'], Backends.POSTGRES)
self.assertIn(settings.DATABASES['app_shard_004']['ENGINE'], Backends.POSTGRES)
else:
self.assertIn(settings.DATABASES['default']['ENGINE'], Backends.SQLITE)
self.assertIn(settings.DATABASES['app_shard_001']['ENGINE'], Backends.SQLITE)
self.assertIn(settings.DATABASES['app_shard_002']['ENGINE'], Backends.SQLITE)
self.assertIn(settings.DATABASES['app_shard_001_replica_001']['ENGINE'], Backends.SQLITE)
self.assertIn(settings.DATABASES['app_shard_001_replica_002']['ENGINE'], Backends.SQLITE)
self.assertIn(settings.DATABASES['app_shard_002']['ENGINE'], Backends.SQLITE)
self.assertIn(settings.DATABASES['app_shard_003']['ENGINE'], Backends.SQLITE)
self.assertIn(settings.DATABASES['app_shard_004']['ENGINE'], Backends.SQLITE)

0 comments on commit 54edb29

Please sign in to comment.