From 742a77d13716409c335c9b5985f785a24ae52ee6 Mon Sep 17 00:00:00 2001 From: pjwozny Date: Fri, 16 Aug 2024 14:01:58 +0200 Subject: [PATCH] made clubs part of config --- rice.py | 7 +++++++ scripts/create_submission_zip.py | 7 +++++-- scripts/rice_rllib_discrete.yaml | 3 +++ 3 files changed, 15 insertions(+), 2 deletions(-) diff --git a/rice.py b/rice.py index ff85355..577d9a6 100644 --- a/rice.py +++ b/rice.py @@ -46,6 +46,8 @@ def __init__( temperature_calibration="base", prescribed_emissions=None, pct_reward=False, + clubs_enabled = False, + club_members = [] ): self.action_space_type = action_space_type self.num_discrete_action_levels = num_discrete_action_levels @@ -67,6 +69,11 @@ def __init__( self.pct_reward = pct_reward self.global_state = {} + #clubs + self.clubs_enabled = clubs_enabled + if self.clubs_enabled: + self.club_members = club_members + self.set_dtypes() self.set_all_region_params() diff --git a/scripts/create_submission_zip.py b/scripts/create_submission_zip.py index 2ca05e1..4143c56 100644 --- a/scripts/create_submission_zip.py +++ b/scripts/create_submission_zip.py @@ -79,8 +79,11 @@ def prepare_submission(results_dir=None): if file.endswith(".state_dict") ] sorted_policy_models = sorted(policy_models, key=os.path.getmtime) - # Delete all but the last policy model file - for policy_model in sorted_policy_models[:-1]: + + #in the case of multi-model, there will be multiple state dictionaries per model. + policy_prefixes = set([model_name.split("/")[-1].split("_")[0]for model_name in sorted_policy_models]) + # Delete all but the last policy model file of each unique prefix + for policy_model in sorted_policy_models[:-len(policy_prefixes)]: os.remove(os.path.join(results_dir_copy, policy_model.split("/")[-1])) shutil.make_archive(submission_file, "zip", results_dir_copy) diff --git a/scripts/rice_rllib_discrete.yaml b/scripts/rice_rllib_discrete.yaml index ca5fb85..46a61ae 100644 --- a/scripts/rice_rllib_discrete.yaml +++ b/scripts/rice_rllib_discrete.yaml @@ -37,6 +37,8 @@ env: carbon_model: "base" temperature_calibration: "base" pct_reward: False + clubs_enabled: True + club_members: [1] regions: num_agents: 3 #can be either {3,7,20,27} @@ -52,6 +54,7 @@ logging: # Policy network settings policy: + multi_model: True #only active if club_enabled also set to True regions: vf_loss_coeff: 0.1 # loss coefficient schedule for the value function loss entropy_coeff_schedule: # loss coefficient schedule for the entropy loss