diff --git a/lms/services/assignment.py b/lms/services/assignment.py index 500a839821..0c36eab339 100644 --- a/lms/services/assignment.py +++ b/lms/services/assignment.py @@ -269,8 +269,7 @@ def get_assignments( # noqa: PLR0913 if admin_organization_ids: admin_organization_ids_clause = Assignment.id.in_( select(Assignment.id) - .join(AssignmentGrouping) - .join(Grouping) + .join(Course) .join(ApplicationInstance) .join(Organization) .where(Organization.id.in_(admin_organization_ids)) @@ -289,78 +288,27 @@ def get_assignments( # noqa: PLR0913 ) if course_ids: - deduplicated_course_assignments = ( - self._deduplicated_course_assigments_query(course_ids).subquery() - ) - - query = query.where( - # Get only assignment from the candidates above - Assignment.id == deduplicated_course_assignments.c.assignment_id, - deduplicated_course_assignments.c.grouping_id.in_(course_ids), - ) + query = query.where(Assignment.course_id.in_(course_ids)) return query.order_by(Assignment.title, Assignment.id).distinct() - def _deduplicated_course_assigments_query(self, course_ids: list[int]): - # Get all assignment IDs we recorded from this course - raw_course_assignments = select(AssignmentGrouping.assignment_id).where( - AssignmentGrouping.grouping_id.in_(course_ids) - ) - - # Get a list of deduplicated assignments based on raw_course_assignments, - # this will contain assignments that belong (now) to other courses - return ( - select(AssignmentGrouping.assignment_id, AssignmentGrouping.grouping_id) - .distinct(AssignmentGrouping.assignment_id) - .join(Grouping) - .where( - # Only look at courses, otherwise courses and sections will deduplicate each other - Grouping.type == "course", - # Use the previous query to look only at the potential candidates - AssignmentGrouping.assignment_id.in_(raw_course_assignments), - ) - # Deduplicate them based on the updated column, take the last one (together with the distinct clause) - .order_by( - AssignmentGrouping.assignment_id, AssignmentGrouping.updated.desc() - ) - ) - - def get_courses_assignments_count( - self, course_ids: list[int], **kwargs - ) -> dict[int, int]: + def get_courses_assignments_count(self, **kwargs) -> dict[int, int]: """Get the number of assignments a given list of courses has.""" - - assignments_query = ( + query = ( # Query assignments self.get_assignments(**kwargs) - # Only select their IDs - .with_only_columns(Assignment.id) - # Remove any sorting options, we are going to avoid having to worry about sorted columns - .order_by(None) - ) - - # We didn't pass course_ids to get_assigments because we need to deduplicate when we count, not before - deduplicated_course_assignments = self._deduplicated_course_assigments_query( - course_ids - ).subquery() - - counts_query = ( - select( - AssignmentGrouping.grouping_id, - func.count(AssignmentGrouping.assignment_id), + # Change the selected columns + .with_only_columns( + Assignment.course_id, + func.count(Assignment.id), ) - .where( - AssignmentGrouping.assignment_id.in_(assignments_query), - AssignmentGrouping.grouping_id.in_(course_ids), - deduplicated_course_assignments.c.grouping_id - == AssignmentGrouping.grouping_id, - AssignmentGrouping.assignment_id - == deduplicated_course_assignments.c.assignment_id, - ) - .group_by(AssignmentGrouping.grouping_id) + # Remove any sorting options, to avoid having to worry about sorted columns being or not in the select + .order_by(None) + # Group by course to get the counts + .group_by(Assignment.course_id) ) - return {x.grouping_id: x.count for x in self._db.execute(counts_query)} # type: ignore + return {x.course_id: x.count for x in self._db.execute(query)} # type: ignore def factory(_context, request): diff --git a/tests/unit/lms/services/assignment_test.py b/tests/unit/lms/services/assignment_test.py index 8a4e6ca07d..f6dcf339d8 100644 --- a/tests/unit/lms/services/assignment_test.py +++ b/tests/unit/lms/services/assignment_test.py @@ -38,8 +38,8 @@ def test_update_assignment( misc_plugin, resource_link_title, title, + course, ): - course = factories.Course() pyramid_request.lti_params["resource_link_title"] = resource_link_title misc_plugin.is_speed_grader_launch.return_value = is_speed_grader @@ -51,14 +51,14 @@ def test_update_assignment( course, ) - assignment.title = title - assignment.course_id = course.id if is_speed_grader: assert assignment.extra == {} assert assignment.document_url != sentinel.document_url else: assert assignment.document_url == sentinel.document_url assert assignment.extra["group_set_id"] == sentinel.group_set_id + assert assignment.title == title + assert assignment.course_id == course.id @pytest.mark.parametrize( "param", @@ -102,8 +102,8 @@ def test_get_assignment_for_launch_existing( misc_plugin, get_assignment, _get_copied_from_assignment, + course, ): - course = factories.Course() misc_plugin.get_assignment_configuration.return_value = { "document_url": sentinel.document_url, "group_set_id": sentinel.group_set_id, @@ -130,11 +130,11 @@ def test_get_assignment_for_launch_existing( assert assignment.course_id == course.id def test_get_assignment_returns_None_with_when_no_document( - self, pyramid_request, svc, misc_plugin + self, pyramid_request, svc, misc_plugin, course ): misc_plugin.get_assignment_configuration.return_value = {"document_url": None} - assert not svc.get_assignment_for_launch(pyramid_request, factories.Course()) + assert not svc.get_assignment_for_launch(pyramid_request, course) @pytest.mark.parametrize("group_set_id", [None, "1"]) def test_get_assignment_creates_assignment( @@ -146,8 +146,8 @@ def test_get_assignment_creates_assignment( _get_copied_from_assignment, create_assignment, group_set_id, + course, ): - course = factories.Course() misc_plugin.get_assignment_configuration.return_value = { "document_url": sentinel.document_url, "group_set_id": group_set_id, @@ -176,6 +176,7 @@ def test_get_assignment_created_assignments_point_to_copy( get_assignment, _get_copied_from_assignment, create_assignment, + course, ): misc_plugin.get_assignment_configuration.return_value = { "document_url": sentinel.document_url @@ -183,7 +184,7 @@ def test_get_assignment_created_assignments_point_to_copy( get_assignment.return_value = None _get_copied_from_assignment.return_value = sentinel.original_assignment - assignment = svc.get_assignment_for_launch(pyramid_request, factories.Course()) + assignment = svc.get_assignment_for_launch(pyramid_request, course) _get_copied_from_assignment.assert_called_once_with(pyramid_request.lti_params) create_assignment.assert_called_once_with( @@ -272,10 +273,9 @@ def test_get_assignments( h_userids, assignment_ids, organization, - application_instance, + course, ): factories.User() - course = factories.Course(application_instance=application_instance) user = factories.User() lti_role = factories.LTIRole(scope=RoleScope.COURSE, type=RoleType.INSTRUCTOR) factories.AssignmentMembership.create( @@ -284,7 +284,7 @@ def test_get_assignments( factories.AssignmentMembership.create( assignment=assignment, user=user, lti_role=factories.LTIRole() ) - factories.AssignmentGrouping.create(assignment=assignment, grouping=course) + assignment.course = course db_session.flush() query_parameters = {} @@ -306,8 +306,7 @@ def test_get_assignments( assert db_session.scalars(query).all() == [assignment] - def test_get_assignments_excludes_empty_titles(self, db_session, svc): - course = factories.Course() + def test_get_assignments_excludes_empty_titles(self, db_session, svc, course): assignment = factories.Assignment(title=None) factories.AssignmentGrouping( grouping=course, assignment=assignment, updated=date(2022, 1, 1) @@ -318,55 +317,22 @@ def test_get_assignments_excludes_empty_titles(self, db_session, svc): svc.get_assignments(course_ids=[course.id]) ).all() == [assignment] - def test_get_assignments_by_course_id_with_duplicate( - self, db_session, svc, application_instance, organization - ): - course = factories.Course(application_instance=application_instance) - other_course = factories.Course(application_instance=application_instance) - - assignment = factories.Assignment() - - # other course only has an assignment that `course` has stolen - factories.AssignmentGrouping( - grouping=other_course, assignment=assignment, updated=date(2020, 1, 1) - ) - factories.AssignmentGrouping( - grouping=course, assignment=assignment, updated=date(2022, 1, 1) - ) - db_session.flush() - - assert db_session.scalars( - svc.get_assignments( - course_ids=[course.id], admin_organization_ids=[organization.id] - ) - ).all() == [assignment] - # We don't expect to get the other one at all, now the assignment belongs to the most recent course - assert not db_session.scalars( - svc.get_assignments( - course_ids=[other_course.id], admin_organization_ids=[organization.id] - ) - ).all() - def test_get_courses_assignments_count( - self, svc, db_session, organization, application_instance + self, svc, db_session, organization, course, application_instance ): - course = factories.Course(application_instance=application_instance) + factories.Assignment(course=course) + factories.Assignment(course=course) + factories.Assignment(course=course) + other_course = factories.Course(application_instance=application_instance) - assignment = factories.Assignment() + factories.Assignment(course=other_course) - # other course only has an assignment that `course` has stolen - factories.AssignmentGrouping( - grouping=other_course, assignment=assignment, updated=date(2020, 1, 1) - ) - factories.AssignmentGrouping( - grouping=course, assignment=assignment, updated=date(2022, 1, 1) - ) db_session.flush() assert svc.get_courses_assignments_count( course_ids=[course.id, other_course.id], admin_organization_ids=[organization.id], - ) == {course.id: 1} + ) == {course.id: 3, other_course.id: 1} @pytest.fixture def svc(self, db_session, misc_plugin): @@ -378,6 +344,12 @@ def assignment(self): created=datetime(2000, 1, 1), updated=datetime(2000, 1, 1) ) + @pytest.fixture() + def course(self, db_session, application_instance): + course = factories.Course(application_instance=application_instance) + db_session.flush() + return course + @pytest.fixture def matching_params(self, assignment): return {