diff --git a/edx_exams/apps/core/management/commands/bulk_add_course_staff.py b/edx_exams/apps/core/management/commands/bulk_add_course_staff.py index 8edc1778..f6960832 100644 --- a/edx_exams/apps/core/management/commands/bulk_add_course_staff.py +++ b/edx_exams/apps/core/management/commands/bulk_add_course_staff.py @@ -4,7 +4,7 @@ import time from django.core.management.base import BaseCommand -from django.db import transaction +from django.db import IntegrityError, transaction from edx_exams.apps.core.models import CourseStaffRole, User @@ -63,34 +63,31 @@ def add_course_staff_from_csv(self, csv_file, batch_size, batch_delay): Add the given set of course staff provided in csv """ reader = list(csv.DictReader(csv_file)) + users = {} - # bulk create users for i in range(0, len(reader), batch_size): - User.objects.bulk_create( - (User( - username=row.get('username'), - email=row.get('email'), - ) for row in reader[i:i + batch_size]), - ignore_conflicts=True, - ) - CourseStaffRole.objects.bulk_create( - (CourseStaffRole( - user=User.objects.get(username=row.get('username')), - course_id=row.get('course_id'), - role=row.get('role'), - ) for row in reader[i:i + batch_size]), - ignore_conflicts=True, - ) + users_list = [] + for row in reader[i:i + batch_size]: + username = row.get('username') + email = row.get('email') + try: + users_list.append(User.objects.get_or_create(username=row.get('username'), email=row.get('email'))) + except IntegrityError: + logger.warning( + f'User with username={username} and email={email} was not created due to an existing duplicate ' + f'user with username.' + ) + continue + users_dict = {(u.username, u) for (u, c) in users_list} + users.update(users_dict) time.sleep(batch_delay) - # bulk create course staff - # for i in range(0, len(reader), batch_size): - # CourseStaffRole.objects.bulk_create( - # (CourseStaffRole( - # user=User.objects.get(username=row.get('username')), - # course_id=row.get('course_id'), - # role=row.get('role'), - # ) for row in reader[i:i + batch_size]), - # ignore_conflicts=True, - # ) - # time.sleep(batch_delay) + CourseStaffRole.objects.bulk_create( + (CourseStaffRole( + user=users.get(row.get('username')), + course_id=row.get('course_id'), + role=row.get('role'), + ) for row in reader), + ignore_conflicts=True, + batch_size=batch_size, + ) diff --git a/edx_exams/apps/core/management/commands/test/test_bulk_add_course_staff.py b/edx_exams/apps/core/management/commands/test/test_bulk_add_course_staff.py index 023974df..edf3b8b5 100644 --- a/edx_exams/apps/core/management/commands/test/test_bulk_add_course_staff.py +++ b/edx_exams/apps/core/management/commands/test/test_bulk_add_course_staff.py @@ -90,7 +90,7 @@ def test_add_course_staff_with_not_default_batch_size(self): 'sam,sam@pond.com,staff,course-v1:edx+test+f20\n'] with NamedTemporaryFile() as csv: csv = self._write_test_csv(csv, lines) - with self.assertNumQueries(8): + with self.assertNumQueries(12): call_command(self.command, f'--csv_path={csv.name}', '--batch_size=1') def test_add_course_staff_with_batch_size_larger_than_list(self): @@ -99,7 +99,7 @@ def test_add_course_staff_with_batch_size_larger_than_list(self): 'sam,sam@pond.com,staff,course-v1:edx+test+f20\n'] with NamedTemporaryFile() as csv: csv = self._write_test_csv(csv, lines) - with self.assertNumQueries(6): + with self.assertNumQueries(11): call_command(self.command, f'--csv_path={csv.name}', '--batch_size=3') def test_add_course_staff_with_batch_size_smaller_than_list(self): @@ -109,7 +109,7 @@ def test_add_course_staff_with_batch_size_smaller_than_list(self): 'tam,tam@pond.com,staff,course-v1:edx+test+f20\n'] with NamedTemporaryFile() as csv: csv = self._write_test_csv(csv, lines) - with self.assertNumQueries(9): + with self.assertNumQueries(16): call_command(self.command, f'--csv_path={csv.name}', '--batch_size=2') def test_add_course_staff_with_not_default_batch_delay(self): @@ -125,16 +125,16 @@ def test_add_course_staff_with_not_default_batch_delay(self): def test_num_queries_correct(self): """ - Assert the number of queries to be 4 + 1 * number of lines: + Assert the number of queries to be 2 + 1 * number of lines: 2 for savepoint/release savepoint - 1 to bulk create users, 1 to bulk create course role - 1 for each user (to get user) + 1 to bulk create course role + 4 for each user (to get user, and savepoints) """ num_lines = 20 lines = [f'pam{i},pam{i}@pond.com,staff,course-v1:edx+test+f20\n' for i in range(num_lines)] with NamedTemporaryFile() as csv: csv = self._write_test_csv(csv, lines) - with self.assertNumQueries(4 + num_lines): + with self.assertNumQueries(3 + 4 * num_lines): call_command(self.command, f'--csv_path={csv.name}') def test_dupe_user_csv(self):