Skip to content

Commit

Permalink
Simplify assignment queries to use assignment.course_id
Browse files Browse the repository at this point in the history
We no longer have to de-duplicate AssignmentGrouping as this is now
denormalized in Assignment.course_id
  • Loading branch information
marcospri committed Aug 7, 2024
1 parent dab1822 commit 05be2b2
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 118 deletions.
78 changes: 13 additions & 65 deletions lms/services/assignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,8 +269,7 @@ def get_assignments( # noqa: PLR0913
if admin_organization_ids:
admin_organization_ids_clause = Assignment.id.in_(
select(Assignment.id)
.join(AssignmentGrouping)
.join(Grouping)
.join(Course)
.join(ApplicationInstance)
.join(Organization)
.where(Organization.id.in_(admin_organization_ids))
Expand All @@ -289,78 +288,27 @@ def get_assignments( # noqa: PLR0913
)

if course_ids:
deduplicated_course_assignments = (
self._deduplicated_course_assigments_query(course_ids).subquery()
)

query = query.where(
# Get only assignment from the candidates above
Assignment.id == deduplicated_course_assignments.c.assignment_id,
deduplicated_course_assignments.c.grouping_id.in_(course_ids),
)
query = query.where(Assignment.course_id.in_(course_ids))

return query.order_by(Assignment.title, Assignment.id).distinct()

def _deduplicated_course_assigments_query(self, course_ids: list[int]):
# Get all assignment IDs we recorded from this course
raw_course_assignments = select(AssignmentGrouping.assignment_id).where(
AssignmentGrouping.grouping_id.in_(course_ids)
)

# Get a list of deduplicated assignments based on raw_course_assignments,
# this will contain assignments that belong (now) to other courses
return (
select(AssignmentGrouping.assignment_id, AssignmentGrouping.grouping_id)
.distinct(AssignmentGrouping.assignment_id)
.join(Grouping)
.where(
# Only look at courses, otherwise courses and sections will deduplicate each other
Grouping.type == "course",
# Use the previous query to look only at the potential candidates
AssignmentGrouping.assignment_id.in_(raw_course_assignments),
)
# Deduplicate them based on the updated column, take the last one (together with the distinct clause)
.order_by(
AssignmentGrouping.assignment_id, AssignmentGrouping.updated.desc()
)
)

def get_courses_assignments_count(
self, course_ids: list[int], **kwargs
) -> dict[int, int]:
def get_courses_assignments_count(self, **kwargs) -> dict[int, int]:
"""Get the number of assignments a given list of courses has."""

assignments_query = (
query = (
# Query assignments
self.get_assignments(**kwargs)
# Only select their IDs
.with_only_columns(Assignment.id)
# Remove any sorting options, we are going to avoid having to worry about sorted columns
.order_by(None)
)

# We didn't pass course_ids to get_assigments because we need to deduplicate when we count, not before
deduplicated_course_assignments = self._deduplicated_course_assigments_query(
course_ids
).subquery()

counts_query = (
select(
AssignmentGrouping.grouping_id,
func.count(AssignmentGrouping.assignment_id),
# Change the selected columns
.with_only_columns(
Assignment.course_id,
func.count(Assignment.id),
)
.where(
AssignmentGrouping.assignment_id.in_(assignments_query),
AssignmentGrouping.grouping_id.in_(course_ids),
deduplicated_course_assignments.c.grouping_id
== AssignmentGrouping.grouping_id,
AssignmentGrouping.assignment_id
== deduplicated_course_assignments.c.assignment_id,
)
.group_by(AssignmentGrouping.grouping_id)
# Remove any sorting options, to avoid having to worry about sorted columns being or not in the select
.order_by(None)
# Group by course to get the counts
.group_by(Assignment.course_id)
)

return {x.grouping_id: x.count for x in self._db.execute(counts_query)} # type: ignore
return {x.course_id: x.count for x in self._db.execute(query)} # type: ignore


def factory(_context, request):
Expand Down
78 changes: 25 additions & 53 deletions tests/unit/lms/services/assignment_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ def test_update_assignment(
misc_plugin,
resource_link_title,
title,
course,
):
course = factories.Course()
pyramid_request.lti_params["resource_link_title"] = resource_link_title
misc_plugin.is_speed_grader_launch.return_value = is_speed_grader

Expand All @@ -51,14 +51,14 @@ def test_update_assignment(
course,
)

assignment.title = title
assignment.course_id = course.id
if is_speed_grader:
assert assignment.extra == {}
assert assignment.document_url != sentinel.document_url
else:
assert assignment.document_url == sentinel.document_url
assert assignment.extra["group_set_id"] == sentinel.group_set_id
assert assignment.title == title
assert assignment.course_id == course.id

@pytest.mark.parametrize(
"param",
Expand Down Expand Up @@ -102,8 +102,8 @@ def test_get_assignment_for_launch_existing(
misc_plugin,
get_assignment,
_get_copied_from_assignment,
course,
):
course = factories.Course()
misc_plugin.get_assignment_configuration.return_value = {
"document_url": sentinel.document_url,
"group_set_id": sentinel.group_set_id,
Expand All @@ -130,11 +130,11 @@ def test_get_assignment_for_launch_existing(
assert assignment.course_id == course.id

def test_get_assignment_returns_None_with_when_no_document(
self, pyramid_request, svc, misc_plugin
self, pyramid_request, svc, misc_plugin, course
):
misc_plugin.get_assignment_configuration.return_value = {"document_url": None}

assert not svc.get_assignment_for_launch(pyramid_request, factories.Course())
assert not svc.get_assignment_for_launch(pyramid_request, course)

@pytest.mark.parametrize("group_set_id", [None, "1"])
def test_get_assignment_creates_assignment(
Expand All @@ -146,8 +146,8 @@ def test_get_assignment_creates_assignment(
_get_copied_from_assignment,
create_assignment,
group_set_id,
course,
):
course = factories.Course()
misc_plugin.get_assignment_configuration.return_value = {
"document_url": sentinel.document_url,
"group_set_id": group_set_id,
Expand Down Expand Up @@ -176,14 +176,15 @@ def test_get_assignment_created_assignments_point_to_copy(
get_assignment,
_get_copied_from_assignment,
create_assignment,
course,
):
misc_plugin.get_assignment_configuration.return_value = {
"document_url": sentinel.document_url
}
get_assignment.return_value = None
_get_copied_from_assignment.return_value = sentinel.original_assignment

assignment = svc.get_assignment_for_launch(pyramid_request, factories.Course())
assignment = svc.get_assignment_for_launch(pyramid_request, course)

_get_copied_from_assignment.assert_called_once_with(pyramid_request.lti_params)
create_assignment.assert_called_once_with(
Expand Down Expand Up @@ -272,10 +273,9 @@ def test_get_assignments(
h_userids,
assignment_ids,
organization,
application_instance,
course,
):
factories.User()
course = factories.Course(application_instance=application_instance)
user = factories.User()
lti_role = factories.LTIRole(scope=RoleScope.COURSE, type=RoleType.INSTRUCTOR)
factories.AssignmentMembership.create(
Expand All @@ -284,7 +284,7 @@ def test_get_assignments(
factories.AssignmentMembership.create(
assignment=assignment, user=user, lti_role=factories.LTIRole()
)
factories.AssignmentGrouping.create(assignment=assignment, grouping=course)
assignment.course = course
db_session.flush()

query_parameters = {}
Expand All @@ -306,8 +306,7 @@ def test_get_assignments(

assert db_session.scalars(query).all() == [assignment]

def test_get_assignments_excludes_empty_titles(self, db_session, svc):
course = factories.Course()
def test_get_assignments_excludes_empty_titles(self, db_session, svc, course):
assignment = factories.Assignment(title=None)
factories.AssignmentGrouping(
grouping=course, assignment=assignment, updated=date(2022, 1, 1)
Expand All @@ -318,55 +317,22 @@ def test_get_assignments_excludes_empty_titles(self, db_session, svc):
svc.get_assignments(course_ids=[course.id])
).all() == [assignment]

def test_get_assignments_by_course_id_with_duplicate(
self, db_session, svc, application_instance, organization
):
course = factories.Course(application_instance=application_instance)
other_course = factories.Course(application_instance=application_instance)

assignment = factories.Assignment()

# other course only has an assignment that `course` has stolen
factories.AssignmentGrouping(
grouping=other_course, assignment=assignment, updated=date(2020, 1, 1)
)
factories.AssignmentGrouping(
grouping=course, assignment=assignment, updated=date(2022, 1, 1)
)
db_session.flush()

assert db_session.scalars(
svc.get_assignments(
course_ids=[course.id], admin_organization_ids=[organization.id]
)
).all() == [assignment]
# We don't expect to get the other one at all, now the assignment belongs to the most recent course
assert not db_session.scalars(
svc.get_assignments(
course_ids=[other_course.id], admin_organization_ids=[organization.id]
)
).all()

def test_get_courses_assignments_count(
self, svc, db_session, organization, application_instance
self, svc, db_session, organization, course, application_instance
):
course = factories.Course(application_instance=application_instance)
factories.Assignment(course=course)
factories.Assignment(course=course)
factories.Assignment(course=course)

other_course = factories.Course(application_instance=application_instance)
assignment = factories.Assignment()
factories.Assignment(course=other_course)

# other course only has an assignment that `course` has stolen
factories.AssignmentGrouping(
grouping=other_course, assignment=assignment, updated=date(2020, 1, 1)
)
factories.AssignmentGrouping(
grouping=course, assignment=assignment, updated=date(2022, 1, 1)
)
db_session.flush()

assert svc.get_courses_assignments_count(
course_ids=[course.id, other_course.id],
admin_organization_ids=[organization.id],
) == {course.id: 1}
) == {course.id: 3, other_course.id: 1}

@pytest.fixture
def svc(self, db_session, misc_plugin):
Expand All @@ -378,6 +344,12 @@ def assignment(self):
created=datetime(2000, 1, 1), updated=datetime(2000, 1, 1)
)

@pytest.fixture()
def course(self, db_session, application_instance):
course = factories.Course(application_instance=application_instance)
db_session.flush()
return course

@pytest.fixture
def matching_params(self, assignment):
return {
Expand Down

0 comments on commit 05be2b2

Please sign in to comment.