diff --git a/services/report/__init__.py b/services/report/__init__.py index 1cc6a4629..03d7213b1 100644 --- a/services/report/__init__.py +++ b/services/report/__init__.py @@ -12,7 +12,6 @@ from asgiref.sync import async_to_sync from celery.exceptions import SoftTimeLimitExceeded from shared.django_apps.reports.models import ReportType -from shared.metrics import metrics from shared.reports.carryforward import generate_carryforward_report from shared.reports.editable import EditableReport from shared.reports.enums import UploadState, UploadType @@ -32,6 +31,7 @@ ReportLevelTotals, RepositoryFlag, UploadLevelTotals, + uploadflagmembership, ) from helpers.exceptions import ( OwnerWithoutValidBotError, @@ -52,7 +52,10 @@ RAW_UPLOAD_RAW_REPORT_COUNT, RAW_UPLOAD_SIZE, ) -from services.report.raw_upload_processor import process_raw_upload +from services.report.raw_upload_processor import ( + SessionAdjustmentResult, + process_raw_upload, +) from services.repository import get_repo_provider_service from services.yaml.reader import get_paths_from_flags, read_yaml_field @@ -72,6 +75,7 @@ class ProcessingResult: session: Session report: Report | None = None error: ProcessingError | None = None + session_adjustment: SessionAdjustmentResult | None = None @dataclass @@ -347,17 +351,18 @@ def fetch_repo_flags(self, db_session, repoid: int) -> dict[str, RepositoryFlag] @sentry_sdk.trace def build_report( - self, chunks, files, sessions, totals, report_class=None + self, chunks, files, sessions: dict, totals, report_class=None ) -> Report: if report_class is None: report_class = Report - for sess in sessions.values(): - if isinstance(sess, Session): - if sess.session_type == SessionType.carriedforward: + for session_id, session in sessions.items(): + if isinstance(session, Session): + if session.session_type == SessionType.carriedforward: report_class = EditableReport else: - # sess is an encoded dict - if sess.get("st") == "carriedforward": + # make sure the `Session` objects get an `id` when decoded: + session["id"] = int(session_id) + if session.get("st") == "carriedforward": report_class = EditableReport return report_class.from_chunks( @@ -726,16 +731,17 @@ def build_report_from_raw_content( log.debug("Retrieved report for processing from url %s", archive_url) try: - with metrics.timer(f"{self.metrics_prefix}.process_report") as t: - process_result = process_raw_upload( - self.current_yaml, - report, - raw_report, - flags, - session, - upload, - ) - result.report = process_result.report + process_result = process_raw_upload( + self.current_yaml, + report, + raw_report, + flags, + session, + upload=upload, + ) + result.report = process_result.report + result.session_adjustment = process_result.session_adjustment + log.info( "Successfully processed report", extra=dict( @@ -745,7 +751,6 @@ def build_report_from_raw_content( commit=commit.commitid, reportid=reportid, commit_yaml=self.current_yaml.to_dict(), - timing_ms=t.ms, content_len=raw_report.size, ), ) @@ -793,7 +798,7 @@ def build_report_from_raw_content( return result def update_upload_with_processing_result( - self, upload_obj: Upload, processing_result: ProcessingResult + self, upload: Upload, processing_result: ProcessingResult ): rounding: str = read_yaml_field( self.current_yaml, ("coverage", "round"), "nearest" @@ -801,8 +806,9 @@ def update_upload_with_processing_result( precision: int = read_yaml_field( self.current_yaml, ("coverage", "precision"), 2 ) - db_session = upload_obj.get_db_session() + db_session = upload.get_db_session() session = processing_result.session + if processing_result.error is None: # this should be enabled for the actual rollout of parallel upload processing. # if PARALLEL_UPLOAD_PROCESSING_BY_REPO.check_value( @@ -811,13 +817,13 @@ def update_upload_with_processing_result( # upload_obj.state_id = UploadState.PARALLEL_PROCESSED.db_id # upload_obj.state = "parallel_processed" # else: - upload_obj.state_id = UploadState.PROCESSED.db_id - upload_obj.state = "processed" - upload_obj.order_number = session.id - upload_totals = upload_obj.totals + upload.state_id = UploadState.PROCESSED.db_id + upload.state = "processed" + upload.order_number = session.id + upload_totals = upload.totals if upload_totals is None: upload_totals = UploadLevelTotals( - upload_id=upload_obj.id, + upload_id=upload.id, branches=0, coverage=0, hits=0, @@ -832,12 +838,22 @@ def update_upload_with_processing_result( upload_totals.update_from_totals( session.totals, precision=precision, rounding=rounding ) + + # delete all the carryforwarded `Upload` records corresponding to `Session`s + # which have been removed from the report. + # we always have a `session_adjustment` in the non-error case. + assert processing_result.session_adjustment + deleted_sessions = ( + processing_result.session_adjustment.fully_deleted_sessions + ) + delete_uploads_by_sessionid(upload, deleted_sessions) + else: error = processing_result.error - upload_obj.state = "error" - upload_obj.state_id = UploadState.ERROR.db_id + upload.state = "error" + upload.state_id = UploadState.ERROR.db_id error_obj = UploadError( - upload_id=upload_obj.id, + upload_id=upload.id, error_code=error.code, error_params=error.params, ) @@ -1029,3 +1045,31 @@ def save_parallel_report_to_archive( "chunks_path": chunks_url, "files_and_sessions_path": files_and_sessions_url, } + + +@sentry_sdk.trace +def delete_uploads_by_sessionid(upload: Upload, session_ids: list[int]): + """ + This deletes all the `Upload` records corresponding to the given `session_ids`. + """ + db_session = upload.get_db_session() + uploads = ( + db_session.query(Upload.id_) + .filter(Upload.report == upload.report, Upload.order_number.in_(session_ids)) + .all() + ) + upload_ids = [upload[0] for upload in uploads] + + db_session.query(UploadError).filter(UploadError.upload_id.in_(upload_ids)).delete( + synchronize_session=False + ) + db_session.query(UploadLevelTotals).filter( + UploadLevelTotals.upload_id.in_(upload_ids) + ).delete(synchronize_session=False) + db_session.query(uploadflagmembership).filter( + uploadflagmembership.c.upload_id.in_(upload_ids) + ).delete(synchronize_session=False) + db_session.query(Upload).filter(Upload.id_.in_(upload_ids)).delete( + synchronize_session=False + ) + db_session.flush() diff --git a/services/report/raw_upload_processor.py b/services/report/raw_upload_processor.py index 0185fe810..cc75f7026 100644 --- a/services/report/raw_upload_processor.py +++ b/services/report/raw_upload_processor.py @@ -33,7 +33,7 @@ class SessionAdjustmentResult: @dataclass class UploadProcessingResult: report: Report # NOTE: this is just returning the input argument, and primarily used in tests - session_adjustment: SessionAdjustmentResult # NOTE: this is only ever used in tests + session_adjustment: SessionAdjustmentResult @sentry_sdk.trace diff --git a/services/tests/test_report.py b/services/tests/test_report.py index 4925a15da..79e4ef607 100644 --- a/services/tests/test_report.py +++ b/services/tests/test_report.py @@ -4059,7 +4059,7 @@ def test_update_upload_with_processing_result_success(self, mocker, dbsession): assert len(upload_obj.errors) == 0 processing_result = ProcessingResult( session=Session(), - error=None, + session_adjustment=SessionAdjustmentResult([], []), ) assert ( ReportService({}).update_upload_with_processing_result( diff --git a/tasks/tests/integration/test_upload_e2e.py b/tasks/tests/integration/test_upload_e2e.py index f7121a0d7..386e43483 100644 --- a/tasks/tests/integration/test_upload_e2e.py +++ b/tasks/tests/integration/test_upload_e2e.py @@ -13,6 +13,7 @@ from sqlalchemy.orm import Session as DbSession from database.models.core import Commit, CompareCommit, Repository +from database.models.reports import Upload from database.tests.factories import CommitFactory, RepositoryFactory from database.tests.factories.core import PullFactory from rollouts import PARALLEL_UPLOAD_PROCESSING_BY_REPO @@ -457,8 +458,10 @@ def test_full_carryforward( report = report_service.get_existing_report_for_commit( base_commit, report_code=None ) - assert report + + base_sessions = report.sessions + assert set(report.files) == {"a.rs", "b.rs"} a = report.get("a.rs") @@ -553,6 +556,14 @@ def test_full_carryforward( ) assert carriedforward_sessions == 2 + # the `Upload`s in the database should match the `sessions` in the report: + uploads = ( + dbsession.query(Upload).filter(Upload.report_id == commit.report.id_).all() + ) + assert {upload.order_number for upload in uploads} == { + session.id for session in sessions.values() + } + # and then overwrite data related to "b" as well do_upload( b""" @@ -591,3 +602,18 @@ def test_full_carryforward( ] assert len(report.sessions) == 2 + uploads = ( + dbsession.query(Upload).filter(Upload.report_id == commit.report.id_).all() + ) + assert {upload.order_number for upload in uploads} == { + session.id for session in report.sessions.values() + } + + # just as a sanity check: any cleanup for the followup commit did not touch + # data of the base commit: + uploads = ( + dbsession.query(Upload).filter(Upload.report_id == base_commit.report.id_).all() + ) + assert {upload.order_number for upload in uploads} == { + session.id for session in base_sessions.values() + }