-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrandomized_metric_helper.py
122 lines (96 loc) · 4.84 KB
/
randomized_metric_helper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import wandb
import os
import pickle
import numpy as np
from argparse import ArgumentParser
from pathlib import Path
from stable_baselines3 import PPO
from tools import initialize_env, sample_trajectory, sample_random_trajectory
"""
This file is a helper for evaluation.
It samples random lines and random trajectories to later use for computing # transitions and # regions over these lines and trajectories.
"""
def sample_random_lines(runs, num_samples):
sampled_points = []
for i in range(num_samples):
idx = np.random.randint(low=0, high=len(runs)-1)
run = runs[idx]
env_name = run.config["env"]
checkpoints_dir = run.config["save_dir"]
run_folder_path = Path(checkpoints_dir) / f"car1d_{run.id}"
# sample a trajectory from the final policy
best_policy_path = run_folder_path / 'best.zip'
best_stats_path = run_folder_path / 'best_stats.pth'
# load policy
best_expert = PPO.load(best_policy_path)
# initialize environment for input standardization
best_env = initialize_env(env_name, best_stats_path)
best_traj, _ = sample_trajectory(best_env, best_expert, is_det=True)
best_states = best_traj['states']
# randomly sample one state and add to set of randomly sampled points
point_idx = np.random.randint(low=0, high=len(best_states)-1)
sampled_points.append(best_states[point_idx])
input_lines_dict = {"origin":[], "mean":[]}
mean_sample = np.mean(sampled_points, axis=0)
for point in sampled_points:
point1, point2 = (np.zeros_like(point), point)
input_lines_dict["origin"].append((point1, point2))
point1, point2 = (mean_sample, point)
input_lines_dict["mean"].append((point1, point2))
return input_lines_dict
def load_random_trajectories(sweep_dir):
random_trajectories_dir = sweep_dir / "random_trajectories"
if not os.path.exists(random_trajectories_dir):
raise ValueError(f"Directory {random_trajectories_dir} is Empty. Must sample and save random trajectories first.")
num_random_trajectories = len(os.listdir(random_trajectories_dir))
random_trajectories = []
for i in range(num_random_trajectories):
path = Path(random_trajectories_dir) / f"random_traj_{i}.traj"
if os.path.exists(path):
random_traj = pickle.load(open(path, "rb"))
random_trajectories.append(random_traj)
return random_trajectories
def load_random_lines(sweep_dir):
random_lines_dir = sweep_dir / "random_lines"
if not os.path.exists(random_lines_dir):
raise ValueError(f"Directory {random_lines_dir} is Empty. Must sample and save random lines first.")
random_lines_dict = {}
for name in ["origin", "mean"]:
path = Path(random_lines_dir) / f"{name}.lines"
if os.path.exists(path):
random_lines_dict[name] = pickle.load(open(path, "rb"))
return random_lines_dict
def main(args):
api = wandb.Api()
if args.run_id is not None and args.sweep_id is None: # single run evaluation
runs = api.runs(path=f"{args.entity}/{args.project_name}", filters={"config.run_name":args.run_id}) #, "tags":"train"
base_path = Path(args.save_dir) / args.run_id
elif args.sweep_id is not None and args.run_id is None:
sweep = api.sweep(f"{args.entity}/{args.project_name}/sweeps/{args.sweep_id}")
runs = sweep.runs
base_path = Path(args.save_dir) / args.sweep_id
else:
raise ValueError("ID of training run/runs required.")
runs = sorted(runs, key=lambda x: x.id)
env_name = runs[0].config["env"]
input_lines_dict = sample_random_lines(runs=runs, num_samples=args.num_random_lines)
for name in ["origin", "mean"]:
file_path = base_path / "random_lines" / f"{name}.lines"
os.makedirs(os.path.dirname(file_path), exist_ok=True)
pickle.dump(input_lines_dict[name], open(file_path, "wb"))
for i in range(args.num_random_trajectories):
file_path = base_path / "random_trajectories" / f"random_traj_{i}.traj"
os.makedirs(os.path.dirname(file_path), exist_ok=True)
random_traj, _ = sample_random_trajectory(env_name)
pickle.dump(random_traj, open(file_path, "wb"))
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--entity", help="Weights & Biases entity", type=str)
parser.add_argument("--project_name", help="Weights & Biases project name", required=True, type=str)
parser.add_argument("--sweep_id", type=str, required=None)
parser.add_argument("--run_id", type=str, default=None)
parser.add_argument("--save_dir", type=str, default="randomized_data")
parser.add_argument("--num_random_lines", type=int, default=100)
parser.add_argument("--num_random_trajectories", type=int, default=10)
args = parser.parse_args()
main(args)