Skip to content

Commit

Permalink
Added pre-commit-hook for formatting code
Browse files Browse the repository at this point in the history
* Removed unnecessary comments
  • Loading branch information
david-zwicker committed Aug 17, 2024
1 parent 20bd997 commit 71a1658
Show file tree
Hide file tree
Showing 35 changed files with 82 additions and 72 deletions.
12 changes: 12 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v2.3.0
hooks:
- id: check-yaml

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.6.1
hooks:
- id: ruff
args: [--fix, --show-fixes]
- id: ruff-format
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
project = "py-pde"
module_name = "pde"
author = "Zwicker Group"
copyright = f"{date.today().year}, {author}" # @ReservedAssignment # noqa: A001
copyright = f"{date.today().year}, {author}" # noqa: A001
html_logo = "_images/logo_small.png"

# Determine the version from the actual package
Expand Down
14 changes: 7 additions & 7 deletions pde/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@
import contextlib

# import all other modules that should occupy the main name space
from .fields import * # @UnusedWildImport
from .grids import * # @UnusedWildImport
from .pdes import * # @UnusedWildImport
from .solvers import * # @UnusedWildImport
from .storage import * # @UnusedWildImport
from .fields import *
from .grids import *
from .pdes import *
from .solvers import *
from .storage import *
from .tools.parameters import Parameter
from .trackers import * # @UnusedWildImport
from .visualization import * # @UnusedWildImport
from .trackers import *
from .visualization import *

with contextlib.suppress(ImportError):
from .tools.modelrunner import *
Expand Down
8 changes: 4 additions & 4 deletions pde/fields/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(
self.label = label
self._logger = logging.getLogger(self.__class__.__name__)

def __init_subclass__(cls, **kwargs): # @NoSelf
def __init_subclass__(cls, **kwargs):
"""Register all subclassess to reconstruct them later."""
super().__init_subclass__(**kwargs)

Expand Down Expand Up @@ -359,7 +359,7 @@ def assert_field_compatible(
Determines whether it is acceptable that `other` is an instance of
:class:`~pde.fields.ScalarField`.
"""
from .scalar import ScalarField # @Reimport
from .scalar import ScalarField

# check whether they are the same class
is_scalar = accept_scalar and isinstance(other, ScalarField)
Expand Down Expand Up @@ -489,7 +489,7 @@ def _binary_operation(

if isinstance(other, FieldBase):
# right operator is a field
from .scalar import ScalarField # @Reimport
from .scalar import ScalarField

# determine the dtype of the result of the operation
dtype = np.result_type(self.data, other.data)
Expand Down Expand Up @@ -539,7 +539,7 @@ def _binary_operation_inplace(
"""
if isinstance(other, FieldBase):
# right operator is a field
from .scalar import ScalarField # @Reimport
from .scalar import ScalarField

if scalar_second:
# right operator must be a scalar
Expand Down
4 changes: 1 addition & 3 deletions pde/fields/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1041,9 +1041,7 @@ def plot(
kind = [kind] * num_panels
reference = [
field.plot(kind=knd, ax=ax, action="none", **kwargs, **sp_args)
for field, knd, ax, sp_args in zip( # @UnusedVariable
self.fields, kind, axs, subplot_args
)
for field, knd, ax, sp_args in zip(self.fields, kind, axs, subplot_args)
]

# return the references for all subplots
Expand Down
4 changes: 2 additions & 2 deletions pde/fields/vectorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def dot(
:class:`~pde.fields.scalar.ScalarField` or
:class:`~pde.fields.vectorial.VectorField`: result of applying the operator
"""
from .tensorial import Tensor2Field # @Reimport
from .tensorial import Tensor2Field

# check input
self.grid.assert_grid_compatible(other.grid)
Expand Down Expand Up @@ -253,7 +253,7 @@ def outer_product(
Returns:
:class:`~pde.fields.tensorial.Tensor2Field`: result of the operation
"""
from .tensorial import Tensor2Field # @Reimport
from .tensorial import Tensor2Field

self.assert_field_compatible(other)

Expand Down
6 changes: 3 additions & 3 deletions pde/grids/_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,7 +732,7 @@ def broadcast(self, data: TData) -> TData:
Returns:
The same data, but on all nodes
"""
from mpi4py.MPI import COMM_WORLD # @UnresolvedImport
from mpi4py.MPI import COMM_WORLD

return COMM_WORLD.bcast(data, root=0) # type: ignore

Expand All @@ -747,7 +747,7 @@ def gather(self, data: TData) -> list[TData] | None:
None on all nodes, except the main node, which receives an ordered list with
the data from all nodes.
"""
from mpi4py.MPI import COMM_WORLD # @UnresolvedImport
from mpi4py.MPI import COMM_WORLD

return COMM_WORLD.gather(data, root=0)

Expand All @@ -761,7 +761,7 @@ def allgather(self, data: TData) -> list[TData]:
Returns:
list: data from all nodes.
"""
from mpi4py.MPI import COMM_WORLD # @UnresolvedImport
from mpi4py.MPI import COMM_WORLD

return COMM_WORLD.allgather(data)

Expand Down
8 changes: 4 additions & 4 deletions pde/grids/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def __init__(self) -> None:
self.axes = [self.c.axes[i] for i in self._axes_described]
self.axes_symmetric = [self.c.axes[i] for i in self.axes_symmetric] # type: ignore

def __init_subclass__(cls, **kwargs) -> None: # @NoSelf
def __init_subclass__(cls, **kwargs) -> None:
"""Register all subclassess to reconstruct them later."""
super().__init_subclass__(**kwargs)
if cls is not GridBase:
Expand Down Expand Up @@ -1066,7 +1066,7 @@ def get_boundary_conditions(
PeriodicityError:
If the boundaries are not compatible with the periodic axes of the grid.
"""
from .boundaries import Boundaries # @Reimport
from .boundaries import Boundaries

if self._mesh is None:
# get boundary conditions for a simple grid that is not part of a mesh
Expand Down Expand Up @@ -1226,7 +1226,7 @@ def register_operator(factor_func_arg: OperatorFactory):

@hybridmethod # type: ignore
@property
def operators(cls) -> set[str]: # @NoSelf
def operators(cls) -> set[str]:
"""set: all operators defined for this class"""
result = set()
# add all customly defined operators
Expand Down Expand Up @@ -1574,7 +1574,7 @@ def integrate(

else:
# we are in a parallel run, so we need to gather the sub-integrals from all
from mpi4py.MPI import COMM_WORLD # @UnresolvedImport
from mpi4py.MPI import COMM_WORLD

integral_full = np.empty_like(integral)
COMM_WORLD.Allreduce(integral, integral_full)
Expand Down
2 changes: 1 addition & 1 deletion pde/grids/boundaries/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ def __init__(self, grid: GridBase, axis: int, upper: bool, *, rank: int = 0):

self._logger = logging.getLogger(self.__class__.__name__)

def __init_subclass__(cls, **kwargs): # @NoSelf
def __init_subclass__(cls, **kwargs):
"""Register all subclasses to reconstruct them later."""
super().__init_subclass__(**kwargs)

Expand Down
2 changes: 1 addition & 1 deletion pde/grids/cylindrical.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ def slice(self, indices: Sequence[int]) -> CartesianGrid | PolarSymGrid:

if indices[0] == 0:
# return a radial grid
from .spherical import PolarSymGrid # @Reimport
from .spherical import PolarSymGrid

return PolarSymGrid(self.radius, self.shape[0])

Expand Down
4 changes: 2 additions & 2 deletions pde/grids/operators/cartesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
from ...tools.typing import OperatorType
from ..boundaries import Boundaries
from ..cartesian import CartesianGrid
from .common import make_derivative as _make_derivative # @UnusedImport
from .common import make_derivative2 as _make_derivative2 # @UnusedImport
from .common import make_derivative as _make_derivative
from .common import make_derivative2 as _make_derivative2
from .common import make_general_poisson_solver, uniform_discretization

# The `make_derivative?` methods are imported for backward compatibility. Their usage is
Expand Down
2 changes: 1 addition & 1 deletion pde/pdes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,7 +627,7 @@ def solve(
the current node is not the main MPI node.
"""
from ..solvers import Controller
from ..solvers.base import SolverBase # @Reimport
from ..solvers.base import SolverBase

# create solver instance
if callable(solver):
Expand Down
2 changes: 1 addition & 1 deletion pde/pdes/laplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from ..fields import ScalarField
from ..grids.base import GridBase
from ..grids.boundaries.axes import BoundariesData # @UnusedImport
from ..grids.boundaries.axes import BoundariesData
from ..tools.docstrings import fill_in_docstring


Expand Down
6 changes: 3 additions & 3 deletions pde/pdes/pde.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import numba as nb
import numpy as np
from numba.typed import Dict as NumbaDict # @UnresolvedImport
from numba.typed import Dict as NumbaDict
from sympy import Symbol
from sympy.core.function import UndefinedFunction

Expand Down Expand Up @@ -329,13 +329,13 @@ def _compile_rhs_single(
# extend the signature
signature += tuple(state.grid.axes)
# inject the spatial coordinates into the expression for the rhs
extra_args = tuple( # @UnusedVariable
extra_args = tuple(
state.grid.cell_coords[..., i] for i in range(state.grid.num_axes)
)

else:
# expression only depends on the actual variables
extra_args = () # @UnusedVariable
extra_args = ()

# check whether all variables are accounted for
extra_vars = set(expr.vars) - set(signature)
Expand Down
4 changes: 2 additions & 2 deletions pde/solvers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def __init__(self, pde: PDEBase, *, backend: BackendType = "auto"):
self.info["pde_class"] = self.pde.__class__.__name__
self._logger = logging.getLogger(self.__class__.__name__)

def __init_subclass__(cls, **kwargs): # @NoSelf
def __init_subclass__(cls, **kwargs):
"""Register all subclassess to reconstruct them later."""
super().__init_subclass__(**kwargs)
if not isabstract(cls):
Expand Down Expand Up @@ -113,7 +113,7 @@ def from_name(cls, name: str, pde: PDEBase, **kwargs) -> SolverBase:
return solver_class(pde, **kwargs)

@classproperty
def registered_solvers(cls) -> list[str]: # @NoSelf
def registered_solvers(cls) -> list[str]:
"""list of str: the names of the registered solvers"""
return sorted(cls._subclasses.keys())

Expand Down
8 changes: 4 additions & 4 deletions pde/storage/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ def extract_field(
:class:`MemoryStorage`: a storage instance that contains the data for the
single field
"""
from .memory import MemoryStorage # @Reimport
from .memory import MemoryStorage

if self._field is None:
self._init_field()
Expand Down Expand Up @@ -435,7 +435,7 @@ def extract_time_range(
Returns:
:class:`MemoryStorage`: a storage instance that contains the extracted data.
"""
from .memory import MemoryStorage # @Reimport
from .memory import MemoryStorage

# get the time bracket
try:
Expand Down Expand Up @@ -502,7 +502,7 @@ def apply(
raise TypeError("The user function must return a field")

if out is None:
from .memory import MemoryStorage # @Reimport
from .memory import MemoryStorage

out = MemoryStorage(field_obj=transformed)

Expand All @@ -517,7 +517,7 @@ def apply(

# make sure that a storage is returned, even when no fields are present
if out is None:
from .memory import MemoryStorage # @Reimport
from .memory import MemoryStorage

out = MemoryStorage()

Expand Down
2 changes: 1 addition & 1 deletion pde/tools/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def make_unserializer(method: SerializerMethod) -> Callable:
return yaml.full_load

if method == "yaml_unsafe":
import yaml # @Reimport
import yaml

Check warning on line 284 in pde/tools/cache.py

View check run for this annotation

Codecov / codecov/patch

pde/tools/cache.py#L284

Added line #L284 was not covered by tests

return yaml.unsafe_load

Expand Down
8 changes: 3 additions & 5 deletions pde/tools/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ def compile_func(func):
# partial function instead of replacing the constants in the sympy expression
# directly since sympy does not work well with numpy arrays.
if constants:
const_values = tuple(self.consts[c] for c in constants) # @UnusedVariable
const_values = tuple(self.consts[c] for c in constants)

if prepare_compilation:
func = jit(func)
Expand Down Expand Up @@ -1099,13 +1099,11 @@ def evaluate(
# extend the signature
signature += tuple(grid.axes)
# inject the spatial coordinates into the expression for the rhs
extra_args = tuple( # @UnusedVariable
grid.cell_coords[..., i] for i in range(grid.num_axes)
)
extra_args = tuple(grid.cell_coords[..., i] for i in range(grid.num_axes))

else:
# expression only depends on the actual variables
extra_args = () # @UnusedVariable
extra_args = ()

# check whether all variables are accounted for
extra_vars = set(expr.vars) - set(signature)
Expand Down
12 changes: 6 additions & 6 deletions pde/tools/numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import numpy as np
from numba.core.types import npytypes, scalars
from numba.extending import overload, register_jitable
from numba.typed import Dict as NumbaDict # @UnresolvedImport
from numba.typed import Dict as NumbaDict

from .. import config
from ..tools.misc import decorator_arguments
Expand Down Expand Up @@ -128,12 +128,12 @@ def f():
"multithreading_threshold": config["numba.multithreading_threshold"],
"fastmath": config["numba.fastmath"],
"debug": config["numba.debug"],
"using_svml": nb.config.USING_SVML, # @UndefinedVariable
"using_svml": nb.config.USING_SVML,
"threading_layer": threading_layer,
"omp_num_threads": os.environ.get("OMP_NUM_THREADS"),
"mkl_num_threads": os.environ.get("MKL_NUM_THREADS"),
"num_threads": nb.config.NUMBA_NUM_THREADS, # @UndefinedVariable
"num_threads_default": nb.config.NUMBA_DEFAULT_NUM_THREADS, # @UndefinedVariable
"num_threads": nb.config.NUMBA_NUM_THREADS,
"num_threads_default": nb.config.NUMBA_DEFAULT_NUM_THREADS,
"cuda_available": cuda_available,
"roc_available": roc_available,
}
Expand Down Expand Up @@ -203,7 +203,7 @@ def jit(function: TFunc, signature=None, parallel: bool = False, **kwargs) -> TF
return nb.jit(signature, **kwargs)(function) # type: ignore


if nb.config.DISABLE_JIT: # @UndefinedVariable
if nb.config.DISABLE_JIT:
# dummy function that creates a ctypes pointer
def address_as_void_pointer(addr):
"""Returns a void pointer from a given memory address.
Expand Down Expand Up @@ -321,7 +321,7 @@ def random_seed(seed: int = 0) -> None:
seed (int): Sets random seed
"""
np.random.seed(seed)
if not nb.config.DISABLE_JIT: # @UndefinedVariable
if not nb.config.DISABLE_JIT:
_random_seed_compiled(seed)


Expand Down
4 changes: 2 additions & 2 deletions pde/tools/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def get_progress_bar_class(fancy: bool = True):
# try using notebook progress bar
try:
# check whether progress bar can use a widget
import ipywidgets # @UnusedImport
import ipywidgets
except ImportError:
# widgets are not available => use standard tqdm
progress_bar_class = tqdm.tqdm
Expand Down Expand Up @@ -130,7 +130,7 @@ def show(self):
def in_jupyter_notebook() -> bool:
"""Checks whether we are in a jupyter notebook."""
try:
from IPython import display, get_ipython # @UnusedImport
from IPython import display, get_ipython
except ImportError:
return False

Expand Down
Loading

0 comments on commit 71a1658

Please sign in to comment.