Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Dec 16, 2024
2 parents 5064558 + 3741f9e commit 63a597b
Show file tree
Hide file tree
Showing 12 changed files with 12 additions and 48 deletions.
6 changes: 1 addition & 5 deletions sota-implementations/cql/cql_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,7 @@ def update(sampled_tensordict):
pbar = tqdm.tqdm(total=cfg.collector.total_frames)

init_random_frames = cfg.collector.init_random_frames
num_updates = int(
cfg.collector.env_per_collector
* cfg.collector.frames_per_batch
* cfg.optim.utd_ratio
)
num_updates = int(cfg.collector.frames_per_batch * cfg.optim.utd_ratio)
prb = cfg.replay_buffer.prb
frames_per_batch = cfg.collector.frames_per_batch
evaluation_interval = cfg.logger.log_interval
Expand Down
6 changes: 1 addition & 5 deletions sota-implementations/cql/discrete_cql_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,11 +140,7 @@ def update(sampled_tensordict):
pbar = tqdm.tqdm(total=cfg.collector.total_frames)

init_random_frames = cfg.collector.init_random_frames
num_updates = int(
cfg.collector.env_per_collector
* cfg.collector.frames_per_batch
* cfg.optim.utd_ratio
)
num_updates = int(cfg.collector.frames_per_batch * cfg.optim.utd_ratio)
prb = cfg.replay_buffer.prb
eval_rollout_steps = cfg.env.max_episode_steps
eval_iter = cfg.logger.eval_iter
Expand Down
6 changes: 1 addition & 5 deletions sota-implementations/crossq/crossq.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,11 +179,7 @@ def update(sampled_tensordict: TensorDict, update_actor: bool):
pbar = tqdm.tqdm(total=cfg.collector.total_frames)

init_random_frames = cfg.collector.init_random_frames
num_updates = int(
cfg.collector.env_per_collector
* cfg.collector.frames_per_batch
* cfg.optim.utd_ratio
)
num_updates = int(cfg.collector.frames_per_batch * cfg.optim.utd_ratio)
prb = cfg.replay_buffer.prb
eval_iter = cfg.logger.eval_iter
frames_per_batch = cfg.collector.frames_per_batch
Expand Down
6 changes: 1 addition & 5 deletions sota-implementations/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,11 +145,7 @@ def update(sampled_tensordict):
pbar = tqdm.tqdm(total=cfg.collector.total_frames)

init_random_frames = cfg.collector.init_random_frames
num_updates = int(
cfg.collector.env_per_collector
* cfg.collector.frames_per_batch
* cfg.optim.utd_ratio
)
num_updates = int(cfg.collector.frames_per_batch * cfg.optim.utd_ratio)
prb = cfg.replay_buffer.prb
frames_per_batch = cfg.collector.frames_per_batch
eval_iter = cfg.logger.eval_iter
Expand Down
6 changes: 1 addition & 5 deletions sota-implementations/discrete_sac/discrete_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,7 @@ def update(sampled_tensordict):
pbar = tqdm.tqdm(total=cfg.collector.total_frames)

init_random_frames = cfg.collector.init_random_frames
num_updates = int(
cfg.collector.env_per_collector
* cfg.collector.frames_per_batch
* cfg.optim.utd_ratio
)
num_updates = int(cfg.collector.frames_per_batch * cfg.optim.utd_ratio)
prb = cfg.replay_buffer.prb
eval_rollout_steps = cfg.env.max_episode_steps
eval_iter = cfg.logger.eval_iter
Expand Down
6 changes: 1 addition & 5 deletions sota-implementations/iql/discrete_iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,11 +148,7 @@ def update(sampled_tensordict):
pbar = tqdm.tqdm(total=cfg.collector.total_frames)

init_random_frames = cfg.collector.init_random_frames
num_updates = int(
cfg.collector.env_per_collector
* cfg.collector.frames_per_batch
* cfg.optim.utd_ratio
)
num_updates = int(cfg.collector.frames_per_batch * cfg.optim.utd_ratio)
prb = cfg.replay_buffer.prb
eval_iter = cfg.logger.eval_iter
frames_per_batch = cfg.collector.frames_per_batch
Expand Down
6 changes: 1 addition & 5 deletions sota-implementations/iql/iql_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,11 +145,7 @@ def update(sampled_tensordict):
collected_frames = 0

init_random_frames = cfg.collector.init_random_frames
num_updates = int(
cfg.collector.env_per_collector
* cfg.collector.frames_per_batch
* cfg.optim.utd_ratio
)
num_updates = int(cfg.collector.frames_per_batch * cfg.optim.utd_ratio)
prb = cfg.replay_buffer.prb
eval_iter = cfg.logger.eval_iter
frames_per_batch = cfg.collector.frames_per_batch
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/sac/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ collector:
frames_per_batch: 1000
init_env_steps: 1000
device:
env_per_collector: 1
env_per_collector: 8
reset_at_each_iter: False

# replay buffer
Expand Down
6 changes: 1 addition & 5 deletions sota-implementations/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,11 +143,7 @@ def update(sampled_tensordict):
pbar = tqdm.tqdm(total=cfg.collector.total_frames)

init_random_frames = cfg.collector.init_random_frames
num_updates = int(
cfg.collector.env_per_collector
* cfg.collector.frames_per_batch
* cfg.optim.utd_ratio
)
num_updates = int(cfg.collector.frames_per_batch * cfg.optim.utd_ratio)
prb = cfg.replay_buffer.prb
eval_iter = cfg.logger.eval_iter
frames_per_batch = cfg.collector.frames_per_batch
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/td3/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ collector:
frames_per_batch: 1000
reset_at_each_iter: False
device:
env_per_collector: 1
env_per_collector: 8
num_workers: 1

# replay buffer
Expand Down
6 changes: 1 addition & 5 deletions sota-implementations/td3/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,11 +154,7 @@ def update(sampled_tensordict, update_actor, prb=prb):
pbar = tqdm.tqdm(total=cfg.collector.total_frames)

init_random_frames = cfg.collector.init_random_frames
num_updates = int(
cfg.collector.env_per_collector
* cfg.collector.frames_per_batch
* cfg.optim.utd_ratio
)
num_updates = int(cfg.collector.frames_per_batch * cfg.optim.utd_ratio)
delayed_updates = cfg.optim.policy_update_delay
eval_rollout_steps = cfg.env.max_episode_steps
eval_iter = cfg.logger.eval_iter
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/td3/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def make_environment(cfg, logger, device):
)
eval_env = TransformedEnv(
ParallelEnv(
cfg.collector.env_per_collector,
1,
EnvCreator(partial),
serial_for_single=True,
device=device,
Expand Down

0 comments on commit 63a597b

Please sign in to comment.