Skip to content

Commit

Permalink
Type hints, naming, and 2 non-working draft model_validators
Browse files Browse the repository at this point in the history
  • Loading branch information
zaneselvans committed Nov 23, 2023
1 parent f757848 commit ca173dc
Showing 1 changed file with 117 additions and 78 deletions.
195 changes: 117 additions & 78 deletions src/pudl/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
import itertools
import json
from enum import Enum, unique
from typing import Any, ClassVar
from typing import Any, ClassVar, Self

import fsspec
import pandas as pd
import yaml
from dagster import Field as DagsterField
from pydantic import BaseModel as PydanticBaseModel
from pydantic import (
AnyHttpUrl,
BaseModel,
ConfigDict,
field_validator,
model_validator,
Expand All @@ -20,7 +21,7 @@
import pudl
import pudl.workspace.setup
from pudl.metadata.classes import DataSource
from pudl.workspace.datastore import Datastore
from pudl.workspace.datastore import Datastore, ZenodoDoi

logger = pudl.logging_helpers.get_logger(__name__)

Expand All @@ -36,13 +37,13 @@ class XbrlFormNumber(Enum):
FORM714 = 714


class BaseModel(PydanticBaseModel):
class FrozenBaseModel(BaseModel):
"""BaseModel with global configuration."""

model_config = ConfigDict(frozen=True, extra="forbid")


class GenericDatasetSettings(BaseModel):
class GenericDatasetSettings(FrozenBaseModel):
"""An abstract pydantic model for generic datasets.
Each dataset must specify working partitions. A dataset can have an arbitrary number
Expand All @@ -53,11 +54,63 @@ class GenericDatasetSettings(BaseModel):
"""

disabled: bool = False
data_source: ClassVar[DataSource]

# TODO[pydantic]: This validator fails because it doesn't reproduce the current
# behavior we have on unspecified datasets.
# @model_validator(mode="before")
@classmethod
def validate_partitions_before(cls, data: dict[str, Any]) -> dict[str, Any]:
"""Refactor with model validator after."""
for (
partition_name,
working_partitions,
) in cls.data_source.working_partitions.items():
try:
partitions = data[partition_name]
except KeyError:
raise ValueError(
f"{cls.__name__} is missing required '{partition_name}' field."
)

# If partition is empty or None, default to using all working_partitions
if not partitions:
data[partition_name] = working_partitions
else:
data[partition_name] = sorted(set(partitions))

nonworking_partitions = list(set(partitions) - set(working_partitions))
if nonworking_partitions:
raise ValueError(
f"'{nonworking_partitions}' {partition_name} are not available."
)

return data

# TODO[pydantic]: This validator fails because the model is immutable after it's
# been defined.
# @model_validator(mode="after")
def validate_partitions_after(self: Self):
"""Refactor with model validator after."""
for name, working_partitions in self.data_source.working_partitions.items():
try:
partition = getattr(self, name)
except KeyError:
raise ValueError(f"{self.__name__} is missing required '{name}' field.")

# If partition is None, default to working_partitions
if getattr(self, name) is None:
setattr(self, name, working_partitions)

nonworking_partitions = list(set(partition) - set(working_partitions))
if nonworking_partitions:
raise ValueError(f"'{nonworking_partitions}' {name} are not available.")
setattr(self, name, sorted(set(partition)))
return self

# TODO[pydantic]: Refactor to use model_validator
@root_validator(skip_on_failure=True)
@classmethod
def validate_partitions(cls, partitions: Any): # noqa: N805
def validate_partitions(cls, partitions):
"""Validate the requested data partitions.
Check that all the partitions defined in the ``working_partitions`` of the
Expand Down Expand Up @@ -235,7 +288,7 @@ def check_eia860m_date(cls, eia860m: bool) -> bool: # noqa: N805
return eia860m


class GlueSettings(BaseModel):
class GlueSettings(FrozenBaseModel):
"""An immutable pydantic model to validate Glue settings.
Args:
Expand All @@ -247,7 +300,7 @@ class GlueSettings(BaseModel):
ferc1: bool = True


class EiaSettings(BaseModel):
class EiaSettings(FrozenBaseModel):
"""An immutable pydantic model to validate EIA datasets settings.
Args:
Expand All @@ -262,25 +315,18 @@ class EiaSettings(BaseModel):

@model_validator(mode="before")
@classmethod
def default_load_all(cls, values): # noqa: N805
"""If no datasets are specified default to all.
def default_load_all(cls, data: dict[str, Any]) -> dict[str, Any]:
"""If no datasets are specified default to all."""
if not any(data.values()):
data["eia860"] = Eia860Settings()
data["eia861"] = Eia861Settings()
data["eia923"] = Eia923Settings()

Args:
values (Dict[str, BaseModel]): dataset settings.
Returns:
values (Dict[str, BaseModel]): dataset settings.
"""
if not any(values.values()):
values["eia860"] = Eia860Settings()
values["eia861"] = Eia861Settings()
values["eia923"] = Eia923Settings()

return values
return data

@model_validator(mode="before")
@classmethod
def check_eia_dependencies(cls, values): # noqa: N805
def check_eia_dependencies(cls, data: dict[str, Any]) -> dict[str, Any]:
"""Make sure the dependencies between the eia datasets are satisfied.
Dependencies:
Expand All @@ -292,20 +338,18 @@ def check_eia_dependencies(cls, values): # noqa: N805
Returns:
values (Dict[str, BaseModel]): dataset settings.
"""
if not values.get("eia923") and values.get("eia860"):
values["eia923"] = Eia923Settings(years=values["eia860"].years)
if not data.get("eia923") and data.get("eia860"):
data["eia923"] = Eia923Settings(years=data["eia860"].years)

if values.get("eia923") and not values.get("eia860"):
if data.get("eia923") and not data.get("eia860"):
available_years = Eia860Settings().years
values["eia860"] = Eia860Settings(
years=[
year for year in values["eia923"].years if year in available_years
]
data["eia860"] = Eia860Settings(
years=[year for year in data["eia923"].years if year in available_years]
)
return values
return data


class DatasetsSettings(BaseModel):
class DatasetsSettings(FrozenBaseModel):
"""An immutable pydantic model to validate PUDL Dataset settings.
Args:
Expand All @@ -315,35 +359,35 @@ class DatasetsSettings(BaseModel):
epacems: Immutable pydantic model to validate epacems settings.
"""

eia: EiaSettings = None
epacems: EpaCemsSettings = None
ferc1: Ferc1Settings = None
ferc714: Ferc714Settings = None
glue: GlueSettings = None
eia: EiaSettings | None = None
epacems: EpaCemsSettings | None = None
ferc1: Ferc1Settings | None = None
ferc714: Ferc714Settings | None = None
glue: GlueSettings | None = None

@model_validator(mode="before")
@classmethod
def default_load_all(cls, values): # noqa: N805
def default_load_all(cls, data: dict[str, Any]) -> dict[str, Any]:
"""If no datasets are specified default to all.
Args:
values (Dict[str, BaseModel]): dataset settings.
data: dataset settings inputs.
Returns:
values (Dict[str, BaseModel]): dataset settings.
Validated dataset settings inputs.
"""
if not any(values.values()):
values["eia"] = EiaSettings()
values["epacems"] = EpaCemsSettings()
values["ferc1"] = Ferc1Settings()
values["ferc714"] = Ferc714Settings()
values["glue"] = GlueSettings()
if not any(data.values()):
data["eia"] = EiaSettings()
data["epacems"] = EpaCemsSettings()
data["ferc1"] = Ferc1Settings()
data["ferc714"] = Ferc714Settings()
data["glue"] = GlueSettings()

return values
return data

@model_validator(mode="before")
@classmethod
def add_glue_settings(cls, data: Any): # noqa: N805
def add_glue_settings(cls, data: dict[str, Any]) -> dict[str, Any]:
"""Add glue settings if ferc1 and eia data are both requested.
Args:
Expand All @@ -358,11 +402,11 @@ def add_glue_settings(cls, data: Any): # noqa: N805
data["glue"] = GlueSettings(ferc1=ferc1, eia=eia)
return data

def get_datasets(self): # noqa: N805
def get_datasets(self: Self):
"""Gets dictionary of dataset settings."""
return vars(self)

def make_datasources_table(self, ds: Datastore) -> pd.DataFrame:
def make_datasources_table(self: Self, ds: Datastore) -> pd.DataFrame:
"""Compile a table of dataset information.
There are three places we can look for information about a dataset:
Expand Down Expand Up @@ -409,7 +453,7 @@ def make_datasources_table(self, ds: Datastore) -> pd.DataFrame:
for dataset in datasets
],
"doi": [
_make_doi_clickable(ds.get_datapackage_descriptor(dataset).doi)
str(_zenodo_doi_to_url(ds.get_datapackage_descriptor(dataset).doi))
for dataset in datasets
],
}
Expand All @@ -431,8 +475,10 @@ def make_datasources_table(self, ds: Datastore) -> pd.DataFrame:
)
],
"doi": [
_make_doi_clickable(
ds.get_datapackage_descriptor("eia860m").doi
str(
_zenodo_doi_to_url(
ds.get_datapackage_descriptor("eia860m").doi
)
)
],
}
Expand Down Expand Up @@ -612,27 +658,20 @@ class FercToSqliteSettings(BaseSettings):

@model_validator(mode="before")
@classmethod
def default_load_all(cls, values): # noqa: N805
"""If no datasets are specified default to all.
Args:
values (Dict[str, BaseModel]): dataset settings.
def default_load_all(cls, data: dict[str, Any]) -> dict[str, Any]:
"""If no datasets are specified default to all."""
if not any(data.values()):
data["ferc1_dbf_to_sqlite_settings"] = Ferc1DbfToSqliteSettings()
data["ferc1_xbrl_to_sqlite_settings"] = Ferc1XbrlToSqliteSettings()
data["ferc2_dbf_to_sqlite_settings"] = Ferc2DbfToSqliteSettings()
data["ferc2_xbrl_to_sqlite_settings"] = Ferc2XbrlToSqliteSettings()
data["ferc6_dbf_to_sqlite_settings"] = Ferc6DbfToSqliteSettings()
data["ferc6_xbrl_to_sqlite_settings"] = Ferc6XbrlToSqliteSettings()
data["ferc60_dbf_to_sqlite_settings"] = Ferc60DbfToSqliteSettings()
data["ferc60_xbrl_to_sqlite_settings"] = Ferc60XbrlToSqliteSettings()
data["ferc714_xbrl_to_sqlite_settings"] = Ferc714XbrlToSqliteSettings()

Returns:
values (Dict[str, BaseModel]): dataset settings.
"""
if not any(values.values()):
values["ferc1_dbf_to_sqlite_settings"] = Ferc1DbfToSqliteSettings()
values["ferc1_xbrl_to_sqlite_settings"] = Ferc1XbrlToSqliteSettings()
values["ferc2_dbf_to_sqlite_settings"] = Ferc2DbfToSqliteSettings()
values["ferc2_xbrl_to_sqlite_settings"] = Ferc2XbrlToSqliteSettings()
values["ferc6_dbf_to_sqlite_settings"] = Ferc6DbfToSqliteSettings()
values["ferc6_xbrl_to_sqlite_settings"] = Ferc6XbrlToSqliteSettings()
values["ferc60_dbf_to_sqlite_settings"] = Ferc60DbfToSqliteSettings()
values["ferc60_xbrl_to_sqlite_settings"] = Ferc60XbrlToSqliteSettings()
values["ferc714_xbrl_to_sqlite_settings"] = Ferc714XbrlToSqliteSettings()

return values
return data

def get_xbrl_dataset_settings(
self, form_number: XbrlFormNumber
Expand Down Expand Up @@ -705,7 +744,7 @@ def _convert_settings_to_dagster_config(d: dict) -> None:
d[k] = DagsterField(type(v), default_value=v)


def create_dagster_config(settings: BaseModel) -> dict:
def create_dagster_config(settings: FrozenBaseModel) -> dict:
"""Create a dictionary of dagster config for the DatasetsSettings Class.
Returns:
Expand All @@ -716,6 +755,6 @@ def create_dagster_config(settings: BaseModel) -> dict:
return ds


def _make_doi_clickable(link):
"""Make a clickable DOI."""
return f"https://doi.org/{link}"
def _zenodo_doi_to_url(doi: ZenodoDoi) -> AnyHttpUrl:
"""Create a DOI URL out o a Zenodo DOI."""
return AnyHttpUrl(f"https://doi.org/{doi}")

0 comments on commit ca173dc

Please sign in to comment.