-
Notifications
You must be signed in to change notification settings - Fork 0
/
export_csv.py
74 lines (57 loc) · 2.12 KB
/
export_csv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import pandas as pd
import numpy as np
import wandb
import seaborn as sns
import matplotlib.pyplot as plt
def smooth(data, sm=1):
i = 1
moving_averages = []
moving_averages.append(data[0])
# Loop through the array elements
while i < len(data):
# Calculate the exponential
# average by using the formula
window_average = (sm*data[i])+(1-sm)*moving_averages[-1]
# Store the cumulative average
# of current window in moving average list
moving_averages.append(window_average)
# Shift window to right by one position
i += 1
return moving_averages
api = wandb.Api()
entity, project = "zhiyuanli", "mujoco"
main_envs = ['HumanoidStandup-v2 17x1', 'Humanoid-v2 17x1']
for env in main_envs:
group = env.split(' ')[0]
agent_conf = env.split(' ')[1]
runs = api.runs(entity + "/" + project,
filters={"$and": [
{"tags": "exp"},
{"state": "finished"},
{"group": group},
{"config.agent_conf": agent_conf},
{"config.agent_obsk": 0},
]}
)
summary_list = []
for run in runs:
history = run.history().dropna()
if run.name.split('_')[0] == 'temporal':
algo = 'BPTA'
elif run.name.split('_')[0] == 'ar':
algo = 'ARMAPPO'
elif run.name.split('_')[0] == 'mappo':
algo = 'MAPPO'
elif run.name.split('_')[0] == 'happo':
algo = 'HAPPO'
else:
raise NotImplementedError
runs_df = pd.DataFrame({
"eval_average_episode_rewards": history.loc[:, ['eval_average_episode_rewards']].values.squeeze(),
"algo": algo,
"step": history.loc[:, ["_step"]].values.squeeze(),
"scenario": env,
})
summary_list.append(runs_df)
runs_df = pd.concat(summary_list)
runs_df.to_csv(f'{env}.csv')