Skip to content

Commit

Permalink
fixed phaseshifting env observation space
Browse files Browse the repository at this point in the history
1b15 committed May 29, 2024

Verified

This commit was signed with the committer’s verified signature.
DavyLandman Davy Landman
1 parent 6109825 commit 4ebd2aa
Showing 2 changed files with 43 additions and 47 deletions.
59 changes: 16 additions & 43 deletions examples/example-6.2-rl-phaseshifting.ipynb

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -58,8 +58,9 @@ def __init__(

self.observation_space = spaces.Dict(
{
"exc": spaces.Box(0, 1, shape=(1,), dtype=float),
"inh": spaces.Box(0, 1, shape=(1,), dtype=float),
"exc": spaces.Box(0, 1, shape=(self.period_n,), dtype=float),
"inh": spaces.Box(0, 1, shape=(self.period_n,), dtype=float),
"target_phase": spaces.Box(0, 2 * np.pi, shape=(1,), dtype=float),
}
)

@@ -80,17 +81,26 @@ def get_target(self):
p_list = []
for i in range(3, len(peaks)):
p_list.append(peaks[i] - peaks[i - 1])

self.period_n = np.ceil(np.mean(p_list)).astype(int)

period = np.mean(p_list) * self.dt
self.period = period

raw = np.stack((wc.exc, wc.inh), axis=1)[0]
index = np.round(self.target_shift * period / (2.0 * np.pi) / self.dt).astype(int)
target = raw[:, index : index + np.round(1 + self.duration / self.dt, 1).astype(int)]
self.target_time = wc.t[index : index + target.shape[1]]
self.target_phase = (self.target_time % self.period) / self.period * 2 * np.pi

return target

def _get_obs(self):
return {"exc": self.model.exc[0], "inh": self.model.inh[0]}
return {
"exc": self.exc_history,
"inh": self.inh_history,
"target_phase": np.array([self.target_phase[self.t_i]]),
}

def _get_info(self):
return {"t": self.t_i * self.dt}
@@ -101,16 +111,25 @@ def reset(self, seed=None, options=None):
self.model.clearModelState()

self.model.params = self.params.copy()

# init history window
self.model.params["duration"] = self.period_n * self.dt
self.model.exc = np.array([[self.x_init]])
self.model.inh = np.array([[self.y_init]])
self.model.run()
self.exc_history = self.model.exc[0]
self.inh_history = self.model.inh[0]

# reset duration parameter
self.model.params = self.params.copy()

observation = self._get_obs()
info = self._get_info()
return observation, info

def _loss(self, obs, action):
control_loss = np.sqrt(
(self.target[0, self.t_i] - obs["exc"].item()) ** 2 + (self.target[1, self.t_i] - obs["inh"].item()) ** 2
(self.target[0, self.t_i] - obs["exc"][-1]) ** 2 + (self.target[1, self.t_i] - obs["inh"][-1]) ** 2
)
control_strength_loss = np.abs(action).sum() * self.control_strength_loss_scale
return control_loss + control_strength_loss
@@ -122,6 +141,10 @@ def step(self, action):
self.model.params["inh_ext"] = np.array([inh])
self.model.run(continue_run=True)

# shift observation window
self.exc_history = np.concatenate((self.exc_history[-self.period_n + 1 :], self.model.exc[0]))
self.inh_history = np.concatenate((self.inh_history[-self.period_n + 1 :], self.model.inh[0]))

observation = self._get_obs()

reward = -self._loss(observation, action)

0 comments on commit 4ebd2aa

Please sign in to comment.