From fd33fa953f56af1d38d670d80845dd399352859c Mon Sep 17 00:00:00 2001 From: Alan Lujan Date: Sun, 8 Sep 2024 10:26:23 -0400 Subject: [PATCH] update Parameters --- HARK/core.py | 806 +++++++++++++++++++++++++++++++--------- HARK/tests/test_core.py | 160 ++++---- 2 files changed, 687 insertions(+), 279 deletions(-) diff --git a/HARK/core.py b/HARK/core.py index adf141556..cda1dbb65 100644 --- a/HARK/core.py +++ b/HARK/core.py @@ -15,7 +15,7 @@ from copy import copy, deepcopy from dataclasses import dataclass, field from time import time -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Tuple, Union from warnings import warn import numpy as np @@ -61,205 +61,318 @@ def set_verbosity_level(level): class Parameters: """ - This class defines an object that stores all of the parameters for a model - as an internal dictionary. It is designed to also handle the age-varying - dynamics of parameters. + A smart container for model parameters that handles age-varying dynamics. - Attributes - ---------- + This class stores parameters as an internal dictionary and manages their + age-varying properties. It provides both attribute-style and dictionary-style + access to parameters. - _length : int - The terminal age of the agents in the model. - _invariant_params : list - A list of the names of the parameters that are invariant over time. - _varying_params : list - A list of the names of the parameters that vary over time. + Attributes: + _length (int): The terminal age of the agents in the model. + _invariant_params (Set[str]): A set of parameter names that are invariant over time. + _varying_params (Set[str]): A set of parameter names that vary over time. + _parameters (Dict[str, Any]): The internal dictionary storing all parameters. """ - def __init__(self, **parameters: Any): - """ - Initializes a Parameters object and parses the age-varying - dynamics of the parameters. + __slots__ = ("_length", "_invariant_params", "_varying_params", "_parameters") - Parameters - ---------- - - parameters : keyword arguments - Any number of keyword arguments of the form key=value. - To parse a dictionary of parameters, use the ** operator. + def __init__(self, **parameters: Any) -> None: """ - params = parameters.copy() - self._length = params.pop("T_cycle", None) - self._invariant_params = set() - self._varying_params = set() - self._parameters: Dict[str, Union[int, float, np.ndarray, list, tuple]] = {} - - for key, value in params.items(): - self._parameters[key] = self.__infer_dims__(key, value) + Initialize a Parameters object and parse the age-varying dynamics of parameters. - def __infer_dims__( - self, key: str, value: Union[int, float, np.ndarray, list, tuple, None] - ) -> Union[int, float, np.ndarray, list, tuple]: + Args: + **parameters (Any): Any number of parameters in the form key=value. """ - Infers the age-varying dimensions of a parameter. + self._length: int = parameters.pop("T_cycle", 1) + self._invariant_params: Set[str] = set() + self._varying_params: Set[str] = set() + self._parameters: Dict[str, Any] = {"T_cycle": self._length} - If the parameter is a scalar, numpy array, boolean, distribution, callable or None, - it is assumed to be invariant over time. If the parameter is a list or - tuple, it is assumed to be varying over time. If the parameter is a list - or tuple of length greater than 1, the length of the list or tuple must match - the `_term_age` attribute of the Parameters object. - - Parameters - ---------- - key : str - name of parameter - value : Any - value of parameter + for key, value in parameters.items(): + self[key] = value + def __getitem__(self, item_or_key: Union[int, str]) -> Union["Parameters", Any]: """ - if isinstance( - value, (int, float, np.ndarray, type(None), Distribution, bool, Callable) - ): - self.__add_to_invariant__(key) - return value - if isinstance(value, (list, tuple)): - if len(value) == 1: - self.__add_to_invariant__(key) - return value[0] - if self._length is None or self._length == 1: - self._length = len(value) - if len(value) == self._length: - self.__add_to_varying__(key) - return value - raise ValueError( - f"Parameter {key} must be of length 1 or {self._length}, not {len(value)}" - ) - raise ValueError(f"Parameter {key} has unsupported type {type(value)}") + Access parameters by age index or parameter name. - def __add_to_invariant__(self, key: str): - """ - Adds parameter name to invariant set and removes from varying set. - """ - self._varying_params.discard(key) - self._invariant_params.add(key) + Args: + item_or_key (Union[int, str]): Age index or parameter name. - def __add_to_varying__(self, key: str): - """ - Adds parameter name to varying set and removes from invariant set. - """ - self._invariant_params.discard(key) - self._varying_params.add(key) + Returns: + Union[Parameters, Any]: A new Parameters object for the specified age, + or the value of the specified parameter. - def __getitem__(self, item_or_key: Union[int, str]): - """ - If item_or_key is an integer, returns a Parameters object with the parameters - that apply to that age. This includes all invariant parameters and the - `item_or_key`th element of all age-varying parameters. If item_or_key is a string, - it returns the value of the parameter with that name. + Raises: + ValueError: If the age index is out of bounds. + KeyError: If the parameter name is not found. + TypeError: If the key is neither an integer nor a string. """ if isinstance(item_or_key, int): if item_or_key >= self._length: raise ValueError( - f"Age {item_or_key} is greater than or equal to terminal age {self._length}." + f"Age {item_or_key} is out of bounds (max: {self._length - 1})." ) params = {key: self._parameters[key] for key in self._invariant_params} params.update( { key: self._parameters[key][item_or_key] + if isinstance(self._parameters[key], (list, tuple, np.ndarray)) + else self._parameters[key] for key in self._varying_params } ) return Parameters(**params) elif isinstance(item_or_key, str): return self._parameters[item_or_key] + else: + raise TypeError("Key must be an integer (age) or string (parameter name).") - def __setitem__(self, key: str, value: Any): + def __setitem__(self, key: str, value: Any) -> None: """ - Sets the value of a parameter. + Set parameter values, automatically inferring time variance. - Parameters - ---------- - key : str - name of parameter - value : Any - value of parameter + Args: + key (str): Name of the parameter. + value (Any): Value of the parameter. + Raises: + ValueError: If the parameter name is not a string or if the value type is unsupported. + ValueError: If the parameter value is inconsistent with the current model length. """ if not isinstance(key, str): - raise ValueError("Parameters must be set with a string key") - self._parameters[key] = self.__infer_dims__(key, value) + raise ValueError(f"Parameter name must be a string, got {type(key)}") + + if isinstance( + value, (int, float, np.ndarray, type(None), Distribution, bool, Callable) + ): + self._invariant_params.add(key) + self._varying_params.discard(key) + elif isinstance(value, (list, tuple)): + if len(value) == 1: + value = value[0] + self._invariant_params.add(key) + self._varying_params.discard(key) + elif self._length is None or self._length == 1: + self._length = len(value) + self._varying_params.add(key) + self._invariant_params.discard(key) + elif len(value) == self._length: + self._varying_params.add(key) + self._invariant_params.discard(key) + else: + raise ValueError( + f"Parameter {key} must have length 1 or {self._length}, not {len(value)}" + ) + else: + raise ValueError(f"Unsupported type for parameter {key}: {type(value)}") - def keys(self): + self._parameters[key] = value + + def __iter__(self) -> Iterator[str]: + """Allow iteration over parameter names.""" + return iter(self._parameters) + + def __len__(self) -> int: + """Return the number of parameters.""" + return len(self._parameters) + + def keys(self) -> Iterator[str]: + """Return a view of parameter names.""" + return self._parameters.keys() + + def values(self) -> Iterator[Any]: + """Return a view of parameter values.""" + return self._parameters.values() + + def items(self) -> Iterator[Tuple[str, Any]]: + """Return a view of parameter (name, value) pairs.""" + return self._parameters.items() + + def to_dict(self) -> Dict[str, Any]: """ - Returns a list of the names of the parameters. + Convert parameters to a plain dictionary. + + Returns: + Dict[str, Any]: A dictionary containing all parameters. """ - return self._invariant_params | self._varying_params + return dict(self._parameters) - def values(self): + def to_namedtuple(self) -> namedtuple: """ - Returns a list of the values of the parameters. + Convert parameters to a namedtuple. + + Returns: + namedtuple: A namedtuple containing all parameters. """ - return list(self._parameters.values()) + return namedtuple("Parameters", self.keys())(**self.to_dict()) - def items(self): + def update(self, other: Union["Parameters", Dict[str, Any]]) -> None: """ - Returns a list of tuples of the form (name, value) for each parameter. + Update parameters from another Parameters object or dictionary. + + Args: + other (Union[Parameters, Dict[str, Any]]): The source of parameters to update from. + + Raises: + TypeError: If the input is neither a Parameters object nor a dictionary. """ - return list(self._parameters.items()) + if isinstance(other, Parameters): + for key, value in other._parameters.items(): + self[key] = value + elif isinstance(other, dict): + for key, value in other.items(): + self[key] = value + else: + raise TypeError( + "Update source must be a Parameters object or a dictionary." + ) - def __iter__(self): + def __repr__(self) -> str: + """Return a detailed string representation of the Parameters object.""" + return ( + f"Parameters(_length={self._length}, " + f"_invariant_params={self._invariant_params}, " + f"_varying_params={self._varying_params}, " + f"_parameters={self._parameters})" + ) + + def __str__(self) -> str: + """Return a simple string representation of the Parameters object.""" + return f"Parameters({str(self._parameters)})" + + def __getattr__(self, name: str) -> Any: """ - Allows for iterating over the parameter names. + Allow attribute-style access to parameters. + + Args: + name (str): Name of the parameter to access. + + Returns: + Any: The value of the specified parameter. + + Raises: + AttributeError: If the parameter name is not found. """ - return iter(self.keys()) + if name.startswith("_"): + return super().__getattribute__(name) + try: + return self._parameters[name] + except KeyError: + raise AttributeError(f"'Parameters' object has no attribute '{name}'") - def __deepcopy__(self, memo): + def __setattr__(self, name: str, value: Any) -> None: """ - Returns a deep copy of the Parameters object. + Allow attribute-style setting of parameters. + + Args: + name (str): Name of the parameter to set. + value (Any): Value to set for the parameter. """ - return Parameters(**deepcopy(self.to_dict(), memo)) + if name.startswith("_"): + super().__setattr__(name, value) + else: + self[name] = value + + def __contains__(self, item: str) -> bool: + """Check if a parameter exists in the Parameters object.""" + return item in self._parameters - def to_dict(self): + def copy(self) -> "Parameters": """ - Returns a dictionary of the parameters. + Create a deep copy of the Parameters object. + + Returns: + Parameters: A new Parameters object with the same contents. """ - return {key: self._parameters[key] for key in self.keys()} + return deepcopy(self) - def to_namedtuple(self): + def add_to_time_vary(self, *params: str) -> None: """ - Returns a namedtuple of the parameters. + Adds any number of parameters to the time-varying set. + + Args: + *params (str): Any number of strings naming parameters to be added to time_vary. """ - return namedtuple("Parameters", self.keys())(**self.to_dict()) + for param in params: + if param in self._parameters: + self._varying_params.add(param) + self._invariant_params.discard(param) + else: + warn( + f"Parameter '{param}' does not exist and cannot be added to time_vary." + ) - def update(self, other_params): + def add_to_time_inv(self, *params: str) -> None: """ - Updates the parameters with the values from another - Parameters object or a dictionary. + Adds any number of parameters to the time-invariant set. - Parameters - ---------- - other_params : Parameters or dict - Parameters object or dictionary of parameters to update with. + Args: + *params (str): Any number of strings naming parameters to be added to time_inv. """ - if isinstance(other_params, Parameters): - self._parameters.update(other_params.to_dict()) - elif isinstance(other_params, dict): - self._parameters.update(other_params) - else: - raise ValueError("Parameters must be a dict or a Parameters object") + for param in params: + if param in self._parameters: + self._invariant_params.add(param) + self._varying_params.discard(param) + else: + warn( + f"Parameter '{param}' does not exist and cannot be added to time_inv." + ) - def __str__(self): + def del_from_time_vary(self, *params: str) -> None: + """ + Removes any number of parameters from the time-varying set. + + Args: + *params (str): Any number of strings naming parameters to be removed from time_vary. + """ + for param in params: + self._varying_params.discard(param) + + def del_from_time_inv(self, *params: str) -> None: + """ + Removes any number of parameters from the time-invariant set. + + Args: + *params (str): Any number of strings naming parameters to be removed from time_inv. + """ + for param in params: + self._invariant_params.discard(param) + + def get(self, key: str, default: Any = None) -> Any: """ - Returns a simple string representation of the Parameters object. + Get a parameter value, returning a default if not found. + + Args: + key (str): The parameter name. + default (Any, optional): The default value to return if the key is not found. + + Returns: + Any: The parameter value or the default. """ - return f"Parameters({str(self.to_dict())})" + return self._parameters.get(key, default) - def __repr__(self): + def set_many(self, **kwargs: Any) -> None: """ - Returns a detailed string representation of the Parameters object. + Set multiple parameters at once. + + Args: + **kwargs: Keyword arguments representing parameter names and values. + """ + for key, value in kwargs.items(): + self[key] = value + + def is_time_varying(self, key: str) -> bool: """ - return f"Parameters( _age_inv = {self._invariant_params}, _age_var = {self._varying_params}, | {self.to_dict()})" + Check if a parameter is time-varying. + + Args: + key (str): The parameter name. + + Returns: + bool: True if the parameter is time-varying, False otherwise. + """ + return key in self._varying_params class Model: @@ -277,15 +390,12 @@ def assign_parameters(self, **kwds): """ Assign an arbitrary number of attributes to this agent. - Parameters - ---------- - **kwds : keyword arguments - Any number of keyword arguments of the form key=value. Each value - will be assigned to the attribute named in self. + Args: + **kwds (keyword arguments): Any number of keyword arguments of the form key=value. + Each value will be assigned to the attribute named in self. - Returns - ------- - none + Returns: + None """ self.parameters.update(kwds) for key in kwds: @@ -295,15 +405,11 @@ def get_parameter(self, name): """ Returns a parameter of this model - Parameters - ---------- - name : string - The name of the parameter to get + Args: + name (str): The name of the parameter to get - Returns - ------- - value : - The value of the parameter + Returns: + value: The value of the parameter """ return self.parameters[name] @@ -335,15 +441,12 @@ def del_param(self, param_name): Deletes a parameter from this instance, removing it both from the object's namespace (if it's there) and the parameters dictionary (likewise). - Parameters - ---------- - param_name : str - A string naming a parameter or data to be deleted from this instance. - Removes information from self.parameters dictionary and own namespace. + Args: + param_name (str): A string naming a parameter or data to be deleted from this instance. + Removes information from self.parameters dictionary and own namespace. - Returns - ------- - None. + Returns: + None """ if param_name in self.parameters: del self.parameters[param_name] @@ -363,21 +466,17 @@ def construct(self, *args, force=False): missing data) will be named in self._missing_key_data. Other errors are recorded in the dictionary attribute _constructor_errors. - Parameters - ---------- - *args : str, optional - Keys of self.constructors that are requested to be constructed. If - no arguments are passed, *all* elements of the dictionary are implied. - force : bool, optional - When True, the method will force its way past any errors, including - missing constructors, missing arguments for constructors, and errors - raised during execution of constructors. Information about all such - errors is stored in the dictionary attributes described above. When - False (default), any errors or exception will be raised. + Args: + *args (str, optional): Keys of self.constructors that are requested to be constructed. + If no arguments are passed, *all* elements of the dictionary are implied. + force (bool, optional): When True, the method will force its way past any errors, including + missing constructors, missing arguments for constructors, and errors + raised during execution of constructors. Information about all such + errors is stored in the dictionary attributes described above. When + False (default), any errors or exception will be raised. - Returns - ------- - None + Returns: + None """ # Set up the requested work if len(args) > 0: @@ -490,15 +589,12 @@ def describe_constructors(self, *args): including their names, the function that constructs them, the names of those functions inputs, and whether those inputs are present. - Parameters - ---------- - *args : str - Optional list of strings naming constructed inputs to be described. - If none are passed, all constructors are described. + Args: + *args (str): Optional list of strings naming constructed inputs to be described. + If none are passed, all constructors are described. - Returns - ------- - None. + Returns: + None """ if len(args) > 0: keys = args @@ -551,6 +647,353 @@ def describe_constructors(self, *args): return +from typing import Any, Dict, Iterator, List, Set, Tuple, Union + + +class Parameters: + """ + A smart container for model parameters that handles age-varying dynamics. + + This class stores parameters as an internal dictionary and manages their + age-varying properties. It provides both attribute-style and dictionary-style + access to parameters. + + Attributes: + _length (int): The terminal age of the agents in the model. + _invariant_params (Set[str]): A set of parameter names that are invariant over time. + _varying_params (Set[str]): A set of parameter names that vary over time. + _parameters (Dict[str, Any]): The internal dictionary storing all parameters. + """ + + __slots__ = ("_length", "_invariant_params", "_varying_params", "_parameters") + + def __init__(self, **parameters: Any) -> None: + """ + Initialize a Parameters object and parse the age-varying dynamics of parameters. + + Args: + **parameters (Any): Keyword arguments representing parameter names and values. + """ + self._length: int = parameters.pop("T_cycle", 1) + self._invariant_params: Set[str] = set() + self._varying_params: Set[str] = set() + self._parameters: Dict[str, Any] = {"T_cycle": self._length} + + for key, value in parameters.items(): + self[key] = value + + def __getitem__(self, item_or_key: Union[int, str]) -> Union["Parameters", Any]: + """ + Access parameters by age index or parameter name. + + Args: + item_or_key (Union[int, str]): Age index or parameter name. + + Returns: + Union[Parameters, Any]: A new Parameters object for the specified age, + or the value of the specified parameter. + + Raises: + ValueError: If the age index is out of bounds. + KeyError: If the parameter name is not found. + TypeError: If the key is neither an integer nor a string. + """ + if isinstance(item_or_key, int): + if item_or_key >= self._length: + raise ValueError( + f"Age {item_or_key} is out of bounds (max: {self._length - 1})." + ) + + params = {key: self._parameters[key] for key in self._invariant_params} + params.update( + { + key: self._parameters[key][item_or_key] + if isinstance(self._parameters[key], (list, tuple, np.ndarray)) + else self._parameters[key] + for key in self._varying_params + } + ) + return Parameters(**params) + elif isinstance(item_or_key, str): + return self._parameters[item_or_key] + else: + raise TypeError("Key must be an integer (age) or string (parameter name).") + + def __setitem__(self, key: str, value: Any) -> None: + """ + Set parameter values, automatically inferring time variance. + + Args: + key (str): Name of the parameter. + value (Any): Value of the parameter. + + Raises: + ValueError: If the parameter name is not a string or if the value type is unsupported. + ValueError: If the parameter value is inconsistent with the current model length. + """ + if not isinstance(key, str): + raise ValueError(f"Parameter name must be a string, got {type(key)}") + + if isinstance( + value, (int, float, np.ndarray, type(None), Distribution, bool, Callable) + ): + self._invariant_params.add(key) + self._varying_params.discard(key) + elif isinstance(value, (list, tuple)): + if len(value) == 1: + value = value[0] + self._invariant_params.add(key) + self._varying_params.discard(key) + elif self._length is None or self._length == 1: + self._length = len(value) + self._varying_params.add(key) + self._invariant_params.discard(key) + elif len(value) == self._length: + self._varying_params.add(key) + self._invariant_params.discard(key) + else: + raise ValueError( + f"Parameter {key} must have length 1 or {self._length}, not {len(value)}" + ) + else: + raise ValueError(f"Unsupported type for parameter {key}: {type(value)}") + + self._parameters[key] = value + + def __getattr__(self, name: str) -> Any: + """ + Allow attribute-style access to parameters. + + Args: + name (str): Name of the parameter to access. + + Returns: + Any: The value of the specified parameter. + + Raises: + AttributeError: If the parameter name is not found. + """ + if name.startswith("_"): + return super().__getattribute__(name) + try: + return self._parameters[name] + except KeyError: + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{name}'" + ) + + def __setattr__(self, name: str, value: Any) -> None: + """ + Allow attribute-style setting of parameters. + + Args: + name (str): Name of the parameter to set. + value (Any): Value to set for the parameter. + """ + if name.startswith("_"): + super().__setattr__(name, value) + else: + self[name] = value + + def __contains__(self, key: str) -> bool: + """ + Check if a parameter exists. + + Args: + key (str): The name of the parameter. + + Returns: + bool: True if the parameter exists, False otherwise. + """ + return key in self._parameters + + def __iter__(self) -> Iterator[str]: + """ + Iterate over parameter names. + + Returns: + Iterator[str]: An iterator over parameter names. + """ + return iter(self._parameters) + + def __len__(self) -> int: + """ + Get the number of parameters. + + Returns: + int: The number of parameters. + """ + return len(self._parameters) + + def __repr__(self) -> str: + """ + Get a string representation of the Parameters object. + + Returns: + str: A string representation of the Parameters object. + """ + return f"Parameters(_length={self._length}, _invariant_params={self._invariant_params}, _varying_params={self._varying_params}, _parameters={self._parameters})" + + def __str__(self) -> str: + """ + Get a string representation of the Parameters object. + + Returns: + str: A string representation of the Parameters object. + """ + return self.__repr__() + + def keys(self) -> Set[str]: + """ + Get the names of all parameters. + + Returns: + Set[str]: The names of all parameters. + """ + return set(self._parameters.keys()) + + def values(self) -> List[Any]: + """ + Get the values of all parameters. + + Returns: + List[Any]: The values of all parameters. + """ + return list(self._parameters.values()) + + def items(self) -> List[Tuple[str, Any]]: + """ + Get the names and values of all parameters. + + Returns: + List[Tuple[str, Any]]: The names and values of all parameters. + """ + return list(self._parameters.items()) + + def to_dict(self) -> Dict[str, Any]: + """ + Convert parameters to a plain dictionary. + + Returns: + Dict[str, Any]: A dictionary containing all parameters. + """ + return dict(self._parameters) + + def to_namedtuple(self) -> namedtuple: + """ + Convert parameters to a namedtuple. + + Returns: + namedtuple: A namedtuple containing all parameters. + """ + return namedtuple("Parameters", self.keys())(**self.to_dict()) + + def update(self, other: Union["Parameters", Dict[str, Any]]) -> None: + """ + Update parameters from another Parameters object or dictionary. + + Args: + other (Union[Parameters, Dict[str, Any]]): The source of parameters to update from. + + Raises: + TypeError: If the input is neither a Parameters object nor a dictionary. + """ + if isinstance(other, Parameters): + for key, value in other._parameters.items(): + self[key] = value + elif isinstance(other, dict): + for key, value in other.items(): + self[key] = value + else: + raise TypeError(f"Expected Parameters or dict, got {type(other)}") + + def copy(self) -> "Parameters": + """ + Create a deep copy of the Parameters object. + + Returns: + Parameters: A new Parameters object with the same contents. + """ + return deepcopy(self) + + def add_to_time_vary(self, *params: str) -> None: + """ + Adds any number of parameters to the time-varying set. + + Args: + *params (str): Any number of strings naming parameters to be added to time_vary. + """ + for param in params: + if param in self._parameters: + self._varying_params.add(param) + + def add_to_time_inv(self, *params: str) -> None: + """ + Adds any number of parameters to the time-invariant set. + + Args: + *params (str): Any number of strings naming parameters to be added to time_inv. + """ + for param in params: + if param in self._parameters: + self._invariant_params.add(param) + + def del_from_time_vary(self, *params: str) -> None: + """ + Removes any number of parameters from the time-varying set. + + Args: + *params (str): Any number of strings naming parameters to be removed from time_vary. + """ + for param in params: + self._varying_params.discard(param) + + def del_from_time_inv(self, *params: str) -> None: + """ + Removes any number of parameters from the time-invariant set. + + Args: + *params (str): Any number of strings naming parameters to be removed from time_inv. + """ + for param in params: + self._invariant_params.discard(param) + + def get(self, key: str, default: Any = None) -> Any: + """ + Get a parameter value, returning a default if not found. + + Args: + key (str): The parameter name. + default (Any, optional): The default value to return if the key is not found. + + Returns: + Any: The parameter value or the default. + """ + return self._parameters.get(key, default) + + def set_many(self, **kwargs: Any) -> None: + """ + Set multiple parameters at once. + + Args: + **kwargs: Keyword arguments representing parameter names and values. + """ + for key, value in kwargs.items(): + self[key] = value + + def is_time_varying(self, key: str) -> bool: + """ + Check if a parameter is time-varying. + + Args: + key (str): The parameter name. + + Returns: + bool: True if the parameter is time-varying, False otherwise. + """ + return key in self._varying_params + + class AgentType(Model): """ A superclass for economic agents in the HARK framework. Each model should @@ -701,8 +1144,7 @@ def unpack(self, parameter): """ Unpacks a parameter from a solution object for easier access. After the model has been solved, the parameters (like consumption function) - reside in the attributes of each element of `ConsumerType.solution` - (e.g. `cFunc`). This method creates a (time varying) attribute of the given + reside in the attributes of each element of `ConsumerType.solution` (e.g. `cFunc`). This method creates a (time varying) attribute of the given parameter name that contains a list of functions accessible by `ConsumerType.parameter`. Parameters @@ -1576,7 +2018,7 @@ class Market(Model): A list of all the AgentTypes in this market. sow_vars : [string] Names of variables generated by the "aggregate market process" that should - be "sown" to the agents in the market. Aggregate state, etc. + "sown" to the agents in the market. Aggregate state, etc. reap_vars : [string] Names of variables to be collected ("reaped") from agents in the market to be used in the "aggregate market process". diff --git a/HARK/tests/test_core.py b/HARK/tests/test_core.py index 102d23deb..f08d05af5 100644 --- a/HARK/tests/test_core.py +++ b/HARK/tests/test_core.py @@ -182,103 +182,69 @@ def test_create_agents(self): self.assertEqual(len(self.agent_pop.agents), 12) -class test_parameters(unittest.TestCase): - def setUp(self): - self.params = Parameters( - T_cycle=3, - a=1, - b=[2, 3, 4], - c=np.array([5, 6, 7]), - d=[lambda x: x, lambda x: x**2, lambda x: x**3], - e=Uniform(), - f=[True, False, True], - ) - - def test_init(self): - self.assertEqual(self.params._length, 3) - self.assertEqual(self.params._invariant_params, {"a", "c", "e"}) - self.assertEqual(self.params._varying_params, {"b", "d", "f"}) - - def test_getitem(self): - self.assertEqual(self.params["a"], 1) - self.assertEqual(self.params[0]["b"], 2) - self.assertEqual(self.params["c"][1], 6) - - def test_setitem(self): - self.params["d"] = 8 - self.assertEqual(self.params["d"], 8) - - def test_update(self): - self.params.update({"a": 9, "b": [10, 11, 12]}) - self.assertEqual(self.params["a"], 9) - self.assertEqual(self.params[0]["b"], 10) - - def test_initialization(self): - params = Parameters(a=1, b=[1, 2], T_cycle=2) - assert params._length == 2 - assert params._invariant_params == {"a"} - assert params._varying_params == {"b"} - - def test_infer_dims_scalar(self): - params = Parameters(a=1) - assert params["a"] == 1 - - def test_infer_dims_array(self): - params = Parameters(b=np.array([1, 2])) - assert all(params["b"] == np.array([1, 2])) - - def test_infer_dims_list_varying(self): - params = Parameters(b=[1, 2], T_cycle=2) - assert params["b"] == [1, 2] - - def test_infer_dims_list_invariant(self): - params = Parameters(b=[1]) - assert params["b"] == 1 - - def test_setitem(self): - params = Parameters(a=1) - params["b"] = 2 - assert params["b"] == 2 - - def test_keys_values_items(self): - params = Parameters(a=1, b=2) - assert set(params.keys()) == {"a", "b"} - assert set(params.values()) == {1, 2} - assert set(params.items()) == {("a", 1), ("b", 2)} - - def test_to_dict(self): - params = Parameters(a=1, b=2) - assert params.to_dict() == {"a": 1, "b": 2} - - def test_to_namedtuple(self): - params = Parameters(a=1, b=2) - named_tuple = params.to_namedtuple() - assert named_tuple.a == 1 - assert named_tuple.b == 2 - - def test_update_params(self): - params1 = Parameters(a=1, b=2) - params2 = Parameters(a=3, c=4) - params1.update(params2) - assert params1["a"] == 3 - assert params1["c"] == 4 - - def test_unsupported_type_error(self): +import pytest +import numpy as np +from HARK.distribution import Uniform +from HARK.core import Parameters + + +@pytest.fixture +def sample_params(): + return Parameters(a=1, b=[2, 3, 4], c=5.0, d=[6.0, 7.0, 8.0], T_cycle=3) + + +class TestParameters: + def test_initialization(self, sample_params): + assert sample_params._length == 3 + assert sample_params._invariant_params == {"a", "c"} + assert sample_params._varying_params == {"b", "d"} + assert sample_params._parameters["T_cycle"] == 3 + + def test_getitem(self, sample_params): + assert sample_params["a"] == 1 + assert sample_params["b"] == [2, 3, 4] + assert sample_params[0]["b"] == 2 + assert sample_params[1]["d"] == 7.0 + + def test_setitem(self, sample_params): + sample_params["e"] = 9 + assert sample_params["e"] == 9 + assert "e" in sample_params._invariant_params + + sample_params["f"] = [10, 11, 12] + assert sample_params["f"] == [10, 11, 12] + assert "f" in sample_params._varying_params + + def test_get(self, sample_params): + assert sample_params.get("a") == 1 + assert sample_params.get("z", 100) == 100 + + def test_set_many(self, sample_params): + sample_params.set_many(g=13, h=[14, 15, 16]) + assert sample_params["g"] == 13 + assert sample_params["h"] == [14, 15, 16] + + def test_is_time_varying(self, sample_params): + assert sample_params.is_time_varying("b") is True + assert sample_params.is_time_varying("a") is False + + def test_to_dict(self, sample_params): + params_dict = sample_params.to_dict() + assert isinstance(params_dict, dict) + assert params_dict["a"] == 1 + assert params_dict["b"] == [2, 3, 4] + + def test_update(self, sample_params): + new_params = Parameters(a=100, e=200) + sample_params.update(new_params) + assert sample_params["a"] == 100 + assert sample_params["e"] == 200 + + @pytest.mark.parametrize("invalid_key", [1, 2.0, None, []]) + def test_setitem_invalid_key(self, sample_params, invalid_key): with pytest.raises(ValueError): - Parameters(b={1, 2}) + sample_params[invalid_key] = 42 - def test_get_item_dimension_error(self): - params = Parameters(b=[1, 2], T_cycle=2) + def test_setitem_invalid_value_length(self, sample_params): with pytest.raises(ValueError): - params[2] - - def test_getitem_with_key(self): - params = Parameters(a=1, b=[2, 3], T_cycle=2) - assert params["a"] == 1 - assert params["b"] == [2, 3] - - def test_getitem_with_item(self): - params = Parameters(a=1, b=[2, 3], T_cycle=2) - age_params = params[1] - assert age_params["a"] == 1 - assert age_params["b"] == 3 + sample_params["invalid"] = [1, 2] # Should be length 1 or 3