From de17841d4d1cf5b317107493c6a3a45344d24560 Mon Sep 17 00:00:00 2001 From: Mateo VG Date: Thu, 22 Feb 2024 20:17:49 -0500 Subject: [PATCH] Use Parameters class in solver --- HARK/core.py | 104 +++++++++++++++++++++++++++++++++++---------------- 1 file changed, 71 insertions(+), 33 deletions(-) diff --git a/HARK/core.py b/HARK/core.py index 71a79e2a9..ef46ebf84 100644 --- a/HARK/core.py +++ b/HARK/core.py @@ -1218,47 +1218,85 @@ def solve_one_cycle(agent, solution_last): A list of one period solutions for one "cycle" of the AgentType's microeconomic model. """ - # Calculate number of periods per cycle, defaults to 1 if all variables are time invariant - if len(agent.time_vary) > 0: - # name = agent.time_vary[0] - # T = len(eval('agent.' + name)) - T = len(agent.__dict__[agent.time_vary[0]]) - else: - T = 1 - solve_dict = {parameter: agent.__dict__[parameter] for parameter in agent.time_inv} - solve_dict.update({parameter: None for parameter in agent.time_vary}) + # Check if the agent has a 'Parameters' attribute of the 'Parameters' class + # if so, take advantage of it. Else, use the old method + if hasattr(agent, "params") and isinstance(agent.params, Parameters): + T = agent.params._length - # Initialize the solution for this cycle, then iterate on periods - solution_cycle = [] - solution_next = solution_last + # Initialize the solution for this cycle, then iterate on periods + solution_cycle = [] + solution_next = solution_last - cycles_range = [0] + list(range(T - 1, 0, -1)) - for k in range(T - 1, -1, -1) if agent.cycles == 1 else cycles_range: - # Update which single period solver to use (if it depends on time) - if hasattr(agent.solve_one_period, "__getitem__"): - solve_one_period = agent.solve_one_period[k] - else: - solve_one_period = agent.solve_one_period + cycles_range = [0] + list(range(T - 1, 0, -1)) + for k in range(T - 1, -1, -1) if agent.cycles == 1 else cycles_range: + # Update which single period solver to use (if it depends on time) + if hasattr(agent.solve_one_period, "__getitem__"): + solve_one_period = agent.solve_one_period[k] + else: + solve_one_period = agent.solve_one_period + + if hasattr(solve_one_period, "solver_args"): + these_args = solve_one_period.solver_args + else: + these_args = get_arg_names(solve_one_period) + + # Make a temporary dictionary for this period + temp_pars = agent.params[k] + temp_dict = { + name: solution_next if name == "solution_next" else temp_pars[name] + for name in these_args + } - if hasattr(solve_one_period, "solver_args"): - these_args = solve_one_period.solver_args + # Solve one period, add it to the solution, and move to the next period + solution_t = solve_one_period(**temp_dict) + solution_cycle.insert(0, solution_t) + solution_next = solution_t + + else: + # Calculate number of periods per cycle, defaults to 1 if all variables are time invariant + if len(agent.time_vary) > 0: + # name = agent.time_vary[0] + # T = len(eval('agent.' + name)) + T = len(agent.__dict__[agent.time_vary[0]]) else: - these_args = get_arg_names(solve_one_period) + T = 1 + + solve_dict = { + parameter: agent.__dict__[parameter] for parameter in agent.time_inv + } + solve_dict.update({parameter: None for parameter in agent.time_vary}) + + # Initialize the solution for this cycle, then iterate on periods + solution_cycle = [] + solution_next = solution_last + + cycles_range = [0] + list(range(T - 1, 0, -1)) + for k in range(T - 1, -1, -1) if agent.cycles == 1 else cycles_range: + # Update which single period solver to use (if it depends on time) + if hasattr(agent.solve_one_period, "__getitem__"): + solve_one_period = agent.solve_one_period[k] + else: + solve_one_period = agent.solve_one_period + + if hasattr(solve_one_period, "solver_args"): + these_args = solve_one_period.solver_args + else: + these_args = get_arg_names(solve_one_period) - # Update time-varying single period inputs - for name in agent.time_vary: - if name in these_args: - solve_dict[name] = agent.__dict__[name][k] - solve_dict["solution_next"] = solution_next + # Update time-varying single period inputs + for name in agent.time_vary: + if name in these_args: + solve_dict[name] = agent.__dict__[name][k] + solve_dict["solution_next"] = solution_next - # Make a temporary dictionary for this period - temp_dict = {name: solve_dict[name] for name in these_args} + # Make a temporary dictionary for this period + temp_dict = {name: solve_dict[name] for name in these_args} - # Solve one period, add it to the solution, and move to the next period - solution_t = solve_one_period(**temp_dict) - solution_cycle.insert(0, solution_t) - solution_next = solution_t + # Solve one period, add it to the solution, and move to the next period + solution_t = solve_one_period(**temp_dict) + solution_cycle.insert(0, solution_t) + solution_next = solution_t # Return the list of per-period solutions return solution_cycle