Skip to content

Commit

Permalink
log success
Browse files Browse the repository at this point in the history
  • Loading branch information
taochenshh committed Mar 16, 2021
1 parent cc0b9fd commit 27adbaa
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 48 deletions.
4 changes: 3 additions & 1 deletion easyrl/engine/basic_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ def get_train_log(self, optim_infos, traj=None):
k_stats = get_list_stats([inf[key] for inf in optim_infos if key in inf])
for sk, sv in k_stats.items():
log_info[f'{key}/' + sk] = sv

if traj is not None:
actions_stats = get_list_stats(traj.actions)
for sk, sv in actions_stats.items():
Expand All @@ -65,6 +64,9 @@ def get_train_log(self, optim_infos, traj=None):
for sk, sv in ep_returns_stats.items():
log_info['episode_return/' + sk] = sv

if len(self.runner.train_success) > 0:
log_info['episode_success'] = np.mean(self.runner.train_success)

train_log_info = dict()
for key, val in log_info.items():
train_log_info['train/' + key] = val
Expand Down
21 changes: 0 additions & 21 deletions easyrl/engine/ppo_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,24 +165,3 @@ def cal_advantages(self, traj):
dones=traj.dones)
return adv

def get_train_log(self, optim_infos, traj):
log_info = dict()
for key in optim_infos[0].keys():
log_info[key] = np.mean([inf[key] for inf in optim_infos if key in inf])
t1 = time.perf_counter()
actions_stats = get_list_stats(traj.actions)
for sk, sv in actions_stats.items():
log_info['rollout_action/' + sk] = sv
log_info['optim_time'] = t1 - self.optim_stime
log_info['rollout_steps_per_iter'] = traj.total_steps
ep_returns = list(chain(*traj.episode_returns))
for epr in ep_returns:
self.train_ep_return.append(epr)
ep_returns_stats = get_list_stats(self.train_ep_return)
for sk, sv in ep_returns_stats.items():
log_info['episode_return/' + sk] = sv

train_log_info = dict()
for key, val in log_info.items():
train_log_info['train/' + key] = val
return train_log_info
26 changes: 0 additions & 26 deletions easyrl/engine/sac_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,32 +129,6 @@ def train_once(self):
optim_infos.append(optim_info)
return self.get_train_log(optim_infos)

def get_train_log(self, optim_infos):
log_info = dict()
vector_keys = set()
scalar_keys = set()
for oinf in optim_infos:
for key in oinf.keys():
if 'vec_' in key:
vector_keys.add(key)
else:
scalar_keys.add(key)

for key in scalar_keys:
log_info[key] = np.mean([inf[key] for inf in optim_infos if key in inf])

for key in vector_keys:
k_stats = get_list_stats([inf[key] for inf in optim_infos if key in inf])
for sk, sv in k_stats.items():
log_info[f'{key}/' + sk] = sv

t1 = time.perf_counter()
log_info['optim_time'] = t1 - self.optim_stime
train_log_info = dict()
for key, val in log_info.items():
train_log_info['train/' + key] = val
return train_log_info

def add_traj_to_memory(self, traj):
obs = traj.obs
actions = traj.actions
Expand Down

0 comments on commit 27adbaa

Please sign in to comment.