-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Begin merging production changes #197
base: main
Are you sure you want to change the base?
Changes from all commits
707aac0
6932f62
727beba
bb2688b
805fa5e
d8db8ad
ad9e144
94a22cf
c690987
8069011
b167795
8a6a0a2
10e4260
d719b15
9df8e64
1f4f989
9ac74bc
df0cb03
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
*.msh filter=lfs diff=lfs merge=lfs -text | ||
*.tensor filter=lfs diff=lfs merge=lfs -text |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,6 +2,7 @@ | |
from typing import Any, Callable, Iterable, Tuple, TypeVar, Union | ||
|
||
import gym | ||
# import gymnasium as gym | ||
import numpy as np | ||
from numpy.typing import ArrayLike | ||
|
||
|
@@ -33,12 +34,13 @@ class PDEBase(metaclass=abc.ABCMeta): | |
any information about solving the time-varying equations | ||
""" | ||
|
||
MAX_CONTROL = np.inf | ||
# MAX_CONTROL = np.inf | ||
MAX_CONTROL_LOW = -np.inf | ||
MAX_CONTROL_UP = np.inf | ||
DEFAULT_MESH = "" | ||
DEFAULT_DT = np.inf | ||
|
||
# Timescale used to smooth inputs | ||
# (should be less than any meaningful timescale of the system) | ||
# Timescale used to smooth inputs (should be less than any meaningful timescale of the system) | ||
TAU = 0.0 | ||
|
||
StateType = TypeVar("StateType") | ||
|
@@ -47,12 +49,13 @@ class PDEBase(metaclass=abc.ABCMeta): | |
|
||
def __init__(self, **config): | ||
self.mesh = self.load_mesh(name=config.get("mesh", self.DEFAULT_MESH)) | ||
self.reward_lambda = config.get("reward_lambda", 0.0) | ||
self.initialize_state() | ||
|
||
self.reset() | ||
|
||
if config.get("restart"): | ||
self.load_checkpoint(config["restart"]) | ||
self.load_checkpoint(config["restart"][0]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why the checkpoint is an iterable now ? |
||
|
||
@property | ||
@abc.abstractmethod | ||
|
@@ -186,6 +189,7 @@ def advance_time(self, dt: float, act: list[float] = None) -> list[float]: | |
Returns: | ||
Iterable[ArrayLike]: Updated actuator state | ||
""" | ||
|
||
if act is None: | ||
act = self.control_state | ||
self.t += dt | ||
|
@@ -272,12 +276,14 @@ def __init__(self, flow: PDEBase, dt: float = None): | |
if dt is None: | ||
dt = flow.DEFAULT_DT | ||
self.dt = dt | ||
self.t = 0.0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Currently the time is owned by the flow, not the solver. I guess you could do it the other way around, but since the "state" (velocity, pressure, etc) is owned by the flow class I figured it made more sense to have time there as well. Then it is advanced with Actually I think the only place it's needed is for the "callbacks" or "controller", but now that time is a property of the flow it doesn't need to be there at all (or the callback could be passed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
def solve( | ||
self, | ||
t_span: Tuple[float, float], | ||
callbacks: Iterable[CallbackBase] = [], | ||
controller: Callable = None, | ||
start_iteration_value: int = 0, | ||
) -> PDEBase: | ||
"""Solve the initial-value problem for the PDE. | ||
|
||
|
@@ -292,6 +298,43 @@ def solve( | |
PDEBase: The state of the PDE at the end of the solve | ||
""" | ||
for iter, t in enumerate(np.arange(*t_span, self.dt)): | ||
iter = iter + start_iteration_value | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It looks like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree - what's the distinction supposed to be here? |
||
if controller is not None: | ||
y = self.flow.get_observations() | ||
u = controller(t, y) | ||
else: | ||
u = None | ||
flow = self.step(iter, control=u) | ||
for cb in callbacks: | ||
cb(iter, t, flow) | ||
|
||
for cb in callbacks: | ||
cb.close() | ||
|
||
return flow | ||
|
||
def solve_multistep( | ||
self, | ||
num_substeps: int, | ||
callbacks: Iterable[CallbackBase] = [], | ||
controller: Callable = None, | ||
start_iteration_value: int = 0, | ||
) -> PDEBase: | ||
"""Solve the initial-value problem for the PDE. | ||
|
||
Args: | ||
t_span (Tuple[float, float]): Tuple of start and end times | ||
callbacks (Iterable[CallbackBase], optional): | ||
List of callbacks to evaluate throughout the solve. Defaults to []. | ||
controller (Callable, optional): | ||
Feedback/forward controller `u = ctrl(t, y)` | ||
|
||
Returns: | ||
PDEBase: The state of the PDE at the end of the solve | ||
""" | ||
for iter in range(num_substeps): | ||
iter = iter + start_iteration_value | ||
t = iter * self.dt | ||
if controller is not None: | ||
y = self.flow.get_observations() | ||
u = controller(t, y) | ||
|
@@ -317,20 +360,25 @@ def step(self, iter: int, control: Iterable[float] = None, **kwargs): | |
|
||
def reset(self): | ||
"""Reset variables for the timestepper""" | ||
pass | ||
self.t = 0.0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See above - time should be owned by the flow, and the reset method for the flow can also reset the time. |
||
|
||
|
||
class FlowEnv(gym.Env): | ||
|
||
def __init__(self, env_config: dict): | ||
self.flow: PDEBase = env_config.get("flow")( | ||
**env_config.get("flow_config", {})) | ||
|
||
self.solver: TransientSolver = env_config.get("solver")( | ||
self.flow, **env_config.get("solver_config", {})) | ||
self.callbacks: Iterable[CallbackBase] = env_config.get("callbacks", []) | ||
self.rewardLogCallback: Iterable[CallbackBase] = env_config.get( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should use snake case for consistent style |
||
"actuation_config", {}).get("rewardLogCallback", None) | ||
self.max_steps: int = env_config.get("max_steps", int(1e6)) | ||
self.iter: int = 0 | ||
self.q0: self.flow.StateType = self.flow.copy_state() | ||
self.restart_ckpts = env_config.get("flow_config", {}).get("restart", None) | ||
if self.restart_ckpts is None: | ||
self.q0: self.flow.StateType = self.flow.copy_state() | ||
|
||
self.observation_space = gym.spaces.Box( | ||
low=-np.inf, | ||
|
@@ -339,12 +387,31 @@ def __init__(self, env_config: dict): | |
dtype=float, | ||
) | ||
self.action_space = gym.spaces.Box( | ||
low=-self.flow.MAX_CONTROL, | ||
high=self.flow.MAX_CONTROL, | ||
low=self.flow.MAX_CONTROL_LOW, | ||
high=self.flow.MAX_CONTROL_UP, | ||
shape=(self.flow.num_inputs,), | ||
dtype=float, | ||
) | ||
|
||
self.t = 0. | ||
self.dt = env_config.get("solver_config", {}).get("dt", None) | ||
assert self.dt is not None, f"Error: Solver timestep dt ({self.dt}) must not be None" | ||
self.num_sim_substeps_per_actuation = env_config.get( | ||
"actuation_config", {}).get("num_sim_substeps_per_actuation", None) | ||
|
||
if self.num_sim_substeps_per_actuation is not None and self.num_sim_substeps_per_actuation > 1: | ||
assert self.rewardLogCallback is not None,\ | ||
f"Error: If num_sim_substeps_per_actuation ({self.num_sim_substeps_per_actuation}) " \ | ||
"is set a reward callback function must be given, {self.rewardLogCallback}" | ||
self.reward_aggreation_rule = env_config.get("actuation_config", {}).get( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what's the purpose of the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (there is a typo in the variable name There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could this all happen in some kind of Wrapper? All of this seems to add a lot of complexity to an otherwise straightforward class. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wrappers can deal with those aggregation rules and also action scaling. |
||
"reward_aggreation_rule", None) | ||
assert self.reward_aggreation_rule in [ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A reward aggregation rule is implemented and mandatory now, a default could possibly be set for backward compat? |
||
'mean', 'sum', 'median' | ||
], f"Error: reward aggregation rule ({self.reward_aggreation_rule}) is not given or not implemented yet" | ||
|
||
def constant_action_controller(self, t, y): | ||
return self.action | ||
|
||
def set_callbacks(self, callbacks: Iterable[CallbackBase]): | ||
self.callbacks = callbacks | ||
|
||
|
@@ -355,19 +422,43 @@ def step( | |
"""Advance the state of the environment. See gym.Env documentation | ||
|
||
Args: | ||
action (Iterable[ArrayLike], optional): Control inputs. Defaults to None. | ||
action (Iterable[ActType], optional): Control inputs. Defaults to None. | ||
|
||
Returns: | ||
Tuple[ArrayLike, float, bool, dict]: obs, reward, done, info | ||
Tuple[ObsType, float, bool, dict]: obs, reward, done, info | ||
""" | ||
self.solver.step(self.iter, control=action) | ||
self.iter += 1 | ||
t = self.iter * self.solver.dt | ||
action = action * self.flow.CONTROL_SCALING | ||
|
||
if self.num_sim_substeps_per_actuation is not None and self.num_sim_substeps_per_actuation > 1: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If I understand correctly |
||
self.action = action | ||
self.flow = self.solver.solve_multistep( | ||
num_substeps=self.num_sim_substeps_per_actuation, | ||
callbacks=[self.rewardLogCallback], | ||
controller=self.constant_action_controller, | ||
start_iteration_value=self.iter) | ||
if self.reward_aggreation_rule == "mean": | ||
averaged_objective_values = np.mean(self.flow.reward_array, axis=0) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. an attribute of the corresponding wrapper then |
||
elif self.reward_aggreation_rule == "sum": | ||
averaged_objective_values = np.sum(self.flow.reward_array, axis=0) | ||
elif self.reward_aggreation_rule == "median": | ||
averaged_objective_values = np.median(self.flow.reward_array, axis=0) | ||
else: | ||
raise NotImplementedError( | ||
f"The {self.reward_aggreation_rule} function is not implemented yet." | ||
) | ||
|
||
self.iter += self.num_sim_substeps_per_actuation | ||
self.t += self.dt * self.num_sim_substeps_per_actuation | ||
reward = self.get_reward(averaged_objective_values) | ||
else: | ||
self.solver.step(self.iter, control=action) | ||
self.iter += 1 | ||
self.t += self.dt | ||
reward = self.get_reward() | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is the |
||
for cb in self.callbacks: | ||
cb(self.iter, t, self.flow) | ||
cb(self.iter, self.t, self.flow) | ||
obs = self.flow.get_observations() | ||
|
||
reward = self.get_reward() | ||
done = self.check_complete() | ||
info = {} | ||
|
||
|
@@ -380,8 +471,15 @@ def step( | |
def stack_observations(self, obs): | ||
return obs | ||
|
||
def get_reward(self): | ||
return -self.solver.dt * self.flow.evaluate_objective() | ||
def get_reward(self, averaged_objective_values=None): | ||
if averaged_objective_values is None: | ||
# return -self.solver.dt * self.flow.evaluate_objective() | ||
return -self.flow.evaluate_objective() | ||
else: | ||
# return -self.solver.dt * self.num_sim_substeps_per_actuation\ | ||
# * self.flow.evaluate_objective(averaged_objective_values=averaged_objective_values) | ||
return -self.flow.evaluate_objective( | ||
averaged_objective_values=averaged_objective_values) | ||
|
||
def check_complete(self): | ||
return self.iter > self.max_steps | ||
|
@@ -391,6 +489,7 @@ def reset(self, t=0.0) -> Union[ArrayLike, Tuple[ArrayLike, dict]]: | |
self.flow.reset(q0=self.q0, t=t) | ||
self.solver.reset() | ||
|
||
# Previously: return self.flow.get_observations(), info | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
return self.flow.get_observations() | ||
|
||
def render(self, mode="human", **kwargs): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from core_1DEnvs import OneDimEnv, PDESolverBase1D | ||
|
||
from .kuramoto_sivashinsky import Kuramoto_Sivashinsky | ||
from .burgers import Burgers | ||
|
||
__all__ = ["OneDimEnv", "PDESolverBase1D", "Kuramoto_Sivashinsky", "Burgers"] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .flow import Burgers |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't seem to be used for anything. Also not related to the physics modeled by the
PDEBase
.