Skip to content

Commit

Permalink
Migrated entry points (#27)
Browse files Browse the repository at this point in the history
Migrated entry points from `project` to the `template`.

PR #27
  • Loading branch information
santacodes authored Jul 22, 2024
2 parents f7038c5 + 510955a commit 85869e9
Show file tree
Hide file tree
Showing 5 changed files with 771 additions and 8 deletions.
27 changes: 21 additions & 6 deletions {{cookiecutter.project_name}}/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,13 @@ classifiers = [
{%- if cookiecutter.backend == "hatch" %}
dynamic = ["version"]
{%- endif %}
dependencies = ["pybamm"]
dependencies = ["pybamm",]

[project.optional-dependencies]
dev = [
"pytest >=6",
"pytest-cov >=3",
"nox",
"nox[uv]",
"pre-commit",
]
docs = [
Expand All @@ -76,6 +76,13 @@ Homepage = "{{ cookiecutter.url }}"
"Bug Tracker" = "{{ cookiecutter.url }}/issues"
Discussions = "{{ cookiecutter.url }}/discussions"
Changelog = "{{ cookiecutter.url }}/releases"

[project.entry-points."parameter_sets"]
Chen2020 = "{{ cookiecutter.__project_slug }}.parameters.input.Chen2020:get_parameter_values"

[project.entry-points."models"]
SPM = "{{ cookiecutter.__project_slug }}.models.input.SPM:SPM"

{# keep this line here for newline #}
{%- if cookiecutter.backend == "hatch" %}
[tool.hatch]
Expand All @@ -88,8 +95,8 @@ envs.default.dependencies = [
{# keep this line here for newline #}
{%- if cookiecutter.mypy %}
[tool.mypy]
python_version = "3.8"
strict = true
python_version = "3.11"
strict = false
warn_return_any = false
show_error_codes = true
enable_error_code = [
Expand All @@ -99,6 +106,9 @@ enable_error_code = [
]
disallow_untyped_defs = false
disallow_untyped_calls = false
ignore_missing_imports = true
allow_redefinition = true
disable_error_code = ["call-overload", "operator"]
{%- endif %}

[tool.coverage]
Expand All @@ -112,7 +122,7 @@ select = [
"E", "F", "W", # flake8
"B", # flake8-bugbear
"I", # isort
"ARG", # flake8-unused-arguments
#"ARG", # flake8-unused-arguments
"C4", # flake8-comprehensions
"EM", # flake8-errmsg
"ICN", # flake8-import-conventions
Expand All @@ -123,7 +133,7 @@ select = [
"PL", # pylint
"PT", # flake8-pytest-style
"PTH", # flake8-use-pathlib
"RET", # flake8-return
#"RET", # flake8-return
"RUF", # Ruff-specific
"SIM", # flake8-simplify
"T20", # flake8-print
Expand All @@ -138,6 +148,11 @@ unfixable = [
"T20", # Removes print statements
"F841", # Removes unused variables
]
ignore = [
"E741", # Ambiguous variable name
"E501", # Line too long
"PLR2004", # Magic value used in comparison
]
line-length = 100
exclude = []
flake8-unused-arguments.ignore-variadic-names = true
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,17 @@
{%- endif %}

from ._version import version as __version__
import pybamm
from .entry_point import Model, parameter_sets, models
{# keep this line here for newline #}
{%- if cookiecutter.mypy %}
__all__: tuple[str] = ("__version__",)
__all__: list[str] = [
{%- else %}
__all__ = ("__version__",)
__all__ = [
{%- endif %}
"__version__",
"pybamm",
"parameter_sets",
"Model",
"models",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
"""
This code is adopted from the PyBaMM project under the BSD-3-Clause
Copyright (c) 2018-2024, the PyBaMM team.
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
* Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""


import importlib.metadata
import sys
import textwrap
from collections.abc import Mapping
from typing import Callable

class EntryPoint(Mapping):
"""
Dict-like interface for accessing parameter sets and models through entry points in cookiecutter template.
Access via :py:data:`pybamm_cookiecutter.parameter_sets` for parameter_sets
Access via :py:data:`pybamm_cookiecutter.Model` for Models
Examples
--------
Listing available parameter sets:
>>> import pybamm_cookiecutter
>>> list(pybamm_cookiecutter.parameter_sets)
['Chen2020', ...]
>>> list(pybamm_cookiecutter.models)
['SPM', ...]
Get the docstring for a parameter set/model:
>>> print(pybamm_cookiecutter.parameter_sets.get_docstring("Ai2020"))
<BLANKLINE>
Parameters for the Enertech cell (Ai2020), from the papers :footcite:t:`Ai2019`,
:footcite:t:`rieger2016new` and references therein.
...
>>> print(pybamm_cookiecutter.models.get_docstring("SPM"))
<BLANKLINE>
Single Particle Model (SPM) model of a lithium-ion battery, from :footcite:t:`Marquis2019`. This class differs from the :class:`pybamm.lithium_ion.SPM` model class in that it shows the whole model in a single class. This comes at the cost of flexibility in combining different physical effects, and in general the main SPM class should be used instead.
...
See also: :ref:`adding-parameter-sets`
"""

_instances = 0
def __init__(self, group):
"""Dict of entry points for parameter sets or models, lazily load entry points as"""
if not hasattr(self, 'initialized'): # Ensure __init__ is called once per instance
self.initialized = True
EntryPoint._instances += 1
self._all_entries = dict()
self.group = group
for entry_point in self.get_entries(self.group):
self._all_entries[entry_point.name] = entry_point

@staticmethod
def get_entries(group_name):
"""Wrapper for the importlib version logic"""
if sys.version_info < (3, 10): # pragma: no cover
return importlib.metadata.entry_points()[group_name]
else:
return importlib.metadata.entry_points(group=group_name)

def __new__(cls, group):
"""Ensure only two instances of entry points exist, one for parameter sets and the other for models"""
if EntryPoint._instances < 2:
cls.instance = super().__new__(cls)
return cls.instance

def __getitem__(self, key) -> dict:
return self._load_entry_point(key)()

def _load_entry_point(self, key) -> Callable:
"""Check that ``key`` is a registered ``parameter_sets`` or ``models` ,
and return the entry point for the parameter set/model, loading it needed."""
if key not in self._all_entries:
raise KeyError(f"Unknown parameter set or model: {key}")
ps = self._all_entries[key]
try:
ps = self._all_entries[key] = ps.load()
except AttributeError:
pass
return ps

def __iter__(self):
return self._all_entries.__iter__()

def __len__(self) -> int:
return len(self._all_entries)

def get_docstring(self, key):
"""Return the docstring for the ``key`` parameter set or model"""
return textwrap.dedent(self._load_entry_point(key).__doc__)

def __getattribute__(self, name):
try:
return super().__getattribute__(name)
except AttributeError as error:
raise error

#: Singleton Instance of :class:ParameterSets """
parameter_sets = EntryPoint(group="parameter_sets")

#: Singleton Instance of :class:ModelEntryPoints"""
models = EntryPoint(group="models")

def Model(model:str):
"""
Returns the loaded model object
Parameters
----------
model : str
The model name or author name of the model mentioned at the model entry point.
Returns
-------
pybamm.model
Model object of the initialised model.
Examples
--------
Listing available models:
>>> import pybamm_cookiecutter
>>> list(pybamm_cookiecutter.models)
['SPM', ...]
>>> pybamm_cookiecutter.Model('Author/Year')
<pybamm_cookiecutter.models.input.SPM.SPM object>
"""
return models[model]
Loading

0 comments on commit 85869e9

Please sign in to comment.