Skip to content

Commit 6fe0758

Browse files
authored
regenerate overall_statistics if outdated (#494)
1 parent 6454f3f commit 6fe0758

File tree

2 files changed

+34
-12
lines changed

2 files changed

+34
-12
lines changed

backend/src/impl/default_controllers_impl.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from explainaboard_web.impl.auth import get_user
2525
from explainaboard_web.impl.benchmark_utils import BenchmarkUtils
2626
from explainaboard_web.impl.db_utils.dataset_db_utils import DatasetDBUtils
27+
from explainaboard_web.impl.db_utils.db_utils import DBUtils
2728
from explainaboard_web.impl.db_utils.system_db_utils import SystemDBUtils
2829
from explainaboard_web.impl.language_code import get_language_codes
2930
from explainaboard_web.impl.private_dataset import is_private_dataset
@@ -54,6 +55,7 @@
5455
from explainaboard_web.models.user import User as modelUser
5556
from flask import current_app
5657
from pymongo import ASCENDING, DESCENDING
58+
from pymongo.client_session import ClientSession
5759

5860

5961
def _is_creator(system: System, user: authUser) -> bool:
@@ -403,10 +405,16 @@ def systems_analyses_post(body: SystemsAnalysesBody):
403405
if len(systems) == 0:
404406
return SystemAnalysesReturn(system_analyses)
405407

408+
def update_overall_statistics(session: ClientSession) -> None:
409+
for sys in systems:
410+
# refresh overall_statistics if it is outdated
411+
sys.update_overall_statistics(session=session)
412+
413+
DBUtils.execute_transaction(update_overall_statistics)
414+
406415
# performance significance test if there are two systems
407416
sig_info = []
408417
if len(systems) == 2:
409-
410418
system1_info: SystemInfo = systems[0].get_system_info()
411419
system1_info_dict = general_to_dict(system1_info)
412420
system1_output_info = SysOutputInfo.from_dict(system1_info_dict)

backend/src/impl/internal_models/system_model.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
import json
66
import re
77
from datetime import datetime
8-
from typing import Any
8+
from importlib.metadata import version
9+
from typing import Any, Final
910

1011
from explainaboard import get_processor
1112
from explainaboard.loaders.file_loader import FileLoaderReturn
@@ -26,7 +27,8 @@
2627

2728
class SystemModel(System):
2829

29-
_SYSTEM_OUTPUT_CONST = "__SYSOUT__"
30+
_SYSTEM_OUTPUT_CONST: Final = "__SYSOUT__"
31+
_CURRENT_SDK_VERSION: Final = version("explainaboard")
3032
"""Same as System but implements several helper functions that retrieves
3133
additional information and persists data to the DB.
3234
"""
@@ -150,12 +152,17 @@ def save_system_output(
150152
)
151153

152154
def update_overall_statistics(
153-
self, session: ClientSession | None = None, force_update=False
155+
self, session: ClientSession, force_update=False
154156
) -> None:
155-
"""regenerates overall statistics and updates cache"""
157+
"""If the analysis is outdated or if `force_update`, the analysis is
158+
regenerated and the cache is updated."""
156159
properties = self._get_private_properties(session=session)
157160
if not force_update:
158-
if "system_info" in properties and "metric_stats" in properties:
161+
if (
162+
"system_info" in properties
163+
and "metric_stats" in properties
164+
and properties.get("sdk_version_used") == self._CURRENT_SDK_VERSION
165+
):
159166
# cache hit
160167
return
161168

@@ -217,13 +224,14 @@ def generate_system_update_values():
217224
system_update_values = {
218225
"results": self.results,
219226
# cache
227+
"sdk_version_used": self._CURRENT_SDK_VERSION,
220228
"system_info": sys_info.to_dict(),
221229
"metric_stats": binarized_metric_stats,
222230
"analysis_cases": update_analysis_cases(),
223231
}
224232
return system_update_values
225233

226-
def update_analysis_cases():
234+
def update_analysis_cases() -> dict[str, str]:
227235
"""saves analysis cases to storage and returns an updated analysis_cases
228236
dict for the DB"""
229237
analysis_cases_lookup: dict[str, str] = {} # level: data_path
@@ -240,16 +248,22 @@ def update_analysis_cases():
240248
analysis_cases_lookup[analysis_level.name] = blob_name
241249
return analysis_cases_lookup
242250

243-
if properties.get("analysis_cases"):
244-
# invalidate cache
245-
get_storage().delete(properties["analysis_cases"].values())
246-
251+
update_values = generate_system_update_values()
247252
DBUtils.update_one_by_id(
248253
DBUtils.DEV_SYSTEM_METADATA,
249254
self.system_id,
250-
generate_system_update_values(),
255+
update_values,
251256
session=session,
252257
)
258+
if properties.get("analysis_cases"):
259+
# remove stale data. This needs to be the last operation so it is
260+
# protected by the transaction.
261+
blobs_to_delete = [
262+
blob
263+
for blob in properties["analysis_cases"].values()
264+
if blob not in update_values["analysis_cases"].values()
265+
]
266+
get_storage().delete(blobs_to_delete)
253267

254268
def get_raw_system_outputs(
255269
self, output_ids: list[int] | None, session: ClientSession | None = None

0 commit comments

Comments
 (0)