Skip to content

Commit

Permalink
fixed phase shifting env random target shift
Browse files Browse the repository at this point in the history
  • Loading branch information
1b15 committed Jun 13, 2024
1 parent 1e1520b commit ea4400d
Showing 1 changed file with 9 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(

self.n_steps = round(self.duration / self.dt)

self.target = self.get_target()
self.init_target()

self.observation_space = spaces.Dict(
{
Expand All @@ -75,7 +75,7 @@ def __init__(
)
)

def get_target(self):
def init_target(self):
wc = WCModel()
wc.params = self.model.params.copy()
wc.params["duration"] = self.duration + 100.0
Expand All @@ -90,15 +90,17 @@ def get_target(self):

period = np.mean(p_list) * self.dt
self.period = period
self.raw_target = np.stack((wc.exc, wc.inh), axis=1)[0]
self.target_t = wc.t

raw = np.stack((wc.exc, wc.inh), axis=1)[0]
def get_target(self):
if self.random_target_shift:
target_shift = np.random.random() * 2 * np.pi
else:
target_shift = self.target_shift
index = np.round(target_shift * period / (2.0 * np.pi) / self.dt).astype(int)
target = raw[:, index : index + np.round(1 + self.duration / self.dt, 1).astype(int)]
self.target_time = wc.t[index : index + target.shape[1]]
index = np.round(target_shift * self.period / (2.0 * np.pi) / self.dt).astype(int)
target = self.raw_target[:, index : index + np.round(1 + self.duration / self.dt, 1).astype(int)]
self.target_time = self.target_t[index : index + target.shape[1]]
self.target_phase = (self.target_time % self.period) / self.period * 2 * np.pi

return target
Expand All @@ -115,6 +117,7 @@ def _get_info(self):

def reset(self, seed=None, options=None):
super().reset(seed=seed, options=options)
self.target = self.get_target()
self.t_i = 0
self.model.clearModelState()

Expand Down

0 comments on commit ea4400d

Please sign in to comment.