-
Notifications
You must be signed in to change notification settings - Fork 1
/
plot_results_csv.py
100 lines (79 loc) · 2.72 KB
/
plot_results_csv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import argparse
import matplotlib.pyplot as plt
import pandas as pd
import wandb
def plot_dqn():
df = pd.read_csv("RL/delivery_plot.csv")
df = df.rename(columns={'delivery run - score': 'score'})
df['score'] = df['score'].astype(float)
df.plot(x='Step', y='score', linewidth=0.2)
plt.grid(axis="y")
plt.show()
def plot_cma_es():
api = wandb.Api()
run3 = api.run("/jacopodona/highway_CMA/runs/pw5opfhv") # for cma es 3
run1 = api.run("/jacopodona/highway_CMA/runs/nnqio6qf") # for cma es 1
df3 = run3.history()
df1 = run1.history()
df3.plot(x="Generation", y=["Best Fitness", "Median Fitness"])
plt.grid(axis="y")
df1.plot(x="Generation", y=["Best Fitness", "Median Fitness"])
plt.grid(axis="y")
plt.legend()
plt.show()
def plot_neat():
api = wandb.Api()
run10 = api.run("/pappol/neat-testing/runs/uv6wtuy9")
run5 = api.run("/pappol/neat-testing/runs/9xaqx4a1")
run3 = api.run("/pappol/neat-testing/runs/mklzsq8s")
df10 = run10.history()
df5 = run5.history()
df3 = run3.history()
plt.plot(df10["_step"], df10["mean_fitness"], label="10 envs")
plt.plot(df5["_step"], df5["mean_fitness"], label="5 envs")
plt.plot(df3["_step"], df3["mean_fitness"], label="3 envs")
# Change labels and title
plt.xlabel("Generation")
plt.ylabel("Mean Fitness")
plt.grid(axis="y")
plt.legend()
plt.show()
plt.plot(df10["_step"], df10["std_fitness"], label="10 envs")
plt.plot(df5["_step"], df5["std_fitness"], label="5 envs")
plt.plot(df3["_step"], df3["std_fitness"], label="3 envs")
# Change labels and title
plt.xlabel("Generation")
plt.ylabel("Std Fitness")
plt.grid(axis="y")
plt.legend()
plt.show()
def plot_neat_500():
api = wandb.Api()
run = api.run("/pappol/neat-solver/runs/80ebzqza")
df = run.history()
plt.plot(df["_step"], df["mean_fitness"], label="500 gen")
plt.xlabel("Generation")
plt.ylabel("Mean Fitness")
plt.grid(axis="y")
plt.legend()
plt.show()
plt.plot(df["_step"], df["std_fitness"], label="500 gen")
plt.xlabel("Generation")
plt.ylabel("Std Fitness")
plt.grid(axis="y")
plt.legend()
plt.show()
def main():
parser = argparse.ArgumentParser(description="Plotting script for DQN, CMA-ES, and NEAT")
parser.add_argument("algorithm", choices=["dqn", "cma-es", "neat", "neat_500"], help="Choose which algorithm to plot")
args = parser.parse_args()
if args.algorithm == "dqn":
plot_dqn()
elif args.algorithm == "cma-es":
plot_cma_es()
elif args.algorithm == "neat":
plot_neat()
elif args.algorithm == "neat_500":
plot_neat_500()
if __name__ == "__main__":
main()