Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resolve test flags in python #898

Merged
merged 1 commit into from
Nov 29, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 22 additions & 12 deletions tasks/cache_test_rollups.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import datetime as dt
from datetime import date, timezone

import polars as pl
from django.db import connections
from redis.exceptions import LockError
from shared.celery_config import cache_test_rollups_task_name
from shared.config import get_config
from shared.django_apps.reports.models import LastCacheRollupDate
from shared.django_apps.reports.models import LastCacheRollupDate, RepositoryFlag

from app import celery_app
from django_scaffold import settings
Expand Down Expand Up @@ -59,11 +59,9 @@

TEST_FLAGS_SUBQUERY = """
SELECT test_id,
array_agg(DISTINCT flag_name) AS flags
array_agg(DISTINCT flag_id) AS flags
FROM reports_test_results_flag_bridge tfb
JOIN reports_test rt ON rt.id = tfb.test_id
JOIN reports_repositoryflag rr ON tfb.flag_id = rr.id
WHERE rt.repoid = %(repoid)s
GROUP BY test_id
"""

Expand Down Expand Up @@ -124,6 +122,12 @@ def run_impl_within_lock(self, repoid: int, branch: str, update_date: bool = Tru
else:
connection = connections["default"]

repo_flags = dict(
RepositoryFlag.objects.filter(repository_id=repoid)
Swatinem marked this conversation as resolved.
Show resolved Hide resolved
.values_list("id", "flag_name")
.all()
)

with connection.cursor() as cursor:
for interval_start, interval_end in [
(1, None),
Expand All @@ -143,11 +147,17 @@ def run_impl_within_lock(self, repoid: int, branch: str, update_date: bool = Tru
if interval_end is not None:
query_params["interval_end"] = f"{interval_end} days"

cursor.execute(
base_query,
query_params,
)
aggregation_of_test_results = cursor.fetchall()
cursor.execute(base_query, query_params)
results = cursor.fetchall()

# manually map the flags from IDs to their names
def resolve_flags(result_t: tuple) -> list:
result = list(result_t)
if result[2]:
result[2] = [repo_flags[flag_id] for flag_id in result[2]]
return result

aggregation_of_test_results = map(resolve_flags, results)

df = pl.DataFrame(
aggregation_of_test_results,
Expand All @@ -158,7 +168,7 @@ def run_impl_within_lock(self, repoid: int, branch: str, update_date: bool = Tru
"test_id",
"failure_rate",
"flake_rate",
("updated_at", pl.Datetime(time_zone=dt.UTC)),
("updated_at", pl.Datetime(time_zone=timezone.utc)),
"avg_duration",
"total_fail_count",
"total_flaky_fail_count",
Expand All @@ -181,7 +191,7 @@ def run_impl_within_lock(self, repoid: int, branch: str, update_date: bool = Tru
LastCacheRollupDate.objects.update_or_create(
repository_id=repoid,
branch=branch,
defaults=dict(last_rollup_date=dt.date.today()),
defaults=dict(last_rollup_date=date.today()),
)

serialized_table.seek(0) # avoids Stream must be at beginning errors
Expand Down
Loading