Skip to content

Commit

Permalink
pinball update
Browse files Browse the repository at this point in the history
  • Loading branch information
Christian Lagemann committed Dec 10, 2024
1 parent 8698bf0 commit 0a19230
Show file tree
Hide file tree
Showing 29 changed files with 1,991,678 additions and 116 deletions.
5 changes: 4 additions & 1 deletion hydrogym/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
from . import distributed, firedrake
from . import distributed, firedrake, torch_env
from .core import CallbackBase, FlowEnv, PDEBase, TransientSolver
from .core_1DEnvs import OneDimEnv, PDESolverBase1D

from .torch_env import Kuramoto_Sivashinsky, Burgers # isort:skip
184 changes: 153 additions & 31 deletions hydrogym/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any, Callable, Iterable, Tuple, TypeVar, Union

import gym
# import gymnasium as gym
import numpy as np
from numpy.typing import ArrayLike

Expand Down Expand Up @@ -33,7 +34,9 @@ class PDEBase(metaclass=abc.ABCMeta):
any information about solving the time-varying equations
"""

MAX_CONTROL = np.inf
# MAX_CONTROL = np.inf
MAX_CONTROL_LOW = -np.inf
MAX_CONTROL_UP = np.inf
DEFAULT_MESH = ""
DEFAULT_DT = np.inf

Expand All @@ -47,12 +50,13 @@ class PDEBase(metaclass=abc.ABCMeta):

def __init__(self, **config):
self.mesh = self.load_mesh(name=config.get("mesh", self.DEFAULT_MESH))
self.reward_lambda = config.get("reward_lambda", 0.0)
self.initialize_state()

self.reset()

if config.get("restart"):
self.load_checkpoint(config["restart"])
self.load_checkpoint(config["restart"][0])

@property
@abc.abstractmethod
Expand Down Expand Up @@ -186,6 +190,7 @@ def advance_time(self, dt: float, act: list[float] = None) -> list[float]:
Returns:
Iterable[ArrayLike]: Updated actuator state
"""

if act is None:
act = self.control_state
self.t += dt
Expand Down Expand Up @@ -272,12 +277,14 @@ def __init__(self, flow: PDEBase, dt: float = None):
if dt is None:
dt = flow.DEFAULT_DT
self.dt = dt
self.t = 0.0

def solve(
self,
t_span: Tuple[float, float],
callbacks: Iterable[CallbackBase] = [],
controller: Callable = None,
start_iteration_value: int = 0,
) -> PDEBase:
"""Solve the initial-value problem for the PDE.
Expand All @@ -292,6 +299,7 @@ def solve(
PDEBase: The state of the PDE at the end of the solve
"""
for iter, t in enumerate(np.arange(*t_span, self.dt)):
iter = iter + start_iteration_value
if controller is not None:
y = self.flow.get_observations()
u = controller(t, y)
Expand All @@ -305,6 +313,43 @@ def solve(
cb.close()

return flow

def solve_multistep(
self,
num_substeps: int,
callbacks: Iterable[CallbackBase] = [],
controller: Callable = None,
start_iteration_value: int = 0,
) -> PDEBase:
"""Solve the initial-value problem for the PDE.
Args:
t_span (Tuple[float, float]): Tuple of start and end times
callbacks (Iterable[CallbackBase], optional):
List of callbacks to evaluate throughout the solve. Defaults to [].
controller (Callable, optional):
Feedback/forward controller `u = ctrl(t, y)`
Returns:
PDEBase: The state of the PDE at the end of the solve
"""
for iter in range(num_substeps):
iter = iter + start_iteration_value
t = iter * self.dt
if controller is not None:
y = self.flow.get_observations()
u = controller(t, y)
# print('controller output in step:', u, flush=True)
else:
u = None
flow = self.step(iter, control=u)
for cb in callbacks:
cb(iter, t, flow)

for cb in callbacks:
cb.close()

return flow

def step(self, iter: int, control: Iterable[float] = None, **kwargs):
"""Advance the transient simulation by one time step
Expand All @@ -315,22 +360,47 @@ def step(self, iter: int, control: Iterable[float] = None, **kwargs):
"""
raise NotImplementedError

def reset(self):
def reset(self, t=0.0):
"""Reset variables for the timestepper"""
pass
self.t = 0.0


class FlowEnv(gym.Env):

def __init__(self, env_config: dict):
self.flow: PDEBase = env_config.get("flow")(
**env_config.get("flow_config", {}))


self.solver: TransientSolver = env_config.get("solver")(
self.flow, **env_config.get("solver_config", {}))
self.callbacks: Iterable[CallbackBase] = env_config.get("callbacks", [])
self.rewardLogCallback: Iterable[CallbackBase] = env_config.get("actuation_config", {}).get("rewardLogCallback", None)
self.max_steps: int = env_config.get("max_steps", int(1e6))
self.iter: int = 0
self.q0: self.flow.StateType = self.flow.copy_state()
self.restart_ckpts = env_config.get("flow_config", {}).get("restart", None)
if self.restart_ckpts is None:
self.q0: self.flow.StateType = self.flow.copy_state()

# if len(env_config.get("flow_config", {}).get("restart", None)) > 1:
# self.dummy_flow: PDEBase = env_config.get("flow")(**env_config.get("flow_config", {}))
# self.restart_ckpts = env_config.get("flow_config", {}).get("restart", None)

# print("Restart ckpts:", self.restart_ckpts, flush=True)
# print("len:", len(self.restart_ckpts), flush=True)
# print("0 ckpt:", self.restart_ckpts[0], flush=True)

# self.q0s = []
# for ckpt in range(len(self.restart_ckpts)):
# # self.dummy_flow.mesh = self.dummy_flow.load_mesh(name=env_config.get("flow_config", {}).get("mesh", self.dummy_flow.DEFAULT_MESH))

# print("ckpt:", ckpt, self.restart_ckpts[ckpt],flush=True)
# self.dummy_flow.load_checkpoint(self.restart_ckpts[ckpt])
# self.q0s.append(self.dummy_flow.copy_state())

# print("self.q0s loaded", flush=True)
# # self.q0 = self.q0s[-1]
# self.q0: self.flow.StateType = self.flow.copy_state()

self.observation_space = gym.spaces.Box(
low=-np.inf,
Expand All @@ -339,63 +409,115 @@ def __init__(self, env_config: dict):
dtype=float,
)
self.action_space = gym.spaces.Box(
low=-self.flow.MAX_CONTROL,
high=self.flow.MAX_CONTROL,
low=self.flow.MAX_CONTROL_LOW,
high=self.flow.MAX_CONTROL_UP,
shape=(self.flow.num_inputs,),
dtype=float,
)

self.t = 0.
self.dt = env_config.get("solver_config", {}).get("dt", None)
assert self.dt is not None, f"Error: Solver timestep dt ({self.dt}) must not be None"
self.num_sim_substeps_per_actuation = env_config.get("actuation_config", {}).get("num_sim_substeps_per_actuation", None)

if self.num_sim_substeps_per_actuation is not None and self.num_sim_substeps_per_actuation > 1:
assert self.rewardLogCallback is not None, f"Error: If num_sim_substeps_per_actuation ({self.num_sim_substeps_per_actuation}) is set a reward callback function must be given, {self.rewardLogCallback}"
self.reward_aggreation_rule = env_config.get("actuation_config", {}).get("reward_aggreation_rule", None)
assert self.reward_aggreation_rule in ['mean', 'sum', 'median'], f"Error: reward aggregation rule ({self.reward_aggreation_rule}) is not given or not implemented yet"

def constant_action_controller(self, t, y):
return self.action

def set_callbacks(self, callbacks: Iterable[CallbackBase]):
self.callbacks = callbacks

def step(
self,
action: Iterable[ArrayLike] = None
) -> Tuple[ArrayLike, float, bool, dict]:
"""Advance the state of the environment. See gym.Env documentation
self, action: Iterable[ArrayLike] = None
) -> Tuple[ArrayLike, float, bool, dict]:
"""Advance the state of the environment. See gym.Env documentation
Args:
action (Iterable[ArrayLike], optional): Control inputs. Defaults to None.
action (Iterable[ActType], optional): Control inputs. Defaults to None.
Returns:
Tuple[ArrayLike, float, bool, dict]: obs, reward, done, info
Tuple[ObsType, float, bool, dict]: obs, reward, done, info
"""
self.solver.step(self.iter, control=action)
self.iter += 1
t = self.iter * self.solver.dt
for cb in self.callbacks:
cb(self.iter, t, self.flow)
obs = self.flow.get_observations()

reward = self.get_reward()
done = self.check_complete()
info = {}

obs = self.stack_observations(obs)

return obs, reward, done, info
# action = action * self.flow.CONTROL_SCALING
# print('control action', action, flush=True)

if self.num_sim_substeps_per_actuation is not None and self.num_sim_substeps_per_actuation > 1:
self.action = action
self.flow = self.solver.solve_multistep(num_substeps=self.num_sim_substeps_per_actuation, callbacks=[self.rewardLogCallback], controller=self.constant_action_controller, start_iteration_value=self.iter)
if self.reward_aggreation_rule == "mean":
# print('flow_array', self.flow.reward_array,flush=True)
# print('mean flow_array', np.mean(self.flow.reward_array, axis=0),flush=True)
averaged_objective_values = np.mean(self.flow.reward_array, axis=0)
elif self.reward_aggreation_rule == "sum":
averaged_objective_values = np.sum(self.flow.reward_array, axis=0)
elif self.reward_aggreation_rule == "median":
averaged_objective_values = np.median(self.flow.reward_array, axis=0)
else:
raise NotImplementedError(f"The {self.reward_aggreation_rule} function is not implemented yet.")

self.iter += self.num_sim_substeps_per_actuation
self.t += self.dt * self.num_sim_substeps_per_actuation
reward = self.get_reward(averaged_objective_values)
else:
self.solver.step(self.iter, control=action)
self.iter += 1
t = self.iter * self.solver.dt
self.t += self.dt
reward = self.get_reward()

for cb in self.callbacks:
cb(self.iter, self.t, self.flow)
obs = self.flow.get_observations()

done = self.check_complete()
# print('max_steps', self.max_steps, 'current iter:', self.iter, 'done', done, flush=True)
info = {}

obs = self.stack_observations(obs)

return obs, reward, done, info

# TODO: Use this to allow for arbitrary returns from collect_observations
# That are then converted to a list/tuple/ndarray here
def stack_observations(self, obs):
return obs

def get_reward(self):
return -self.solver.dt * self.flow.evaluate_objective()
def get_reward(self, averaged_objective_values=None):
if averaged_objective_values is None:
# return -self.solver.dt * self.flow.evaluate_objective()
return -self.flow.evaluate_objective()
else:
# return -self.solver.dt * self.num_sim_substeps_per_actuation * self.flow.evaluate_objective(averaged_objective_values=averaged_objective_values)
return -self.flow.evaluate_objective(averaged_objective_values=averaged_objective_values)

def check_complete(self):
return self.iter > self.max_steps

def reset(self, t=0.0) -> Union[ArrayLike, Tuple[ArrayLike, dict]]:
self.iter = 0
self.flow.reset(q0=self.q0, t=t)
self.t = 0.

if self.restart_ckpts is not None:
ckpt_index = np.random.randint(0, len(self.restart_ckpts))
self.flow.load_checkpoint(self.restart_ckpts[ckpt_index])
# print("Loaded ckeckpoint:", self.restart_ckpts[ckpt_index], flush=True)

self.flow.reset(q0=self.flow.copy_state() if self.restart_ckpts is not None else self.q0)
self.solver.reset()
info = {}

return self.flow.get_observations()
return self.flow.get_observations(), info

def render(self, mode="human", **kwargs):
self.flow.render(mode=mode, **kwargs)

def close(self):
for cb in self.callbacks:
cb.close()



Loading

0 comments on commit 0a19230

Please sign in to comment.