Skip to content

Commit

Permalink
Merge pull request #201 from alanlujan91/i195
Browse files Browse the repository at this point in the history
[WIP] create explode_agents method
  • Loading branch information
sbenthall authored Mar 21, 2023
2 parents b2bc4d0 + 7c44892 commit c15f1e3
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 31 deletions.
35 changes: 24 additions & 11 deletions sharkfin/population.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from copy import deepcopy
from dataclasses import dataclass
from functools import partial
from pprint import pprint
Expand Down Expand Up @@ -180,7 +181,9 @@ def agent_data(self):
data_calls = {
"aLvl": lambda a: a.state_now["aLvl"],
"mNrm": lambda a: a.state_now["mNrm"],
"cNrm": lambda a: a.controls["cNrm"] if "cNrm" in a.controls else np.full(a.AgentCount, np.nan),
"cNrm": lambda a: a.controls["cNrm"]
if "cNrm" in a.controls
else np.full(a.AgentCount, np.nan),
"mNrm_ratio_StE": lambda a: a.state_now["mNrm"] / a.mNrmStE,
}

Expand All @@ -205,18 +208,13 @@ def class_stats(self, store=False):
if self.ex_ante_hetero_params is None or len(self.ex_ante_hetero_params) == 0:
cs = agent_data.aggregate(["mean", "std"])

mean_data = cs.loc['mean'].to_dict()
std_data = cs.loc['std'].to_dict()
mean_data = cs.loc["mean"].to_dict()
std_data = cs.loc["std"].to_dict()

# this collapse the data into one row with appropriate column names
all_data = {k + '_mean' : [mean_data[k]] for k in mean_data}
all_data.update({
k + '_std' : [std_data[k]]
for k
in std_data
})
all_data['label'] = ["all"]

all_data = {k + "_mean": [mean_data[k]] for k in mean_data}
all_data.update({k + "_std": [std_data[k]] for k in std_data})
all_data["label"] = ["all"]

cs = pd.DataFrame.from_dict(all_data)

Expand Down Expand Up @@ -261,6 +259,21 @@ def solve_distributed_agents(self):
for agent in self.agents:
agent.solve()

def explode_agents(self, num):

exploded_agents = []
exploded_dicts = []

for i, agent in enumerate(self.agents):
for j in range(num):
exploded_agents.append(deepcopy(agent))
exploded_dicts.append(deepcopy(self.agent_dicts[i]))

self.agents = exploded_agents
self.agent_dicts = exploded_dicts

self.create_database()

def unpack_solutions(self):
self.solution = [agent.solution for agent in self.agents]

Expand Down
94 changes: 75 additions & 19 deletions sharkfin/tests/test_simulation.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
import numpy as np
from HARK.ConsumptionSaving.ConsPortfolioModel import SequentialPortfolioConsumerType
from pytest import approx

from sharkfin.broker import *
from sharkfin.expectations import *
from sharkfin.population import *
from sharkfin.simulation import *
from simulate.parameters import (
WHITESHARK,
build_population
)
import numpy as np

from pytest import approx
from simulate.parameters import LUCAS0, WHITESHARK, build_population

## MARKET SIMULATIONS


def test_market_simulation():
"""
Sets up and runs a MarketSimulation with no population.
Expand Down Expand Up @@ -51,6 +49,7 @@ def test_calibration_simulation():

assert len(data["prices"]) == 2


def test_series_simulation():
"""
Sets up and runs an agent population simulation
Expand All @@ -63,7 +62,19 @@ def test_series_simulation():
market = None

sim = SeriesSimulation(q=q, r=r, market=market)
sim.simulate(burn_in=2, series=[(10000, 0), (10000, 0), (10000, 0), (10000, 0), (0,10000), (0, 10000), (0, 10000), (0, 10000)])
sim.simulate(
burn_in=2,
series=[
(10000, 0),
(10000, 0),
(10000, 0),
(10000, 0),
(0, 10000),
(0, 10000),
(0, 10000),
(0, 10000),
],
)

assert sim.broker.buy_sell_history[2] == (10000, 0)
# assert(len(sim.history['buy_sell']) == 3) # need the padded day
Expand All @@ -74,6 +85,7 @@ def test_series_simulation():

## MACRO SIMULATIONS


def test_macro_simulation():
"""
Sets up and runs an simulation with an agent population.
Expand All @@ -82,12 +94,12 @@ def test_macro_simulation():
pop = build_population(
SequentialPortfolioConsumerType,
WHITESHARK,
rng = np.random.default_rng(1)
)
rng=np.random.default_rng(1),
)

# arguments to attention simulation

#a = 0.2
# a = 0.2
q = 1
r = 30
market = None
Expand All @@ -97,7 +109,7 @@ def test_macro_simulation():
attsim = MacroSimulation(
pop,
FinanceModel,
#a=a,
# a=a,
q=q,
r=r,
market=market,
Expand All @@ -119,6 +131,7 @@ def test_macro_simulation():

assert len(data["prices"]) == 30


def test_attention_simulation():
"""
Sets up and runs an agent population simulation
Expand All @@ -128,9 +141,8 @@ def test_attention_simulation():
pop = build_population(
SequentialPortfolioConsumerType,
WHITESHARK,
rng = np.random.default_rng(1)
)

rng=np.random.default_rng(1),
)

# arguments to attention simulation

Expand All @@ -149,9 +161,7 @@ def test_attention_simulation():
r=r,
market=market,
days_per_quarter=days_per_quarter,
fm_args = {
'p1' : 0.5
}
fm_args={"p1": 0.5},
)
attsim.simulate(burn_in=20)

Expand All @@ -167,7 +177,7 @@ def test_attention_simulation():
assert attsim.days_per_quarter == days_per_quarter
assert attsim.fm.days_per_quarter == days_per_quarter

assert sim_stats['end_day'] == 30
assert sim_stats["end_day"] == 30

data = attsim.daily_data()

Expand All @@ -184,3 +194,49 @@ def test_attention_simulation():
assert ror_mean_1 == approx(ror_mean_2)


def test_lucas0_simulation():
"""
Sets up and runs an simulation with an agent population.
"""
# initialize population
pop = build_population(
SequentialPortfolioConsumerType,
LUCAS0,
rng=np.random.default_rng(1),
)

assert len(pop.agent_database.index) == LUCAS0["num_per_type"]

# arguments to attention simulation

# a = 0.2
q = 1
r = 30
market = None

days_per_quarter = 30

attsim = MacroSimulation(
pop,
FinanceModel,
# a=a,
q=q,
r=r,
market=market,
days_per_quarter=days_per_quarter,
)
attsim.simulate(burn_in=20)

## testing for existence of this class stat
attsim.pop.class_stats()["mNrm_ratio_StE_mean"]

attsim.daily_data()["sell_macro"]

attsim.sim_stats()

assert attsim.days_per_quarter == days_per_quarter
assert attsim.fm.days_per_quarter == days_per_quarter

data = attsim.daily_data()

assert len(data["prices"]) == 30
7 changes: 6 additions & 1 deletion simulate/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@


def build_population(agent_type, parameters, rng=None, dphm=1500):

num_per_type = parameters.get("num_per_type", 1)

pop = AgentPopulation(
agent_type(), parameters, rng=rng, dollars_per_hark_money_unit=dphm
)
Expand All @@ -21,6 +24,8 @@ def build_population(agent_type, parameters, rng=None, dphm=1500):

pop.solve(merge_by=parameters["ex_post"])

pop.explode_agents(num_per_type)

# initialize population model
pop.init_simulation()

Expand Down Expand Up @@ -99,6 +104,6 @@ def build_population(agent_type, parameters, rng=None, dphm=1500):

lucas0_parameter_dict = lucas0_agent_population_params
lucas0_parameter_dict["AgentCount"] = 10 # TODO: What should this be?

lucas0_parameter_dict["num_per_type"] = 10

LUCAS0 = lucas0_parameter_dict

0 comments on commit c15f1e3

Please sign in to comment.