diff --git a/tasks/cache_test_rollups.py b/tasks/cache_test_rollups.py index e1d5daa6a..7005ccae6 100644 --- a/tasks/cache_test_rollups.py +++ b/tasks/cache_test_rollups.py @@ -1,11 +1,11 @@ -import datetime as dt +from datetime import date, timezone import polars as pl from django.db import connections from redis.exceptions import LockError from shared.celery_config import cache_test_rollups_task_name from shared.config import get_config -from shared.django_apps.reports.models import LastCacheRollupDate +from shared.django_apps.reports.models import LastCacheRollupDate, RepositoryFlag from app import celery_app from django_scaffold import settings @@ -59,11 +59,9 @@ TEST_FLAGS_SUBQUERY = """ SELECT test_id, - array_agg(DISTINCT flag_name) AS flags + array_agg(DISTINCT flag_id) AS flags FROM reports_test_results_flag_bridge tfb JOIN reports_test rt ON rt.id = tfb.test_id -JOIN reports_repositoryflag rr ON tfb.flag_id = rr.id -WHERE rt.repoid = %(repoid)s GROUP BY test_id """ @@ -124,6 +122,12 @@ def run_impl_within_lock(self, repoid: int, branch: str, update_date: bool = Tru else: connection = connections["default"] + repo_flags = dict( + RepositoryFlag.objects.filter(repository_id=repoid) + .values_list("id", "flag_name") + .all() + ) + with connection.cursor() as cursor: for interval_start, interval_end in [ (1, None), @@ -143,11 +147,17 @@ def run_impl_within_lock(self, repoid: int, branch: str, update_date: bool = Tru if interval_end is not None: query_params["interval_end"] = f"{interval_end} days" - cursor.execute( - base_query, - query_params, - ) - aggregation_of_test_results = cursor.fetchall() + cursor.execute(base_query, query_params) + results = cursor.fetchall() + + # manually map the flags from IDs to their names + def resolve_flags(result_t: tuple) -> list: + result = list(result_t) + if result[2]: + result[2] = [repo_flags[flag_id] for flag_id in result[2]] + return result + + aggregation_of_test_results = map(resolve_flags, results) df = pl.DataFrame( aggregation_of_test_results, @@ -158,7 +168,7 @@ def run_impl_within_lock(self, repoid: int, branch: str, update_date: bool = Tru "test_id", "failure_rate", "flake_rate", - ("updated_at", pl.Datetime(time_zone=dt.UTC)), + ("updated_at", pl.Datetime(time_zone=timezone.utc)), "avg_duration", "total_fail_count", "total_flaky_fail_count", @@ -181,7 +191,7 @@ def run_impl_within_lock(self, repoid: int, branch: str, update_date: bool = Tru LastCacheRollupDate.objects.update_or_create( repository_id=repoid, branch=branch, - defaults=dict(last_rollup_date=dt.date.today()), + defaults=dict(last_rollup_date=date.today()), ) serialized_table.seek(0) # avoids Stream must be at beginning errors