diff --git a/sharkfin/population.py b/sharkfin/population.py index b4f49ac..910e426 100644 --- a/sharkfin/population.py +++ b/sharkfin/population.py @@ -178,16 +178,18 @@ def agent_data(self): agent_data = self.agent_database[self.ex_ante_hetero_params + ["agents"]] data_calls = { - "aLvl": lambda a: a.state_now["aLvl"][0], - "mNrm": lambda a: a.state_now["mNrm"][0], - "cNrm": lambda a: a.controls["cNrm"][0] if "cNrm" in a.controls else None, - "mNrm_ratio_StE": lambda a: a.state_now["mNrm"][0] / a.mNrmStE, + "aLvl": lambda a: a.state_now["aLvl"], + "mNrm": lambda a: a.state_now["mNrm"], + "cNrm": lambda a: a.controls["cNrm"] if "cNrm" in a.controls else np.full(a.AgentCount, np.nan), + "mNrm_ratio_StE": lambda a: a.state_now["mNrm"] / a.mNrmStE, } for dc in data_calls: col = agent_data.loc[:, "agents"].apply(data_calls[dc]) agent_data[dc] = col + agent_data = agent_data.explode(list(data_calls.keys())) + pd.options.mode.chained_assignment = pdomca return agent_data @@ -201,16 +203,23 @@ def class_stats(self, store=False): agent_data = self.agent_data().drop(columns="agents") if self.ex_ante_hetero_params is None or len(self.ex_ante_hetero_params) == 0: - cs = agent_data.copy() - cs["aLvl_mean"] = agent_data["aLvl"] - cs["aLvl_std"] = 0 - cs["mNrm_mean"] = agent_data["mNrm"] - cs["mNrm_std"] = 0 - cs["cNrm_mean"] = agent_data["cNrm"] - cs["cNrm_std"] = 0 - cs["mNrm_ratio_StE_mean"] = agent_data["mNrm_ratio_StE"] - cs["mNrm_ratio_StE_std"] = 0 - cs["label"] = "all" + cs = agent_data.aggregate(["mean", "std"]) + + mean_data = cs.loc['mean'].to_dict() + std_data = cs.loc['std'].to_dict() + + # this collapse the data into one row with appropriate column names + all_data = {k + '_mean' : [mean_data[k]] for k in mean_data} + all_data.update({ + k + '_std' : [std_data[k]] + for k + in std_data + }) + all_data['label'] = ["all"] + + + cs = pd.DataFrame.from_dict(all_data) + else: cs = ( agent_data.groupby(self.ex_ante_hetero_params) @@ -227,10 +236,6 @@ def class_stats(self, store=False): cs.columns = ["_".join(col).strip("_") for col in cs.columns.values] - # print(cs.columns) - - # print(cs.columns) - if store: self.stored_class_stats = cs diff --git a/simulate/parameters.py b/simulate/parameters.py index 2ca3f72..db1419c 100644 --- a/simulate/parameters.py +++ b/simulate/parameters.py @@ -98,7 +98,7 @@ def build_population(agent_type, parameters, rng=None, dphm=1500): } lucas0_parameter_dict = lucas0_agent_population_params -lucas0_parameter_dict["AgentCount"] = 1 # TODO: What should this be? +lucas0_parameter_dict["AgentCount"] = 10 # TODO: What should this be? LUCAS0 = lucas0_parameter_dict