Skip to content

Commit

Permalink
Merge pull request #193 from leaf-ai/return_preds
Browse files Browse the repository at this point in the history
#189 Make generate_cases_and_stringency_for_prescriptions return the generated DataFrames
  • Loading branch information
ofrancon authored Jan 23, 2021
2 parents 1fc07b2 + 51d41c9 commit 286640d
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 5 deletions.
15 changes: 11 additions & 4 deletions covid_xprize/scoring/prescriptor_scoring.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import time

import pandas as pd

from covid_xprize.standard_predictor.xprize_predictor import XPrizePredictor
Expand All @@ -15,21 +17,22 @@ def weight_prescriptions_by_cost(pres_df, cost_df):


def generate_cases_and_stringency_for_prescriptions(start_date, end_date, prescription_file, costs_file):
start_time = time.time()
# Load the prescriptions, handling Date and regions
pres_df = XPrizePredictor.load_original_data(prescription_file)

# Generate predictions for all prescriptions
predictor = XPrizePredictor()
pred_dfs = []
pred_dfs = {}
for idx in pres_df['PrescriptionIndex'].unique():
idx_df = pres_df[pres_df['PrescriptionIndex'] == idx]
idx_df = idx_df.drop(columns='PrescriptionIndex') # Predictor doesn't need this
# Generate the predictions
pred_df = predictor.predict_from_df(start_date, end_date, idx_df)
print(f"Generated predictions for PrescriptionIndex {idx}")
pred_df['PrescriptionIndex'] = idx
pred_dfs.append(pred_df)
pred_df = pd.concat(pred_dfs)
pred_dfs[idx] = pred_df
pred_df = pd.concat(list(pred_dfs.values()))

# Aggregate cases by prescription index and geo
agg_pred_df = pred_df.groupby(['CountryName',
Expand Down Expand Up @@ -65,8 +68,12 @@ def generate_cases_and_stringency_for_prescriptions(start_date, end_date, prescr
'PrescriptionIndex',
'PredictedDailyNewCases',
'Stringency']]
end_time = time.time()
elapsed_time = end_time - start_time
elapsed_time_tring = time.strftime("%H:%M:%S", time.gmtime(elapsed_time))
print(f"Evaluated {len(pred_dfs)} PrescriptionIndex in {elapsed_time_tring} seconds")

return df
return df, pred_dfs


# Compute domination relationship for each pair of prescriptors for each geo
Expand Down
2 changes: 1 addition & 1 deletion prescriptor_robojudge.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@
"dfs = []\n",
"for prescriptor_name, prescription_file in sorted(prescription_files.items()):\n",
" print(\"Generating predictions for\", prescriptor_name)\n",
" df = generate_cases_and_stringency_for_prescriptions(START_DATE, END_DATE, prescription_file, TEST_COST)\n",
" df, _ = generate_cases_and_stringency_for_prescriptions(START_DATE, END_DATE, prescription_file, TEST_COST)\n",
" df['PrescriptorName'] = prescriptor_name\n",
" dfs.append(df)\n",
"df = pd.concat(dfs)"
Expand Down

0 comments on commit 286640d

Please sign in to comment.