Skip to content

Commit

Permalink
Small updates (#231)
Browse files Browse the repository at this point in the history
* Remove unneeded dependencies

* Update docs

* Add CSV loading function

* Add `relative_change` function

* Fix Enum bugs

* Remove print statement

* Fix enum bugs

* Remove microdf imports
  • Loading branch information
nikhilwoodruff authored Aug 2, 2024
1 parent a8ff6dc commit 33a2eb0
Show file tree
Hide file tree
Showing 11 changed files with 70 additions and 24 deletions.
5 changes: 5 additions & 0 deletions changelog_entry.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
- bump: minor
changes:
added:
- Simulation loading from dataframes.
- Simulation `start_instant` attribute.
File renamed without changes.
2 changes: 1 addition & 1 deletion policyengine_core/charts/bar.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pandas as pd
from .formatting import *
import plotly.express as px
from microdf import MicroSeries
from policyengine_core.weighting import MicroSeries
from typing import Callable
import numpy as np

Expand Down
7 changes: 3 additions & 4 deletions policyengine_core/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,17 +371,16 @@ def from_dataframe(dataframe: pd.DataFrame, time_period: str = None):
Returns:
Dataset: The dataset.
"""
file_path = Path(file_path)
dataset = type(
"Dataset",
(Dataset,),
{
"name": file_path.stem,
"label": file_path.stem,
"name": "dataframe",
"label": "DataFrame",
"data_format": Dataset.FLAT_FILE,
"file_path": "dataframe",
"time_period": time_period,
"load": lambda: dataframe,
"load": lambda self: dataframe,
},
)()

Expand Down
5 changes: 3 additions & 2 deletions policyengine_core/enums/enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,11 @@ def encode(cls, array: Union[EnumArray, np.ndarray]) -> EnumArray:
# Confusingly, Numpy uses "S" to refer to byte-string arrays
# and "U" to refer to Unicode-string arrays, which are also
# referred to as the "str" type
if array.dtype.kind == "S":
if isinstance(array[0], Enum):
array = np.array([item.name for item in array])
if array.dtype.kind == "S" or array.dtype == object:
# Convert boolean array to string array
array = array.astype(str)

if isinstance(array, np.ndarray) and array.dtype.kind in {"U", "S"}:
# String array
indices = np.select(
Expand Down
9 changes: 9 additions & 0 deletions policyengine_core/parameters/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,3 +224,12 @@ def _get_at_instant(self, instant):
if value_at_instant.instant_str <= instant:
return value_at_instant.value
return None

def relative_change(self, start_instant, end_instant):
start_instant = str(start_instant)
end_instant = str(end_instant)
end_value = self._get_at_instant(end_instant)
start_value = self._get_at_instant(start_instant)
if end_value is None or start_value is None:
return None
return end_value / start_value - 1
16 changes: 16 additions & 0 deletions policyengine_core/populations/group_population.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from policyengine_core.entities import Entity, Role
from policyengine_core.enums import EnumArray
from policyengine_core.populations.population import Population
from policyengine_core.periods.period_ import Period
from typing import Optional, Container

if TYPE_CHECKING:
from policyengine_core.simulations import Simulation
Expand All @@ -21,6 +23,20 @@ def __init__(self, entity: Entity, members: Population):
self._members_position: ArrayLike = None
self._ordered_members_map = None

def __call__(
self,
variable_name: str,
period: Period = None,
options: Optional[Container[str]] = None,
):
variable = self.simulation.tax_benefit_system.variables.get(
variable_name
)
if variable.entity.is_person:
return self.sum(self.members(variable_name, period, options))
else:
return super().__call__(variable_name, period, options)

def clone(
self, simulation: "Simulation", members: Population
) -> "GroupPopulation":
Expand Down
3 changes: 3 additions & 0 deletions policyengine_core/populations/population.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ def clone(self, simulation: "Simulation") -> "Population":
result.ids = self.ids
return result

def has_any_input(self, variable_name: str) -> bool:
return len(self.get_holder(variable_name).get_known_periods()) > 0

def empty_array(self) -> numpy.ndarray:
return numpy.zeros(self.count)

Expand Down
2 changes: 1 addition & 1 deletion policyengine_core/simulations/microsimulation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Dict, Type

from microdf import MicroDataFrame, MicroSeries
from policyengine_core.weighting import MicroDataFrame, MicroSeries
import numpy as np
from policyengine_core.data.dataset import Dataset
from policyengine_core.periods import Period
Expand Down
43 changes: 29 additions & 14 deletions policyengine_core/simulations/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ class Simulation:
macro_cache_write: bool = True
"""Whether to write to the macro cache."""

start_instant: str = None
"""The earliest data input instant of the simulation."""

def __init__(
self,
tax_benefit_system: "TaxBenefitSystem" = None,
Expand Down Expand Up @@ -155,6 +158,10 @@ def __init__(
)
if isinstance(dataset, type):
self.dataset: Dataset = dataset(require=True)
elif isinstance(dataset, pd.DataFrame):
self.dataset = Dataset.from_dataframe(
dataset, self.default_input_period
)
else:
self.dataset = dataset
self.build_from_dataset()
Expand Down Expand Up @@ -242,6 +249,9 @@ def build_from_dataset(self) -> None:
+ "Make sure you have downloaded or built it using the `policyengine-core data` command."
) from e

if self.dataset.data_format == Dataset.FLAT_FILE:
data = {col: data[col].values for col in data.columns}

person_entity = self.tax_benefit_system.person_entity
entity_id_field = f"{person_entity.key}_id"
if self.dataset.data_format != Dataset.FLAT_FILE:
Expand All @@ -250,14 +260,11 @@ def build_from_dataset(self) -> None:
), f"Missing {entity_id_field} column in the dataset. Each person entity must have an ID array defined for ETERNITY."
elif entity_id_field not in data:
data[entity_id_field] = np.arange(len(data))
if self.dataset.data_format != Dataset.FLAT_FILE:
get_eternity_array = lambda ds: (
ds[list(ds.keys())[0]]
if self.dataset.data_format == Dataset.TIME_PERIOD_ARRAYS
else ds
)
else:
get_eternity_array = lambda ds: ds
get_eternity_array = lambda ds: (
ds[list(ds.keys())[0]]
if self.dataset.data_format == Dataset.TIME_PERIOD_ARRAYS
else ds
)
entity_ids = get_eternity_array(data[entity_id_field])
builder.declare_person_entity(person_entity.key, entity_ids)

Expand All @@ -268,7 +275,12 @@ def build_from_dataset(self) -> None:
entity_id_field in data
), f"Missing {entity_id_field} column in the dataset. Each group entity must have an ID array defined for ETERNITY."
elif entity_id_field not in data:
data[entity_id_field] = np.arange(len(data))
if f"person_{group_entity.key}_id" in data:
data[entity_id_field] = np.arange(
len(np.unique(data[f"person_{group_entity.key}_id"]))
)
else:
data[entity_id_field] = np.arange(len(data))

entity_ids = get_eternity_array(data[entity_id_field])
builder.declare_entity(group_entity.key, entity_ids)
Expand Down Expand Up @@ -333,9 +345,6 @@ def build_from_dataset(self) -> None:
)

if variable_name not in self.tax_benefit_system.variables:
logging.warn(
f"Variable {variable_name} not found. Skipping."
)
continue

variable_meta = self.tax_benefit_system.get_variable(
Expand All @@ -355,7 +364,9 @@ def build_from_dataset(self) -> None:

self.set_input(variable, time_period, entity_level_data)

self.default_calculation_period = self.dataset.time_period
self.default_calculation_period = (
self.dataset.time_period or self.default_calculation_period
)

self.tax_benefit_system.data_modified = False

Expand Down Expand Up @@ -684,6 +695,8 @@ def _calculate(
):
# Variables with a calculate-output property specify
last_known_period = sorted(known_periods)[-1]
if last_known_period.start > period.start:
return holder.default_array()
array = holder.get_array(last_known_period)
else:
array = holder.default_array()
Expand Down Expand Up @@ -1139,10 +1152,12 @@ def set_input(

If a ``set_input`` property has been set for the variable, this method may accept inputs for periods not matching the ``definition_period`` of the variable. To read more about this, check the `documentation <https://openfisca.org/doc/coding-the-legislation/35_periods.html#automatically-process-variable-inputs-defined-for-periods-not-matching-the-definitionperiod>`_.
"""
period = periods.period(period)
if self.start_instant is None or self.start_instant > period.start:
self.start_instant = period.start
variable = self.tax_benefit_system.get_variable(
variable_name, check_existence=True
)
period = periods.period(period)
if (variable.end is not None) and (period.start.date > variable.end):
return
self.get_holder(variable_name).set_input(
Expand Down
2 changes: 0 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
"psutil<6",
"wheel<1",
"h5py>=3,<4",
"microdf_python>=0.3.0,<1",
"tqdm>=4.46.0,<5",
"requests>=2.27.1,<3",
"pandas>=1",
"plotly>=5.6.0,<6",
Expand Down

0 comments on commit 33a2eb0

Please sign in to comment.