Skip to content

Commit

Permalink
Merge branch 'discrete_CQL' of https://github.com/BY571/rl into discr…
Browse files Browse the repository at this point in the history
…ete_CQL
  • Loading branch information
BY571 committed Nov 6, 2023
2 parents f9427ef + e8847c4 commit dcba00c
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 10 deletions.
3 changes: 1 addition & 2 deletions examples/cql/discrete_cql_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ def main(cfg: "DictConfig"): # noqa: F821
optimizer = make_cql_optimizer(cfg, loss_module)

# Main loop
start_time = time.time()
collected_frames = 0
pbar = tqdm.tqdm(total=cfg.collector.total_frames)

Expand All @@ -92,7 +91,7 @@ def main(cfg: "DictConfig"): # noqa: F821
eval_iter = cfg.logger.eval_iter
frames_per_batch = cfg.collector.frames_per_batch

sampling_start = time.time()
start_time = sampling_start = time.time()
for tensordict in collector:
sampling_time = time.time() - sampling_start

Expand Down
6 changes: 1 addition & 5 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -5183,12 +5183,8 @@ def _create_mock_actor(
action_spec = OneHotDiscreteTensorSpec(action_dim)
elif action_spec_type == "categorical":
action_spec = DiscreteTensorSpec(action_dim)
# elif action_spec_type == "nd_bounded":
# action_spec = BoundedTensorSpec(
# -torch.ones(action_dim), torch.ones(action_dim), (action_dim,)
# )
else:
raise ValueError(f"Wrong {action_spec_type}")
raise ValueError(f"Wrong action spec type: {action_spec_type}")

module = nn.Linear(obs_dim, action_dim)
if is_nn_module:
Expand Down
6 changes: 3 additions & 3 deletions torchrl/objectives/cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -964,9 +964,9 @@ def __init__(
raise ValueError(self.ACTION_SPEC_ERROR)
if action_space is None:
warnings.warn(
"action_space was not specified. DiscreteCQLLoss will default to 'one-hot'."
"This behaviour will be deprecated soon and a space will have to be passed."
"Check the DiscreteCQLLoss documentation to see how to pass the action space. "
"action_space was not specified. DiscreteCQLLoss will default to 'one-hot'. "
"This behaviour will be deprecated soon and a space will have to be passed. "
"Check the DiscreteCQLLoss documentation to see how to pass the action space."
)
action_space = "one-hot"
self.action_space = _find_action_space(action_space)
Expand Down

0 comments on commit dcba00c

Please sign in to comment.