This project uses double deep Q-learning to train an agent in Cartpole-V1 environment of Open AI's Gym. It takes inspiration from this paper by Google DeepMind.
The model consists of a target network that trains the main network by functioning as the TD target. These networks are synced at regular intervals to stabilize training. Furthermore, a replay buffer is used to store the agent's experience and break correlation when training the main network.
Given:
-
$Q_{1}$ : Main network -
$Q_{2}$ : Target network
Integrating double Q-learning helps eliminate the problem of overestimation bias by incorporating randomness.
Agent()
class:
-
q_net
: the main network i.e online network. ($Q_1$ ) -
target_net
: the target network. ($Q_2$ ) -
sync_net()
: function to sync weights of$Q_1$ and$Q_2$ . -
target_vals()
&q_vals
: functions for forward pass of target and main network.
ReplayBuffer()
class:
buffer_size
,n_env
,device
: size of the replay buffer from which mini batch is sampled, number of parallel environments for sampling experience and the device available for training.add()
: adds observation, next observation, action, reward and dones to the buffer. Also, implements FIFO functionality in case the buffer overflows.sample()
: randomly samples a mini-batch from replay buffer.
-
ret_e()
: anneals the epsilon to adjust exploration-exploitation tradeoff. -
calc_TD_target()
: calculates the TD target. ($Y_{target}$ ) -
calc_TD_target_DDQN()
: calculates TD target according to double Q-learning rule. ($Y^{DDQN}_{target}$ ) -
e_greedy_policy()
: samples actions according to e-greedy policy. -
hyper
: an easyDict to store hyperparameters.