From cf8d4f1fd9cbe48ae51855c3af09b984a7fac745 Mon Sep 17 00:00:00 2001 From: Zane Selvans Date: Wed, 15 Nov 2023 15:55:46 -0600 Subject: [PATCH] WIP: Initial refactor for pydantic v2 * import missing modules, remove unused demonstration asset module. * Import working; syntax issues fixed; real unit tests broken. --- .gitignore | 2 + pyproject.toml | 10 +- src/pudl/__init__.py | 1 + src/pudl/etl/__init__.py | 1 + src/pudl/etl/analysis_assets.py | 34 --- src/pudl/metadata/classes.py | 224 ++++++++---------- src/pudl/metadata/resources/eia860.py | 13 +- src/pudl/metadata/resources/eia923.py | 21 +- .../resources/ferc1_eia_record_linkage.py | 6 +- src/pudl/metadata/resources/glue.py | 6 +- src/pudl/output/ferc1.py | 56 +++-- src/pudl/settings.py | 139 +++++------ src/pudl/transform/classes.py | 83 ++++--- src/pudl/transform/ferc1.py | 27 ++- src/pudl/workspace/datastore.py | 22 +- src/pudl/workspace/setup.py | 33 +-- test/conftest.py | 5 +- test/unit/settings_test.py | 8 +- 18 files changed, 302 insertions(+), 389 deletions(-) delete mode 100644 src/pudl/etl/analysis_assets.py diff --git a/.gitignore b/.gitignore index 173bc5da8b..a582a409d5 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,8 @@ docs/data_dictionaries/pudl_db.rst .ipynb_checkpoints/ .cache/ +.ruff_cache/ +.mypy_cache/ .pytest_cache/* .DS_Store build/ diff --git a/pyproject.toml b/pyproject.toml index 1d250f22e5..b5166cef6b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,10 +50,8 @@ dependencies = [ "pandera>=0.17.2", "pre-commit>=3", "pyarrow>=14.0.1", # pandas[parquet] - "pydantic>=1.10,<2", - # Required after pandera-core is retired and we switch to pydantic v2 - #"pydantic>=2.4", - #"pydantic-settings>=2", + "pydantic>=2.4", + "pydantic-settings>=2", "pytest>=7.4", "pytest-cov>=4.1", "pytest-console-scripts>=1.4", @@ -220,10 +218,6 @@ exclude = ["migrations/versions/*"] "test/*" = ["D"] "migrations/*" = ["D", "Q"] -[tool.ruff.pep8-naming] -# Allow Pydantic's `@validator` decorator to trigger class method treatment. -classmethod-decorators = ["pydantic.validator", "pydantic.root_validator"] - [tool.ruff.isort] known-first-party = ["pudl"] diff --git a/src/pudl/__init__.py b/src/pudl/__init__.py index 17048a6e3c..a578d8f855 100644 --- a/src/pudl/__init__.py +++ b/src/pudl/__init__.py @@ -8,6 +8,7 @@ convert, etl, extract, + ferc_to_sqlite, glue, helpers, io_managers, diff --git a/src/pudl/etl/__init__.py b/src/pudl/etl/__init__.py index e62d1003c7..d49ce5387b 100644 --- a/src/pudl/etl/__init__.py +++ b/src/pudl/etl/__init__.py @@ -21,6 +21,7 @@ from pudl.settings import EtlSettings from . import ( + check_foreign_keys, eia_bulk_elec_assets, epacems_assets, glue_assets, diff --git a/src/pudl/etl/analysis_assets.py b/src/pudl/etl/analysis_assets.py deleted file mode 100644 index a605000bb9..0000000000 --- a/src/pudl/etl/analysis_assets.py +++ /dev/null @@ -1,34 +0,0 @@ -"""Derived / analysis assets that aren't simple to construct. - -This is really too large & generic of a category. Should we have an asset group for each -set of related analyses? E.g. - -* mcoe_assets -* service_territory_assets -* heat_rate_assets -* state_demand_assets -* depreciation_assets -* plant_parts_eia_assets -* ferc1_eia_record_linkage_assets - -Not sure what the right organization is but they'll be defined across a bunch of -different modules. Eventually I imagine these would just be the novel derived values, -probably in pretty skinny tables, which get joined / aggregated with other data in the -denormalized tables. -""" -import pandas as pd -from dagster import asset - -import pudl - -logger = pudl.logging_helpers.get_logger(__name__) - - -@asset(io_manager_key="pudl_sqlite_io_manager", compute_kind="Python") -def utility_analysis(utils_eia860: pd.DataFrame) -> pd.DataFrame: - """Example of how to create an analysis table that depends on an output view. - - This final dataframe will be written to the database (without a schema). - """ - # Do some analysis on utils_eia860 - return utils_eia860 diff --git a/src/pudl/metadata/classes.py b/src/pudl/metadata/classes.py index a5e8f0be31..a54975753e 100644 --- a/src/pudl/metadata/classes.py +++ b/src/pudl/metadata/classes.py @@ -7,7 +7,7 @@ from collections.abc import Callable, Iterable from functools import lru_cache from pathlib import Path -from typing import Any, Literal +from typing import Annotated, Any, Literal import jinja2 import pandas as pd @@ -15,6 +15,7 @@ import pydantic import sqlalchemy as sa from pandas._libs.missing import NAType +from pydantic import ConfigDict, StringConstraints, ValidationInfo from pydantic.types import DirectoryPath import pudl.logging_helpers @@ -161,13 +162,7 @@ class Base(pydantic.BaseModel): {'fields': ['y']} """ - class Config: - """Custom Pydantic configuration.""" - - validate_all: bool = True - validate_assignment: bool = True - extra: str = "forbid" - arbitrary_types_allowed = True + model_config = ConfigDict(arbitrary_types_allowed=True) def dict(self, *args, by_alias=True, **kwargs) -> dict: # noqa: A003 """Return as a dictionary.""" @@ -200,12 +195,17 @@ def __repr_args__(self) -> list[tuple[str, Any]]: # ---- Class attribute types ---- # # NOTE: Using regex=r"^\S(.*\S)*$" to fail on whitespace is too slow -String = pydantic.constr(min_length=1, strict=True, regex=r"^\S+(\s+\S+)*$") +String = Annotated[ + str, StringConstraints(min_length=1, strict=True, pattern=r"^\S+(\s+\S+)*$") +] """Non-empty :class:`str` with no trailing or leading whitespace.""" -SnakeCase = pydantic.constr( - min_length=1, strict=True, regex=r"^[a-z_][a-z0-9_]*(_[a-z0-9]+)*$" -) +SnakeCase = Annotated[ + str, + StringConstraints( + min_length=1, strict=True, pattern=r"^[a-z_][a-z0-9_]*(_[a-z0-9]+)*$" + ), +] """Snake-case variable name :class:`str` (e.g. 'pudl', 'entity_eia860').""" Bool = pydantic.StrictBool @@ -217,10 +217,10 @@ def __repr_args__(self) -> list[tuple[str, Any]]: Int = pydantic.StrictInt """Any :class:`int`.""" -PositiveInt = pydantic.conint(ge=0, strict=True) +PositiveInt = Annotated[int, pydantic.Field(ge=0, strict=True)] """Positive :class:`int`.""" -PositiveFloat = pydantic.confloat(ge=0, strict=True) +PositiveFloat = Annotated[float, pydantic.Field(ge=0, strict=True)] """Positive :class:`float`.""" Email = pydantic.EmailStr @@ -230,60 +230,13 @@ def __repr_args__(self) -> list[tuple[str, Any]]: """Http(s) URL.""" -class BaseType: - """Base class for custom pydantic types.""" - - @classmethod - def __get_validators__(cls) -> Callable: - """Yield validator methods.""" - yield cls.validate - - -class Date(BaseType): - """Any :class:`datetime.date`.""" - - @classmethod - def validate(cls, value: Any) -> datetime.date: - """Validate as date.""" - if not isinstance(value, datetime.date): - raise TypeError("value is not a date") - return value - - -class Datetime(BaseType): - """Any :class:`datetime.datetime`.""" - - @classmethod - def validate(cls, value: Any) -> datetime.datetime: - """Validate as datetime.""" - if not isinstance(value, datetime.datetime): - raise TypeError("value is not a datetime") - return value - - -class Pattern(BaseType): - """Regular expression pattern.""" - - @classmethod - def validate(cls, value: Any) -> re.Pattern: - """Validate as pattern.""" - if not isinstance(value, str | re.Pattern): - raise TypeError("value is not a string or compiled regular expression") - if isinstance(value, str): - try: - value = re.compile(value) - except re.error: - raise ValueError("string is not a valid regular expression") - return value - - -def StrictList(item_type: type = Any) -> pydantic.ConstrainedList: # noqa: N802 +def StrictList(item_type: type = Any) -> type: # noqa: N802 """Non-empty :class:`list`. Allows :class:`list`, :class:`tuple`, :class:`set`, :class:`frozenset`, :class:`collections.deque`, or generators and casts to a :class:`list`. """ - return pydantic.conlist(item_type=item_type, min_items=1) + return Annotated[list[item_type], pydantic.Field(min_length=1)] # ---- Class attribute validators ---- # @@ -303,7 +256,7 @@ def _validator(*names, fn: Callable) -> Callable: Args: names: Names of attributes to validate. - fn: Validation function (see :meth:`pydantic.validator`). + fn: Validation function (see :meth:`pydantic.field_validator`). Examples: >>> class Class(Base): @@ -313,7 +266,7 @@ def _validator(*names, fn: Callable) -> Callable: Traceback (most recent call last): ValidationError: ... """ - return pydantic.validator(*names, allow_reuse=True)(fn) + return pydantic.field_validator(*names)(fn) # ---- Classes: Field ---- # @@ -329,17 +282,20 @@ class FieldConstraints(Base): unique: Bool = False min_length: PositiveInt = None max_length: PositiveInt = None - minimum: Int | Float | Date | Datetime = None - maximum: Int | Float | Date | Datetime = None - pattern: Pattern = None + minimum: Int | Float | datetime.date | datetime.datetime = None + maximum: Int | Float | datetime.date | datetime.datetime = None + pattern: re.Pattern = None # TODO: Replace with String (min_length=1) once "" removed from enums - enum: StrictList(pydantic.StrictStr | Int | Float | Bool | Date | Datetime) = None + enum: StrictList( + pydantic.StrictStr | Int | Float | Bool | datetime.date | datetime.datetime + ) = None _check_unique = _validator("enum", fn=_check_unique) - @pydantic.validator("max_length") - def _check_max_length(cls, value, values): # noqa: N805 - minimum, maximum = values.get("min_length"), value + @pydantic.field_validator("max_length") + @classmethod + def _check_max_length(cls, value, info: ValidationInfo): + minimum, maximum = info.data.get("min_length"), value if minimum is not None and maximum is not None: if type(minimum) is not type(maximum): raise ValueError("must be same type as min_length") @@ -347,9 +303,10 @@ def _check_max_length(cls, value, values): # noqa: N805 raise ValueError("must be greater or equal to min_length") return value - @pydantic.validator("maximum") - def _check_max(cls, value, values): # noqa: N805 - minimum, maximum = values.get("minimum"), value + @pydantic.field_validator("maximum") + @classmethod + def _check_max(cls, value, info: ValidationInfo): + minimum, maximum = info.data.get("minimum"), value if minimum is not None and maximum is not None: if type(minimum) is not type(maximum): raise ValueError("must be same type as minimum") @@ -434,7 +391,8 @@ class Encoder(Base): name: String = None """The name of the code.""" - @pydantic.validator("df") + @pydantic.field_validator("df") + @classmethod def _df_is_encoding_table(cls, df): # noqa: N805 """Verify that the coding table provides both codes and descriptions.""" errors = [] @@ -449,52 +407,56 @@ def _df_is_encoding_table(cls, df): # noqa: N805 raise ValueError(format_errors(*errors, pydantic=True)) return df - @pydantic.validator("ignored_codes") - def _good_and_ignored_codes_are_disjoint(cls, ignored_codes, values): # noqa: N805 + @pydantic.field_validator("ignored_codes") + @classmethod + def _good_and_ignored_codes_are_disjoint(cls, ignored_codes, info: ValidationInfo): """Check that there's no overlap between good and ignored codes.""" - if "df" not in values: + if "df" not in info.data: return ignored_codes errors = [] - overlap = set(values["df"]["code"]).intersection(ignored_codes) + overlap = set(info.data["df"]["code"]).intersection(ignored_codes) if overlap: errors.append(f"Overlap found between good and ignored codes: {overlap}.") if errors: raise ValueError(format_errors(*errors, pydantic=True)) return ignored_codes - @pydantic.validator("code_fixes") - def _good_and_fixable_codes_are_disjoint(cls, code_fixes, values): # noqa: N805 + @pydantic.field_validator("code_fixes") + @classmethod + def _good_and_fixable_codes_are_disjoint(cls, code_fixes, info: ValidationInfo): """Check that there's no overlap between the good and fixable codes.""" - if "df" not in values: + if "df" not in info.data: return code_fixes errors = [] - overlap = set(values["df"]["code"]).intersection(code_fixes) + overlap = set(info.data["df"]["code"]).intersection(code_fixes) if overlap: errors.append(f"Overlap found between good and fixable codes: {overlap}") if errors: raise ValueError(format_errors(*errors, pydantic=True)) return code_fixes - @pydantic.validator("code_fixes") - def _fixable_and_ignored_codes_are_disjoint(cls, code_fixes, values): # noqa: N805 + @pydantic.field_validator("code_fixes") + @classmethod + def _fixable_and_ignored_codes_are_disjoint(cls, code_fixes, info: ValidationInfo): """Check that there's no overlap between the ignored and fixable codes.""" - if "ignored_codes" not in values: + if "ignored_codes" not in info.data: return code_fixes errors = [] - overlap = set(code_fixes).intersection(values["ignored_codes"]) + overlap = set(code_fixes).intersection(info.data["ignored_codes"]) if overlap: errors.append(f"Overlap found between fixable and ignored codes: {overlap}") if errors: raise ValueError(format_errors(*errors, pydantic=True)) return code_fixes - @pydantic.validator("code_fixes") - def _check_fixed_codes_are_good_codes(cls, code_fixes, values): # noqa: N805 + @pydantic.field_validator("code_fixes") + @classmethod + def _check_fixed_codes_are_good_codes(cls, code_fixes, info: ValidationInfo): """Check that every every fixed code is also one of the good codes.""" - if "df" not in values: + if "df" not in info.data: return code_fixes errors = [] - bad_codes = set(code_fixes.values()).difference(values["df"]["code"]) + bad_codes = set(code_fixes.values()).difference(info.data["df"]["code"]) if bad_codes: errors.append( f"Some fixed codes aren't in the list of good codes: {bad_codes}" @@ -594,15 +556,16 @@ class Field(Base): format: Literal["default"] = "default" # noqa: A003 description: String = None unit: String = None - constraints: FieldConstraints = {} - harvest: FieldHarvest = {} - encoder: Encoder = None + constraints: FieldConstraints = FieldConstraints() + harvest: FieldHarvest = FieldHarvest() + encoder: Encoder | None = None - @pydantic.validator("constraints") - def _check_constraints(cls, value, values): # noqa: N805, C901 - if "type" not in values: + @pydantic.field_validator("constraints") + @classmethod + def _check_constraints(cls, value, info: ValidationInfo): # noqa: C901 + if "type" not in info.data: return value - dtype = values["type"] + dtype = info.data["type"] errors = [] for key in ("min_length", "max_length", "pattern"): if getattr(value, key) is not None and dtype != "string": @@ -622,12 +585,13 @@ def _check_constraints(cls, value, values): # noqa: N805, C901 raise ValueError(format_errors(*errors, pydantic=True)) return value - @pydantic.validator("encoder") - def _check_encoder(cls, value, values): # noqa: N805 - if "type" not in values or value is None: + @pydantic.field_validator("encoder") + @classmethod + def _check_encoder(cls, value, info: ValidationInfo): + if "type" not in info.data or value is None: return value errors = [] - dtype = values["type"] + dtype = info.data["type"] if dtype not in ["string", "integer"]: errors.append( "Encoding only supported for string and integer fields, found " @@ -772,9 +736,10 @@ class ForeignKey(Base): _check_unique = _validator("fields_", fn=_check_unique) - @pydantic.validator("reference") - def _check_fields_equal_length(cls, value, values): # noqa: N805 - if "fields_" in values and len(value.fields) != len(values["fields_"]): + @pydantic.field_validator("reference") + @classmethod + def _check_fields_equal_length(cls, value, info: ValidationInfo): + if "fields_" in info.data and len(value.fields) != len(info.data["fields_"]): raise ValueError("fields and reference.fields are not equal length") return value @@ -805,20 +770,22 @@ class Schema(Base): "missing_values", "primary_key", "foreign_keys", fn=_check_unique ) - @pydantic.validator("fields_") + @pydantic.field_validator("fields_") + @classmethod def _check_field_names_unique(cls, value): # noqa: N805 _check_unique([f.name for f in value]) return value - @pydantic.validator("primary_key") - def _check_primary_key_in_fields(cls, value, values): # noqa: N805 - if value is not None and "fields_" in values: + @pydantic.field_validator("primary_key") + @classmethod + def _check_primary_key_in_fields(cls, value, info: ValidationInfo): + if value is not None and "fields_" in info.data: missing = [] - names = [f.name for f in values["fields_"]] + names = [f.name for f in info.data["fields_"]] for name in value: if name in names: # Flag primary key fields as required - field = values["fields_"][names.index(name)] + field = info.data["fields_"][names.index(name)] field.constraints.required = True else: missing.append(field.name) @@ -826,14 +793,15 @@ def _check_primary_key_in_fields(cls, value, values): # noqa: N805 raise ValueError(f"names {missing} missing from fields") return value - @pydantic.validator("foreign_keys", each_item=True) - def _check_foreign_key_in_fields(cls, value, values): # noqa: N805 - if value and "fields_" in values: - names = [f.name for f in values["fields_"]] - missing = [x for x in value.fields if x not in names] - if missing: - raise ValueError(f"names {missing} missing from fields") - return value + # TODO[pydantic] Refactor... + # @pydantic.validator("foreign_keys", each_item=True) + # def _check_foreign_key_in_fields(cls, value, info: ValidationInfo): + # if value and "fields_" in info.data: + # names = [f.name for f in info.data["fields_"]] + # missing = [x for x in value.fields if x not in names] + # if missing: + # raise ValueError(f"names {missing} missing from fields") + # return value class License(Base): @@ -1174,7 +1142,7 @@ class Resource(Base): name: SnakeCase title: String = None description: String = None - harvest: ResourceHarvest = {} + harvest: ResourceHarvest = ResourceHarvest() schema_: Schema = pydantic.Field(alias="schema") format_: String = pydantic.Field(alias="format", default=None) mediatype: String = None @@ -1185,7 +1153,7 @@ class Resource(Base): licenses: list[License] = [] sources: list[DataSource] = [] keywords: list[String] = [] - encoder: Encoder = None + encoder: Encoder | None = None field_namespace: Literal[ "eia", "epacems", @@ -1222,9 +1190,10 @@ class Resource(Base): "contributors", "keywords", "licenses", "sources", fn=_check_unique ) - @pydantic.validator("schema_") - def _check_harvest_primary_key(cls, value, values): # noqa: N805 - if values["harvest"].harvest and not value.primary_key: + @pydantic.field_validator("schema_") + @classmethod + def _check_harvest_primary_key(cls, value, info: ValidationInfo): + if info.data["harvest"].harvest and not value.primary_key: raise ValueError("Harvesting requires a primary key") return value @@ -1745,14 +1714,15 @@ class Package(Base): description: String = None keywords: list[String] = [] homepage: HttpUrl = "https://catalyst.coop/pudl" - created: Datetime = datetime.datetime.utcnow() + created: datetime.datetime = datetime.datetime.utcnow() contributors: list[Contributor] = [] sources: list[DataSource] = [] licenses: list[License] = [] resources: StrictList(Resource) profile: String = "tabular-data-package" - @pydantic.validator("resources") + @pydantic.field_validator("resources") + @classmethod def _check_foreign_keys(cls, value): # noqa: N805 rnames = [resource.name for resource in value] errors = [] diff --git a/src/pudl/metadata/resources/eia860.py b/src/pudl/metadata/resources/eia860.py index 22ec9424c6..cdab8d5539 100644 --- a/src/pudl/metadata/resources/eia860.py +++ b/src/pudl/metadata/resources/eia860.py @@ -501,8 +501,8 @@ "description": ( """The cost, type, operating status, retirement date, and install year of emissions control equipment reported to EIA. Includes control ids for sulfur dioxide -(SO2), particulate matter, mercury, nitrogen oxide (NOX), and acid (HCl) gas monitoring. -""" +(SO2), particulate matter, mercury, nitrogen oxide (NOX), and acid (HCl) gas +monitoring.""" ), "schema": { "fields": [ @@ -533,8 +533,7 @@ emissions control equipment reported to EIA. Includes control ids for sulfur dioxide (SO2), particulate matter, mercury, nitrogen oxide (NOX), and acid (HCl) gas monitoring. The denormalized version contains plant name, utility id, pudl id, and utility name -columns. -""" +columns.""" ), "schema": { "fields": [ @@ -569,8 +568,7 @@ "description": ( """A table that links EIA boiler IDs to emissions control IDs for NOx, SO2, mercury, and particulate monitoring. The relationship between the IDs is sometimes many -to many. -""" +to many.""" ), "schema": { "fields": [ @@ -617,8 +615,7 @@ "boiler_stack_flue_assn_eia860": { "description": ( """A table that links EIA boiler IDs to EIA stack and/or flue -system IDs. -""" +system IDs.""" ), "schema": { "fields": [ diff --git a/src/pudl/metadata/resources/eia923.py b/src/pudl/metadata/resources/eia923.py index f79bd60360..5de5724284 100644 --- a/src/pudl/metadata/resources/eia923.py +++ b/src/pudl/metadata/resources/eia923.py @@ -18,8 +18,7 @@ complex. Note that a small number of respondents only report annual fuel consumption, and all of -it is reported in December. -""" +it is reported in December.""" ), "fuel_receipts_costs_eia923": ( """Data describing fuel deliveries to power plants, reported in EIA-923 Schedule 2, Part A. @@ -44,8 +43,7 @@ Northeastern US reports essentially no fine-grained data about its natural gas prices. Additional data which we haven't yet integrated is available in a similar format from -2002-2008 via the EIA-423, and going back as far as 1972 from the FERC-423. -""" +2002-2008 via the EIA-423, and going back as far as 1972 from the FERC-423.""" ), "generation_eia923": ( """EIA-923 Monthly Generating Unit Net Generation. From EIA-923 Schedule 3. @@ -62,8 +60,7 @@ incomplete boiler-generator associations. Note that a small number of respondents only report annual net generation, and all of -it is reported in December. -""" +it is reported in December.""" ), "generation_fuel_eia923": ( """EIA-923 Monthly Generation and Fuel Consumption Time Series. From EIA-923 Schedule 3. @@ -83,8 +80,7 @@ generation. Note that a small number of respondents only report annual fuel consumption and net -generation, and all of it is reported in December. -""" +generation, and all of it is reported in December.""" ), "generation_fuel_nuclear_eia923": ( """EIA-923 Monthly Generation and Fuel Consumption Time Series. From EIA-923 Schedule 3. @@ -93,8 +89,7 @@ fuel and prime mover within a nuclear generation unit. This data is originally reported alongside similar information for fossil fuel plants, but the nuclear data is reported by (nuclear) generation unit rather than fuel type and prime mover, and so has a -different primary key. -""" +different primary key.""" ), "generation_fuel_combined_eia923": ( """EIA-923 Monthly Generation and Fuel Consumption Time Series. From EIA-923 Schedule 3. @@ -102,8 +97,7 @@ Denormalized, combined data from the ``generation_fuel_eia923`` and ``generation_fuel_nuclear_eia923`` with nuclear generation aggregated from the nuclear generation unit level up to the plant prime mover level, so as to be compatible with -fossil fuel generation data. -""" +fossil fuel generation data.""" ), } @@ -256,8 +250,7 @@ We have not yet taken the time to rigorously clean this data, but it could be linked with both Mining Safety and Health Administration (MSHA) and USGS data to provide more insight into where coal is coming from, and what the employment and geological context -is for those supplies. -""" +is for those supplies.""" ), "schema": { "fields": [ diff --git a/src/pudl/metadata/resources/ferc1_eia_record_linkage.py b/src/pudl/metadata/resources/ferc1_eia_record_linkage.py index e6cb93a781..e1a5f89032 100644 --- a/src/pudl/metadata/resources/ferc1_eia_record_linkage.py +++ b/src/pudl/metadata/resources/ferc1_eia_record_linkage.py @@ -27,8 +27,7 @@ and the total records are labeled as "total". This table includes A LOT of duplicative information about EIA plants. It is primarily -meant for use as an input into the record linkage between FERC1 plants and EIA. -""", +meant for use as an input into the record linkage between FERC1 plants and EIA.""", "schema": { "fields": [ "record_id_eia", @@ -137,8 +136,7 @@ The EIA data associated with each FERC plant record comes from our Plant Parts EIA table. The EIA data in each record represents an aggregation of several slices of an EIA -plant, across both physical characteristics and utility ownership. -""", +plant, across both physical characteristics and utility ownership.""", "schema": { "fields": [ "record_id_ferc1", diff --git a/src/pudl/metadata/resources/glue.py b/src/pudl/metadata/resources/glue.py index fcfe802a8e..3a4898baf6 100644 --- a/src/pudl/metadata/resources/glue.py +++ b/src/pudl/metadata/resources/glue.py @@ -18,8 +18,7 @@ Our version of the crosswalk clarifies some of the column names and removes unmatched rows. The :func:`pudl.etl.glue_assets.epacamd_eia` function doc strings explain -what changes are made from the EPA's version. -""", +what changes are made from the EPA's version.""", "schema": { "fields": [ "report_year", @@ -71,8 +70,7 @@ This table does not have primary keys because the primary keys would have been: plant_id_eia, generator_id, subplant_id and emissions_unit_id_epa, but there are some null records in the generator_id column because ~2 percent of all EPA CAMD records are not -successfully mapped to EIA generators. -""", +successfully mapped to EIA generators.""", "schema": { "fields": [ "plant_id_eia", diff --git a/src/pudl/output/ferc1.py b/src/pudl/output/ferc1.py index 5ed3551f3e..4931473071 100644 --- a/src/pudl/output/ferc1.py +++ b/src/pudl/output/ferc1.py @@ -12,7 +12,7 @@ from matplotlib import pyplot as plt from networkx.drawing.nx_agraph import graphviz_layout from pandas._libs.missing import NAType as pandas_NAType -from pydantic import BaseModel, validator +from pydantic import BaseModel, ConfigDict, field_validator, validator import pudl from pudl.transform.ferc1 import ( @@ -1611,22 +1611,22 @@ class XbrlCalculationForestFerc1(BaseModel): seeds: list[NodeId] = [] tags: pd.DataFrame = pd.DataFrame() group_metric_checks: GroupMetricChecks = GroupMetricChecks() + model_config = ConfigDict( + arbitrary_types_allowed=True, ignored_types=(cached_property,) + ) - class Config: - """Allow the class to store a dataframe.""" - - arbitrary_types_allowed = True - keep_untouched = (cached_property,) - + # TODO[pydantic]: refactor to use @field_validator @validator("parent_cols", always=True) + @classmethod def set_parent_cols(cls, v, values) -> list[str]: """A convenience property to generate parent column.""" return [col + "_parent" for col in values["calc_cols"]] - @validator("exploded_calcs") - def unique_associations(cls, v: pd.DataFrame, values) -> pd.DataFrame: + @field_validator("exploded_calcs") + @classmethod + def unique_associations(cls, v: pd.DataFrame, info) -> pd.DataFrame: """Ensure parent-child associations in exploded calculations are unique.""" - pks = values["calc_cols"] + values["parent_cols"] + pks = info.data["calc_cols"] + info.data["parent_cols"] dupes = v.duplicated(subset=pks, keep=False) if dupes.any(): logger.warning( @@ -1638,10 +1638,11 @@ def unique_associations(cls, v: pd.DataFrame, values) -> pd.DataFrame: assert not v.duplicated(subset=pks, keep=False).any() return v - @validator("exploded_calcs") - def calcs_have_required_cols(cls, v: pd.DataFrame, values) -> pd.DataFrame: + @field_validator("exploded_calcs") + @classmethod + def calcs_have_required_cols(cls, v: pd.DataFrame, info) -> pd.DataFrame: """Ensure exploded calculations include all required columns.""" - required_cols = values["parent_cols"] + values["calc_cols"] + ["weight"] + required_cols = info.data["parent_cols"] + info.data["calc_cols"] + ["weight"] missing_cols = [col for col in required_cols if col not in v.columns] if missing_cols: raise ValueError( @@ -1649,24 +1650,27 @@ def calcs_have_required_cols(cls, v: pd.DataFrame, values) -> pd.DataFrame: ) return v[required_cols] - @validator("exploded_calcs") + @field_validator("exploded_calcs") + @classmethod def calc_parents_notna(cls, v: pd.DataFrame) -> pd.DataFrame: """Ensure that parent table_name and xbrl_factoid columns are non-null.""" if v[["table_name_parent", "xbrl_factoid_parent"]].isna().any(axis=None): raise AssertionError("Null parent table name or xbrl_factoid found.") return v - @validator("tags") - def tags_have_required_cols(cls, v: pd.DataFrame, values) -> pd.DataFrame: + @field_validator("tags") + @classmethod + def tags_have_required_cols(cls, v: pd.DataFrame, info) -> pd.DataFrame: """Ensure tagging dataframe contains all required index columns.""" - missing_cols = [col for col in values["calc_cols"] if col not in v.columns] + missing_cols = [col for col in info.data["calc_cols"] if col not in v.columns] if missing_cols: raise ValueError( f"Tagging dataframe was missing expected columns: {missing_cols=}" ) return v - @validator("tags") + @field_validator("tags") + @classmethod def tags_cols_notnull(cls, v: pd.DataFrame) -> pd.DataFrame: """Ensure all tags have non-null table_name and xbrl_factoid.""" null_idx_rows = v[v.table_name.isna() | v.xbrl_factoid.isna()] @@ -1679,25 +1683,29 @@ def tags_cols_notnull(cls, v: pd.DataFrame) -> pd.DataFrame: v = v.dropna(subset=["table_name", "xbrl_factoid"]) return v - @validator("tags") - def single_valued_tags(cls, v: pd.DataFrame, values) -> pd.DataFrame: + @field_validator("tags") + @classmethod + def single_valued_tags(cls, v: pd.DataFrame, info) -> pd.DataFrame: """Ensure all tags have unique values.""" - dupes = v.duplicated(subset=values["calc_cols"], keep=False) + dupes = v.duplicated(subset=info.data["calc_cols"], keep=False) if dupes.any(): logger.warning( f"Found {dupes.sum()} duplicate tag records:\n{v.loc[dupes]}" ) return v - @validator("seeds") - def seeds_within_bounds(cls, v: pd.DataFrame, values) -> pd.DataFrame: + @field_validator("seeds") + @classmethod + def seeds_within_bounds(cls, v: pd.DataFrame, info) -> pd.DataFrame: """Ensure that all seeds are present within exploded_calcs index. For some reason this validator is being run before exploded_calcs has been added to the values dictionary, which doesn't make sense, since "seeds" is defined after exploded_calcs in the model. """ - all_nodes = values["exploded_calcs"].set_index(values["parent_cols"]).index + all_nodes = ( + info.data["exploded_calcs"].set_index(info.data["parent_cols"]).index + ) bad_seeds = [seed for seed in v if seed not in all_nodes] if bad_seeds: raise ValueError(f"Seeds missing from exploded_calcs index: {bad_seeds=}") diff --git a/src/pudl/settings.py b/src/pudl/settings.py index d432acd330..33579c1695 100644 --- a/src/pudl/settings.py +++ b/src/pudl/settings.py @@ -2,20 +2,28 @@ import itertools import json from enum import Enum, unique -from typing import ClassVar +from typing import Any, ClassVar import fsspec import pandas as pd import yaml -from dagster import Any, DagsterInvalidDefinitionError, Field -from pydantic import AnyHttpUrl, BaseSettings, root_validator, validator +from dagster import Field as DagsterField from pydantic import BaseModel as PydanticBaseModel +from pydantic import ( + ConfigDict, + field_validator, + model_validator, + root_validator, +) +from pydantic_settings import BaseSettings import pudl import pudl.workspace.setup from pudl.metadata.classes import DataSource from pudl.workspace.datastore import Datastore +logger = pudl.logging_helpers.get_logger(__name__) + @unique class XbrlFormNumber(Enum): @@ -31,11 +39,7 @@ class XbrlFormNumber(Enum): class BaseModel(PydanticBaseModel): """BaseModel with global configuration.""" - class Config: - """Pydantic config.""" - - allow_mutation = False - extra = "forbid" + model_config = ConfigDict(frozen=True, extra="forbid") class GenericDatasetSettings(BaseModel): @@ -50,8 +54,10 @@ class GenericDatasetSettings(BaseModel): disabled: bool = False - @root_validator - def validate_partitions(cls, partitions): # noqa: N805 + # TODO[pydantic]: Refactor to use model_validator + @root_validator(skip_on_failure=True) + @classmethod + def validate_partitions(cls, partitions: Any): # noqa: N805 """Validate the requested data partitions. Check that all the partitions defined in the ``working_partitions`` of the @@ -150,7 +156,8 @@ class EpaCemsSettings(GenericDatasetSettings): years: list[int] = data_source.working_partitions["years"] states: list[str] = data_source.working_partitions["states"] - @validator("states") + @field_validator("states") + @classmethod def allow_all_keyword(cls, states): # noqa: N805 """Allow users to specify ['all'] to get all states.""" if states == ["all"]: @@ -202,7 +209,8 @@ class Eia860Settings(GenericDatasetSettings): years: list[int] = data_source.working_partitions["years"] eia860m: bool = True - @validator("eia860m") + @field_validator("eia860m") + @classmethod def check_eia860m_date(cls, eia860m: bool) -> bool: # noqa: N805 """Check 860m date-year is exactly one year after most recent working 860 year. @@ -248,11 +256,12 @@ class EiaSettings(BaseModel): eia923: Immutable pydantic model to validate eia923 settings. """ - eia860: Eia860Settings = None - eia861: Eia861Settings = None - eia923: Eia923Settings = None + eia860: Eia860Settings | None = None + eia861: Eia861Settings | None = None + eia923: Eia923Settings | None = None - @root_validator(pre=True) + @model_validator(mode="before") + @classmethod def default_load_all(cls, values): # noqa: N805 """If no datasets are specified default to all. @@ -269,7 +278,8 @@ def default_load_all(cls, values): # noqa: N805 return values - @root_validator + @model_validator(mode="before") + @classmethod def check_eia_dependencies(cls, values): # noqa: N805 """Make sure the dependencies between the eia datasets are satisfied. @@ -282,15 +292,15 @@ def check_eia_dependencies(cls, values): # noqa: N805 Returns: values (Dict[str, BaseModel]): dataset settings. """ - eia923 = values.get("eia923") - eia860 = values.get("eia860") - if not eia923 and eia860: - values["eia923"] = Eia923Settings(years=eia860.years) + if not values.get("eia923") and values.get("eia860"): + values["eia923"] = Eia923Settings(years=values["eia860"].years) - if eia923 and not eia860: - available_years = Eia860Settings() + if values.get("eia923") and not values.get("eia860"): + available_years = Eia860Settings().years values["eia860"] = Eia860Settings( - years=[year for year in eia923.years if year in available_years] + years=[ + year for year in values["eia923"].years if year in available_years + ] ) return values @@ -311,7 +321,8 @@ class DatasetsSettings(BaseModel): ferc714: Ferc714Settings = None glue: GlueSettings = None - @root_validator(pre=True) + @model_validator(mode="before") + @classmethod def default_load_all(cls, values): # noqa: N805 """If no datasets are specified default to all. @@ -330,8 +341,9 @@ def default_load_all(cls, values): # noqa: N805 return values - @root_validator - def add_glue_settings(cls, values): # noqa: N805 + @model_validator(mode="before") + @classmethod + def add_glue_settings(cls, data: Any): # noqa: N805 """Add glue settings if ferc1 and eia data are both requested. Args: @@ -340,11 +352,11 @@ def add_glue_settings(cls, values): # noqa: N805 Returns: values (Dict[str, BaseModel]): dataset settings. """ - ferc1 = bool(values.get("ferc1")) - eia = bool(values.get("eia")) - - values["glue"] = GlueSettings(ferc1=ferc1, eia=eia) - return values + ferc1 = bool(data.get("ferc1")) + eia = bool(data.get("eia")) + if ferc1 and eia: + data["glue"] = GlueSettings(ferc1=ferc1, eia=eia) + return data def get_datasets(self): # noqa: N805 """Gets dictionary of dataset settings.""" @@ -454,7 +466,7 @@ class FercGenericXbrlToSqliteSettings(BaseSettings): disabled: if True, skip processing this dataset. """ - taxonomy: AnyHttpUrl + taxonomy: str years: list[int] disabled: bool = False @@ -471,7 +483,7 @@ class Ferc1XbrlToSqliteSettings(FercGenericXbrlToSqliteSettings): years: list[int] = [ year for year in data_source.working_partitions["years"] if year >= 2021 ] - taxonomy: AnyHttpUrl = "https://eCollection.ferc.gov/taxonomy/form1/2022-01-01/form/form1/form-1_2022-01-01.xsd" + taxonomy: str = "https://eCollection.ferc.gov/taxonomy/form1/2022-01-01/form/form1/form-1_2022-01-01.xsd" class Ferc2XbrlToSqliteSettings(FercGenericXbrlToSqliteSettings): @@ -485,7 +497,7 @@ class Ferc2XbrlToSqliteSettings(FercGenericXbrlToSqliteSettings): years: list[int] = [ year for year in data_source.working_partitions["years"] if year >= 2021 ] - taxonomy: AnyHttpUrl = "https://eCollection.ferc.gov/taxonomy/form2/2022-01-01/form/form2/form-2_2022-01-01.xsd" + taxonomy: str = "https://eCollection.ferc.gov/taxonomy/form2/2022-01-01/form/form2/form-2_2022-01-01.xsd" class Ferc2DbfToSqliteSettings(GenericDatasetSettings): @@ -532,7 +544,7 @@ class Ferc6XbrlToSqliteSettings(FercGenericXbrlToSqliteSettings): years: list[int] = [ year for year in data_source.working_partitions["years"] if year >= 2021 ] - taxonomy: AnyHttpUrl = "https://eCollection.ferc.gov/taxonomy/form6/2022-01-01/form/form6/form-6_2022-01-01.xsd" + taxonomy: str = "https://eCollection.ferc.gov/taxonomy/form6/2022-01-01/form/form6/form-6_2022-01-01.xsd" class Ferc60DbfToSqliteSettings(GenericDatasetSettings): @@ -563,7 +575,7 @@ class Ferc60XbrlToSqliteSettings(FercGenericXbrlToSqliteSettings): years: list[int] = [ year for year in data_source.working_partitions["years"] if year >= 2021 ] - taxonomy: AnyHttpUrl = "https://eCollection.ferc.gov/taxonomy/form60/2022-01-01/form/form60/form-60_2022-01-01.xsd" + taxonomy: str = "https://eCollection.ferc.gov/taxonomy/form60/2022-01-01/form/form60/form-60_2022-01-01.xsd" class Ferc714XbrlToSqliteSettings(FercGenericXbrlToSqliteSettings): @@ -574,8 +586,8 @@ class Ferc714XbrlToSqliteSettings(FercGenericXbrlToSqliteSettings): """ data_source: ClassVar[DataSource] = DataSource.from_id("ferc714") - years: list[int] = [2021] - taxonomy: AnyHttpUrl = "https://eCollection.ferc.gov/taxonomy/form714/2022-01-01/form/form714/form-714_2022-01-01.xsd" + years: list[int] = [2021, 2022] + taxonomy: str = "https://eCollection.ferc.gov/taxonomy/form714/2022-01-01/form/form714/form-714_2022-01-01.xsd" class FercToSqliteSettings(BaseSettings): @@ -584,21 +596,22 @@ class FercToSqliteSettings(BaseSettings): Args: ferc1_dbf_to_sqlite_settings: Settings for converting FERC 1 DBF data to SQLite. ferc1_xbrl_to_sqlite_settings: Settings for converting FERC 1 XBRL data to - SQLite. + SQLite. other_xbrl_forms: List of non-FERC1 forms to convert from XBRL to SQLite. """ - ferc1_dbf_to_sqlite_settings: Ferc1DbfToSqliteSettings = None - ferc1_xbrl_to_sqlite_settings: Ferc1XbrlToSqliteSettings = None - ferc2_dbf_to_sqlite_settings: Ferc2DbfToSqliteSettings = None - ferc2_xbrl_to_sqlite_settings: Ferc2XbrlToSqliteSettings = None - ferc6_dbf_to_sqlite_settings: Ferc6DbfToSqliteSettings = None - ferc6_xbrl_to_sqlite_settings: Ferc6XbrlToSqliteSettings = None - ferc60_dbf_to_sqlite_settings: Ferc60DbfToSqliteSettings = None - ferc60_xbrl_to_sqlite_settings: Ferc60XbrlToSqliteSettings = None - ferc714_xbrl_to_sqlite_settings: Ferc714XbrlToSqliteSettings = None - - @root_validator(pre=True) + ferc1_dbf_to_sqlite_settings: Ferc1DbfToSqliteSettings | None = None + ferc1_xbrl_to_sqlite_settings: Ferc1XbrlToSqliteSettings | None = None + ferc2_dbf_to_sqlite_settings: Ferc2DbfToSqliteSettings | None = None + ferc2_xbrl_to_sqlite_settings: Ferc2XbrlToSqliteSettings | None = None + ferc6_dbf_to_sqlite_settings: Ferc6DbfToSqliteSettings | None = None + ferc6_xbrl_to_sqlite_settings: Ferc6XbrlToSqliteSettings | None = None + ferc60_dbf_to_sqlite_settings: Ferc60DbfToSqliteSettings | None = None + ferc60_xbrl_to_sqlite_settings: Ferc60XbrlToSqliteSettings | None = None + ferc714_xbrl_to_sqlite_settings: Ferc714XbrlToSqliteSettings | None = None + + @model_validator(mode="before") + @classmethod def default_load_all(cls, values): # noqa: N805 """If no datasets are specified default to all. @@ -648,13 +661,13 @@ def get_xbrl_dataset_settings( class EtlSettings(BaseSettings): """Main settings validation class.""" - ferc_to_sqlite_settings: FercToSqliteSettings = None - datasets: DatasetsSettings = None + ferc_to_sqlite_settings: FercToSqliteSettings | None = None + datasets: DatasetsSettings | None = None - name: str = None - title: str = None - description: str = None - version: str = None + name: str | None = None + title: str | None = None + description: str | None = None + version: str | None = None # This is list of fsspec compatible paths to publish the output datasets to. publish_destinations: list[str] = [] @@ -671,7 +684,7 @@ def from_yaml(cls, path: str) -> "EtlSettings": """ with fsspec.open(path) as f: yaml_file = yaml.safe_load(f) - return cls.parse_obj(yaml_file) + return cls.model_validate(yaml_file) def _convert_settings_to_dagster_config(d: dict) -> None: @@ -689,13 +702,7 @@ def _convert_settings_to_dagster_config(d: dict) -> None: if isinstance(v, dict): _convert_settings_to_dagster_config(v) else: - try: - d[k] = Field(type(v), default_value=v) - except DagsterInvalidDefinitionError: - # Dagster config accepts a valid dagster types. - # Most of our settings object properties are valid types - # except for fields like taxonomy which are the AnyHttpUrl type. - d[k] = Field(Any, default_value=v) + d[k] = DagsterField(type(v), default_value=v) def create_dagster_config(settings: BaseModel) -> dict: @@ -704,7 +711,7 @@ def create_dagster_config(settings: BaseModel) -> dict: Returns: A dictionary of dagster configuration. """ - ds = settings.dict() + ds = settings.model_dump() _convert_settings_to_dagster_config(ds) return ds diff --git a/src/pudl/transform/classes.py b/src/pudl/transform/classes.py index f32071696e..12f11fe09e 100644 --- a/src/pudl/transform/classes.py +++ b/src/pudl/transform/classes.py @@ -69,11 +69,18 @@ from collections.abc import Callable from functools import wraps from itertools import combinations -from typing import Any, Protocol +from typing import Annotated, Any, Protocol, Self import numpy as np import pandas as pd -from pydantic import BaseModel, conset, root_validator, validator +from pydantic import ( + BaseModel, + ConfigDict, + Field, + ValidationInfo, + field_validator, + model_validator, +) import pudl.logging_helpers import pudl.transform.params.ferc1 @@ -92,11 +99,7 @@ class TransformParams(BaseModel): when applied by their associated function. """ - class Config: - """Prevent parameters from changing part way through.""" - - allow_mutation = False - extra = "forbid" + model_config = ConfigDict(frozen=True, extra="forbid") class MultiColumnTransformParams(TransformParams): @@ -118,16 +121,16 @@ class MultiColumnTransformParams(TransformParams): https://pydantic-docs.helpmanual.io/blog/pydantic-v2/#validation-without-a-model """ - @root_validator - def single_param_type(cls, params): # noqa: N805 + @model_validator(mode="after") + def single_param_type(self: Self, info: ValidationInfo): """Check that all TransformParams in the dictionary are of the same type.""" - param_types = {type(params[col]) for col in params} + param_types = {type(info.data[col]) for col in info.data} if len(param_types) > 1: raise ValueError( "Found multiple parameter types in multi-column transform params: " f"{param_types}" ) - return params + return self ##################################################################################### @@ -424,7 +427,8 @@ class StringCategories(TransformParams): :func:`categorize_strings` to see how it is used. """ - @validator("categories") + @field_validator("categories") + @classmethod def categories_are_disjoint(cls, v): """Ensure that each string to be categorized only appears in one category.""" for cat1, cat2 in combinations(v, 2): @@ -436,7 +440,8 @@ def categories_are_disjoint(cls, v): ) return v - @validator("categories") + @field_validator("categories") + @classmethod def categories_are_idempotent(cls, v): """Ensure that every category contains the string it will map to. @@ -503,17 +508,17 @@ class UnitConversion(TransformParams): from_unit: str = "" # If it's the empty string, no renaming will happen. to_unit: str = "" # If it's the empty string, no renaming will happen. - @root_validator - def both_or_neither_units_are_none(cls, params): + @model_validator(mode="after") + def both_or_neither_units_are_none(self: Self): """Ensure that either both or neither of the units strings are None.""" - if (params["from_unit"] == "" and params["to_unit"] != "") or ( - params["from_unit"] != "" and params["to_unit"] == "" + if (self.from_unit == "" and self.to_unit != "") or ( + self.from_unit != "" and self.to_unit == "" ): raise ValueError( "Either both or neither of from_unit and to_unit must be non-empty. " - f"Got {params['from_unit']=} {params['to_unit']=}." + f"Got {self.from_unit=} {self.to_unit=}." ) - return params + return self def inverse(self) -> "UnitConversion": """Construct a :class:`UnitConversion` that is the inverse of self. @@ -572,12 +577,13 @@ class ValidRange(TransformParams): lower_bound: float = -np.inf upper_bound: float = np.inf - @validator("upper_bound") - def upper_bound_gte_lower_bound(cls, v, values): + @field_validator("upper_bound") + @classmethod + def upper_bound_gte_lower_bound(cls, upper_bound: float, info: ValidationInfo): """Require upper bound to be greater than or equal to lower bound.""" - if values["lower_bound"] > v: + if info.data["lower_bound"] > upper_bound: raise ValueError("upper_bound must be greater than or equal to lower_bound") - return v + return upper_bound def nullify_outliers(col: pd.Series, params: ValidRange) -> pd.Series: @@ -622,7 +628,8 @@ class UnitCorrections(TransformParams): unit_conversions: list[UnitConversion] """A list of unit conversions to use to identify errors and correct them.""" - @validator("unit_conversions") + @field_validator("unit_conversions") + @classmethod def no_column_rename(cls, params: list[UnitConversion]) -> list[UnitConversion]: """Ensure that the unit conversions used in corrections don't rename the column. @@ -636,8 +643,8 @@ def no_column_rename(cls, params: list[UnitConversion]) -> list[UnitConversion]: ) return new_conversions - @root_validator - def distinct_domains(cls, params): + @model_validator(mode="after") + def distinct_domains(self: Self): """Verify that all unit conversions map distinct domains to the valid range. If the domains being mapped to the valid range overlap, then it is ambiguous @@ -654,12 +661,12 @@ def distinct_domains(cls, params): corrected to be 2. """ input_vals = pd.Series( - [params["valid_range"].lower_bound, params["valid_range"].upper_bound], + [self.valid_range.lower_bound, self.valid_range.upper_bound], name="dude", ) # We need to make sure that the unit conversion doesn't map the valid range # onto itself either, so add an additional conversion that does nothing: - uc_combos = combinations(params["unit_conversions"] + [UnitConversion()], 2) + uc_combos = combinations(self.unit_conversions + [UnitConversion()], 2) for uc1, uc2 in uc_combos: out1 = convert_units(input_vals, uc1.inverse()) out2 = convert_units(input_vals, uc2.inverse()) @@ -667,11 +674,11 @@ def distinct_domains(cls, params): raise ValueError( "The following pair of unit corrections are incompatible due to " "overlapping domains.\n" - f"{params['valid_range']=}\n" + f"{self.valid_range=}\n" f"{uc1=}\n" f"{uc2=}\n" ) - return params + return self def correct_units(df: pd.DataFrame, params: UnitCorrections) -> pd.DataFrame: @@ -734,7 +741,7 @@ def correct_units(df: pd.DataFrame, params: UnitCorrections) -> pd.DataFrame: class InvalidRows(TransformParams): """Pameters that identify invalid rows to drop.""" - invalid_values: conset(Any, min_items=1) | None = None + invalid_values: Annotated[set[Any], Field(min_length=1)] | None = None """A list of values that should be considered invalid in the selected columns.""" required_valid_cols: list[str] | None = None @@ -753,16 +760,16 @@ class InvalidRows(TransformParams): regex: str | None = None """A regular expression to use as the ``regex`` argument to :meth:`pd.filter`.""" - @root_validator - def one_filter_argument(cls, values): + @model_validator(mode="after") + def one_filter_argument(self: Self): """Validate that only one argument is specified for :meth:`pd.filter`.""" num_args = sum( int(bool(val)) for val in [ - values["required_valid_cols"], - values["allowed_invalid_cols"], - values["like"], - values["regex"], + self.required_valid_cols, + self.allowed_invalid_cols, + self.like, + self.regex, ] ) if num_args > 1: @@ -771,7 +778,7 @@ def one_filter_argument(cls, values): f"{num_args} were found." ) - return values + return self def drop_invalid_rows(df: pd.DataFrame, params: InvalidRows) -> pd.DataFrame: diff --git a/src/pudl/transform/ferc1.py b/src/pudl/transform/ferc1.py index 1c497fd370..9ac4ad288c 100644 --- a/src/pudl/transform/ferc1.py +++ b/src/pudl/transform/ferc1.py @@ -15,14 +15,14 @@ from abc import abstractmethod from collections import namedtuple from collections.abc import Mapping -from typing import Any, Literal, Self +from typing import Annotated, Any, Literal, Self import numpy as np import pandas as pd import sqlalchemy as sa from dagster import AssetIn, AssetsDefinition, asset from pandas.core.groupby import DataFrameGroupBy -from pydantic import BaseModel, confloat, validator +from pydantic import BaseModel, Field, field_validator import pudl from pudl.analysis.classify_plants_ferc1 import ( @@ -205,13 +205,13 @@ def rename_dicts_xbrl(self): class WideToTidy(TransformParams): """Parameters for converting a wide table to a tidy table with value types.""" - idx_cols: list[str] | None + idx_cols: list[str] | None = None """List of column names to treat as the table index.""" stacked_column_name: str | None = None """Name of column that will contain the stacked categories.""" - value_types: list[str] | None + value_types: list[str] | None = None """List of names of value types that will end up being the column names. Some of the FERC tables have multiple data types spread across many different @@ -643,7 +643,8 @@ class CombineAxisColumnsXbrl(TransformParams): new_axis_column_name: str | None = None """The name of the combined axis column -- must end with the suffix ``_axis``!.""" - @validator("new_axis_column_name") + @field_validator("new_axis_column_name") + @classmethod def doesnt_end_with_axis(cls, v): """Ensure that new axis column ends in _axis.""" if v is not None and not v.endswith("_axis"): @@ -724,10 +725,10 @@ def combine_axis_columns_xbrl( class IsCloseTolerance(TransformParams): """Info for testing a particular check.""" - isclose_rtol: confloat(ge=0.0) = 1e-5 + isclose_rtol: Annotated[float, Field(ge=0.0)] = 1e-5 """Relative tolerance to use in :func:`np.isclose` for determining equality.""" - isclose_atol: confloat(ge=0.0, le=0.01) = 1e-8 + isclose_atol: Annotated[float, Field(ge=0.0, le=0.01)] = 1e-8 """Absolute tolerance to use in :func:`np.isclose` for determining equality.""" @@ -744,12 +745,12 @@ class CalculationIsCloseTolerance(TransformParams): class MetricTolerances(TransformParams): """Tolerances for all data checks to be preformed within a grouped df.""" - error_frequency: confloat(ge=0.0, le=1.0) = 0.01 - relative_error_magnitude: confloat(ge=0.0) = 0.02 - null_calculated_value_frequency: confloat(ge=0.0, le=1.0) = 0.7 + error_frequency: Annotated[float, Field(ge=0.0, le=1.0)] = 0.01 + relative_error_magnitude: Annotated[float, Field(ge=0.0)] = 0.02 + null_calculated_value_frequency: Annotated[float, Field(ge=0.0, le=1.0)] = 0.7 """Fraction of records with non-null reported values and null calculated values.""" - absolute_error_magnitude: confloat(ge=0.0) = np.inf - null_reported_value_frequency: confloat(ge=0.0, le=1.0) = 1.0 + absolute_error_magnitude: Annotated[float, Field(ge=0.0)] = np.inf + null_reported_value_frequency: Annotated[float, Field(ge=0.0, le=1.0)] = 1.0 # ooof this one is just bad @@ -820,7 +821,7 @@ class GroupMetricChecks(TransformParams): group_metric_tolerances: GroupMetricTolerances = GroupMetricTolerances() is_close_tolerance: CalculationIsCloseTolerance = CalculationIsCloseTolerance() - # @root_validator + # @model_validator # def grouped_tol_ge_ungrouped_tol(cls, values): # """Grouped tolerance should always be greater than or equal to ungrouped.""" # group_metric_tolerances = values["group_metric_tolerances"] diff --git a/src/pudl/workspace/datastore.py b/src/pudl/workspace/datastore.py index 73b72a5571..9e52160d2e 100644 --- a/src/pudl/workspace/datastore.py +++ b/src/pudl/workspace/datastore.py @@ -9,13 +9,14 @@ from collections import defaultdict from collections.abc import Iterator from pathlib import Path -from typing import Any, Self +from typing import Annotated, Any, Self from urllib.parse import ParseResult, urlparse import datapackage import requests from google.auth.exceptions import DefaultCredentialsError -from pydantic import BaseSettings, HttpUrl, constr +from pydantic import HttpUrl, StringConstraints +from pydantic_settings import BaseSettings, SettingsConfigDict from requests.adapters import HTTPAdapter from urllib3.util.retry import Retry @@ -27,9 +28,12 @@ logger = pudl.logging_helpers.get_logger(__name__) PUDL_YML = Path.home() / ".pudl.yml" -ZenodoDoi = constr( - strict=True, min_length=16, regex=r"(10\.5072|10\.5281)/zenodo.([\d]+)" -) +ZenodoDoi = Annotated[ + str, + StringConstraints( + strict=True, min_length=16, pattern=r"(10\.5072|10\.5281)/zenodo.([\d]+)" + ), +] class ChecksumMismatchError(ValueError): @@ -194,13 +198,7 @@ class ZenodoDoiSettings(BaseSettings): ferc714: ZenodoDoi = "10.5281/zenodo.8326694" # ferc714: ZenodoDoi = "10.5072/zenodo.1237565" phmsagas: ZenodoDoi = "10.5281/zenodo.8346646" - # phmsagas: ZenodoDoi = "10.5072/zenodo.1239253" - - class Config: - """Pydantic config, reads from .env file.""" - - env_prefix = "pudl_zenodo_doi_" - env_file = ".env" + model_config = SettingsConfigDict(env_prefix="pudl_zenodo_doi_", env_file=".env") class ZenodoFetcher: diff --git a/src/pudl/workspace/setup.py b/src/pudl/workspace/setup.py index 6afaa751ea..f4d6d50715 100644 --- a/src/pudl/workspace/setup.py +++ b/src/pudl/workspace/setup.py @@ -4,36 +4,15 @@ import pathlib import shutil from pathlib import Path -from typing import Any -from pydantic import BaseSettings, DirectoryPath -from pydantic.validators import path_validator +from pydantic import DirectoryPath, NewPath +from pydantic_settings import BaseSettings, SettingsConfigDict import pudl.logging_helpers logger = pudl.logging_helpers.get_logger(__name__) - -class MissingPath(Path): - """Validates potential path that doesn't exist.""" - - @classmethod - def __get_validators__(cls) -> Any: - """Validates that path doesn't exist and is path-like.""" - yield path_validator - yield cls.validate - - @classmethod - def validate(cls, value: Path) -> Path: - """Validates that path doesn't exist.""" - if value.exists(): - raise ValueError("path exists") - - return value - - -# TODO: The following could be replaced with NewPath from pydantic v2 -PotentialDirectoryPath = DirectoryPath | MissingPath +PotentialDirectoryPath = DirectoryPath | NewPath class PudlPaths(BaseSettings): @@ -45,11 +24,7 @@ class PudlPaths(BaseSettings): pudl_input: PotentialDirectoryPath pudl_output: PotentialDirectoryPath - - class Config: - """Pydantic config, reads from .env file.""" - - env_file = ".env" + model_config = SettingsConfigDict(env_file=".env") @property def input_dir(self) -> Path: diff --git a/test/conftest.py b/test/conftest.py index 14aef3e05b..643c6fd22f 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -8,7 +8,6 @@ import pytest import sqlalchemy as sa -import yaml from dagster import build_init_resource_context, materialize_to_memory import pudl @@ -112,9 +111,7 @@ def etl_parameters(request, test_dir) -> EtlSettings: etl_settings_yml = Path( test_dir.parent / "src/pudl/package_data/settings/etl_fast.yml" ) - with Path.open(etl_settings_yml, encoding="utf8") as settings_file: - etl_settings_out = yaml.safe_load(settings_file) - etl_settings = EtlSettings().parse_obj(etl_settings_out) + etl_settings = EtlSettings.from_yaml(etl_settings_yml) return etl_settings diff --git a/test/unit/settings_test.py b/test/unit/settings_test.py index 3a966ea93e..89c7224d06 100644 --- a/test/unit/settings_test.py +++ b/test/unit/settings_test.py @@ -212,18 +212,18 @@ class TestGlobalConfig: def test_unknown_dataset(self): """Test unkown dataset fed to DatasetsSettings.""" with pytest.raises(ValidationError): - DatasetsSettings().parse_obj({"unknown_data": "data"}) + DatasetsSettings().model_validate({"unknown_data": "data"}) with pytest.raises(ValidationError): - EiaSettings().parse_obj({"unknown_data": "data"}) + EiaSettings().model_validate({"unknown_data": "data"}) def test_immutability(self): """Test immutability config is working correctly.""" - with pytest.raises(TypeError): + with pytest.raises(ValidationError): settings = DatasetsSettings() settings.eia = EiaSettings() - with pytest.raises(TypeError): + with pytest.raises(ValidationError): settings = EiaSettings() settings.eia860 = Eia860Settings()