Skip to content

Commit

Permalink
Resolve test flags in python (#898)
Browse files Browse the repository at this point in the history
  • Loading branch information
Swatinem authored Nov 29, 2024
1 parent a56f25f commit 2ac43ea
Showing 1 changed file with 22 additions and 12 deletions.
34 changes: 22 additions & 12 deletions tasks/cache_test_rollups.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import datetime as dt
from datetime import date, timezone

import polars as pl
from django.db import connections
from redis.exceptions import LockError
from shared.celery_config import cache_test_rollups_task_name
from shared.config import get_config
from shared.django_apps.reports.models import LastCacheRollupDate
from shared.django_apps.reports.models import LastCacheRollupDate, RepositoryFlag

from app import celery_app
from django_scaffold import settings
Expand Down Expand Up @@ -59,11 +59,9 @@

TEST_FLAGS_SUBQUERY = """
SELECT test_id,
array_agg(DISTINCT flag_name) AS flags
array_agg(DISTINCT flag_id) AS flags
FROM reports_test_results_flag_bridge tfb
JOIN reports_test rt ON rt.id = tfb.test_id
JOIN reports_repositoryflag rr ON tfb.flag_id = rr.id
WHERE rt.repoid = %(repoid)s
GROUP BY test_id
"""

Expand Down Expand Up @@ -124,6 +122,12 @@ def run_impl_within_lock(self, repoid: int, branch: str, update_date: bool = Tru
else:
connection = connections["default"]

repo_flags = dict(
RepositoryFlag.objects.filter(repository_id=repoid)
.values_list("id", "flag_name")
.all()
)

with connection.cursor() as cursor:
for interval_start, interval_end in [
(1, None),
Expand All @@ -143,11 +147,17 @@ def run_impl_within_lock(self, repoid: int, branch: str, update_date: bool = Tru
if interval_end is not None:
query_params["interval_end"] = f"{interval_end} days"

cursor.execute(
base_query,
query_params,
)
aggregation_of_test_results = cursor.fetchall()
cursor.execute(base_query, query_params)
results = cursor.fetchall()

# manually map the flags from IDs to their names
def resolve_flags(result_t: tuple) -> list:
result = list(result_t)
if result[2]:
result[2] = [repo_flags[flag_id] for flag_id in result[2]]
return result

aggregation_of_test_results = map(resolve_flags, results)

df = pl.DataFrame(
aggregation_of_test_results,
Expand All @@ -158,7 +168,7 @@ def run_impl_within_lock(self, repoid: int, branch: str, update_date: bool = Tru
"test_id",
"failure_rate",
"flake_rate",
("updated_at", pl.Datetime(time_zone=dt.UTC)),
("updated_at", pl.Datetime(time_zone=timezone.utc)),
"avg_duration",
"total_fail_count",
"total_flaky_fail_count",
Expand All @@ -181,7 +191,7 @@ def run_impl_within_lock(self, repoid: int, branch: str, update_date: bool = Tru
LastCacheRollupDate.objects.update_or_create(
repository_id=repoid,
branch=branch,
defaults=dict(last_rollup_date=dt.date.today()),
defaults=dict(last_rollup_date=date.today()),
)

serialized_table.seek(0) # avoids Stream must be at beginning errors
Expand Down

0 comments on commit 2ac43ea

Please sign in to comment.