-
Notifications
You must be signed in to change notification settings - Fork 1
/
config.gin
75 lines (61 loc) · 2.5 KB
/
config.gin
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import gin.tf.external_configurables
import crowd_pbrl.sampling_utils
import crowd_pbrl.trajectory_representation_learning
FRAMES_PER_CLIP = 60
FRAME_SHAPE = (84,84)
STACK_SIZE = 4
N_CPU = 4
LABEL_PER_QUERY = 3
N_WORKERS = 2500
DROP_OUT = 0.
FC_DIM = 512
REPRE_DIM = 128
BATCHSIZE = 32
sample_trajectories.log_path = "D:\\atari_offline_logs"
sample_trajectories.n_trajectory_per_file = 5
sample_trajectories.n_clips_per_trajectory = 6
sample_trajectories.frames_per_clip = %FRAMES_PER_CLIP
sample_queries.n_train_clips = 3
sample_queries.n_query_per_trajectory = 20
ObservationEncoder.fc_dim = %FC_DIM
ObservationEncoder.emb_dim = %REPRE_DIM
ContinuousObservationEncoder.fc_dim = 256
ContinuousObservationEncoder.emb_dim = %REPRE_DIM
ClipEncoder.frames_per_clip = %FRAMES_PER_CLIP
ClipEncoder.emb_dim = %REPRE_DIM
ClipEncoder.l1_weight = 0.01
ClipEncoder.temporal_l2_weight = 0.01
train_encoder.n_steps = 300000
train_encoder.lr = @tf.train.exponential_decay
tf.train.exponential_decay.learning_rate = 1e-4
tf.train.exponential_decay.decay_steps = 50000
tf.train.exponential_decay.decay_rate = 0.9
tf.train.exponential_decay.staircase = True
train_encoder.frames_per_clip = %FRAMES_PER_CLIP
train_encoder.logging_freq = 1000
trajectory_representation_learning_input_fn.frames_per_clip = %FRAMES_PER_CLIP
trajectory_representation_learning_input_fn.frame_shape = %FRAME_SHAPE
trajectory_representation_learning_input_fn.batchsize = 256
trajectory_representation_learning_input_fn.ratio = 1
compute_pretrain_repre.frames_per_clip = %FRAMES_PER_CLIP
generate_bt_answers.n_task_repeat = 3
generate_bt_answers.n_query_per_task = 50
generate_trex_answers.n_task_repeat = 3
reward_learning_input_fn.frames_per_clip = %FRAMES_PER_CLIP
reward_learning_input_fn.moment_len = 20
reward_learning_input_fn.frame_shape = %FRAME_SHAPE
reward_learning_input_fn.batchsize = 128
reward_learning_input_fn.split = "all"
train_reward_model.frames_per_clip = %FRAMES_PER_CLIP
train_reward_model.logging_freq = 1000
compute_clip_reward.frames_per_clip = %FRAMES_PER_CLIP
compute_clip_reward.batchsize = %BATCHSIZE
compute_clip_reward.frame_shape = %FRAME_SHAPE
infer_rewards.frame_shape = %FRAME_SHAPE
infer_rewards.batchsize = %BATCHSIZE
infer_rewards.stack_size = %STACK_SIZE
infer_rewards.log_path = "D:\\atari_offline_logs"
sample_d4rl_trajectories.n_d4rl_trajecotry = 250
sample_d4rl_trajectories.n_clips_per_trajectory = 6
sample_d4rl_trajectories.frames_per_clip = %FRAMES_PER_CLIP
generate_perturbed_expert_demo_datasets.atari_log_path = "D:\\atari_offline_logs"