From 3d204061e1fdee59505b390ddd0be1bcfc839a91 Mon Sep 17 00:00:00 2001 From: Arpad Borsos Date: Thu, 31 Oct 2024 15:46:33 +0100 Subject: [PATCH] Only update the `Upload` once in `UploadFinisher` This avoids updating the `Upload` state in the first `UploadProcessor` job, deferring that work to the final `UploadFinisher`, which was doing the same anyway. This way, it is also possible to run the `UPDATE`, as well as all the `INSERT`s for either `UploadError` or `UploadLevelTotals` (which we should start removing) in bulk. --- services/processing/merging.py | 92 ++++++++++++++++--- services/processing/processing.py | 2 +- services/processing/types.py | 4 +- services/report/__init__.py | 68 ++------------ services/tests/test_report.py | 47 +--------- tasks/tests/unit/test_upload_finisher_task.py | 45 +++++++++ .../tests/unit/test_upload_processing_task.py | 31 +------ tasks/upload_finisher.py | 30 +++--- tasks/upload_processor.py | 11 +-- 9 files changed, 156 insertions(+), 174 deletions(-) diff --git a/services/processing/merging.py b/services/processing/merging.py index d4c29ff82..ce7c5c632 100644 --- a/services/processing/merging.py +++ b/services/processing/merging.py @@ -1,15 +1,20 @@ +import functools +from decimal import Decimal + import sentry_sdk from shared.reports.editable import EditableReport, EditableReportFile from shared.reports.enums import UploadState -from shared.reports.resources import Report +from shared.reports.resources import Report, ReportTotals from shared.yaml import UserYaml from sqlalchemy.orm import Session as DbSession -from database.models.reports import Upload +from database.models.reports import Upload, UploadError, UploadLevelTotals +from helpers.number import precise_round from services.report import delete_uploads_by_sessionid from services.report.raw_upload_processor import clear_carryforward_sessions +from services.yaml.reader import read_yaml_field -from .types import IntermediateReport, MergeResult +from .types import IntermediateReport, MergeResult, ProcessingResult @sentry_sdk.trace @@ -49,7 +54,13 @@ def merge_reports( @sentry_sdk.trace -def update_uploads(db_session: DbSession, merge_result: MergeResult): +def update_uploads( + db_session: DbSession, + commit_yaml: UserYaml, + processing_results: list[ProcessingResult], + intermediate_reports: list[IntermediateReport], + merge_result: MergeResult, +): """ Updates all the `Upload` records with the `MergeResult`. In particular, this updates the `order_number` to match the new `session_id`, @@ -69,17 +80,74 @@ def update_uploads(db_session: DbSession, merge_result: MergeResult): db_session, report_id, merge_result.deleted_sessions ) - # then, update all the sessions that have been merged - for upload_id, session_id in merge_result.session_mapping.items(): - update = { - Upload.state_id: UploadState.PROCESSED.db_id, - Upload.state: "processed", - Upload.order_number: session_id, - } - db_session.query(Upload).filter(Upload.id_ == upload_id).update(update) + precision: int = read_yaml_field(commit_yaml, ("coverage", "precision"), 2) + rounding: str = read_yaml_field(commit_yaml, ("coverage", "round"), "nearest") + make_totals = functools.partial(make_upload_totals, precision, rounding) + + reports = {ir.upload_id: ir.report for ir in intermediate_reports} + + # then, update all the `Upload`s with their state, and the final `order_number`, + # as well as add a `UploadLevelTotals` or `UploadError`s where appropriate. + all_errors: list[UploadError] = [] + all_totals: list[UploadLevelTotals] = [] + all_upload_updates: list[dict] = [] + for result in processing_results: + upload_id = result["upload_id"] + + if result["successful"]: + update = { + "state_id": UploadState.PROCESSED.db_id, + "state": "processed", + } + report = reports.get(upload_id) + if report is not None: + all_totals.append(make_totals(upload_id, report.totals)) + elif result["error"]: + update = { + "state_id": UploadState.ERROR.db_id, + "state": "error", + } + error = UploadError( + upload_id=upload_id, + error_code=result["error"]["code"], + error_params=result["error"]["params"], + ) + all_errors.append(error) + + update["id"] = upload_id + order_number = merge_result.session_mapping.get(upload_id) + update["order_number"] = order_number + all_upload_updates + + db_session.bulk_update_mappings(Upload, all_upload_updates) + db_session.bulk_save_objects(all_errors) + db_session.bulk_save_objects(all_totals) + db_session.flush() +# TODO(swatinem): we should eventually remove `UploadLevelTotals` completely +def make_upload_totals( + precision: int, rounding: str, upload_id: int, totals: ReportTotals +) -> UploadLevelTotals: + if totals.coverage is not None: + coverage = precise_round(Decimal(totals.coverage), precision, rounding) + else: + coverage = Decimal(0) + + return UploadLevelTotals( + upload_id=upload_id, + branches=totals.branches, + coverage=coverage, + hits=totals.hits, + lines=totals.lines, + methods=totals.methods, + misses=totals.misses, + partials=totals.partials, + files=totals.files, + ) + + def change_sessionid(report: EditableReport, old_id: int, new_id: int): """ Modifies the `EditableReport`, changing the session with `old_id` to have `new_id` instead. diff --git a/services/processing/processing.py b/services/processing/processing.py index 9ae7bd649..1a0a664ae 100644 --- a/services/processing/processing.py +++ b/services/processing/processing.py @@ -66,7 +66,7 @@ def process_upload( result["successful"] = True log.info("Finished processing upload", extra={"result": result}) - report_service.update_upload_with_processing_result(upload, processing_result) + # TODO(swatinem): only save the intermediate report on success save_intermediate_report(archive_service, commit_sha, upload_id, report) state.mark_upload_as_processed(upload_id) diff --git a/services/processing/types.py b/services/processing/types.py index 92554c8a8..365f377ba 100644 --- a/services/processing/types.py +++ b/services/processing/types.py @@ -3,6 +3,8 @@ from shared.reports.editable import EditableReport +from services.report import ProcessingErrorDict + class UploadArguments(TypedDict): # TODO(swatinem): migrate this over to `upload_id` @@ -13,7 +15,7 @@ class ProcessingResult(TypedDict): upload_id: int arguments: UploadArguments successful: bool - error: NotRequired[dict] + error: NotRequired[ProcessingErrorDict] @dataclass diff --git a/services/report/__init__.py b/services/report/__init__.py index 3ebeed5dc..b52f73ebe 100644 --- a/services/report/__init__.py +++ b/services/report/__init__.py @@ -5,7 +5,7 @@ import uuid from dataclasses import dataclass from time import time -from typing import Any, Mapping, Sequence +from typing import Any, Mapping, Sequence, TypedDict import orjson import sentry_sdk @@ -61,13 +61,18 @@ from services.yaml.reader import get_paths_from_flags, read_yaml_field +class ProcessingErrorDict(TypedDict): + code: UploadErrorCode + params: dict[str, Any] + + @dataclass class ProcessingError: code: UploadErrorCode params: dict[str, Any] is_retryable: bool = False - def as_dict(self): + def as_dict(self) -> ProcessingErrorDict: return {"code": self.code, "params": self.params} @@ -760,65 +765,6 @@ def build_report_from_raw_content( raw_report_info.error = result.error return result - def update_upload_with_processing_result( - self, upload: Upload, processing_result: ProcessingResult - ): - rounding: str = read_yaml_field( - self.current_yaml, ("coverage", "round"), "nearest" - ) - precision: int = read_yaml_field( - self.current_yaml, ("coverage", "precision"), 2 - ) - db_session = upload.get_db_session() - session = processing_result.session - - if processing_result.error is None: - upload.state_id = UploadState.PROCESSED.db_id - upload.state = "processed" - upload.order_number = session.id - upload_totals = upload.totals - if upload_totals is None: - upload_totals = UploadLevelTotals( - upload_id=upload.id, - branches=0, - coverage=0, - hits=0, - lines=0, - methods=0, - misses=0, - partials=0, - files=0, - ) - db_session.add(upload_totals) - if session.totals is not None: - upload_totals.update_from_totals( - session.totals, precision=precision, rounding=rounding - ) - - # delete all the carryforwarded `Upload` records corresponding to `Session`s - # which have been removed from the report. - # we always have a `session_adjustment` in the non-error case. - assert processing_result.session_adjustment - deleted_sessions = set( - processing_result.session_adjustment.fully_deleted_sessions - ) - if deleted_sessions: - delete_uploads_by_sessionid( - db_session, upload.report_id, deleted_sessions - ) - - else: - error = processing_result.error - upload.state = "error" - upload.state_id = UploadState.ERROR.db_id - error_obj = UploadError( - upload_id=upload.id, - error_code=error.code, - error_params=error.params, - ) - db_session.add(error_obj) - db_session.flush() - @sentry_sdk.trace def save_report(self, commit: Commit, report: Report, report_code=None): if len(report._chunks) > 2 * len(report._files) and len(report._files) > 0: diff --git a/services/tests/test_report.py b/services/tests/test_report.py index b621f1271..b6ef56928 100644 --- a/services/tests/test_report.py +++ b/services/tests/test_report.py @@ -3,22 +3,16 @@ import mock import pytest from celery.exceptions import SoftTimeLimitExceeded -from shared.reports.enums import UploadState from shared.reports.resources import Report, ReportFile, Session, SessionType from shared.reports.types import ReportLine, ReportTotals from shared.torngit.exceptions import TorngitRateLimitError from shared.yaml import UserYaml from database.models import CommitReport, ReportDetails, RepositoryFlag, Upload -from database.tests.factories import CommitFactory, UploadFactory +from database.tests.factories import CommitFactory from helpers.exceptions import RepositoryWithoutValidBotError from services.archive import ArchiveService -from services.report import ( - NotReadyToBuildReportYetError, - ProcessingError, - ProcessingResult, - ReportService, -) +from services.report import NotReadyToBuildReportYetError, ReportService from services.report import log as report_log from services.report.raw_upload_processor import ( SessionAdjustmentResult, @@ -4029,43 +4023,6 @@ def test_create_report_upload(self, dbsession): assert first_flag.flag_name == "unittest" assert first_flag.repository_id == commit.repoid - def test_update_upload_with_processing_result_error(self, mocker, dbsession): - upload_obj = UploadFactory.create(state="started", storage_path="url") - dbsession.add(upload_obj) - dbsession.flush() - assert len(upload_obj.errors) == 0 - processing_result = ProcessingResult( - session=mocker.MagicMock(), - error=ProcessingError(code="abclkj", params={"banana": "value"}), - ) - ReportService({}).update_upload_with_processing_result( - upload_obj, processing_result - ) - dbsession.refresh(upload_obj) - assert upload_obj.state == "error" - assert upload_obj.state_id == UploadState.ERROR.db_id - assert len(upload_obj.errors) == 1 - assert upload_obj.errors[0].error_code == "abclkj" - assert upload_obj.errors[0].error_params == {"banana": "value"} - assert upload_obj.errors[0].report_upload == upload_obj - - def test_update_upload_with_processing_result_success(self, mocker, dbsession): - upload_obj = UploadFactory.create(state="started", storage_path="url") - dbsession.add(upload_obj) - dbsession.flush() - assert len(upload_obj.errors) == 0 - processing_result = ProcessingResult( - session=Session(), - session_adjustment=SessionAdjustmentResult([], []), - ) - ReportService({}).update_upload_with_processing_result( - upload_obj, processing_result - ) - dbsession.refresh(upload_obj) - assert upload_obj.state == "processed" - assert upload_obj.state_id == UploadState.PROCESSED.db_id - assert len(upload_obj.errors) == 0 - def test_shift_carryforward_report( self, dbsession, sample_report, mocker, mock_repo_provider ): diff --git a/tasks/tests/unit/test_upload_finisher_task.py b/tasks/tests/unit/test_upload_finisher_task.py index 3c03cf17d..7d5fb02e9 100644 --- a/tasks/tests/unit/test_upload_finisher_task.py +++ b/tasks/tests/unit/test_upload_finisher_task.py @@ -8,9 +8,13 @@ from shared.celery_config import timeseries_save_commit_measurements_task_name from shared.yaml import UserYaml +from database.models.reports import CommitReport from database.tests.factories import CommitFactory, PullFactory, RepositoryFactory +from database.tests.factories.core import UploadFactory from helpers.checkpoint_logger import CheckpointLogger, _kwargs_key from helpers.checkpoint_logger.flows import UploadFlow +from services.processing.merging import update_uploads +from services.processing.types import MergeResult, ProcessingResult from tasks.upload_finisher import ( ReportService, ShouldCallNotifyResult, @@ -74,6 +78,47 @@ def test_results_arg_new(): ] +def test_mark_uploads_as_failed(dbsession): + commit = CommitFactory.create() + dbsession.add(commit) + dbsession.flush() + report = CommitReport(commit_id=commit.id_) + dbsession.add(report) + dbsession.flush() + upload_1 = UploadFactory.create(report=report, state="started", storage_path="url") + upload_2 = UploadFactory.create(report=report, state="started", storage_path="url2") + dbsession.add(upload_1) + dbsession.add(upload_2) + dbsession.flush() + + results: list[ProcessingResult] = [ + { + "upload_id": upload_1.id, + "successful": False, + "error": {"code": "report_empty", "params": {}}, + }, + { + "upload_id": upload_2.id, + "successful": False, + "error": {"code": "report_expired", "params": {}}, + }, + ] + + update_uploads(dbsession, UserYaml(), [], [], MergeResult({}, set())) + + assert upload_1.state == "error" + assert len(upload_1.errors) == 1 + assert upload_1.errors[0].error_code == "report_empty" + assert upload_1.errors[0].error_params == {} + assert upload_1.errors[0].report_upload == upload_1 + + assert upload_2.state == "error" + assert len(upload_2.errors) == 1 + assert upload_2.errors[0].error_code == "report_expired" + assert upload_2.errors[0].error_params == {} + assert upload_2.errors[0].report_upload == upload_2 + + class TestUploadFinisherTask(object): @pytest.mark.django_db(databases={"default"}) def test_upload_finisher_task_call( diff --git a/tasks/tests/unit/test_upload_processing_task.py b/tasks/tests/unit/test_upload_processing_task.py index b919cfc0f..3cd8f9aa5 100644 --- a/tasks/tests/unit/test_upload_processing_task.py +++ b/tasks/tests/unit/test_upload_processing_task.py @@ -4,14 +4,13 @@ import pytest from celery.exceptions import Retry from shared.config import get_config -from shared.reports.enums import UploadState from shared.reports.resources import Report, ReportFile, ReportLine, ReportTotals from shared.storage.exceptions import FileNotInStorageError from shared.torngit.exceptions import TorngitObjectNotFoundError from shared.upload.constants import UploadErrorCode from shared.yaml import UserYaml -from database.models import CommitReport, ReportDetails, UploadError +from database.models import CommitReport, ReportDetails from database.tests.factories import CommitFactory, UploadFactory from helpers.exceptions import ( ReportEmptyError, @@ -233,7 +232,6 @@ def test_upload_processor_call_with_upload_obj( "successful": True, } ] - assert upload.state == "processed" # storage is overwritten with parsed contents data = mock_storage.read_file("archive", url) @@ -274,8 +272,6 @@ def test_upload_task_call_exception_within_individual_upload( "parse_raw_report_from_storage", return_value="ParsedRawReport()", ) - mocker.patch("tasks.upload_processor.load_commit_diff") - mocker.patch("tasks.upload_processor.save_report_results") mocked_post_process = mocker.patch( "services.processing.processing.rewrite_or_delete_upload" @@ -301,17 +297,6 @@ def test_upload_task_call_exception_within_individual_upload( }, } ] - assert upload.state_id == UploadState.ERROR.db_id - assert upload.state == "error" - - error_obj = ( - dbsession.query(UploadError) - .filter(UploadError.upload_id == upload.id) - .first() - ) - assert error_obj is not None - assert error_obj.error_code == UploadErrorCode.UNKNOWN_PROCESSING - mocked_post_process.assert_called_with( mocker.ANY, mocker.ANY, @@ -468,7 +453,6 @@ def test_upload_task_process_individual_report_with_notfound_report( } ] assert commit.state == "complete" - assert upload.state == "error" def test_upload_task_process_individual_report_with_notfound_report_no_retries_yet( self, dbsession, mocker @@ -591,11 +575,6 @@ def test_upload_task_call_with_empty_report( }, ] assert commit.state == "complete" - assert len(upload_2.errors) == 1 - assert upload_2.errors[0].error_code == "report_empty" - assert upload_2.errors[0].error_params == {} - assert upload_2.errors[0].report_upload == upload_2 - assert len(upload_1.errors) == 0 @pytest.mark.django_db(databases={"default"}) def test_upload_task_call_no_successful_report( @@ -680,14 +659,6 @@ def test_upload_task_call_no_successful_report( "error": {"code": "report_expired", "params": {}}, }, ] - assert len(upload_2.errors) == 1 - assert upload_2.errors[0].error_code == "report_expired" - assert upload_2.errors[0].error_params == {} - assert upload_2.errors[0].report_upload == upload_2 - assert len(upload_1.errors) == 1 - assert upload_1.errors[0].error_code == "report_empty" - assert upload_1.errors[0].error_params == {} - assert upload_1.errors[0].report_upload == upload_1 @pytest.mark.django_db(databases={"default"}) def test_upload_task_call_softtimelimit( diff --git a/tasks/upload_finisher.py b/tasks/upload_finisher.py index b415b18b8..cd9618527 100644 --- a/tasks/upload_finisher.py +++ b/tasks/upload_finisher.py @@ -111,6 +111,7 @@ def run_impl( upload_ids = [upload["upload_id"] for upload in processing_results] pr = processing_results[0]["arguments"].get("pr") diff = load_commit_diff(commit, pr, self.name) + commit_yaml = UserYaml(commit_yaml) try: with get_report_lock(repoid, commitid, self.hard_time_limit_task): @@ -121,19 +122,13 @@ def run_impl( archive_service, commit_yaml, commit, - upload_ids, + processing_results, ) log.info( "Saving combined report", - extra=dict( - repoid=repoid, - commit=commitid, - processing_results=processing_results, - parent_task=self.request.parent_id, - ), + extra={"processing_results": processing_results}, ) - save_report_results( report_service, commit, report, diff, pr, report_code ) @@ -444,25 +439,32 @@ def get_report_lock(repoid: int, commitid: str, hard_time_limit: int) -> Lock: def perform_report_merging( report_service: ReportService, archive_service: ArchiveService, - commit_yaml: dict, + commit_yaml: UserYaml, commit: Commit, - upload_ids: list[int], + processing_results: list[ProcessingResult], ) -> Report: master_report = report_service.get_existing_report_for_commit(commit) if master_report is None: master_report = Report() + upload_ids = [ + upload["upload_id"] for upload in processing_results if upload["successful"] + ] intermediate_reports = load_intermediate_reports( archive_service, commit.commitid, upload_ids ) - merge_result = merge_reports( - UserYaml(commit_yaml), master_report, intermediate_reports - ) + merge_result = merge_reports(commit_yaml, master_report, intermediate_reports) # Update the `Upload` in the database with the final session_id # (aka `order_number`) and other statuses - update_uploads(commit.get_db_session(), merge_result) + update_uploads( + commit.get_db_session(), + commit_yaml, + processing_results, + intermediate_reports, + merge_result, + ) return master_report diff --git a/tasks/upload_processor.py b/tasks/upload_processor.py index 3cc2ba034..c6e0a675c 100644 --- a/tasks/upload_processor.py +++ b/tasks/upload_processor.py @@ -103,7 +103,6 @@ def load_commit_diff( commit: Commit, pr: Pull | None, task_name: str | None ) -> dict | None: repository = commit.repository - commitid = commit.commitid try: installation_name_to_use = ( get_installation_name_for_owner_for_task(task_name, repository.owner) @@ -113,7 +112,7 @@ def load_commit_diff( repository_service = get_repo_provider_service( repository, installation_name_to_use=installation_name_to_use ) - return async_to_sync(repository_service.get_commit_diff)(commitid) + return async_to_sync(repository_service.get_commit_diff)(commit.commitid) # TODO: can we maybe get rid of all this logging? except TorngitError: @@ -123,10 +122,6 @@ def load_commit_diff( # alternative of refusing an otherwise "good" report because of the lack of diff log.warning( "Could not apply diff to report because there was an error fetching diff from provider", - extra=dict( - repoid=commit.repoid, - commit=commit.commitid, - ), exc_info=True, ) except RepositoryWithoutValidBotError: @@ -141,10 +136,6 @@ def load_commit_diff( log.warning( "Could not apply diff to report because there is no valid bot found for that repo", - extra=dict( - repoid=commit.repoid, - commit=commit.commitid, - ), exc_info=True, )