Skip to content

Commit

Permalink
Ray dependency and seq_length improvements (#253)
Browse files Browse the repository at this point in the history
* Ray is an optional dependency now.
* seq_length cleanup and adding warnings for using deprecated seq_len config name for RNN networks
  • Loading branch information
ViktorM authored Sep 26, 2023
1 parent 66ce12f commit f5bd8f2
Show file tree
Hide file tree
Showing 13 changed files with 85 additions and 52 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,9 @@ Additional environment supported properties and functions
* Added evaluation feature for inferencing during training. Checkpoints from training process can be automatically picked up and updated in the inferencing process when enabled.
* Added get/set API for runtime update of rl training parameters. Thanks to @ArthurAllshire for the initial version of fast PBT code.
* Fixed SAC not loading weights properly.
* Removed Ray dependency for use cases it's not required.
* Added warning for using deprecated 'seq_len' instead of 'seq_length' in configs with RNN networks.


1.6.0

Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ tensorboardX = "^2.5"
PyYAML = "^6.0"
psutil = "^5.9.0"
setproctitle = "^1.2.2"
ray = "^1.11.0"
opencv-python = "^4.5.5"
wandb = "^0.12.11"

Expand Down
6 changes: 3 additions & 3 deletions rl_games/algos_torch/a2c_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(self, base_name, params):
'horizon_length' : self.horizon_length,
'num_actors' : self.num_actors,
'num_actions' : self.actions_num,
'seq_len' : self.seq_len,
'seq_length' : self.seq_length,
'normalize_value' : self.normalize_value,
'network' : self.central_value_config['network'],
'config' : self.central_value_config,
Expand All @@ -52,7 +52,7 @@ def __init__(self, base_name, params):
self.central_value_net = central_value.CentralValueTrain(**cv_config).to(self.ppo_device)

self.use_experimental_cv = self.config.get('use_experimental_cv', True)
self.dataset = datasets.PPODataset(self.batch_size, self.minibatch_size, self.is_discrete, self.is_rnn, self.ppo_device, self.seq_len)
self.dataset = datasets.PPODataset(self.batch_size, self.minibatch_size, self.is_discrete, self.is_rnn, self.ppo_device, self.seq_length)
if self.normalize_value:
self.value_mean_std = self.central_value_net.model.value_mean_std if self.has_central_value else self.model.value_mean_std

Expand Down Expand Up @@ -98,7 +98,7 @@ def calc_gradients(self, input_dict):
if self.is_rnn:
rnn_masks = input_dict['rnn_masks']
batch_dict['rnn_states'] = input_dict['rnn_states']
batch_dict['seq_length'] = self.seq_len
batch_dict['seq_length'] = self.seq_length

if self.zero_rnn_on_done:
batch_dict['dones'] = input_dict['dones']
Expand Down
7 changes: 4 additions & 3 deletions rl_games/algos_torch/a2c_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(self, base_name, params):
'horizon_length' : self.horizon_length,
'num_actors' : self.num_actors,
'num_actions' : self.actions_num,
'seq_len' : self.seq_len,
'seq_length' : self.seq_length,
'normalize_value' : self.normalize_value,
'network' : self.central_value_config['network'],
'config' : self.central_value_config,
Expand All @@ -55,7 +55,7 @@ def __init__(self, base_name, params):
self.central_value_net = central_value.CentralValueTrain(**cv_config).to(self.ppo_device)

self.use_experimental_cv = self.config.get('use_experimental_cv', False)
self.dataset = datasets.PPODataset(self.batch_size, self.minibatch_size, self.is_discrete, self.is_rnn, self.ppo_device, self.seq_len)
self.dataset = datasets.PPODataset(self.batch_size, self.minibatch_size, self.is_discrete, self.is_rnn, self.ppo_device, self.seq_length)

if self.normalize_value:
self.value_mean_std = self.central_value_net.model.value_mean_std if self.has_central_value else self.model.value_mean_std
Expand Down Expand Up @@ -127,11 +127,12 @@ def calc_gradients(self, input_dict):
}
if self.use_action_masks:
batch_dict['action_masks'] = input_dict['action_masks']

rnn_masks = None
if self.is_rnn:
rnn_masks = input_dict['rnn_masks']
batch_dict['rnn_states'] = input_dict['rnn_states']
batch_dict['seq_length'] = self.seq_len
batch_dict['seq_length'] = self.seq_length
batch_dict['bptt_len'] = self.bptt_len
if self.zero_rnn_on_done:
batch_dict['dones'] = input_dict['dones']
Expand Down
19 changes: 9 additions & 10 deletions rl_games/algos_torch/central_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@
class CentralValueTrain(nn.Module):

def __init__(self, state_shape, value_size, ppo_device, num_agents, horizon_length, num_actors, num_actions,
seq_len, normalize_value, network, config, writter, max_epochs, multi_gpu, zero_rnn_on_done):
seq_length, normalize_value, network, config, writter, max_epochs, multi_gpu, zero_rnn_on_done):
nn.Module.__init__(self)

self.ppo_device = ppo_device
self.num_agents, self.horizon_length, self.num_actors, self.seq_len = num_agents, horizon_length, num_actors, seq_len
self.num_agents, self.horizon_length, self.num_actors, self.seq_length = num_agents, horizon_length, num_actors, seq_length
self.normalize_value = normalize_value
self.num_actions = num_actions
self.state_shape = state_shape
Expand Down Expand Up @@ -78,8 +78,8 @@ def __init__(self, state_shape, value_size, ppo_device, num_agents, horizon_leng
self.rnn_states = self.model.get_default_rnn_state()
self.rnn_states = [s.to(self.ppo_device) for s in self.rnn_states]
total_agents = self.num_actors #* self.num_agents
num_seqs = self.horizon_length // self.seq_len
assert ((self.horizon_length * total_agents // self.num_minibatches) % self.seq_len == 0)
num_seqs = self.horizon_length // self.seq_length
assert ((self.horizon_length * total_agents // self.num_minibatches) % self.seq_length == 0)
self.mb_rnn_states = [ torch.zeros((num_seqs, s.size()[0], total_agents, s.size()[2]), dtype=torch.float32, device=self.ppo_device) for s in self.rnn_states]

self.local_rank = 0
Expand All @@ -100,7 +100,7 @@ def __init__(self, state_shape, value_size, ppo_device, num_agents, horizon_leng
config['print_stats'] = False
config['lr_schedule'] = None

self.dataset = datasets.PPODataset(self.batch_size, self.minibatch_size, True, self.is_rnn, self.ppo_device, self.seq_len)
self.dataset = datasets.PPODataset(self.batch_size, self.minibatch_size, True, self.is_rnn, self.ppo_device, self.seq_length)

def update_lr(self, lr):
if self.multi_gpu:
Expand Down Expand Up @@ -167,9 +167,9 @@ def _preproc_obs(self, obs_batch):
def pre_step_rnn(self, n):
if not self.is_rnn:
return
if n % self.seq_len == 0:
if n % self.seq_length == 0:
for s, mb_s in zip(self.rnn_states, self.mb_rnn_states):
mb_s[n // self.seq_len,:,:,:] = s
mb_s[n // self.seq_length,:,:,:] = s

def post_step_rnn(self, all_done_indices, zero_rnn_on_done=True):
if not self.is_rnn:
Expand All @@ -183,7 +183,6 @@ def post_step_rnn(self, all_done_indices, zero_rnn_on_done=True):
def forward(self, input_dict):
return self.model(input_dict)


def get_value(self, input_dict):
self.eval()
obs_batch = input_dict['states']
Expand Down Expand Up @@ -245,7 +244,7 @@ def calc_gradients(self, batch):

batch_dict = {'obs' : obs_batch,
'actions' : actions_batch,
'seq_length' : self.seq_len,
'seq_length' : self.seq_length,
'dones' : dones_batch}
if self.is_rnn:
batch_dict['rnn_states'] = batch['rnn_states']
Expand Down Expand Up @@ -284,5 +283,5 @@ def calc_gradients(self, batch):
nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_norm)

self.optimizer.step()

return loss
26 changes: 18 additions & 8 deletions rl_games/algos_torch/network_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,17 @@

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import math
import numpy as np
from rl_games.algos_torch.d2rl import D2RLNet
from rl_games.algos_torch.sac_helper import SquashedNormal
from rl_games.common.layers.recurrent import GRUWithDones, LSTMWithDones
from rl_games.common.layers.value import TwoHotEncodedValue, DefaultValue
from rl_games.algos_torch.layers import symexp, symlog


def _create_initializer(func, **kwargs):
return lambda v : func(v, **kwargs)


class NetworkBuilder:
def __init__(self, **kwargs):
pass
Expand Down Expand Up @@ -196,6 +193,7 @@ def __init__(self, params, **kwargs):
input_shape = kwargs.pop('input_shape')
self.value_size = kwargs.pop('value_size', 1)
self.num_seqs = num_seqs = kwargs.pop('num_seqs', 1)

NetworkBuilder.BaseNetwork.__init__(self)
self.load(params)
self.actor_cnn = nn.Sequential()
Expand Down Expand Up @@ -306,9 +304,9 @@ def __init__(self, params, **kwargs):
def forward(self, obs_dict):
obs = obs_dict['obs']
states = obs_dict.get('rnn_states', None)
seq_length = obs_dict.get('seq_length', 1)
dones = obs_dict.get('dones', None)
bptt_len = obs_dict.get('bptt_len', 0)

if self.has_cnn:
# for obs shape 4
# input expected shape (B, W, H, C)
Expand All @@ -325,6 +323,8 @@ def forward(self, obs_dict):
c_out = c_out.contiguous().view(c_out.size(0), -1)

if self.has_rnn:
seq_length = obs_dict.get('seq_length', 1)

if not self.is_rnn_before_mlp:
a_out_in = a_out
c_out_in = c_out
Expand Down Expand Up @@ -359,9 +359,11 @@ def forward(self, obs_dict):
c_out = c_out.transpose(0,1)
a_out = a_out.contiguous().reshape(a_out.size()[0] * a_out.size()[1], -1)
c_out = c_out.contiguous().reshape(c_out.size()[0] * c_out.size()[1], -1)

if self.rnn_ln:
a_out = self.a_layer_norm(a_out)
c_out = self.c_layer_norm(c_out)

if type(a_states) is not tuple:
a_states = (a_states,)
c_states = (c_states,)
Expand Down Expand Up @@ -398,6 +400,8 @@ def forward(self, obs_dict):
out = out.flatten(1)

if self.has_rnn:
seq_length = obs_dict.get('seq_length', 1)

out_in = out
if not self.is_rnn_before_mlp:
out_in = out
Expand Down Expand Up @@ -703,13 +707,16 @@ def forward(self, obs_dict):
dones = obs_dict.get('dones', None)
bptt_len = obs_dict.get('bptt_len', 0)
states = obs_dict.get('rnn_states', None)
seq_length = obs_dict.get('seq_length', 1)

out = obs
out = self.cnn(out)
out = out.flatten(1)
out = self.flatten_act(out)

if self.has_rnn:
#seq_length = obs_dict['seq_length']
seq_length = obs_dict.get('seq_length', 1)

out_in = out
if not self.is_rnn_before_mlp:
out_in = out
Expand Down Expand Up @@ -769,20 +776,23 @@ def load(self, params):
self.is_multi_discrete = 'multi_discrete'in params['space']
self.value_activation = params.get('value_activation', 'None')
self.normalization = params.get('normalization', None)

if self.is_continuous:
self.space_config = params['space']['continuous']
self.fixed_sigma = self.space_config['fixed_sigma']
elif self.is_discrete:
self.space_config = params['space']['discrete']
elif self.is_multi_discrete:
self.space_config = params['space']['multi_discrete']
self.space_config = params['space']['multi_discrete']

self.has_rnn = 'rnn' in params
if self.has_rnn:
self.rnn_units = params['rnn']['units']
self.rnn_layers = params['rnn']['layers']
self.rnn_name = params['rnn']['name']
self.is_rnn_before_mlp = params['rnn'].get('before_mlp', False)
self.rnn_ln = params['rnn'].get('layer_norm', False)

self.has_cnn = True
self.permute_input = params['cnn'].get('permute_input', True)
self.conv_depths = params['cnn']['conv_depths']
Expand Down
25 changes: 18 additions & 7 deletions rl_games/common/a2c_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,16 @@ def __init__(self, base_name, params):
self.rewards_shaper = config['reward_shaper']
self.num_agents = self.env_info.get('agents', 1)
self.horizon_length = config['horizon_length']
self.seq_len = self.config.get('seq_length', 4)
self.bptt_len = self.config.get('bptt_length', self.seq_len) # not used right now. Didn't show that it is usefull

# seq_length is used only with rnn policy and value functions
if 'seq_len' in config:
print('WARNING: seq_len is deprecated, use seq_length instead')

self.seq_length = self.config.get('seq_length', 4)
print('seq_length:', self.seq_length)
self.bptt_len = self.config.get('bptt_length', self.seq_length) # not used right now. Didn't show that it is usefull
self.zero_rnn_on_done = self.config.get('zero_rnn_on_done', True)

self.normalize_advantage = config['normalize_advantage']
self.normalize_rms_advantage = config.get('normalize_rms_advantage', False)
self.normalize_input = self.config['normalize_input']
Expand All @@ -229,7 +236,7 @@ def __init__(self, base_name, params):
self.game_shaped_rewards = torch_ext.AverageMeter(self.value_size, self.games_to_track).to(self.ppo_device)
self.game_lengths = torch_ext.AverageMeter(1, self.games_to_track).to(self.ppo_device)
self.obs = None
self.games_num = self.config['minibatch_size'] // self.seq_len # it is used only for current rnn implementation
self.games_num = self.config['minibatch_size'] // self.seq_length # it is used only for current rnn implementation

self.batch_size = self.horizon_length * self.num_actors * self.num_agents
self.batch_size_envs = self.horizon_length * self.num_actors
Expand Down Expand Up @@ -463,8 +470,8 @@ def init_tensors(self):
self.rnn_states = [s.to(self.ppo_device) for s in self.rnn_states]

total_agents = self.num_agents * self.num_actors
num_seqs = self.horizon_length // self.seq_len
assert((self.horizon_length * total_agents // self.num_minibatches) % self.seq_len == 0)
num_seqs = self.horizon_length // self.seq_length
assert((self.horizon_length * total_agents // self.num_minibatches) % self.seq_length == 0)
self.mb_rnn_states = [torch.zeros((num_seqs, s.size()[0], total_agents, s.size()[2]), dtype = torch.float32, device=self.ppo_device) for s in self.rnn_states]

def init_rnn_from_model(self, model):
Expand Down Expand Up @@ -792,9 +799,9 @@ def play_steps_rnn(self):
step_time = 0.0

for n in range(self.horizon_length):
if n % self.seq_len == 0:
if n % self.seq_length == 0:
for s, mb_s in zip(self.rnn_states, mb_rnn_states):
mb_s[n // self.seq_len,:,:,:] = s
mb_s[n // self.seq_length,:,:,:] = s

if self.has_central_value:
self.central_value_net.pre_step_rnn(n)
Expand All @@ -804,6 +811,7 @@ def play_steps_rnn(self):
res_dict = self.get_masked_action_values(self.obs, masks)
else:
res_dict = self.get_action_values(self.obs)

self.rnn_states = res_dict['rnn_states']
self.experience_buffer.update_data('obses', n, self.obs['obs'])
self.experience_buffer.update_data('dones', n, self.dones.byte())
Expand Down Expand Up @@ -860,15 +868,18 @@ def play_steps_rnn(self):
mb_advs = self.discount_values(fdones, last_values, mb_fdones, mb_values, mb_rewards)
mb_returns = mb_advs + mb_values
batch_dict = self.experience_buffer.get_transformed_list(swap_and_flatten01, self.tensor_list)

batch_dict['returns'] = swap_and_flatten01(mb_returns)
batch_dict['played_frames'] = self.batch_size
states = []
for mb_s in mb_rnn_states:
t_size = mb_s.size()[0] * mb_s.size()[2]
h_size = mb_s.size()[3]
states.append(mb_s.permute(1,2,0,3).reshape(-1,t_size, h_size))

batch_dict['rnn_states'] = states
batch_dict['step_time'] = step_time

return batch_dict


Expand Down
20 changes: 12 additions & 8 deletions rl_games/common/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,23 @@
import copy
from torch.utils.data import Dataset


class PPODataset(Dataset):
def __init__(self, batch_size, minibatch_size, is_discrete, is_rnn, device, seq_len):

def __init__(self, batch_size, minibatch_size, is_discrete, is_rnn, device, seq_length):

self.is_rnn = is_rnn
self.seq_len = seq_len
self.seq_length = seq_length
self.batch_size = batch_size
self.minibatch_size = minibatch_size
self.device = device
self.length = self.batch_size // self.minibatch_size
self.is_discrete = is_discrete
self.is_continuous = not is_discrete
total_games = self.batch_size // self.seq_len
self.num_games_batch = self.minibatch_size // self.seq_len
total_games = self.batch_size // self.seq_length
self.num_games_batch = self.minibatch_size // self.seq_length
self.game_indexes = torch.arange(total_games, dtype=torch.long, device=self.device)
self.flat_indexes = torch.arange(total_games * self.seq_len, dtype=torch.long, device=self.device).reshape(total_games, self.seq_len)
self.flat_indexes = torch.arange(total_games * self.seq_length, dtype=torch.long, device=self.device).reshape(total_games, self.seq_length)

self.special_names = ['rnn_states']

Expand All @@ -34,9 +37,10 @@ def __len__(self):
def _get_item_rnn(self, idx):
gstart = idx * self.num_games_batch
gend = (idx + 1) * self.num_games_batch
start = gstart * self.seq_len
end = gend * self.seq_len
self.last_range = (start, end)
start = gstart * self.seq_length
end = gend * self.seq_length
self.last_range = (start, end)

input_dict = {}
for k,v in self.values_dict.items():
if k not in self.special_names:
Expand Down
3 changes: 2 additions & 1 deletion rl_games/common/vecenv.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import ray
from rl_games.common.ivecenv import IVecEnv
from rl_games.common.env_configurations import configurations
from rl_games.common.tr_helpers import dicts_to_dict_with_arrays
Expand Down Expand Up @@ -102,6 +101,8 @@ def __init__(self, config_name, num_actors, **kwargs):
self.num_actors = num_actors
self.use_torch = False
self.seed = kwargs.pop('seed', None)

import ray
self.remote_worker = ray.remote(RayWorker)
self.workers = [self.remote_worker.remote(self.config_name, kwargs) for i in range(self.num_actors)]

Expand Down
Loading

0 comments on commit f5bd8f2

Please sign in to comment.