Skip to content

Commit

Permalink
Merge pull request GazzolaLab#373 from skim0119/typing/timestepper
Browse files Browse the repository at this point in the history
Typing: `timestepper` fix pytest and cleanup integration routine
  • Loading branch information
skim0119 authored May 7, 2024
2 parents b152c20 + 42bfd09 commit 98e9d0e
Show file tree
Hide file tree
Showing 16 changed files with 499 additions and 576 deletions.
24 changes: 12 additions & 12 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#* Variables
PYTHON := python3
PYTHONPATH := `pwd`
AUTOFLAKE8_ARGS := -r --exclude '__init__.py' --keep-pass-after-docstring
AUTOFLAKE_ARGS := -r
#* Poetry
.PHONY: poetry-download
poetry-download:
Expand Down Expand Up @@ -47,19 +47,19 @@ flake8:
poetry run flake8 --version
poetry run flake8 elastica tests

.PHONY: autoflake8-check
autoflake8-check:
poetry run autoflake8 --version
poetry run autoflake8 $(AUTOFLAKE8_ARGS) elastica tests examples
poetry run autoflake8 --check $(AUTOFLAKE8_ARGS) elastica tests examples
.PHONY: autoflake-check
autoflake-check:
poetry run autoflake --version
poetry run autoflake $(AUTOFLAKE_ARGS) elastica tests examples
poetry run autoflake --check $(AUTOFLAKE_ARGS) elastica tests examples

.PHONY: autoflake8-format
autoflake8-format:
poetry run autoflake8 --version
poetry run autoflake8 --in-place $(AUTOFLAKE8_ARGS) elastica tests examples
.PHONY: autoflake-format
autoflake-format:
poetry run autoflake --version
poetry run autoflake --in-place $(AUTOFLAKE_ARGS) elastica tests examples

.PHONY: format-codestyle
format-codestyle: black flake8
format-codestyle: black autoflake-format

.PHONY: mypy
mypy:
Expand All @@ -78,7 +78,7 @@ test_coverage_xml:
NUMBA_DISABLE_JIT=1 poetry run pytest --cov=elastica --cov-report=xml

.PHONY: check-codestyle
check-codestyle: black-check flake8 autoflake8-check
check-codestyle: black-check flake8 autoflake-check

.PHONY: formatting
formatting: format-codestyle
Expand Down
11 changes: 6 additions & 5 deletions elastica/modules/base_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
Basic coordinating for multiple, smaller systems that have an independently integrable
interface (i.e. works with symplectic or explicit routines `timestepper.py`.)
"""
from typing import Iterable, Callable, AnyStr
from typing import Iterable, Callable, AnyStr, Type
from elastica.typing import SystemType

import numpy as np

Expand Down Expand Up @@ -57,7 +58,7 @@ def __init__(self):
# We need to initialize our mixin classes
super(BaseSystemCollection, self).__init__()
# List of system types/bases that are allowed
self.allowed_sys_types = (RodBase, RigidBodyBase, SurfaceBase)
self.allowed_sys_types: tuple[Type, ...] = (RodBase, RigidBodyBase, SurfaceBase)
# List of systems to be integrated
self._systems = []
# Flag Finalize: Finalizing twice will cause an error,
Expand Down Expand Up @@ -98,11 +99,11 @@ def insert(self, idx, system):
def __str__(self):
return str(self._systems)

def extend_allowed_types(self, additional_types):
def extend_allowed_types(self, additional_types: list[Type, ...]):
self.allowed_sys_types += additional_types

def override_allowed_types(self, allowed_types):
self.allowed_sys_types = allowed_types
def override_allowed_types(self, allowed_types: list[Type, ...]):
self.allowed_sys_types = tuple(allowed_types)

def _get_sys_idx_if_valid(self, sys_to_be_added):
from numpy import int_ as npint
Expand Down
1 change: 0 additions & 1 deletion elastica/py.typed
Original file line number Diff line number Diff line change
@@ -1 +0,0 @@

4 changes: 3 additions & 1 deletion elastica/systems/analytical.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
from elastica._rotations import _rotate
from elastica.rod.data_structures import _RodSymplecticStepperMixin
from elastica.rod.rod_base import RodBase


class BaseStatefulSystem:
Expand Down Expand Up @@ -355,8 +356,9 @@ def make_simple_system_with_positions_directors(
)


class SimpleSystemWithPositionsDirectors(_RodSymplecticStepperMixin):
class SimpleSystemWithPositionsDirectors(_RodSymplecticStepperMixin, RodBase):
def __init__(self, start_position, end_position, start_director):
self.ring_rod_flag = False # TODO:
self.a = 0.5
self.b = 1
self.c = 2
Expand Down
4 changes: 2 additions & 2 deletions elastica/systems/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ def external_forces(self) -> NDArray: ...
@property
def external_torques(self) -> NDArray: ...

def update_internal_forces_and_torques(self, time: np.floating) -> None: ...


class SymplecticSystemProtocol(SystemProtocol, Protocol):
"""
Expand All @@ -64,8 +66,6 @@ def dynamic_rates(
self, time: np.floating, prefac: np.floating
) -> tuple[NDArray]: ...

def update_internal_forces_and_torques(self, time: np.floating) -> None: ...


class ExplicitSystemProtocol(SystemProtocol, Protocol):
# TODO: Temporarily made to handle explicit stepper.
Expand Down
111 changes: 52 additions & 59 deletions elastica/timestepper/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
__doc__ = """Timestepping utilities to be used with Rod and RigidBody classes"""

from typing import Tuple, List, Callable, Type
from elastica.typing import SystemType
from typing import Tuple, List, Callable, Type, Any, overload
from elastica.typing import SystemType, SystemCollectionType, SteppersOperatorsType

import numpy as np
from tqdm import tqdm
Expand All @@ -10,57 +10,52 @@

from .symplectic_steppers import PositionVerlet, PEFRL
from .explicit_steppers import RungeKutta4, EulerForward
from .protocol import StepperProtocol, SymplecticStepperProtocol

from .tag import SymplecticStepperTag, ExplicitStepperTag
from .protocol import StepperProtocol, StatefulStepperProtocol
from .protocol import MethodCollectorProtocol


# TODO: Both extend_stepper_interface and integrate should be in separate file.
# __init__ is probably not an ideal place to have these scripts.
# Deprecated: Remove in the future version
# Many script still uses this method to control timestep. Keep it for backward compatibility
def extend_stepper_interface(
Stepper: StepperProtocol, System: SystemType
) -> Tuple[Callable, Tuple[Callable]]:

# StepperMethodCollector: Type[MethodCollectorProtocol]
# SystemStepper: Type[StepperProtocol]
if isinstance(Stepper.Tag, SymplecticStepperTag):
from elastica.timestepper.symplectic_steppers import (
_SystemInstanceStepper,
_SystemCollectionStepper,
SymplecticStepperMethods,
)

StepperMethodCollector = SymplecticStepperMethods
elif isinstance(Stepper.Tag, ExplicitStepperTag): # type: ignore[no-redef]
from elastica.timestepper.explicit_steppers import (
_SystemInstanceStepper,
_SystemCollectionStepper,
ExplicitStepperMethods,
)

StepperMethodCollector = ExplicitStepperMethods
else:
raise NotImplementedError(
"Only explicit and symplectic steppers are supported, given stepper is {}".format(
Stepper.__class__.__name__
)
)

# Check if system is a "collection" of smaller systems
if is_system_a_collection(System):
SystemStepper = _SystemCollectionStepper
else:
SystemStepper = _SystemInstanceStepper

stepper_methods: Tuple[Callable] = StepperMethodCollector(Stepper).step_methods()
do_step_method: Callable = SystemStepper.do_step
stepper: StepperProtocol, system_collection: SystemCollectionType
) -> Tuple[
Callable[
[StepperProtocol, SystemCollectionType, np.floating, np.floating], np.floating
],
SteppersOperatorsType,
]:
try:
stepper_methods: SteppersOperatorsType = stepper.steps_and_prefactors
do_step_method: Callable = stepper.do_step # type: ignore[attr-defined]
except AttributeError as e:
raise NotImplementedError(f"{stepper} stepper is not supported.") from e
return do_step_method, stepper_methods


@overload
def integrate(
stepper: StepperProtocol,
systems: SystemType,
final_time: float,
n_steps: int,
restart_time: float,
progress_bar: bool,
) -> float: ...


@overload
def integrate(
StatefulStepper: StatefulStepperProtocol,
System: SystemType,
stepper: StepperProtocol,
systems: SystemCollectionType,
final_time: float,
n_steps: int,
restart_time: float,
progress_bar: bool,
) -> float: ...


def integrate(
stepper: StepperProtocol,
systems: SystemType | SystemCollectionType,
final_time: float,
n_steps: int = 1000,
restart_time: float = 0.0,
Expand All @@ -70,9 +65,9 @@ def integrate(
Parameters
----------
StatefulStepper : StatefulStepperProtocol
stepper : StepperProtocol
Stepper algorithm to use.
System : SystemType
systems : SystemType | SystemCollectionType
The elastica-system to simulate.
final_time : float
Total simulation time. The timestep is determined by final_time / n_steps.
Expand All @@ -86,17 +81,15 @@ def integrate(
assert final_time > 0.0, "Final time is negative!"
assert n_steps > 0, "Number of integration steps is negative!"

# Extend the stepper's interface after introspecting the properties
# of the system. If system is a collection of small systems (whose
# states cannot be aggregated), then stepper now loops over the system
# state
do_step, stages_and_updates = extend_stepper_interface(StatefulStepper, System)

dt = np.float64(float(final_time) / n_steps)
time = restart_time
dt = np.float_(float(final_time) / n_steps)
time = np.float_(restart_time)

for i in tqdm(range(n_steps), disable=(not progress_bar)):
time = do_step(StatefulStepper, stages_and_updates, System, time, dt)
if is_system_a_collection(systems):
for i in tqdm(range(n_steps), disable=(not progress_bar)):
time = stepper.step(systems, time, dt) # type: ignore[arg-type]
else:
for i in tqdm(range(n_steps), disable=(not progress_bar)):
time = stepper.step_single_instance(systems, time, dt) # type: ignore[arg-type]

print("Final time of simulation is : ", time)
return time
return float(time)
Loading

0 comments on commit 98e9d0e

Please sign in to comment.