Skip to content

Commit

Permalink
Bug fixes + minor clean up of DeepTDLearning and DeepQLearning
Browse files Browse the repository at this point in the history
Summary:
A minor cleanup as suggested by Yi

1) We take `target_update_freq` as an input parameter for the DeepQLearning class but never pass it to DeepTDlearning. We need to modify this.

2) We never used any keyword parameters for DeepTDLearning so remving '**kwargs' from the input arguments. Since DeepTDLearning and its parent class PolicyLearner are base classes, I am not sure if we are going to use **kwargs here.

Also updated the docstring of DeepQLearning class.

Reviewed By: xuruiyang, yiwan-rl

Differential Revision: D56493906

fbshipit-source-id: 64655e6a7605e1515aafe296406557e90cc55439
  • Loading branch information
jb3618 authored and facebook-github-bot committed Apr 27, 2024
1 parent 8a36e3a commit 25b01ac
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

# pyre-strict

from typing import Any, List, Optional, Tuple, Type
from typing import Any, Dict, List, Optional, Tuple, Type

import torch
from pearl.action_representation_modules.action_representation_module import (
Expand Down Expand Up @@ -94,8 +94,13 @@ class and uses `act` and `learn_batch` methods of that class. We only implement
Note: This is an alternative to specifying a `network_type`. If provided, the
specified `network_type` is ignored and the input `network_instance` is used for
learning. Allows for custom implementations of Q-value networks.
"""
**kwargs: Additional arguments to be passed when using `TwoTowerNetwork`
class as the QValueNetwork. This includes {state_output_dim (int),
action_output_dim (int), state_hidden_dims (List[int]),
action_hidden_dims (List[int])}, all of which are used to instantiate a
`TwoTowerNetwork` object.
"""
super(DeepQLearning, self).__init__(
exploration_module=(
exploration_module
Expand All @@ -108,11 +113,14 @@ class and uses `act` and `learn_batch` methods of that class. We only implement
hidden_dims=hidden_dims,
learning_rate=learning_rate,
soft_update_tau=soft_update_tau,
is_conservative=is_conservative,
conservative_alpha=conservative_alpha,
network_type=network_type,
action_representation_module=action_representation_module,
discount_factor=discount_factor,
training_rounds=training_rounds,
batch_size=batch_size,
target_update_freq=target_update_freq,
network_instance=network_instance,
**kwargs,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,14 @@ def __init__(
target_update_freq: int = 10,
soft_update_tau: float = 0.1,
is_conservative: bool = False,
conservative_alpha: float = 2.0,
conservative_alpha: Optional[float] = 2.0,
network_type: Type[QValueNetwork] = VanillaQValueNetwork,
state_output_dim: Optional[int] = None,
action_output_dim: Optional[int] = None,
state_hidden_dims: Optional[List[int]] = None,
action_hidden_dims: Optional[List[int]] = None,
network_instance: Optional[QValueNetwork] = None,
action_representation_module: Optional[ActionRepresentationModule] = None,
**kwargs: Any,
) -> None:
"""Constructs a DeepTDLearning based policy learner. DeepTDLearning is the base class
for all value based (i.e. temporal difference learning based) algorithms.
Expand Down

0 comments on commit 25b01ac

Please sign in to comment.