Skip to content

Commit

Permalink
regenerate overall_statistics if outdated (#494)
Browse files Browse the repository at this point in the history
  • Loading branch information
lyuyangh authored Nov 8, 2022
1 parent 6454f3f commit 6fe0758
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 12 deletions.
10 changes: 9 additions & 1 deletion backend/src/impl/default_controllers_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from explainaboard_web.impl.auth import get_user
from explainaboard_web.impl.benchmark_utils import BenchmarkUtils
from explainaboard_web.impl.db_utils.dataset_db_utils import DatasetDBUtils
from explainaboard_web.impl.db_utils.db_utils import DBUtils
from explainaboard_web.impl.db_utils.system_db_utils import SystemDBUtils
from explainaboard_web.impl.language_code import get_language_codes
from explainaboard_web.impl.private_dataset import is_private_dataset
Expand Down Expand Up @@ -54,6 +55,7 @@
from explainaboard_web.models.user import User as modelUser
from flask import current_app
from pymongo import ASCENDING, DESCENDING
from pymongo.client_session import ClientSession


def _is_creator(system: System, user: authUser) -> bool:
Expand Down Expand Up @@ -403,10 +405,16 @@ def systems_analyses_post(body: SystemsAnalysesBody):
if len(systems) == 0:
return SystemAnalysesReturn(system_analyses)

def update_overall_statistics(session: ClientSession) -> None:
for sys in systems:
# refresh overall_statistics if it is outdated
sys.update_overall_statistics(session=session)

DBUtils.execute_transaction(update_overall_statistics)

# performance significance test if there are two systems
sig_info = []
if len(systems) == 2:

system1_info: SystemInfo = systems[0].get_system_info()
system1_info_dict = general_to_dict(system1_info)
system1_output_info = SysOutputInfo.from_dict(system1_info_dict)
Expand Down
36 changes: 25 additions & 11 deletions backend/src/impl/internal_models/system_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import json
import re
from datetime import datetime
from typing import Any
from importlib.metadata import version
from typing import Any, Final

from explainaboard import get_processor
from explainaboard.loaders.file_loader import FileLoaderReturn
Expand All @@ -26,7 +27,8 @@

class SystemModel(System):

_SYSTEM_OUTPUT_CONST = "__SYSOUT__"
_SYSTEM_OUTPUT_CONST: Final = "__SYSOUT__"
_CURRENT_SDK_VERSION: Final = version("explainaboard")
"""Same as System but implements several helper functions that retrieves
additional information and persists data to the DB.
"""
Expand Down Expand Up @@ -150,12 +152,17 @@ def save_system_output(
)

def update_overall_statistics(
self, session: ClientSession | None = None, force_update=False
self, session: ClientSession, force_update=False
) -> None:
"""regenerates overall statistics and updates cache"""
"""If the analysis is outdated or if `force_update`, the analysis is
regenerated and the cache is updated."""
properties = self._get_private_properties(session=session)
if not force_update:
if "system_info" in properties and "metric_stats" in properties:
if (
"system_info" in properties
and "metric_stats" in properties
and properties.get("sdk_version_used") == self._CURRENT_SDK_VERSION
):
# cache hit
return

Expand Down Expand Up @@ -217,13 +224,14 @@ def generate_system_update_values():
system_update_values = {
"results": self.results,
# cache
"sdk_version_used": self._CURRENT_SDK_VERSION,
"system_info": sys_info.to_dict(),
"metric_stats": binarized_metric_stats,
"analysis_cases": update_analysis_cases(),
}
return system_update_values

def update_analysis_cases():
def update_analysis_cases() -> dict[str, str]:
"""saves analysis cases to storage and returns an updated analysis_cases
dict for the DB"""
analysis_cases_lookup: dict[str, str] = {} # level: data_path
Expand All @@ -240,16 +248,22 @@ def update_analysis_cases():
analysis_cases_lookup[analysis_level.name] = blob_name
return analysis_cases_lookup

if properties.get("analysis_cases"):
# invalidate cache
get_storage().delete(properties["analysis_cases"].values())

update_values = generate_system_update_values()
DBUtils.update_one_by_id(
DBUtils.DEV_SYSTEM_METADATA,
self.system_id,
generate_system_update_values(),
update_values,
session=session,
)
if properties.get("analysis_cases"):
# remove stale data. This needs to be the last operation so it is
# protected by the transaction.
blobs_to_delete = [
blob
for blob in properties["analysis_cases"].values()
if blob not in update_values["analysis_cases"].values()
]
get_storage().delete(blobs_to_delete)

def get_raw_system_outputs(
self, output_ids: list[int] | None, session: ClientSession | None = None
Expand Down

0 comments on commit 6fe0758

Please sign in to comment.