From 40432cd008ec23efab129d1c3b4a3029d125b332 Mon Sep 17 00:00:00 2001 From: SimonBoothroyd Date: Sun, 29 Sep 2024 22:09:15 -0400 Subject: [PATCH] Add `pydantic =1` support --- .github/workflows/ci.yaml | 3 + Makefile | 2 +- descent/train.py | 143 ++++++++++++++++++++++++-------------- devtools/envs/base.yaml | 1 + 4 files changed, 97 insertions(+), 52 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index f864da7..9b422b6 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -25,6 +25,9 @@ jobs: make lint make test make docs-build + + mamba install --name descent --yes "pydantic <2" + make test - name: CodeCov uses: codecov/codecov-action@v3.1.1 diff --git a/Makefile b/Makefile index 06c3d00..5222376 100644 --- a/Makefile +++ b/Makefile @@ -19,7 +19,7 @@ format: $(CONDA_ENV_RUN) ruff check --fix --select I $(PACKAGE_DIR) test: - $(CONDA_ENV_RUN) pytest -v --cov=$(PACKAGE_NAME) --cov-report=xml --color=yes $(PACKAGE_DIR)/tests/ + $(CONDA_ENV_RUN) pytest -v --cov=$(PACKAGE_NAME) --cov-append --cov-report=xml --color=yes $(PACKAGE_DIR)/tests/ docs-build: $(CONDA_ENV_RUN) mkdocs build diff --git a/descent/train.py b/descent/train.py index 0c7d938..0aaea7e 100644 --- a/descent/train.py +++ b/descent/train.py @@ -26,48 +26,53 @@ def _unflatten_tensors( return tensors -class _PotentialKey(pydantic.BaseModel): - """ - - TODO: Needed until interchange upgrades to pydantic >=2 - """ +if pydantic.__version__.startswith("1."): + _PotentialKey = openff.interchange.models.PotentialKey + PotentialKeyList = list[_PotentialKey] +else: - id: str - mult: int | None = None - associated_handler: str | None = None - bond_order: float | None = None - - def __hash__(self) -> int: - return hash((self.id, self.mult, self.associated_handler, self.bond_order)) + class _PotentialKey(pydantic.BaseModel): + """ - def __eq__(self, other: object) -> bool: - import openff.interchange.models + TODO: Needed until interchange upgrades to pydantic >=2 + """ - return ( - isinstance(other, (_PotentialKey, openff.interchange.models.PotentialKey)) - and self.id == other.id - and self.mult == other.mult - and self.associated_handler == other.associated_handler - and self.bond_order == other.bond_order - ) + id: str + mult: int | None = None + associated_handler: str | None = None + bond_order: float | None = None + + def __hash__(self) -> int: + return hash((self.id, self.mult, self.associated_handler, self.bond_order)) + + def __eq__(self, other: object) -> bool: + import openff.interchange.models + + return ( + isinstance( + other, (_PotentialKey, openff.interchange.models.PotentialKey) + ) + and self.id == other.id + and self.mult == other.mult + and self.associated_handler == other.associated_handler + and self.bond_order == other.bond_order + ) + def _convert_keys(value: typing.Any) -> typing.Any: + if not isinstance(value, list): + return value -def _convert_keys(value: typing.Any) -> typing.Any: - if not isinstance(value, list): + value = [ + _PotentialKey(**v.dict()) + if isinstance(v, openff.interchange.models.PotentialKey) + else v + for v in value + ] return value - value = [ - _PotentialKey(**v.dict()) - if isinstance(v, openff.interchange.models.PotentialKey) - else v - for v in value + PotentialKeyList = typing.Annotated[ + list[_PotentialKey], pydantic.BeforeValidator(_convert_keys) ] - return value - - -PotentialKeyList = typing.Annotated[ - list[_PotentialKey], pydantic.BeforeValidator(_convert_keys) -] class AttributeConfig(pydantic.BaseModel): @@ -89,17 +94,35 @@ class AttributeConfig(pydantic.BaseModel): "none indicates no constraint.", ) - @pydantic.model_validator(mode="after") - def _validate_keys(self): - """Ensure that the keys in `scales` and `limits` match `cols`.""" + if pydantic.__version__.startswith("1."): - if any(key not in self.cols for key in self.scales): - raise ValueError("cannot scale non-trainable parameters") + @pydantic.root_validator + def _validate_keys(cls, values): + cols = values.get("cols") - if any(key not in self.cols for key in self.limits): - raise ValueError("cannot clamp non-trainable parameters") + scales = values.get("scales") + limits = values.get("limits") - return self + if any(key not in cols for key in scales): + raise ValueError("cannot scale non-trainable parameters") + if any(key not in cols for key in limits): + raise ValueError("cannot clamp non-trainable parameters") + + return values + + else: + + @pydantic.model_validator(mode="after") + def _validate_keys(self): + """Ensure that the keys in `scales` and `limits` match `cols`.""" + + if any(key not in self.cols for key in self.scales): + raise ValueError("cannot scale non-trainable parameters") + + if any(key not in self.cols for key in self.limits): + raise ValueError("cannot clamp non-trainable parameters") + + return self class ParameterConfig(AttributeConfig): @@ -118,18 +141,36 @@ class ParameterConfig(AttributeConfig): "If ``None``, no parameters will be excluded.", ) - @pydantic.model_validator(mode="after") - def _validate_include_exclude(self): - """Ensure that the keys in `include` and `exclude` are disjoint.""" + if pydantic.__version__.startswith("1."): + + @pydantic.root_validator + def _validate_include_exclude(cls, values): + include = values.get("include") + exclude = values.get("exclude") + + if include is not None and exclude is not None: + include = {*include} + exclude = {*exclude} + + if include & exclude: + raise ValueError("cannot include and exclude the same parameter") + + return values + + else: + + @pydantic.model_validator(mode="after") + def _validate_include_exclude(self): + """Ensure that the keys in `include` and `exclude` are disjoint.""" - if self.include is not None and self.exclude is not None: - include = {*self.include} - exclude = {*self.exclude} + if self.include is not None and self.exclude is not None: + include = {*self.include} + exclude = {*self.exclude} - if include & exclude: - raise ValueError("cannot include and exclude the same parameter") + if include & exclude: + raise ValueError("cannot include and exclude the same parameter") - return self + return self class Trainable: diff --git a/devtools/envs/base.yaml b/devtools/envs/base.yaml index 56fa11a..d6c6ca3 100644 --- a/devtools/envs/base.yaml +++ b/devtools/envs/base.yaml @@ -57,5 +57,6 @@ dependencies: - mkdocs-literate-nav - mkdocstrings - mkdocstrings-python + - griffe <1 - black - mike