diff --git a/README.md b/README.md index 2bad1ad..07b7bd0 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ ============ [![Read the Docs](https://img.shields.io/readthedocs/blobrl?style=for-the-badge)](https://blobrl.readthedocs.io/en/latest/?badge=latest) -[![Build Status](https://img.shields.io/travis/french-ai/reinforcement?branch=master.svg&style=for-the-badge)](https://travis-ci.org/french-ai/reinforcement) +[![Build Status](https://img.shields.io/travis/french-ai/reinforcement/master.svg?=master&style=for-the-badge)](https://travis-ci.org/french-ai/reinforcement) [![CodeFactor](https://www.codefactor.io/repository/github/french-ai/reinforcement/badge?style=for-the-badge)](https://www.codefactor.io/repository/github/french-ai/reinforcement) [![Codecov](https://img.shields.io/codecov/c/github/french-ai/reinforcement?style=for-the-badge)](https://codecov.io/gh/french-ai/reinforcement) [![Discord](https://img.shields.io/badge/discord-chat-7289DA.svg?logo=Discord&style=for-the-badge)](https://discord.gg/f5MZP2K) diff --git a/TODO.md b/TODO.md index 9588dd1..db6ad41 100644 --- a/TODO.md +++ b/TODO.md @@ -20,6 +20,7 @@ - [x] Random Agent - [x] Constant Agent + - [x] Deep Q Network (Mnih *et al.*, [2013](https://arxiv.org/abs/1312.5602)) - [ ] Deep Recurrent Q Network (Hausknecht *et al.*, [2015](https://arxiv.org/abs/1507.06527)) - [ ] Persistent Advantage Learning (Bellamare *et al.*, [2015](https://arxiv.org/abs/1512.04860)) @@ -30,34 +31,41 @@ - [x] Categorical Deep Q Network (Bellamare *et al.*, [2017](https://arxiv.org/abs/1707.06887)) - [ ] Quantile Regression DQN (Dabney et al, [2017](https://arxiv.org/abs/1710.10044)) + - [ ] Rainbow (Hessel *et al.*, [2017](https://arxiv.org/abs/1710.02298)) - [ ] Quantile Regression Deep Q Network (Dabney *et al.*, [2017](https://arxiv.org/abs/1710.10044)) + - [ ] Soft Actor-Critic (Haarnoja et al, [2018](https://arxiv.org/abs/1801.01290)) + - [ ] Vanilla Policy Gradient ([2000](https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf)) + - [ ] Deep Deterministic Policy Gradient (Lillicrap et al, [2015](https://arxiv.org/abs/1509.02971)) - [ ] Twin Delayed DDPG (Fujimoto et al, [2018](https://arxiv.org/abs/1802.09477)) + - [ ] Trust Region Policy Optimization (Schulman *et al.*, [2015](https://arxiv.org/abs/1502.05477)) - [ ] Proximal Policy Optimizations (Schulman *et al.*, [2017](https://arxiv.org/abs/1707.06347)) + - [ ] A2C (Mnih et al, [2016](https://arxiv.org/abs/1602.01783)) - [ ] A3C (Mnih et al, [2016](https://arxiv.org/abs/1602.01783)) + - [ ] Hindsight Experience Replay (Andrychowicz et al, [2017](https://arxiv.org/abs/1707.01495)) # Network -- [ ] base network support discrete action space -- [ ] base network support continuous action space -- [ ] base network support discrete observation space -- [ ] base network support continuous observation space -- [ ] simple network support discrete/continuous action/observation space -- [ ] c51 network support discrete/continuous action/observation space -- [ ] base dueling network support discrete/continuous action/observation space -- [ ] simple dueling network support discrete/continuous action/observation space +- [x] base network support discrete action space +- [x] base network support continuous action space +- [x] base network support discrete observation space +- [x] base network support continuous observation space +- [x] simple network support discrete/continuous action/observation space +- [x] c51 network support discrete action/observation space +- [x] base dueling network support discrete/continuous action/observation space +- [x] simple dueling network support discrete/continuous action/observation space # Explorations list diff --git a/blobrl/agents/__init__.py b/blobrl/agents/__init__.py index 4ab51f5..4215af5 100644 --- a/blobrl/agents/__init__.py +++ b/blobrl/agents/__init__.py @@ -3,5 +3,4 @@ from .agent_random import AgentRandom from .dqn import DQN from .double_dqn import DoubleDQN -from .dueling_dqn import DuelingDQN from .categorical_dqn import CategoricalDQN diff --git a/blobrl/agents/agent_constant.py b/blobrl/agents/agent_constant.py index a002143..cdafb72 100644 --- a/blobrl/agents/agent_constant.py +++ b/blobrl/agents/agent_constant.py @@ -25,14 +25,7 @@ def __init__(self, observation_space, action_space, device=None): :param action_space: Space for init action size :type observation_space: gym.Space """ - super().__init__(device) - if not isinstance(action_space, Space): - raise TypeError("action_space need to be instance of gym.spaces.Space, not :" + str(type(action_space))) - if not isinstance(observation_space, Space): - raise TypeError( - "observation_space need to be instance of gym.spaces.Space, not :" + str(type(observation_space))) - self.action_space = action_space - self.observation_space = observation_space + super().__init__(observation_space, action_space, device) self.action = self.action_space.sample() diff --git a/blobrl/agents/agent_interface.py b/blobrl/agents/agent_interface.py index efe1abd..2117505 100644 --- a/blobrl/agents/agent_interface.py +++ b/blobrl/agents/agent_interface.py @@ -2,15 +2,30 @@ import torch +from gym.spaces import Space + class AgentInterface(metaclass=abc.ABCMeta): - def __init__(self, device): + def __init__(self, observation_space, action_space, device): """ + :param device: torch device to run agent + :type: torch.device + :param observation_space: Space for init observation size + :type observation_space: gym.Space :param device: torch device to run agent :type: torch.device """ + + if not isinstance(action_space, Space): + raise TypeError("action_space need to be instance of gym.spaces.Space, not :" + str(type(action_space))) + if not isinstance(observation_space, Space): + raise TypeError( + "observation_space need to be instance of gym.spaces.Space, not :" + str(type(observation_space))) + self.action_space = action_space + self.observation_space = observation_space + if device is None: device = torch.device("cpu") if not isinstance(device, torch.device): diff --git a/blobrl/agents/agent_random.py b/blobrl/agents/agent_random.py index 291512e..484d612 100644 --- a/blobrl/agents/agent_random.py +++ b/blobrl/agents/agent_random.py @@ -2,7 +2,6 @@ import pickle import torch -from gym.spaces import Space from blobrl.agents import AgentInterface @@ -25,14 +24,7 @@ def __init__(self, observation_space, action_space, device=None): :param action_space: Space for init action size :type observation_space: gym.Space """ - super().__init__(device) - if not isinstance(action_space, Space): - raise TypeError("action_space need to be instance of gym.spaces.Space, not :" + str(type(action_space))) - if not isinstance(observation_space, Space): - raise TypeError( - "observation_space need to be instance of gym.spaces.Space, not :" + str(type(observation_space))) - self.action_space = action_space - self.observation_space = observation_space + super().__init__(observation_space, action_space, device) def get_action(self, observation): """ Return action randomly choice in action_space diff --git a/blobrl/agents/categorical_dqn.py b/blobrl/agents/categorical_dqn.py index 496f2d5..1a32a0f 100644 --- a/blobrl/agents/categorical_dqn.py +++ b/blobrl/agents/categorical_dqn.py @@ -1,7 +1,7 @@ import torch import torch.nn.functional as F import torch.optim as optim -from gym.spaces import Discrete, Space, flatdim, flatten +from gym.spaces import flatten from blobrl.agents import DQN from blobrl.memories import ExperienceReplay @@ -10,8 +10,8 @@ class CategoricalDQN(DQN): - def __init__(self, action_space, observation_space, memory=ExperienceReplay(), neural_network=None, num_atoms=51, - r_min=-10, r_max=10, step_train=2, batch_size=32, gamma=0.99, + def __init__(self, observation_space, action_space, memory=ExperienceReplay(), network=None, num_atoms=51, + r_min=-10, r_max=10, step_train=1, batch_size=32, gamma=1.0, optimizer=None, greedy_exploration=None, device=None): """ @@ -20,7 +20,7 @@ def __init__(self, action_space, observation_space, memory=ExperienceReplay(), n :param action_space: :param observation_space: :param memory: - :param neural_network: + :param network: :param num_atoms: :param r_min: :param r_max: @@ -30,25 +30,16 @@ def __init__(self, action_space, observation_space, memory=ExperienceReplay(), n :param optimizer: :param greedy_exploration: """ - loss = None - - if not isinstance(action_space, Discrete): - raise TypeError( - "action_space need to be instance of gym.spaces.Space.Discrete, not :" + str(type(action_space))) - if not isinstance(observation_space, Space): - raise TypeError( - "observation_space need to be instance of gym.spaces.Space.Discrete, not :" + str( - type(observation_space))) - - if neural_network is None and optimizer is None: - neural_network = C51Network(observation_shape=flatdim(observation_space), - action_shape=flatdim(action_space)) + if network is None and optimizer is None: + network = C51Network(observation_space=observation_space, + action_space=action_space) num_atoms = 51 - optimizer = optim.Adam(neural_network.parameters()) + optimizer = optim.Adam(network.parameters()) - super().__init__(action_space, observation_space, memory, neural_network, step_train, batch_size, gamma, loss, - optimizer, greedy_exploration, device=device) + super().__init__(observation_space=observation_space, action_space=action_space, memory=memory, + network=network, step_train=step_train, batch_size=batch_size, gamma=gamma, + loss=None, optimizer=optimizer, greedy_exploration=greedy_exploration, device=device) self.num_atoms = num_atoms self.r_min = r_min @@ -63,70 +54,77 @@ def get_action(self, observation): :param observation: stat of environment :type observation: gym.Space """ - observation = torch.tensor([flatten(self.observation_space, observation)], device=self.device) + if not self.greedy_exploration.be_greedy(self.step) and self.with_exploration: + return self.action_space.sample() - prediction = self.neural_network.forward(observation).detach()[0] - q_values = prediction * self.z - q_values = torch.sum(q_values, dim=1) + observation = torch.tensor([flatten(self.observation_space, observation)], device=self.device).float() - return torch.argmax(q_values).detach().item() + prediction = self.network.forward(observation) - def train(self): - """ + def return_values(values): + if isinstance(values, list): + return [return_values(v) for v in values] - """ - self.batch_size = 3 + q_values = values * self.z + q_values = torch.sum(q_values, dim=2) + return torch.argmax(q_values).detach().item() + + return return_values(prediction) + + def apply_loss(self, next_prediction, prediction, actions, rewards, next_observations, dones, len_space): + if isinstance(next_prediction, list): + [self.apply_loss(n, p, a, rewards, next_observations, dones, c) for n, p, a, c in + zip(next_prediction, prediction, actions.permute(1, 0, *[i for i in range(2, len(actions.shape))]), + len_space)] + else: - observations, actions, rewards, next_observations, dones = self.memory.sample(self.batch_size, - device=self.device) + q_values_next = next_prediction * self.z + q_values_next = torch.sum(q_values_next, dim=2) - actions = actions.to(torch.long) - actions = F.one_hot(actions, num_classes=self.action_space.n) + actions = F.one_hot(actions.long(), num_classes=len_space) - predictions_next = self.neural_network.forward(next_observations).detach() - q_values_next = predictions_next * self.z - q_values_next = torch.sum(q_values_next, dim=2) + actions_next = torch.argmax(q_values_next, dim=1) + actions_next = F.one_hot(actions_next, num_classes=len_space) - actions_next = torch.argmax(q_values_next, dim=1) - actions_next = actions_next.to(torch.long) - actions_next = F.one_hot(actions_next, num_classes=self.action_space.n) + dones = dones.view(-1, 1) - dones = dones.view(-1, 1) + tz = rewards.view(-1, 1) + self.gamma * self.z * (1 - dones) + tz = tz.clamp(self.r_min, self.r_max) + b = (tz - self.r_min) / self.delta_z - tz = torch.clamp(rewards.view(-1, 1) + self.gamma * self.z * (1 - dones), self.r_min, self.r_max) - b = (tz - self.r_min) / self.delta_z + l, u = b.floor().to(torch.int64), b.ceil().to(torch.int64) - l, u = b.floor().to(torch.int64), b.ceil().to(torch.int64) + l[(u > 0) * (l == u)] -= 1 + u[(l < (self.num_atoms - 1)) * (l == u)] += 1 - m_prob = torch.zeros((self.batch_size, self.action_space.n, self.num_atoms), device=self.device) + m_prob = torch.zeros((self.batch_size, len_space, self.num_atoms), device=self.device) - predictions_next = predictions_next[actions_next == 1, :] + predictions_next = next_prediction[actions_next == 1, :] - offset = torch.linspace(0, (self.batch_size - 1) * self.num_atoms, self.batch_size, device=self.device).view(-1, - 1) - offset = offset.expand(self.batch_size, self.num_atoms) + offset = torch.linspace(0, (self.batch_size - 1) * self.num_atoms, self.batch_size, + device=self.device).view(-1, + 1) + offset = offset.expand(self.batch_size, self.num_atoms) - u_index = (u + offset).view(-1).to(torch.int64) - l_index = (l + offset).view(-1).to(torch.int64) + u_index = (u + offset).view(-1).to(torch.int64) + l_index = (l + offset).view(-1).to(torch.int64) - predictions_next = (dones + (1 - dones) * predictions_next) + predictions_next = (dones + (1 - dones) * predictions_next) - m_prob_action = m_prob[actions == 1, :].view(-1) - m_prob_action.index_add_(0, u_index, (predictions_next * (u - b)).view(-1)) - m_prob_action.index_add_(0, l_index, (predictions_next * (b - l)).view(-1)) + m_prob_action = m_prob[actions == 1, :].view(-1) + m_prob_action.index_add_(0, u_index, (predictions_next * (u - b)).view(-1)) + m_prob_action.index_add_(0, l_index, (predictions_next * (b - l)).view(-1)) - m_prob[actions == 1, :] = m_prob_action.view(-1, self.num_atoms) + m_prob[actions == 1, :] = m_prob_action.view(-1, self.num_atoms) - self.optimizer.zero_grad() - predictions = self.neural_network.forward(observations) - loss = - predictions.log() * m_prob - loss.sum((1, 2)).mean().backward() + self.optimizer.zero_grad() - self.optimizer.step() + loss = - prediction.log() * m_prob + loss.sum((1, 2)).mean().backward(retain_graph=True) def __str__(self): return 'CategoricalDQN-' + str(self.observation_space) + "-" + str(self.action_space) + "-" + str( - self.neural_network) + "-" + str(self.memory) + "-" + str(self.step_train) + "-" + str( + self.network) + "-" + str(self.memory) + "-" + str(self.step_train) + "-" + str( self.step) + "-" + str(self.batch_size) + "-" + str(self.gamma) + "-" + str(self.loss) + "-" + str( self.optimizer) + "-" + str(self.greedy_exploration) + "-" + str(self.num_atoms) + "-" + str( self.r_min) + "-" + str(self.r_max) + "-" + str(self.delta_z) + "-" + str(self.z) diff --git a/blobrl/agents/double_dqn.py b/blobrl/agents/double_dqn.py index 994882a..35b55e4 100644 --- a/blobrl/agents/double_dqn.py +++ b/blobrl/agents/double_dqn.py @@ -5,7 +5,7 @@ import torch import torch.nn.functional as F import torch.optim as optim -from gym.spaces import flatdim +from gym.spaces import Discrete, MultiDiscrete from blobrl.agents import DQN from blobrl.memories import ExperienceReplay @@ -14,8 +14,8 @@ class DoubleDQN(DQN): """ from 'Deep Reinforcement Learning with Double Q-learning' in https://arxiv.org/pdf/1509.06461.pdf """ - def __init__(self, action_space, observation_space, memory=ExperienceReplay(), neural_network=None, step_copy=500, - step_train=2, batch_size=32, gamma=0.99, loss=None, optimizer=None, greedy_exploration=None, + def __init__(self, observation_space, action_space, memory=ExperienceReplay(), network=None, step_copy=500, + step_train=1, batch_size=32, gamma=1.0, loss=None, optimizer=None, greedy_exploration=None, device=None): """ @@ -24,7 +24,7 @@ def __init__(self, action_space, observation_space, memory=ExperienceReplay(), n :param action_space: :param observation_space: :param memory: - :param neural_network: + :param network: :param step_copy: :param step_train: :param batch_size: @@ -33,17 +33,17 @@ def __init__(self, action_space, observation_space, memory=ExperienceReplay(), n :param optimizer: :param greedy_exploration: """ - super().__init__(action_space, observation_space, memory, neural_network, step_train, batch_size, gamma, loss, + super().__init__(observation_space, action_space, memory, network, step_train, batch_size, gamma, loss, optimizer, greedy_exploration, device=device) - self.neural_network_target = deepcopy(self.neural_network) + self.network_target = deepcopy(self.network) self.copy_online_to_target() self.step_copy = step_copy if optimizer is None: - self.optimizer = optim.Adam(self.neural_network.parameters()) + self.optimizer = optim.Adam(self.network.parameters()) - self.neural_network_target.to(self.device) + self.network_target.to(self.device) def learn(self, observation, action, reward, next_observation, done) -> None: """ learn from parameters @@ -72,25 +72,45 @@ def train(self): observations, actions, rewards, next_observations, dones = self.memory.sample(self.batch_size, device=self.device) - actions_next = torch.argmax(self.neural_network.forward(next_observations).detach(), dim=1) - actions_next_one_hot = F.one_hot(actions_next.to(torch.int64), num_classes=self.action_space.n) - q_next = self.neural_network_target.forward(next_observations).detach() * actions_next_one_hot + next_prediction = self.network.forward(next_observations) + prediction = self.network.forward(observations) + target_next_prediction = self.network_target.forward(next_observations) + + if isinstance(self.action_space, Discrete): + self.apply_loss(next_prediction, prediction, actions, rewards, next_observations, dones, + self.action_space.n, target_next_prediction) + # find space for one_hot encore action in apply loss + elif isinstance(self.action_space, MultiDiscrete): + self.apply_loss(next_prediction, prediction, actions, rewards, next_observations, dones, + self.action_space.nvec, target_next_prediction) + self.optimizer.step() - q = rewards + self.gamma * torch.max(q_next, dim=1)[0] * (1 - dones) + def apply_loss(self, next_prediction, prediction, actions, rewards, next_observations, dones, len_space, + target_next_prediction): + if isinstance(next_prediction, list): + [self.apply_loss(n, p, a, rewards, next_observations, dones, c, t) for n, p, a, c, t in + zip(next_prediction, prediction, actions.permute(1, 0, *[i for i in range(2, len(actions.shape))]), + len_space, target_next_prediction)] + else: - actions_one_hot = F.one_hot(actions.to(torch.int64), num_classes=self.action_space.n) - q_predict = torch.max(self.neural_network.forward(observations) * actions_one_hot, dim=1)[0] + actions_next = torch.argmax(next_prediction.detach(), dim=1) + actions_next_one_hot = F.one_hot(actions_next.to(torch.int64), num_classes=len_space) + q_next = target_next_prediction.detach() * actions_next_one_hot - self.optimizer.zero_grad() - loss = self.loss(q_predict, q) - loss.backward() - self.optimizer.step() + q = rewards + self.gamma * torch.max(q_next, dim=1)[0] * (1 - dones) + + actions_one_hot = F.one_hot(actions.to(torch.int64), num_classes=len_space) + q_predict = torch.max(prediction * actions_one_hot, dim=1)[0] + + self.optimizer.zero_grad() + loss = self.loss(q_predict, q) + loss.backward(retain_graph=True) def copy_online_to_target(self): """ """ - self.neural_network_target.load_state_dict(self.neural_network.state_dict()) + self.network_target.load_state_dict(self.network.state_dict()) def save(self, file_name, dire_name="."): """ Save agent at dire_name/file_name @@ -105,8 +125,8 @@ def save(self, file_name, dire_name="."): dict_save = dict() dict_save["observation_space"] = pickle.dumps(self.observation_space) dict_save["action_space"] = pickle.dumps(self.action_space) - dict_save["neural_network_class"] = pickle.dumps(type(self.neural_network)) - dict_save["neural_network"] = self.neural_network.state_dict() + dict_save["network_class"] = pickle.dumps(type(self.network)) + dict_save["network"] = self.network.state_dict() dict_save["step_train"] = pickle.dumps(self.step_train) dict_save["batch_size"] = pickle.dumps(self.batch_size) dict_save["gamma"] = pickle.dumps(self.gamma) @@ -130,14 +150,14 @@ def load(cls, file_name, dire_name=".", device=None): """ dict_save = torch.load(os.path.abspath(os.path.join(dire_name, file_name))) - neural_network = pickle.loads(dict_save["neural_network_class"])( - observation_shape=flatdim(pickle.loads(dict_save["observation_space"])), - action_shape=flatdim(pickle.loads(dict_save["action_space"]))) - neural_network.load_state_dict(dict_save["neural_network"]) + network = pickle.loads(dict_save["network_class"])( + observation_space=pickle.loads(dict_save["observation_space"]), + action_space=pickle.loads(dict_save["action_space"])) + network.load_state_dict(dict_save["network"]) double_dqn = DoubleDQN(observation_space=pickle.loads(dict_save["observation_space"]), action_space=pickle.loads(dict_save["action_space"]), - neural_network=neural_network, + network=network, step_train=pickle.loads(dict_save["step_train"]), batch_size=pickle.loads(dict_save["batch_size"]), gamma=pickle.loads(dict_save["gamma"]), @@ -152,6 +172,6 @@ def load(cls, file_name, dire_name=".", device=None): def __str__(self): return 'DoubleDQN-' + str(self.observation_space) + "-" + str(self.action_space) + "-" + str( - self.neural_network) + "-" + str(self.memory) + "-" + str(self.step_train) + "-" + str( + self.network) + "-" + str(self.memory) + "-" + str(self.step_train) + "-" + str( self.step) + "-" + str(self.batch_size) + "-" + str(self.gamma) + "-" + str(self.loss) + "-" + str( self.optimizer) + "-" + str(self.greedy_exploration) + "-" + str(self.step_copy) diff --git a/blobrl/agents/dqn.py b/blobrl/agents/dqn.py index 24827b0..47a3c2a 100644 --- a/blobrl/agents/dqn.py +++ b/blobrl/agents/dqn.py @@ -1,11 +1,10 @@ import os import pickle -from abc import ABCMeta import torch import torch.nn.functional as F import torch.optim as optim -from gym.spaces import Discrete, Space, flatdim, flatten +from gym.spaces import Discrete, MultiDiscrete, flatten from blobrl.agents import AgentInterface from blobrl.explorations import GreedyExplorationInterface, EpsilonGreedy @@ -13,69 +12,69 @@ from blobrl.networks import SimpleNetwork -class DQN(AgentInterface, metaclass=ABCMeta): +class DQN(AgentInterface): def enable_exploration(self): - self.trainable = True + self.with_exploration = True def disable_exploration(self): - self.trainable = False + self.with_exploration = False - def __init__(self, action_space, observation_space, memory=ExperienceReplay(), neural_network=None, - step_train=2, batch_size=32, gamma=0.99, loss=None, optimizer=None, greedy_exploration=None, - device=None): + def __init__(self, observation_space, action_space, memory=None, network=None, step_train=1, batch_size=32, + gamma=1.00, loss=None, optimizer=None, greedy_exploration=None, device=None): """ - :param device: torch device to run agent - :type: torch.device + :param action_space: :param observation_space: :param memory: - :param neural_network: + :param network: :param step_train: :param batch_size: :param gamma: :param loss: :param optimizer: :param greedy_exploration: + :param device: torch device to run agent + :type: torch.device """ - if not isinstance(action_space, Discrete): + if not isinstance(action_space, (Discrete, MultiDiscrete)): raise TypeError( - "action_space need to be instance of gym.spaces.Space.Discrete, not :" + str(type(action_space))) - if not isinstance(observation_space, Space): - raise TypeError( - "observation_space need to be instance of gym.spaces.Space.Discrete, not :" + str( - type(observation_space))) - - if neural_network is None and optimizer is not None: - raise TypeError("If neural_network is None, optimizer need to be None not " + str(type(optimizer))) + "action_space need to be instance of Discrete or MultiDiscrete, not :" + str(type(action_space))) - if neural_network is None: - neural_network = SimpleNetwork(observation_shape=flatdim(observation_space), - action_shape=flatdim(action_space)) - if not isinstance(neural_network, torch.nn.Module): - raise TypeError("neural_network need to be instance of torch.nn.Module, not :" + str(type(neural_network))) + if memory is None: + memory = ExperienceReplay() if not isinstance(memory, MemoryInterface): raise TypeError( "memory need to be instance of blobrls.memories.MemoryInterface, not :" + str(type(memory))) if loss is not None and not isinstance(loss, torch.nn.Module): - raise TypeError("loss need to be instance of blobrls.memories.MemoryInterface, not :" + str(type(loss))) + raise TypeError("loss need to be instance of torch.nn.Module, not :" + str(type(loss))) if optimizer is not None and not isinstance(optimizer, optim.Optimizer): raise TypeError( - "optimizer need to be instance of blobrls.memories.MemoryInterface, not :" + str(type(optimizer))) + "optimizer need to be instance of torch.optim.Optimizer, not :" + str(type(optimizer))) + + if network is None and optimizer is not None: + raise TypeError("If network is None, optimizer need to be None not " + str(type(optimizer))) if greedy_exploration is not None and not isinstance(greedy_exploration, GreedyExplorationInterface): raise TypeError( "greedy_exploration need to be instance of blobrls.explorations.GreedyExplorationInterface, not :" + str( type(greedy_exploration))) + if network is None: + network = SimpleNetwork(observation_space=observation_space, + action_space=action_space) + if not isinstance(network, torch.nn.Module): + raise TypeError("network need to be instance of torch.nn.Module, not :" + str(type(network))) + + super().__init__(observation_space=observation_space, action_space=action_space, device=device) + + self.network = network + self.network.to(self.device) - self.observation_space = observation_space - self.action_space = action_space - self.neural_network = neural_network self.memory = memory self.step_train = step_train @@ -90,7 +89,7 @@ def __init__(self, action_space, observation_space, memory=ExperienceReplay(), n self.loss = loss if optimizer is None: - self.optimizer = optim.Adam(self.neural_network.parameters()) + self.optimizer = optim.Adam(self.network.parameters(), lr=0.01) else: self.optimizer = optimizer @@ -99,10 +98,7 @@ def __init__(self, action_space, observation_space, memory=ExperienceReplay(), n else: self.greedy_exploration = greedy_exploration - self.trainable = True - - super().__init__(device) - self.neural_network.to(self.device) + self.with_exploration = True def get_action(self, observation): """ Return action choice by the agents @@ -110,14 +106,20 @@ def get_action(self, observation): :param observation: stat of environment :type observation: gym.Space """ - if not self.greedy_exploration.be_greedy(self.step) and self.trainable: + if not self.greedy_exploration.be_greedy(self.step) and self.with_exploration: return self.action_space.sample() - observation = torch.tensor([flatten(self.observation_space, observation)], device=self.device) + observation = torch.tensor([flatten(self.observation_space, observation)], device=self.device).float() + + q_values = self.network.forward(observation) - q_values = self.neural_network.forward(observation) + def return_values(values): + if isinstance(values, list): + return [return_values(v) for v in values] - return torch.argmax(q_values).detach().item() + return torch.argmax(values).detach().item() + + return return_values(q_values) def learn(self, observation, action, reward, next_observation, done) -> None: """ learn from parameters @@ -134,7 +136,8 @@ def learn(self, observation, action, reward, next_observation, done) -> None: :param done: if env is finished :type done: bool """ - self.memory.append(observation, action, reward, next_observation, done) + self.memory.append([flatten(self.observation_space, observation)], action, reward, + [flatten(self.observation_space, next_observation)], done) self.step += 1 if (self.step % self.step_train) == 0: @@ -150,18 +153,37 @@ def train(self): observations, actions, rewards, next_observations, dones = self.memory.sample(self.batch_size, device=self.device) - q = rewards + self.gamma * torch.max(self.neural_network.forward(next_observations), dim=1)[0].detach() * ( - 1 - dones) + next_prediction = self.network.forward(next_observations) - actions_one_hot = F.one_hot(actions.to(torch.int64), num_classes=self.action_space.n) - q_values_predict = self.neural_network.forward(observations) * actions_one_hot - q_predict = torch.max(q_values_predict, dim=1) + prediction = self.network.forward(observations) - self.optimizer.zero_grad() - loss = self.loss(q_predict[0], q) - loss.backward() + if isinstance(self.action_space, Discrete): + self.apply_loss(next_prediction, prediction, actions, rewards, next_observations, dones, + self.action_space.n) + # find space for one_hot encore action in apply loss + elif isinstance(self.action_space, MultiDiscrete): + self.apply_loss(next_prediction, prediction, actions, rewards, next_observations, dones, + self.action_space.nvec) self.optimizer.step() + def apply_loss(self, next_prediction, prediction, actions, rewards, next_observations, dones, len_space): + if isinstance(next_prediction, list): + [self.apply_loss(n, p, a, rewards, next_observations, dones, c) for n, p, a, c in + zip(next_prediction, prediction, actions.permute(1, 0, *[i for i in range(2, len(actions.shape))]), + len_space)] + else: + + q = rewards + self.gamma * next_prediction.max(1)[0].detach() * ( + 1 - dones) + + actions_one_hot = F.one_hot(actions.to(torch.int64), num_classes=len_space) + q_values_predict = prediction * actions_one_hot + q_predict = torch.max(q_values_predict, dim=1) + + self.optimizer.zero_grad() + loss = self.loss(q_predict[0], q) + loss.backward(retain_graph=True) + def save(self, file_name, dire_name="."): """ Save agent at dire_name/file_name @@ -175,8 +197,8 @@ def save(self, file_name, dire_name="."): dict_save = dict() dict_save["observation_space"] = pickle.dumps(self.observation_space) dict_save["action_space"] = pickle.dumps(self.action_space) - dict_save["neural_network_class"] = pickle.dumps(type(self.neural_network)) - dict_save["neural_network"] = self.neural_network.cpu().state_dict() + dict_save["network_class"] = pickle.dumps(type(self.network)) + dict_save["network"] = self.network.cpu().state_dict() dict_save["step_train"] = pickle.dumps(self.step_train) dict_save["batch_size"] = pickle.dumps(self.batch_size) dict_save["gamma"] = pickle.dumps(self.gamma) @@ -199,14 +221,14 @@ def load(cls, file_name, dire_name=".", device=None): """ dict_save = torch.load(os.path.abspath(os.path.join(dire_name, file_name))) - neural_network = pickle.loads(dict_save["neural_network_class"])( - observation_shape=flatdim(pickle.loads(dict_save["observation_space"])), - action_shape=flatdim(pickle.loads(dict_save["action_space"]))) - neural_network.load_state_dict(dict_save["neural_network"]) + network = pickle.loads(dict_save["network_class"])( + observation_space=pickle.loads(dict_save["observation_space"]), + action_space=pickle.loads(dict_save["action_space"])) + network.load_state_dict(dict_save["network"]) return DQN(observation_space=pickle.loads(dict_save["observation_space"]), action_space=pickle.loads(dict_save["action_space"]), - neural_network=neural_network, + network=network, step_train=pickle.loads(dict_save["step_train"]), batch_size=pickle.loads(dict_save["batch_size"]), gamma=pickle.loads(dict_save["gamma"]), @@ -217,6 +239,6 @@ def load(cls, file_name, dire_name=".", device=None): def __str__(self): return 'DQN-' + str(self.observation_space) + "-" + str(self.action_space) + "-" + str( - self.neural_network) + "-" + str(self.memory) + "-" + str(self.step_train) + "-" + str( + self.network) + "-" + str(self.memory) + "-" + str(self.step_train) + "-" + str( self.step) + "-" + str(self.batch_size) + "-" + str(self.gamma) + "-" + str(self.loss) + "-" + str( self.optimizer) + "-" + str(self.greedy_exploration) diff --git a/blobrl/agents/dueling_dqn.py b/blobrl/agents/dueling_dqn.py deleted file mode 100644 index 7dcd38d..0000000 --- a/blobrl/agents/dueling_dqn.py +++ /dev/null @@ -1,84 +0,0 @@ -import os -import pickle - -import torch -from gym.spaces import flatdim - -from blobrl.agents import DoubleDQN -from blobrl.memories import ExperienceReplay -from blobrl.networks import SimpleDuelingNetwork, BaseDuelingNetwork - - -class DuelingDQN(DoubleDQN): - """ from 'Dueling Network Architectures for Deep Reinforcement Learning' in https://arxiv.org/abs/1511.06581 """ - - def __init__(self, action_space, observation_space, memory=ExperienceReplay(), neural_network=None, step_copy=500, - step_train=2, batch_size=32, gamma=0.99, loss=None, optimizer=None, greedy_exploration=None, - device=None): - """ - - :param device: torch device to run agent - :type: torch.device - :param action_space: - :param observation_space: - :param memory: - :param neural_network: - :param step_copy: - :param step_train: - :param batch_size: - :param gamma: - :param loss: - :param optimizer: - :param greedy_exploration: - """ - - if neural_network is None: - neural_network = SimpleDuelingNetwork(observation_shape=flatdim(observation_space), - action_shape=flatdim(action_space)) - - if not isinstance(neural_network, BaseDuelingNetwork): - raise TypeError("neural_network need to be instance of blobrl.agents.BaseDuelingNetwork, not :" + str( - type(neural_network))) - - super().__init__(action_space, observation_space, memory=memory, neural_network=neural_network, - step_copy=step_copy, step_train=step_train, batch_size=batch_size, gamma=gamma, loss=loss, - optimizer=optimizer, greedy_exploration=greedy_exploration, device=device) - - @classmethod - def load(cls, file_name, dire_name=".", device=None): - """ load agent form dire_name/file_name - - :param device: torch device to run agent - :type: torch.device - :param file_name: name of file for load - :type file_name: string - :param dire_name: name of directory where we would load it - :type file_name: string - """ - dict_save = torch.load(os.path.abspath(os.path.join(dire_name, file_name))) - - neural_network = pickle.loads(dict_save["neural_network_class"])( - observation_shape=flatdim(pickle.loads(dict_save["observation_space"])), - action_shape=flatdim(pickle.loads(dict_save["action_space"]))) - neural_network.load_state_dict(dict_save["neural_network"]) - - dueling_dqn = DuelingDQN(observation_space=pickle.loads(dict_save["observation_space"]), - action_space=pickle.loads(dict_save["action_space"]), - neural_network=neural_network, - step_train=pickle.loads(dict_save["step_train"]), - batch_size=pickle.loads(dict_save["batch_size"]), - gamma=pickle.loads(dict_save["gamma"]), - loss=pickle.loads(dict_save["loss"]), - optimizer=pickle.loads(dict_save["optimizer"]), - greedy_exploration=pickle.loads(dict_save["greedy_exploration"]), - device=device) - - dueling_dqn.step_copy = pickle.loads(dict_save["step_copy"]) - - return dueling_dqn - - def __str__(self): - return 'DuelingDQN-' + str(self.observation_space) + "-" + str(self.action_space) + "-" + str( - self.neural_network) + "-" + str(self.memory) + "-" + str(self.step_train) + "-" + str( - self.step) + "-" + str(self.batch_size) + "-" + str(self.gamma) + "-" + str(self.loss) + "-" + str( - self.optimizer) + "-" + str(self.greedy_exploration) + "-" + str(self.step_copy) diff --git a/blobrl/memories/experience_replay.py b/blobrl/memories/experience_replay.py index 6bcc120..37ff4de 100644 --- a/blobrl/memories/experience_replay.py +++ b/blobrl/memories/experience_replay.py @@ -6,7 +6,7 @@ class ExperienceReplay(MemoryInterface): - def __init__(self, max_size=300): + def __init__(self, max_size=5000): """ :param max_size: @@ -38,28 +38,8 @@ def extend(self, observations, actions, rewards, next_observations, dones): :param next_observations: :param dones: """ - datas = np.array((observations, actions, rewards, next_observations, dones)).T - - len_datas = len(datas) - - if len_datas > self.max_size: - datas = datas[-self.max_size:] - len_datas = self.max_size - - datas[:, 0] = [np.array(observation) for observation in datas[:, 0]] - datas[:, 3] = [np.array(next_observation) for next_observation in datas[:, 3]] - - idx_max = self.index + len_datas - - if idx_max > self.max_size: - idx_max = self.max_size - idx_max - self.buffer[self.index:self.max_size] = datas[:idx_max] - self.buffer[:idx_max] = datas[idx_max:] - else: - self.buffer[self.index:idx_max] = datas - - self.size = min(self.size + len_datas, self.max_size) - self.index = idx_max + for o, a, r, n, d in zip(observations, actions, rewards, next_observations, dones): + self.append(o, a, r, n, d) def sample(self, batch_size, device): """ diff --git a/blobrl/networks/__init__.py b/blobrl/networks/__init__.py index 3fec55c..6ce6207 100644 --- a/blobrl/networks/__init__.py +++ b/blobrl/networks/__init__.py @@ -3,3 +3,5 @@ from .base_dueling_network import BaseDuelingNetwork from .simple_dueling_network import SimpleDuelingNetwork from .c51_network import C51Network + +from .utils import get_last_layers diff --git a/blobrl/networks/base_dueling_network.py b/blobrl/networks/base_dueling_network.py index eaad7df..7459069 100644 --- a/blobrl/networks/base_dueling_network.py +++ b/blobrl/networks/base_dueling_network.py @@ -5,21 +5,31 @@ class BaseDuelingNetwork(BaseNetwork): @abc.abstractmethod - def __init__(self, observation_shape, action_shape): + def __init__(self, network): """ - :param observation_shape: - :param action_shape: + :param network: network when we add Value head + :type network: BaseNetwork and not BaseDuelingNetwork """ - super().__init__(observation_shape=observation_shape, action_shape=action_shape) - self.features = None - self.advantage = None - self.value = None + if not isinstance(network, BaseNetwork): + raise TypeError("network need to be instance of BaseNetwork, not :" + str(type(network))) + if isinstance(network, BaseDuelingNetwork): + raise TypeError("network can't be instance of BaseDuelingNetwork :" + str(type(network))) + + super().__init__(observation_space=network.observation_space, action_space=network.action_space) + self.network = network + self.value_outputs = None def forward(self, observation): x = observation.view(observation.shape[0], -1) - x = self.features(x) - advantage = self.advantage(x) - value = self.value(x) - return value + advantage - advantage.mean() + x = self.network.network(x) + + def map_forward(layers, last_tensor, value_outputs): + if isinstance(layers, list): + return [map_forward(layers, last_tensor, value_outputs) for layers in layers] + advantage = layers(last_tensor) + value = value_outputs(last_tensor) + return value + advantage - advantage.mean() + + return map_forward(self.network.outputs, x, self.value_outputs, ) diff --git a/blobrl/networks/base_network.py b/blobrl/networks/base_network.py index 66f13e9..751e4b6 100644 --- a/blobrl/networks/base_network.py +++ b/blobrl/networks/base_network.py @@ -1,19 +1,36 @@ import abc - +from gym.spaces import Space import torch.nn as nn class BaseNetwork(nn.Module, metaclass=abc.ABCMeta): @abc.abstractmethod - def __init__(self, observation_shape, action_shape): + def __init__(self, observation_space, action_space): """ - :param observation_shape: - :param action_shape: + :param observation_space: + :param action_space: """ super().__init__() - self.observation_space = observation_shape - self.action_space = action_shape + + if not isinstance(observation_space, Space): + raise TypeError("observation_space need to be Space not " + str(type(observation_space))) + if not isinstance(action_space, Space): + raise TypeError("action_space need to be Space not " + str(type(action_space))) + + self.observation_space = observation_space + self.action_space = action_space + self.outputs = None + self.network = None + + @abc.abstractmethod + def forward(self, observation): + """ + + :param observation: + :return: + """ + pass @abc.abstractmethod def __str__(self): diff --git a/blobrl/networks/c51_network.py b/blobrl/networks/c51_network.py index eb1856e..a82b4ec 100644 --- a/blobrl/networks/c51_network.py +++ b/blobrl/networks/c51_network.py @@ -1,42 +1,56 @@ import numpy as np import torch import torch.nn as nn +from gym.spaces import flatdim, Discrete, MultiDiscrete from blobrl.networks import BaseNetwork class C51Network(BaseNetwork): - def __init__(self, observation_shape, action_shape): + def __init__(self, observation_space, action_space): """ - :param observation_shape: - :param action_shape: + :param observation_space: + :param action_space: """ - super().__init__(observation_shape=observation_shape, action_shape=action_shape) + if not isinstance(action_space, (Discrete, MultiDiscrete)): + raise TypeError( + "action_space need to be instance of Discrete or MultiDiscrete, not :" + str(type(action_space))) - if not isinstance(observation_shape, (tuple, int)): - raise TypeError("observation_space need to be Space not " + str(type(observation_shape))) - if not isinstance(action_shape, (tuple, int)): - raise TypeError("action_space need to be Space not " + str(type(action_shape))) + super().__init__(observation_space=observation_space, action_space=action_space) self.NUM_ATOMS = 51 self.network = nn.Sequential() - self.network.add_module("C51_Linear_Input", nn.Linear(np.prod(self.observation_space), 64)) + self.network.add_module("C51_Linear_Input", nn.Linear(np.prod(flatdim(self.observation_space)), 64)) self.network.add_module("C51_LeakyReLU_Input", nn.LeakyReLU()) self.network.add_module("C51_Linear_1", nn.Linear(64, 64)) self.network.add_module("C51_LeakyReLU_1", nn.LeakyReLU()) self.distributional_list = [] - self.len_distributional = np.prod(self.action_space) + if isinstance(self.action_space, Discrete): + self.len_distributional = self.action_space.n - for i in range(self.len_distributional): - distributional = nn.Sequential() - distributional.add_module("C51_Distributional_" + str(i) + "_Linear", nn.Linear(64, self.NUM_ATOMS)) - distributional.add_module("C51_Distributional_" + str(i) + "_Softmax", nn.Softmax(dim=1)) + for i in range(self.len_distributional): + distributional = nn.Sequential() + distributional.add_module("C51_Distributional_" + str(i) + "_Linear", nn.Linear(64, self.NUM_ATOMS)) + distributional.add_module("C51_Distributional_" + str(i) + "_Softmax", nn.Softmax(dim=1)) - self.add_module("C51_Distributional_" + str(i) + "_Sequential", distributional) - self.distributional_list.append(distributional) + self.add_module("C51_Distributional_" + str(i) + "_Sequential", distributional) + self.distributional_list.append(distributional) + + elif isinstance(self.action_space, MultiDiscrete): + def gen_outputs(nvec): + dis = [] + for nspace in nvec: + if isinstance(nspace, (list, np.ndarray)): + dis.append(gen_outputs(nspace)) + else: + dis.append( + [nn.Sequential(nn.Linear(64, self.NUM_ATOMS), nn.Softmax(dim=1)) for i in range(nspace)]) + return dis + + self.distributional_list = gen_outputs(self.action_space.nvec) def forward(self, observation): """ @@ -46,13 +60,27 @@ def forward(self, observation): """ x = observation.view(observation.shape[0], -1) x = self.network(x) + if isinstance(self.action_space, Discrete): + q = [distributionalLayer(x) for distributionalLayer in self.distributional_list] + q = torch.cat(q) + q = torch.reshape(q, (self.action_space.n, -1, self.NUM_ATOMS)) + q = q.permute(1, 0, 2) + + return q + if isinstance(self.action_space, MultiDiscrete): + + def do_forward(nvec, llayers, x): + if isinstance(llayers[-1], list): + return [do_forward(n, l, x) for n, l in zip(nvec, llayers)] + + q = [distributionalLayer(x) for distributionalLayer in llayers] + q = torch.cat(q) + q = torch.reshape(q, (nvec, -1, self.NUM_ATOMS)) + q = q.permute(1, 0, 2) - q = [distributionalLayer(x) for distributionalLayer in self.distributional_list] - q = torch.cat(q) - q = torch.reshape(q, (self.len_distributional, -1, self.NUM_ATOMS)) - q = q.permute(1, 0, 2) + return q - return q + return do_forward(self.action_space.nvec, self.distributional_list, x) def __str__(self): return 'C51Network-' + str(self.observation_space) + "-" + str(self.action_space) diff --git a/blobrl/networks/simple_dueling_network.py b/blobrl/networks/simple_dueling_network.py index 769592c..9ccc201 100644 --- a/blobrl/networks/simple_dueling_network.py +++ b/blobrl/networks/simple_dueling_network.py @@ -1,36 +1,22 @@ -import numpy as np from torch import nn - from blobrl.networks import BaseDuelingNetwork class SimpleDuelingNetwork(BaseDuelingNetwork): - def __init__(self, observation_shape, action_shape): + def __init__(self, network): """ - :param observation_shape: - :param action_shape: + :param observation_space: + :param action_space: """ - super().__init__(observation_shape=observation_shape, action_shape=action_shape) - - self.features = nn.Sequential() - self.features.add_module("NetWorkSimple_Linear_Input", nn.Linear(np.prod(self.observation_space), 64)) - self.features.add_module("NetWorkSimple_LeakyReLU_Input", nn.LeakyReLU()) - self.features.add_module("NetWorkSimple_Linear_1", nn.Linear(64, 64)) - self.features.add_module("NetWorkSimple_LeakyReLU_1", nn.LeakyReLU()) - self.features.add_module("NetWorkSimple_Linear_Output", nn.Linear(64, 64)) - self.advantage = nn.Sequential( - nn.Linear(64, 64), - nn.LeakyReLU(), - nn.Linear(64, np.prod(self.action_space)) - ) + super().__init__(network=network) - self.value = nn.Sequential( + self.value_outputs = nn.Sequential( nn.Linear(64, 64), nn.LeakyReLU(), nn.Linear(64, 1) ) def __str__(self): - return 'SimpleDuelingNetwork-' + str(self.observation_space) + "-" + str(self.action_space) + return 'SimpleDuelingNetwork-' + str(self.network) diff --git a/blobrl/networks/simple_network.py b/blobrl/networks/simple_network.py index c1c482a..f6155de 100644 --- a/blobrl/networks/simple_network.py +++ b/blobrl/networks/simple_network.py @@ -1,29 +1,26 @@ import numpy as np import torch.nn as nn - +from gym.spaces import flatdim +from .utils import get_last_layers from blobrl.networks import BaseNetwork class SimpleNetwork(BaseNetwork): - def __init__(self, observation_shape, action_shape): + def __init__(self, observation_space, action_space): """ - :param observation_shape: - :param action_shape: + :param observation_space: + :param action_space: """ - super().__init__(observation_shape=observation_shape, action_shape=action_shape) - - if not isinstance(observation_shape, (tuple, int)): - raise TypeError("observation_space need to be Space not " + str(type(observation_shape))) - if not isinstance(action_shape, (tuple, int)): - raise TypeError("action_space need to be Space not " + str(type(action_shape))) + super().__init__(observation_space=observation_space, action_space=action_space) self.network = nn.Sequential() - self.network.add_module("NetWorkSimple_Linear_Input", nn.Linear(np.prod(self.observation_space), 64)) + self.network.add_module("NetWorkSimple_Linear_Input", nn.Linear(np.prod(flatdim(self.observation_space)), 64)) self.network.add_module("NetWorkSimple_LeakyReLU_Input", nn.LeakyReLU()) self.network.add_module("NetWorkSimple_Linear_1", nn.Linear(64, 64)) self.network.add_module("NetWorkSimple_LeakyReLU_1", nn.LeakyReLU()) - self.network.add_module("NetWorkSimple_Linear_Output", nn.Linear(64, np.prod(self.action_space))) + + self.outputs = get_last_layers(self.action_space, last_dim=64) def forward(self, observation): """ @@ -31,8 +28,16 @@ def forward(self, observation): :param observation: :return: """ + x = observation.view(observation.shape[0], -1) - return self.network(x) + x = self.network(x) + + def forwards(last_tensor, layers): + if isinstance(layers, list): + return [forwards(last_tensor, layers) for layers in layers] + return layers(last_tensor) + + return forwards(x, self.outputs) def __str__(self): return 'SimpleNetwork-' + str(self.observation_space) + "-" + str(self.action_space) diff --git a/blobrl/networks/utils.py b/blobrl/networks/utils.py new file mode 100644 index 0000000..c939e89 --- /dev/null +++ b/blobrl/networks/utils.py @@ -0,0 +1,38 @@ +from gym.spaces import flatdim +from gym.spaces import Box, Discrete, MultiDiscrete, MultiBinary, Tuple, Dict +import torch.nn as nn +import numpy as np + + +def get_last_layers(space, last_dim): + if isinstance(space, Box): + def map_box(ld, n): + if isinstance(n, list): + return [map_box(ld, x) for x in n] + return nn.Linear(ld, 1) + + return map_box(last_dim, np.empty(space.shape).tolist()) + if isinstance(space, Discrete): + return nn.Sequential(*[nn.Linear(last_dim, flatdim(space))]) + if isinstance(space, Tuple): + return [get_last_layers(s, last_dim) for s in space] + if isinstance(space, Dict): + return [get_last_layers(s, last_dim) for s in space.spaces.values()] + if isinstance(space, MultiBinary): + def map_multibinary(ld, n): + if isinstance(n, list): + return [map_multibinary(ld, x) for x in n] + + return nn.Sequential(*[nn.Linear(ld, 1), nn.Sigmoid()]) + + return map_multibinary(last_dim, np.empty(space.n).tolist()) + if isinstance(space, MultiDiscrete): + def map_multidiscrete(ld, n): + if isinstance(n, list): + return [map_multidiscrete(ld, x) for x in n] + + return nn.Sequential(*[nn.Linear(ld, n)]) + + return map_multidiscrete(last_dim, space.nvec.tolist()) + + raise NotImplementedError diff --git a/blobrl/trainer.py b/blobrl/trainer.py index 4722c5b..b1caff0 100644 --- a/blobrl/trainer.py +++ b/blobrl/trainer.py @@ -7,7 +7,7 @@ from IPython import display from blobrl import Logger, Record -from blobrl.agents import AgentInterface, AgentRandom, DQN, DoubleDQN, CategoricalDQN, DuelingDQN +from blobrl.agents import AgentInterface, AgentRandom, DQN, DoubleDQN, CategoricalDQN class Trainer: @@ -25,7 +25,7 @@ def __init__(self, environment, agent, log_dir="./runs"): self.agent = agent(observation_space=observation_space, action_space=action_space) elif isinstance(agent, AgentInterface): import warnings - warnings.warn("be sure of agent have good input and output dimension") + warnings.warn("be sure of your agent need to have good input and output dimension") self.agent = agent else: raise TypeError("this type (" + str(type(agent)) + ") is an AgentInterface or instance of AgentInterface") @@ -52,14 +52,10 @@ def do_step(self, observation, learn=True, logger=None, render=True): :param observation: - :param env: - :param agent: :param learn: :param logger: :param render: if show env render :type render: bool - :param on_notebook: if render is on notebook - :type on_notebook: bool :return: """ if render: @@ -75,13 +71,9 @@ def do_step(self, observation, learn=True, logger=None, render=True): def do_episode(self, logger=None, render=True): """ - :param env: - :param agent: :param logger: :param render: if show env render :type render: bool - :param on_notebook: if render is on notebook - :type on_notebook: bool """ self.agent.enable_exploration() observation = self.environment.reset() @@ -96,13 +88,9 @@ def do_episode(self, logger=None, render=True): def evaluate(self, logger=None, render=True): """ - :param env: - :param agent: :param logger: :param render: if show env render :type render: bool - :param on_notebook: if render is on notebook - :type on_notebook: bool """ self.agent.disable_exploration() observation = self.environment.reset() @@ -120,16 +108,19 @@ def train(self, max_episode=1000, nb_evaluation=4, render=True): :param max_episode: :param render: if show env render :type render: bool - :param on_notebook: if render is on notebook - :type on_notebook: bool """ self.environment.reset() for i_episode in range(1, max_episode + 1): self.do_episode(logger=self.logger, render=render) - if nb_evaluation != 0: - if i_episode == 1 or i_episode == max_episode or i_episode % (max_episode // (nb_evaluation - 1)) == 0: - self.evaluate(logger=self.logger, render=True) + if nb_evaluation > 0: + if nb_evaluation <= 1: + if i_episode == max_episode: + self.evaluate(logger=self.logger, render=render) + + elif i_episode == 1 or i_episode == max_episode or i_episode % int( + max_episode // (nb_evaluation - 1)) == 0: + self.evaluate(logger=self.logger, render=render) self.close() def render(self): @@ -182,8 +173,6 @@ def arg_to_agent(arg_agent) -> AgentInterface: return DoubleDQN if arg_agent == "categorical_dqn": return CategoricalDQN - if arg_agent == "dueling_dqn": - return DuelingDQN raise ValueError("this agent (" + str(arg_agent) + ") is not implemented") @@ -194,9 +183,6 @@ def arg_to_agent(arg_agent) -> AgentInterface: parser.add_argument('--max_episode', type=int, help='number of episode to train', nargs='?', const=1, default=100) parser.add_argument('--render', type=bool, help='if show render on each step or not', nargs='?', const=1, default=False) - # parser.add_argument('--train', type=bool, help='if train agent or not', nargs='?', const=1, - # default=True) - # parser.add_argument('--file_path', type=str, help='path to file for load trained agent') args = parser.parse_args() trainer = Trainer(environment=args.env, agent=arg_to_agent(args.agent)) diff --git a/examples/example_train_jupyter.ipynb b/examples/example_train_jupyter.ipynb new file mode 100644 index 0000000..b460bdf --- /dev/null +++ b/examples/example_train_jupyter.ipynb @@ -0,0 +1,723 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "ZTuJfCB_CYzh" + }, + "source": [ + "## OpenAI Gym Available Environment\n", + "\n", + "Gym comes with a diverse suite of environments that range from easy to difficult and involve many different kinds of data. View the [full list of environments](https://gym.openai.com/envs) to get the birds-eye view.\n", + "\n", + "- [Classic control](https://gym.openai.com/envs#classic_control) and [toy text](https://gym.openai.com/envs#toy_text): complete small-scale tasks, mostly from the RL literature. They’re here to get you started.\n", + "\n", + "- [Algorithmic](https://gym.openai.com/envs#algorithmic): perform computations such as adding multi-digit numbers and reversing sequences. One might object that these tasks are easy for a computer. The challenge is to learn these algorithms purely from examples. These tasks have the nice property that it’s easy to vary the difficulty by varying the sequence length.\n", + "\n", + "- [Atari](https://gym.openai.com/envs#atari): play classic Atari games. \n", + "\n", + "- [2D and 3D robots](https://gym.openai.com/envs#mujoco): control a robot in simulation. These tasks use the MuJoCo physics engine, which was designed for fast and accurate robot simulation. \n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Don't forget to set matplotlib to inline" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib\n", + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "NEnwOnr6qHbO" + }, + "source": [ + "# CartPole-v1 exemple" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "Fj9G5c5kqM_y" + }, + "source": [ + "## Initialize environment" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 35 + }, + "colab_type": "code", + "id": "5Q7p1iRwdGpE", + "outputId": "fa6bfbf8-a868-402e-a1d7-344ba712b002" + }, + "outputs": [], + "source": [ + "import gym\n", + "env = gym.make('CartPole-v1')\n", + "_ = env.reset()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "observation_space= Box(4,) action_space= Discrete(2)\n" + ] + } + ], + "source": [ + "print(\"observation_space=\",env.observation_space, \"action_space=\",env.action_space)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "mxUtaISFqQ4Q" + }, + "source": [ + "## Initialize agent" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "srJt2AZPp58A" + }, + "outputs": [], + "source": [ + "from blobrl.agents import DQN\n", + "agent = DQN(observation_space=env.observation_space, action_space=env.action_space)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "EX_NctEhqfgu" + }, + "source": [ + "## Train" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "fldGbNR2qlDo" + }, + "source": [ + "Create Trainer" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 54 + }, + "colab_type": "code", + "id": "CY1LF52LqeyH", + "outputId": "96567768-4a32-4e02-8fc8-c3f7e17897ab" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "D:\\Users\\nathan\\Anaconda3\\envs\\RL\\lib\\site-packages\\blobrl\\trainer.py:28: UserWarning: be sure of your agent need to have good input and output dimension\n", + " warnings.warn(\"be sure of your agent need to have good input and output dimension\")\n" + ] + } + ], + "source": [ + "from blobrl import Trainer\n", + "trainer = Trainer(environment=env, agent=agent, log_dir=\"./logs\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "cGq2ksiRqkRR" + }, + "source": [ + "Start train" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "0BACJeOjqkXO" + }, + "outputs": [], + "source": [ + "trainer.train(max_episode=200, nb_evaluation=0, render=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjwAAAGCCAYAAADkJxkCAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy86wFpkAAAACXBIWXMAAAsTAAALEwEAmpwYAAAM6klEQVR4nO3dW4ycZ2HG8febmfV6ba8dx4HESTkECOEYmqZN6UGtVJSgBgUhqHJRoV6nvYAbkJAAoXDXIC56USH1IFWtqlZVLqBVoWmgVEWQNpRwTCDBuElx3AUf1mxsr3d2Z4YrXIzZHce7+72zz/x+0kr2zmvvI0se/ffbnf2a0WhUAACSdWoPAADYboIHAIgneACAeIIHAIgneACAeIIHAIjXG/O416wDADtFs94DrvAAAPEEDwAQT/AAAPEEDwAQT/AAAPEEDwAQT/AAAPEEDwAQT/AAAPEEDwAQT/AAAPEEDwAQT/AAAPEEDwAQT/AAAPEEDwAQT/AAAPEEDwAQT/AAAPEEDwAQT/AAAPEEDwAQT/AAAPEEDwAQT/AAAPEEDwAQT/AAAPEEDwAQT/AAAPEEDwAQT/AAAPEEDwAQT/AAAPEEDwAQT/AAAPEEDwAQT/AAAPEEDwAQT/AAAPEEDwAQT/AAAPEEDwAQT/AAAPEEDwAQT/AAAPEEDwAQT/AAAPEEDwAQT/AAAPEEDwAQT/AAAPEEDwAQT/AAAPEEDwAQT/AAAPEEDwAQT/AAAPEEDwAQT/AAAPEEDwAQT/AAAPEEDwAQT/AAAPEEDwAQT/AAAPEEDwAQT/AAAPEEDwAQT/AAAPEEDwAQT/AAAPEEDwAQT/AAAPEEDwAQT/AAAPEEDwAQT/AAAPEEDwAQT/AAAPEEDwAQT/AAAPEEDwAQT/AAAPEEDwAQT/AAAPEEDwAQT/AAAPEEDwAQT/AAAPEEDwAQT/AAAPEEDwAQT/AAAPEEDwAQT/AAAPEEDwAQT/AAAPEEDwAQT/AAAPEEDwAQT/AAAPEEDwAQT/AAAPEEDwAQT/AAAPEEDwAQT/AAAPEEDwAQT/AAAPEEDwAQT/AAAPEEDwAQT/AAAPEEDwAQT/AAAPEEDwAQT/AAAPEEDwAQT/AAAPEEDwAQT/AAAPEEDwAQT/AAAPEEDwAQT/AAAPEEDwAQT/AAAPEEDwAQT/AAAPEEDwAQT/AAAPEEDwAQT/AAAPEEDwAQT/AAAPEEDwAQT/AAAPEEDwAQT/AAAPEEDwAQT/AAAPEEDwAQT/AAAPEEDwAQT/AAAPEEDwAQT/AAAPEEDwAQT/AAAPEEDwAQT/AAAPEEDwAQT/AAAPEEDwAQT/AAAPEEDwAQT/AAAPEEDwAQT/AAAPEEDwAQT/AAW2q41i/9s4tl0L9QewrARb3aA4CdazQaleXTz13yvuef+3Y59p8PlRt/5R3l8O2/W2kZwKUED/CCnV04UoaD1VJGo/LdT/9J7TkAYwke4AU7+tk/L6vnz9SeAXDFfA8PsC3OLnz3si93AdQieIAX7PAvva00nY0vEC99/4ly7sQz7QwCGEPwAC/Yi173W6XpdmvPALhiggcAiCd4AIB4ggfYNscefagsPfed2jMABA9wdW5794OlNBs/hQz658tosNbSIoD1CR7gqnRndteeAHDFBA8AEE/wANtq0D9fRsNB7RnAlBM8wFXbfc0NY8/8z7/9ZTl/6lgLawDWJ3iAq/a63/tw7QkAV0TwAADxBA8AEE/wANvuxBOfL2sr52rPAKaY4AE2oSkv/c3fH3vq1NOPlkF/uYU9AD+f4AGuWtM05dpb3lx7BsBYggcAiCd4gFY89Y8f82UtoBrBA2xKp7ervP6+B8aeWz13poxGoxYWAVxO8ACb0jRN6e6aqz0DYEOCBwCIJ3iA1gxWzvmyFlCF4AE2r+mUmb0Hxx771t9/yJ3TgSoED7BpM3Pz5ZZ73lt7BsC6BA8AEE/wAADxBA/QqoWvfto3LgOtEzzAlpjZc6Bcf9tdY8/93+P/3MIagEsJHmBL9Gb3lAMvu632DICfS/AAAPEED9C6Jx/6aO0JwJQRPMCW2Xf9K8vLfvsPxp5bWfphC2sA/p/gAbZM0+mW7szu2jMALiN4AIB4ggdo36iUQX+59gpgiggeYEt1erOlO7tnwzOj4Vr55t99sKVFAIIH2GIHXvqGcuMd99aeAXAJwQMAxBM8AEA8wQNUMRqslhNP/kftGcCUEDzAltt3+JYyf9NrNjwzXOuXha8/3NIiYNoJHmDL7Tn0krL3RS+vPQPgIsEDAMQTPEA1q+cWy9HP/UXtGcAUEDxANaPhoPSfP1l7BjAFBA+wLQ7fcW+59lW/WnsGQClF8ADbpNPtlabTrT0DoJQieIDKRqNhGa6t1p4BhBM8wLbpzs6VptPb8Mz5E8+Wo5/7s5YWAdNK8ADb5iW/dl+Zv+nW2jMABA8AkE/wAADxBA9Q3crSybJ07Nu1ZwDBBA+wra679TfKzN5rNjxzYfF4WTz63+0MAqaS4AG21cFX3FFm5vbXngFMOcEDAMQTPABAPMEDTITTRx4rP/jGI7VnAKEED7DtXn3v+8qufddueGa41i+D/nJLi4BpI3iAbdedmS2lNLVnAFNM8AATYzQcltFwUHsGEEjwAK3o7d479szC1z5TTj71xRbWANNG8ACteO07P1i6u+ZqzwCmlOABAOIJHgAgnuABJsrzx58uK0sna88AwggeoDU33H5PKc3GL09f/N6Xy/LpYy0tAqaF4AFac8Ob7i5N42kHaJ9nHgAgnuABAOIJHmDiPPuFvy3nT/5v7RlAEMEDtOq2dz849sza8lIZDlZbWANMC8EDtKo7O/4WEwBbTfAAE2k0HJTRaFR7BhBC8AAT6el/+nhZWfph7RlACMEDtG7XvkO1JwBTRvAArWqaprz+vgdqzwCmjOABAOIJHmBiLR59vAzXvDwd2DzBA7Su6XTKi9/4lrHnjn/5k2WweqGFRUA6wQO0rul0y+Hb76k9A5giggcAiCd4AIB4ggeYaE996o/dVwvYNMEDVNGd3VNe844PjD23snSiFHeYADZJ8ABVNE2n9Hbvqz0DmBKCBwCIJ3iAiTdYveDO6cCmCB6gnqZTervnxx77xt+8r4UxQDLBA1QzO3+ovPKtf1R7BjAFBA8AEE/wADvC6SOP1Z4A7GCCB6hqZs+Bcs3Lf3HsuWf+/a+2fQuQS/AAVc3OHyqHXv3rtWcA4QQPABBP8AAA8QQPsDOMhuU7n3qw9gpghxI8QHX7f+G15aY73zn23IXF4y2sARIJHqC6Tm9X6c7uqT0DCCZ4gE1bWFgovV5vU29/eP/9Yz/OmTNnNv1xfvbtkUceaeFfCKitV3sAkGEwGGzuzw+HV3Su1yllZXVzH+unuSkpTAdXeICJsLK6Vs6vrF78fX84W1YGc5e87Zo9UB564L6KK4GdyhUeYCJ85r+OlJuu21/uf/svl+XB3vLYqbeWpbXrLjnTa/rlls5fV1oI7GSu8AAT5eza/vL44u9cFjullLI22lUePfW2CquAnU7wABPle2ffVE73D6/7+MzM7nLzzXe2uAhIIHiAifHkMyfK0eOnNzyzd+/Bctdd72lpEZBC8AAT40tPfL987cgPas8AAgkeYEeZaS6UW+e/UnsGsMMIHmCi3LzvW+XgzPpXeXqd1XL97mdbXAQkEDzARNnfWyy3H/x8me+duuyxbtMvd177LxVWATudn8MDTJQ//eRj5cbr5subb10tw1G3lFLKuz7yD6W/OihNMyqf6J4tq2tX9lOZAX5C8AAT5UfnVsr7P/Gvpdf97MX3Pb/cv/jrMxU2ATvfhsHjHjPAldjq54rl/tqW/n3jeK6DDE3TrP/YRv/R5+bmPAsAY41Go7KyslJ7xlWZmZkp3W639gxgCywvL69bPBsGTylF8ABjLSwslMOH1//pyJPs4YcfLnfffXftGcDWWDd4vEoLAIgneACAeIIHAIgneACAeIIHAIgneACAeIIHAIgneACAeIIHAIgneACAeIIHAIi34d3SAa5Ur7czn046HZ/3wTRw81AAIIWbhwIA00vwAADxBA8AEE/wAADxBA8AEE/wAADxBA8AEE/wAADxBA8AEE/wAADxBA8AEE/wAADxBA8AEE/wAADxBA8AEE/wAADxBA8AEE/wAADxBA8AEE/wAADxBA8AEE/wAADxBA8AEE/wAADxBA8AEE/wAADxBA8AEE/wAADxBA8AEE/wAADxBA8AEE/wAADxBA8AEE/wAADxBA8AEE/wAADxBA8AEE/wAADxBA8AEE/wAADxBA8AEE/wAADxBA8AEE/wAADxBA8AEE/wAADxBA8AEE/wAADxBA8AEE/wAADxBA8AEE/wAADxBA8AEE/wAADxBA8AEE/wAADxBA8AEE/wAADxBA8AEE/wAADxBA8AEE/wAADxBA8AEE/wAADxBA8AEE/wAADxBA8AEE/wAADxBA8AEE/wAADxBA8AEE/wAADxBA8AEE/wAADxBA8AEE/wAADxBA8AEE/wAADxemMeb1pZAQCwjVzhAQDiCR4AIJ7gAQDiCR4AIJ7gAQDiCR4AIN6PAbD2vOY4stcpAAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "trainer.evaluate()" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "env.close()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# MountainCar-v0 exemple" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "import gym\n", + "env = gym.make('MountainCar-v0')\n", + "_ = env.reset()" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "observation_space= Box(2,) action_space= Discrete(3)\n" + ] + } + ], + "source": [ + "print(\"observation_space=\",env.observation_space, \"action_space=\",env.action_space)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Initialize agent" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "from blobrl.agents import DQN\n", + "agent = DQN(observation_space=env.observation_space, action_space=env.action_space)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Train" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Create Trainer" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "from blobrl import Trainer\n", + "trainer = Trainer(environment=env, agent=agent, log_dir=\"./logs\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Start train" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "trainer.train(max_episode=200, nb_evaluation=0, render=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "trainer.evaluate()" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "env.close()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Copy-v0 exemple" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "import gym\n", + "env = gym.make('Copy-v0')\n", + "_ = env.reset()" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "observation_space= Discrete(6) action_space= Tuple(Discrete(2), Discrete(2), Discrete(5))\n" + ] + } + ], + "source": [ + "print(\"observation_space=\",env.observation_space, \"action_space=\",env.action_space)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Initialize agent" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from blobrl.agents import DQN\n", + "agent = DQN(observation_space=env.observation_space, action_space=env.action_space)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Train" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Create Trainer" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "from blobrl import Trainer\n", + "trainer = Trainer(environment=env, agent=agent, log_dir=\"./logs\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Start train" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "trainer.train(max_episode=200, nb_evaluation=0, render=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "trainer.evaluate()\n", + "env.close()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# FrozenLake-v0 example" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [], + "source": [ + "import gym\n", + "env = gym.make('FrozenLake-v0')\n", + "_ = env.reset()" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "observation_space= Discrete(16) action_space= Discrete(4)\n" + ] + } + ], + "source": [ + "print(\"observation_space=\",env.observation_space, \"action_space=\",env.action_space)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Initialize agent" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [], + "source": [ + "from blobrl.agents import DQN\n", + "agent = DQN(observation_space=env.observation_space, action_space=env.action_space)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Train" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Create Trainer" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "D:\\Users\\nathan\\Anaconda3\\envs\\RL\\lib\\site-packages\\blobrl\\trainer.py:28: UserWarning: be sure of your agent need to have good input and output dimension\n", + " warnings.warn(\"be sure of your agent need to have good input and output dimension\")\n" + ] + } + ], + "source": [ + "from blobrl import Trainer\n", + "trainer = Trainer(environment=env, agent=agent, log_dir=\"./logs\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Start train" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [], + "source": [ + "trainer.train(max_episode=1000, nb_evaluation=0, render=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " (Right)\n", + "SFFF\n", + "\u001b[41mF\u001b[0mHFH\n", + "FFFH\n", + "HFFG\n", + " (Right)\n", + "SFFF\n", + "\u001b[41mF\u001b[0mHFH\n", + "FFFH\n", + "HFFG\n" + ] + } + ], + "source": [ + "trainer.evaluate()\n", + "env.close()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Assault-v0 example" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [], + "source": [ + "import gym\n", + "env = gym.make('Assault-v0')\n", + "_ = env.reset()" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "observation_space= Box(250, 160, 3) action_space= Discrete(7)\n" + ] + } + ], + "source": [ + "print(\"observation_space=\",env.observation_space, \"action_space=\",env.action_space)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Initialize agent" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [], + "source": [ + "from blobrl.agents import DQN\n", + "agent = DQN(observation_space=env.observation_space, action_space=env.action_space)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Train" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Create Trainer" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [], + "source": [ + "from blobrl import Trainer\n", + "trainer = Trainer(environment=env, agent=agent, log_dir=\"./logs\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Start train" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [], + "source": [ + "trainer.train(max_episode=200, nb_evaluation=0, render=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "trainer.evaluate()\n", + "env.close()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [], + "name": "Copie de Gym_Envs_1_preamble_evn_list.ipynb", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/results/Analyse_result.ipynb b/results/Analyse_result.ipynb index 53b6994..a34f3d6 100644 --- a/results/Analyse_result.ipynb +++ b/results/Analyse_result.ipynb @@ -9,7 +9,7 @@ }, { "cell_type": "code", - "execution_count": 62, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -17,7 +17,7 @@ "import pandas as pd\n", "from pandas.core.common import flatten\n", "import seaborn as sns\n", - "import matplotlib.pyplot as plt\n" + "import matplotlib.pyplot as plt" ] }, { @@ -75,6 +75,15 @@ " df = pd.concat([df,ite], ignore_index=True)" ] }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "df.drop(columns=[\"max\",\"min\",\"avg\"], inplace=True)" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -84,7 +93,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -120,9 +129,6 @@ " memories\n", " max_size\n", " step\n", - " max\n", - " min\n", - " avg\n", " sum\n", " \n", " \n", @@ -131,191 +137,161 @@ " 0\n", " CategoricalDQN\n", " 1\n", - " 1\n", - " 0.99\n", - " AdaptativeEpsilonGreedy-0.3-0.1-50000-0\n", + " 32\n", + " 0.95\n", + " AdaptativeEpsilonGreedy-0.3-0.1-30000-0\n", " C51Network\n", " \n", " Adam\n", " 0.0001\n", " ExperienceReplay\n", - " 128\n", + " 2048\n", " 1\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", " 9.0\n", " \n", " \n", " 1\n", " CategoricalDQN\n", " 1\n", - " 1\n", - " 0.99\n", - " AdaptativeEpsilonGreedy-0.3-0.1-50000-0\n", + " 32\n", + " 0.95\n", + " AdaptativeEpsilonGreedy-0.3-0.1-30000-0\n", " C51Network\n", " \n", " Adam\n", " 0.0001\n", " ExperienceReplay\n", - " 128\n", - " 166\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", + " 2048\n", + " 10\n", " 10.0\n", " \n", " \n", " 2\n", " CategoricalDQN\n", " 1\n", - " 1\n", - " 0.99\n", - " AdaptativeEpsilonGreedy-0.3-0.1-50000-0\n", + " 32\n", + " 0.95\n", + " AdaptativeEpsilonGreedy-0.3-0.1-30000-0\n", " C51Network\n", " \n", " Adam\n", " 0.0001\n", " ExperienceReplay\n", - " 128\n", - " 332\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 9.0\n", + " 2048\n", + " 20\n", + " 10.0\n", " \n", " \n", " 3\n", " CategoricalDQN\n", " 1\n", - " 1\n", - " 0.99\n", - " AdaptativeEpsilonGreedy-0.3-0.1-50000-0\n", + " 32\n", + " 0.95\n", + " AdaptativeEpsilonGreedy-0.3-0.1-30000-0\n", " C51Network\n", " \n", " Adam\n", " 0.0001\n", " ExperienceReplay\n", - " 128\n", - " 498\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 11.0\n", + " 2048\n", + " 30\n", + " 9.0\n", " \n", " \n", " 4\n", " CategoricalDQN\n", " 1\n", - " 1\n", - " 0.99\n", - " AdaptativeEpsilonGreedy-0.3-0.1-50000-0\n", + " 32\n", + " 0.95\n", + " AdaptativeEpsilonGreedy-0.3-0.1-30000-0\n", " C51Network\n", " \n", " Adam\n", " 0.0001\n", " ExperienceReplay\n", - " 128\n", - " 500\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", + " 2048\n", + " 40\n", " 9.0\n", " \n", " \n", " 5\n", " CategoricalDQN\n", " 1\n", - " 1\n", - " 0.99\n", - " AdaptativeEpsilonGreedy-0.3-0.1-50000-0\n", + " 32\n", + " 0.95\n", + " AdaptativeEpsilonGreedy-0.3-0.1-30000-0\n", " C51Network\n", " \n", " Adam\n", " 0.0001\n", " ExperienceReplay\n", - " 16\n", - " 1\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 14.0\n", + " 2048\n", + " 50\n", + " 8.0\n", " \n", " \n", " 6\n", " CategoricalDQN\n", " 1\n", - " 1\n", - " 0.99\n", - " AdaptativeEpsilonGreedy-0.3-0.1-50000-0\n", + " 32\n", + " 0.95\n", + " AdaptativeEpsilonGreedy-0.3-0.1-30000-0\n", " C51Network\n", " \n", " Adam\n", " 0.0001\n", " ExperienceReplay\n", - " 16\n", - " 166\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 8.0\n", + " 2048\n", + " 60\n", + " 9.0\n", " \n", " \n", " 7\n", " CategoricalDQN\n", " 1\n", - " 1\n", - " 0.99\n", - " AdaptativeEpsilonGreedy-0.3-0.1-50000-0\n", + " 32\n", + " 0.95\n", + " AdaptativeEpsilonGreedy-0.3-0.1-30000-0\n", " C51Network\n", " \n", " Adam\n", " 0.0001\n", " ExperienceReplay\n", - " 16\n", - " 332\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 9.0\n", + " 2048\n", + " 70\n", + " 10.0\n", " \n", " \n", " 8\n", " CategoricalDQN\n", " 1\n", - " 1\n", - " 0.99\n", - " AdaptativeEpsilonGreedy-0.3-0.1-50000-0\n", + " 32\n", + " 0.95\n", + " AdaptativeEpsilonGreedy-0.3-0.1-30000-0\n", " C51Network\n", " \n", " Adam\n", " 0.0001\n", " ExperienceReplay\n", - " 16\n", - " 498\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", + " 2048\n", + " 80\n", " 10.0\n", " \n", " \n", " 9\n", " CategoricalDQN\n", " 1\n", - " 1\n", - " 0.99\n", - " AdaptativeEpsilonGreedy-0.3-0.1-50000-0\n", + " 32\n", + " 0.95\n", + " AdaptativeEpsilonGreedy-0.3-0.1-30000-0\n", " C51Network\n", " \n", " Adam\n", " 0.0001\n", " ExperienceReplay\n", - " 16\n", - " 500\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 10.0\n", + " 2048\n", + " 90\n", + " 9.0\n", " \n", " \n", "\n", @@ -323,43 +299,43 @@ ], "text/plain": [ " algo step_train batch_size gamma \\\n", - "0 CategoricalDQN 1 1 0.99 \n", - "1 CategoricalDQN 1 1 0.99 \n", - "2 CategoricalDQN 1 1 0.99 \n", - "3 CategoricalDQN 1 1 0.99 \n", - "4 CategoricalDQN 1 1 0.99 \n", - "5 CategoricalDQN 1 1 0.99 \n", - "6 CategoricalDQN 1 1 0.99 \n", - "7 CategoricalDQN 1 1 0.99 \n", - "8 CategoricalDQN 1 1 0.99 \n", - "9 CategoricalDQN 1 1 0.99 \n", + "0 CategoricalDQN 1 32 0.95 \n", + "1 CategoricalDQN 1 32 0.95 \n", + "2 CategoricalDQN 1 32 0.95 \n", + "3 CategoricalDQN 1 32 0.95 \n", + "4 CategoricalDQN 1 32 0.95 \n", + "5 CategoricalDQN 1 32 0.95 \n", + "6 CategoricalDQN 1 32 0.95 \n", + "7 CategoricalDQN 1 32 0.95 \n", + "8 CategoricalDQN 1 32 0.95 \n", + "9 CategoricalDQN 1 32 0.95 \n", "\n", " greedy_exploration network optimizer lr \\\n", - "0 AdaptativeEpsilonGreedy-0.3-0.1-50000-0 C51Network Adam 0.0001 \n", - "1 AdaptativeEpsilonGreedy-0.3-0.1-50000-0 C51Network Adam 0.0001 \n", - "2 AdaptativeEpsilonGreedy-0.3-0.1-50000-0 C51Network Adam 0.0001 \n", - "3 AdaptativeEpsilonGreedy-0.3-0.1-50000-0 C51Network Adam 0.0001 \n", - "4 AdaptativeEpsilonGreedy-0.3-0.1-50000-0 C51Network Adam 0.0001 \n", - "5 AdaptativeEpsilonGreedy-0.3-0.1-50000-0 C51Network Adam 0.0001 \n", - "6 AdaptativeEpsilonGreedy-0.3-0.1-50000-0 C51Network Adam 0.0001 \n", - "7 AdaptativeEpsilonGreedy-0.3-0.1-50000-0 C51Network Adam 0.0001 \n", - "8 AdaptativeEpsilonGreedy-0.3-0.1-50000-0 C51Network Adam 0.0001 \n", - "9 AdaptativeEpsilonGreedy-0.3-0.1-50000-0 C51Network Adam 0.0001 \n", + "0 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 C51Network Adam 0.0001 \n", + "1 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 C51Network Adam 0.0001 \n", + "2 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 C51Network Adam 0.0001 \n", + "3 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 C51Network Adam 0.0001 \n", + "4 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 C51Network Adam 0.0001 \n", + "5 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 C51Network Adam 0.0001 \n", + "6 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 C51Network Adam 0.0001 \n", + "7 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 C51Network Adam 0.0001 \n", + "8 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 C51Network Adam 0.0001 \n", + "9 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 C51Network Adam 0.0001 \n", "\n", - " memories max_size step max min avg sum \n", - "0 ExperienceReplay 128 1 1.0 1.0 1.0 9.0 \n", - "1 ExperienceReplay 128 166 1.0 1.0 1.0 10.0 \n", - "2 ExperienceReplay 128 332 1.0 1.0 1.0 9.0 \n", - "3 ExperienceReplay 128 498 1.0 1.0 1.0 11.0 \n", - "4 ExperienceReplay 128 500 1.0 1.0 1.0 9.0 \n", - "5 ExperienceReplay 16 1 1.0 1.0 1.0 14.0 \n", - "6 ExperienceReplay 16 166 1.0 1.0 1.0 8.0 \n", - "7 ExperienceReplay 16 332 1.0 1.0 1.0 9.0 \n", - "8 ExperienceReplay 16 498 1.0 1.0 1.0 10.0 \n", - "9 ExperienceReplay 16 500 1.0 1.0 1.0 10.0 " + " memories max_size step sum \n", + "0 ExperienceReplay 2048 1 9.0 \n", + "1 ExperienceReplay 2048 10 10.0 \n", + "2 ExperienceReplay 2048 20 10.0 \n", + "3 ExperienceReplay 2048 30 9.0 \n", + "4 ExperienceReplay 2048 40 9.0 \n", + "5 ExperienceReplay 2048 50 8.0 \n", + "6 ExperienceReplay 2048 60 9.0 \n", + "7 ExperienceReplay 2048 70 10.0 \n", + "8 ExperienceReplay 2048 80 10.0 \n", + "9 ExperienceReplay 2048 90 9.0 " ] }, - "execution_count": 7, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } @@ -377,11 +353,11 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ - "for c in [\"step_train\", \"batch_size\", \"gamma\", \"lr\", \"step\", \"max\", \"min\", \"avg\", \"sum\"]:\n", + "for c in [\"step_train\", \"batch_size\", \"gamma\", \"lr\", \"step\", \"sum\"]:\n", " df[c] = df[c].astype(float)\n", "for c in df.columns:\n", " if df[c].dtypes == \"object\":\n", @@ -390,7 +366,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 10, "metadata": {}, "outputs": [ { @@ -426,604 +402,696 @@ " memories\n", " max_size\n", " step\n", - " max\n", - " min\n", - " avg\n", " sum\n", " \n", " \n", " \n", " \n", - " 3436\n", - " DQN\n", - " 1.0\n", - " 32.0\n", - " 0.99\n", - " AdaptativeEpsilonGreedy-0.3-0.1-50000-0\n", - " SimpleNetwork\n", - " \n", - " Adam\n", - " 0.0010\n", - " ExperienceReplay\n", - " 128\n", - " 166.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 500.0\n", - " \n", - " \n", - " 3711\n", - " DQN\n", - " 1.0\n", - " 64.0\n", - " 0.99\n", - " EpsilonGreedy-0.1\n", - " SimpleNetwork\n", - " \n", - " Adam\n", - " 0.0010\n", - " ExperienceReplay\n", - " 16\n", - " 166.0\n", - " 1.0\n", - " 1.0\n", + " 12340\n", + " DoubleDQN\n", " 1.0\n", - " 500.0\n", - " \n", - " \n", - " 4176\n", - " DQN\n", " 32.0\n", - " 64.0\n", - " 0.99\n", - " AdaptativeEpsilonGreedy-0.3-0.1-50000-0\n", - " SimpleNetwork\n", + " 1.00\n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0\n", + " SimpleDuelingNetwork\n", " \n", " Adam\n", - " 0.1000\n", + " 0.001\n", " ExperienceReplay\n", - " 16\n", - " 166.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", + " 2048\n", + " 20.0\n", " 500.0\n", " \n", " \n", - " 2117\n", + " 12930\n", " DoubleDQN\n", " 1.0\n", - " 64.0\n", - " 0.99\n", - " EpsilonGreedy-0.6\n", - " SimpleNetwork\n", - " \n", - " Adam\n", - " 0.0001\n", - " ExperienceReplay\n", - " 128\n", - " 332.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 500.0\n", - " \n", - " \n", - " 3442\n", - " DQN\n", - " 1.0\n", " 32.0\n", - " 0.99\n", - " AdaptativeEpsilonGreedy-0.3-0.1-50000-0\n", + " 1.00\n", + " EpsilonGreedy-0.1\n", " SimpleNetwork\n", " \n", " Adam\n", - " 0.0010\n", + " 0.001\n", " ExperienceReplay\n", - " 16\n", - " 332.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", + " 512\n", + " 30.0\n", " 500.0\n", " \n", " \n", - " 3537\n", + " 31127\n", " DQN\n", " 1.0\n", " 32.0\n", - " 0.99\n", - " EpsilonGreedy-0.1\n", + " 1.00\n", + " EpsilonGreedy-0.6\n", " SimpleNetwork\n", " \n", " Adam\n", - " 0.0010\n", + " 0.001\n", " ExperienceReplay\n", - " 32\n", - " 332.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", + " 2048\n", + " 30.0\n", " 500.0\n", " \n", " \n", - " 3612\n", + " 31716\n", " DQN\n", " 1.0\n", " 64.0\n", - " 0.99\n", - " AdaptativeEpsilonGreedy-0.3-0.1-50000-0\n", - " SimpleNetwork\n", + " 0.95\n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0\n", + " SimpleDuelingNetwork\n", " \n", " Adam\n", - " 0.0001\n", + " 0.001\n", " ExperienceReplay\n", - " 32\n", - " 332.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", + " 512\n", + " 30.0\n", " 500.0\n", " \n", " \n", - " 3717\n", + " 34103\n", " DQN\n", " 1.0\n", " 64.0\n", " 0.99\n", - " EpsilonGreedy-0.1\n", - " SimpleNetwork\n", - " \n", - " Adam\n", - " 0.0010\n", - " ExperienceReplay\n", - " 32\n", - " 332.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 500.0\n", - " \n", - " \n", - " 4797\n", - " DQN\n", - " 4.0\n", - " 64.0\n", - " 0.99\n", - " EpsilonGreedy-0.1\n", + " EpsilonGreedy-0.6\n", " SimpleNetwork\n", " \n", " Adam\n", - " 0.0010\n", + " 0.001\n", " ExperienceReplay\n", - " 32\n", - " 332.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", + " 2048\n", + " 30.0\n", " 500.0\n", " \n", " \n", - " 1953\n", + " 11412\n", " DoubleDQN\n", " 1.0\n", " 32.0\n", " 0.99\n", - " EpsilonGreedy-0.6\n", + " EpsilonGreedy-0.1\n", " SimpleNetwork\n", " \n", " Adam\n", - " 0.0010\n", + " 0.001\n", " ExperienceReplay\n", - " 128\n", - " 498.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", + " 2048\n", + " 40.0\n", " 500.0\n", " \n", " \n", - " 2118\n", + " 24122\n", " DoubleDQN\n", - " 1.0\n", + " 32.0\n", " 64.0\n", " 0.99\n", - " EpsilonGreedy-0.6\n", + " AdaptativeEpsilonGreedy-0.3-0.1-30000-0\n", " SimpleNetwork\n", " \n", " Adam\n", - " 0.0001\n", + " 0.100\n", " ExperienceReplay\n", - " 128\n", - " 498.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", + " 2048\n", + " 40.0\n", " 500.0\n", " \n", " \n", - " 3268\n", + " 27997\n", " DQN\n", " 1.0\n", - " 1.0\n", - " 0.99\n", - " AdaptativeEpsilonGreedy-0.3-0.1-50000-0\n", - " SimpleNetwork\n", + " 32.0\n", + " 0.95\n", + " EpsilonGreedy-0.6\n", + " SimpleDuelingNetwork\n", " \n", " Adam\n", - " 0.0010\n", + " 0.001\n", " ExperienceReplay\n", - " 32\n", - " 498.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", + " 512\n", + " 40.0\n", " 500.0\n", " \n", " \n", - " 3538\n", + " 29082\n", " DQN\n", " 1.0\n", " 32.0\n", " 0.99\n", " EpsilonGreedy-0.1\n", - " SimpleNetwork\n", + " SimpleDuelingNetwork\n", " \n", " Adam\n", - " 0.0010\n", + " 0.001\n", " ExperienceReplay\n", - " 32\n", - " 498.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", + " 2048\n", + " 40.0\n", " 500.0\n", " \n", " \n", - " 3573\n", + " 30601\n", " DQN\n", " 1.0\n", " 32.0\n", - " 0.99\n", - " EpsilonGreedy-0.6\n", - " SimpleNetwork\n", - " \n", - " Adam\n", - " 0.0010\n", - " ExperienceReplay\n", - " 128\n", - " 498.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 500.0\n", - " \n", - " \n", - " 3718\n", - " DQN\n", - " 1.0\n", - " 64.0\n", - " 0.99\n", + " 1.00\n", " EpsilonGreedy-0.1\n", - " SimpleNetwork\n", + " SimpleDuelingNetwork\n", " \n", " Adam\n", - " 0.0010\n", + " 0.001\n", " ExperienceReplay\n", - " 32\n", - " 498.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", + " 512\n", + " 40.0\n", " 500.0\n", " \n", - " \n", - " 3753\n", - " DQN\n", - " 1.0\n", - " 64.0\n", - " 0.99\n", - " EpsilonGreedy-0.6\n", - " SimpleNetwork\n", - " \n", - " Adam\n", - " 0.0010\n", - " ExperienceReplay\n", - " 128\n", - " 498.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 500.0\n", + " \n", + "\n", + "" + ], + "text/plain": [ + " algo step_train batch_size gamma \\\n", + "12340 DoubleDQN 1.0 32.0 1.00 \n", + "12930 DoubleDQN 1.0 32.0 1.00 \n", + "31127 DQN 1.0 32.0 1.00 \n", + "31716 DQN 1.0 64.0 0.95 \n", + "34103 DQN 1.0 64.0 0.99 \n", + "11412 DoubleDQN 1.0 32.0 0.99 \n", + "24122 DoubleDQN 32.0 64.0 0.99 \n", + "27997 DQN 1.0 32.0 0.95 \n", + "29082 DQN 1.0 32.0 0.99 \n", + "30601 DQN 1.0 32.0 1.00 \n", + "\n", + " greedy_exploration network \\\n", + "12340 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork \n", + "12930 EpsilonGreedy-0.1 SimpleNetwork \n", + "31127 EpsilonGreedy-0.6 SimpleNetwork \n", + "31716 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork \n", + "34103 EpsilonGreedy-0.6 SimpleNetwork \n", + "11412 EpsilonGreedy-0.1 SimpleNetwork \n", + "24122 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleNetwork \n", + "27997 EpsilonGreedy-0.6 SimpleDuelingNetwork \n", + "29082 EpsilonGreedy-0.1 SimpleDuelingNetwork \n", + "30601 EpsilonGreedy-0.1 SimpleDuelingNetwork \n", + "\n", + " optimizer lr memories max_size step sum \n", + "12340 Adam 0.001 ExperienceReplay 2048 20.0 500.0 \n", + "12930 Adam 0.001 ExperienceReplay 512 30.0 500.0 \n", + "31127 Adam 0.001 ExperienceReplay 2048 30.0 500.0 \n", + "31716 Adam 0.001 ExperienceReplay 512 30.0 500.0 \n", + "34103 Adam 0.001 ExperienceReplay 2048 30.0 500.0 \n", + "11412 Adam 0.001 ExperienceReplay 2048 40.0 500.0 \n", + "24122 Adam 0.100 ExperienceReplay 2048 40.0 500.0 \n", + "27997 Adam 0.001 ExperienceReplay 512 40.0 500.0 \n", + "29082 Adam 0.001 ExperienceReplay 2048 40.0 500.0 \n", + "30601 Adam 0.001 ExperienceReplay 512 40.0 500.0 " + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df.sort_values(by =[\"sum\",\"step\"], ascending = [False, True]).head(10)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Correlation matrix" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "df_corr = df.copy()" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "for c in df_corr.columns:\n", + " try:\n", + " df_corr[c] = df_corr[c].cat.codes\n", + " except:\n", + " pass" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "algo 0.061506\n", + "step_train -0.314323\n", + "batch_size 0.044669\n", + "gamma -0.005792\n", + "greedy_exploration 0.017274\n", + "network 0.076661\n", + " NaN\n", + "optimizer NaN\n", + "lr -0.166335\n", + "memories NaN\n", + "max_size -0.018768\n", + "step 0.161043\n", + "sum 1.000000\n", + "Name: sum, dtype: float64" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_corr.corr()[\"sum\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots(figsize=(10,10)) \n", + "sns.heatmap(df.corr()[abs(df.corr()) > 0.05], annot = True, fmt='.2g',cmap= 'coolwarm', ax=ax)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Correlation matrix for best result" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "df_corr_best = df_corr[df_corr[\"sum\"] >= 300]" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots(figsize=(10,10)) \n", + "sns.heatmap(df_corr_best.corr()[abs(df_corr_best.corr()) > 0.05], annot = True, fmt='.2g',cmap= 'coolwarm', ax=ax)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Result by algo" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### DQN" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### SimpleNetwork" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "df_DQN = df[df[\"algo\"] == \"DQN\"].copy()\n", + "df_DQN = df_DQN[df_DQN[\"network\"] == \"SimpleNetwork\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", + " \n", + " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", + " \n", " \n", + " \n", + " \n", + " \n", + " \n", " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -1032,1517 +1100,714 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", + " \n", " \n", + " \n", + " \n", + " \n", + " \n", " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", + " \n", " \n", + " \n", + " \n", + " \n", + " \n", " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", + " \n", " \n", + " \n", + " \n", + " \n", + " \n", " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", + " \n", " \n", - " \n", - " \n", - " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", + " \n", " \n", - " \n", - " \n", - " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", + " \n", " \n", - " \n", - " \n", - " \n", - " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", " \n", + " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", - " \n", + " \n", + " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
algostep_trainbatch_sizegammagreedy_explorationnetworkoptimizerlrmemoriesmax_sizestepsum
452331127DQN4.01.032.00.99AdaptativeEpsilonGreedy-0.3-0.1-50000-01.00EpsilonGreedy-0.6SimpleNetworkAdam0.00100.001ExperienceReplay16498.01.01.01.0204830.0500.0
466834103DQN4.032.01.064.00.99EpsilonGreedy-0.6SimpleNetworkAdam0.10000.001ExperienceReplay128498.01.01.01.0204830.0500.0
470333050DQN4.01.064.00.99AdaptativeEpsilonGreedy-0.3-0.1-50000-0AdaptativeEpsilonGreedy-0.3-0.1-30000-0SimpleNetworkAdam0.00100.100ExperienceReplay16498.01.01.01.0204840.0500.0
475332648DQN4.01.064.00.99AdaptativeEpsilonGreedy-0.8-0.2-50000-00.95EpsilonGreedy-0.6SimpleNetworkAdam0.00100.001ExperienceReplay32498.01.01.01.051250.0500.0
479334508DQN4.01.064.00.99EpsilonGreedy-0.11.00AdaptativeEpsilonGreedy-0.3-0.1-30000-0SimpleNetworkAdam0.00100.001ExperienceReplay16498.01.01.01.051250.0500.0
479828898DQN4.064.01.032.00.99EpsilonGreedy-0.1AdaptativeEpsilonGreedy-0.8-0.2-10000-0SimpleNetworkAdam0.00100.001ExperienceReplay32498.01.01.01.0204860.0500.0
6173DuelingDQN4.032.033052DQN1.064.00.99AdaptativeEpsilonGreedy-0.8-0.2-50000-0SimpleDuelingNetworkAdaptativeEpsilonGreedy-0.3-0.1-30000-0SimpleNetworkAdam0.00010.100ExperienceReplay16498.01.01.01.0204860.0500.0
6448DuelingDQN4.034943DQN1.064.00.99EpsilonGreedy-0.6SimpleDuelingNetwork1.00AdaptativeEpsilonGreedy-0.8-0.2-10000-0SimpleNetworkAdam0.00010.100ExperienceReplay32498.01.01.01.051260.0500.0
326435687DQN1.01.00.99AdaptativeEpsilonGreedy-0.3-0.1-50000-064.01.00EpsilonGreedy-0.6SimpleNetworkAdam0.00100.100ExperienceReplay16500.01.01.01.051260.0500.0
339931906DQN1.01.00.99EpsilonGreedy-0.664.00.95AdaptativeEpsilonGreedy-0.8-0.2-10000-0SimpleNetworkAdam0.00100.001ExperienceReplay16500.01.01.01.051270.0500.0
342932991DQN1.032.064.00.99AdaptativeEpsilonGreedy-0.3-0.1-50000-0AdaptativeEpsilonGreedy-0.3-0.1-30000-0SimpleNetworkAdam0.00010.001ExperienceReplay16204870.0500.0
34479DQN1.01.01.064.01.00AdaptativeEpsilonGreedy-0.3-0.1-30000-0SimpleNetworkAdam0.001ExperienceReplay204870.0500.0
353928559DQN1.032.00.99EpsilonGreedy-0.1AdaptativeEpsilonGreedy-0.3-0.1-30000-0SimpleNetworkAdam0.00100.001ExperienceReplay32500.01.01.01.051280.0500.0
357428993DQN1.032.00.99EpsilonGreedy-0.6AdaptativeEpsilonGreedy-0.8-0.2-10000-0SimpleNetworkAdam0.00100.100ExperienceReplay128500.01.01.01.051280.0500.0
362431163DQN1.064.00.99AdaptativeEpsilonGreedy-0.3-0.1-50000-032.01.00EpsilonGreedy-0.6SimpleNetworkAdam0.00100.001ExperienceReplay16500.01.01.01.051280.0500.0
366931225DQN1.064.00.99AdaptativeEpsilonGreedy-0.8-0.2-50000-032.01.00EpsilonGreedy-0.6SimpleNetworkAdam0.00100.100ExperienceReplay16500.01.01.01.051280.0500.0
371933798DQN1.064.0SimpleNetworkAdam0.00100.100ExperienceReplay32500.01.01.01.0204880.0500.0
375435255DQN1.064.00.99EpsilonGreedy-0.61.00EpsilonGreedy-0.1SimpleNetworkAdam0.00100.001ExperienceReplay128500.01.01.01.051280.0500.0
452427072DQN4.01.032.00.99AdaptativeEpsilonGreedy-0.3-0.1-50000-00.95AdaptativeEpsilonGreedy-0.3-0.1-30000-0SimpleNetworkAdam0.00100.001ExperienceReplay16500.01.01.01.051290.0500.0
461927785DQN4.01.032.00.990.95EpsilonGreedy-0.1SimpleNetworkAdam0.00100.001ExperienceReplay32500.01.01.01.0204890.0500.0
465928188DQN4.01.032.00.990.95EpsilonGreedy-0.6SimpleNetworkAdam0.00100.001ExperienceReplay16500.01.01.01.051290.0500.0
466931164DQN4.01.032.00.991.00EpsilonGreedy-0.6SimpleNetworkAdam0.10000.001ExperienceReplay12851290.0500.0
32714DQN1.01.01.064.00.95EpsilonGreedy-0.6SimpleNetworkAdam0.100ExperienceReplay51290.0500.0
470432993DQN4.01.064.00.99AdaptativeEpsilonGreedy-0.3-0.1-50000-0AdaptativeEpsilonGreedy-0.3-0.1-30000-0SimpleNetworkAdam0.00100.001ExperienceReplay16500.01.01.01.0204890.0500.0
479433055DQN4.01.064.00.99EpsilonGreedy-0.1AdaptativeEpsilonGreedy-0.3-0.1-30000-0SimpleNetworkAdam0.00100.100ExperienceReplay16500.01.01.01.0204890.0500.0
479933768DQN4.01.064.00.99EpsilonGreedy-0.1SimpleNetworkAdam0.00100.001ExperienceReplay3251290.0500.0
28561DQN1.032.00.99AdaptativeEpsilonGreedy-0.3-0.1-30000-0SimpleNetworkAdam0.001ExperienceReplay512100.0500.0
30390DQN1.032.01.00AdaptativeEpsilonGreedy-0.8-0.2-10000-0SimpleNetworkAdam0.001ExperienceReplay2048100.0500.0
31134DQN1.032.01.00EpsilonGreedy-0.6SimpleNetworkAdam0.001ExperienceReplay2048100.0500.0
484433025DQN4.01.064.00.99EpsilonGreedy-0.6AdaptativeEpsilonGreedy-0.3-0.1-30000-0SimpleNetworkAdam0.00100.001ExperienceReplay32512100.0500.0
33397DQN1.01.01.064.00.99AdaptativeEpsilonGreedy-0.8-0.2-10000-0SimpleNetworkAdam0.001ExperienceReplay512100.0500.0
375833428DQN1.064.00.99EpsilonGreedy-0.6AdaptativeEpsilonGreedy-0.8-0.2-10000-0SimpleNetworkAdam0.00100.100ExperienceReplay16498.01.01.01.0486.02048100.0500.0
456427074DQN4.01.032.00.99AdaptativeEpsilonGreedy-0.8-0.2-50000-00.95AdaptativeEpsilonGreedy-0.3-0.1-30000-0SimpleNetworkAdam0.00100.001ExperienceReplay128512110.0500.01.01.01.0485.0
348428562DQN1.032.00.99AdaptativeEpsilonGreedy-0.8-0.2-50000-0AdaptativeEpsilonGreedy-0.3-0.1-30000-0SimpleNetworkAdam0.00100.001ExperienceReplay128512110.0500.01.01.01.0483.0
361933026DQN1.064.00.99AdaptativeEpsilonGreedy-0.3-0.1-50000-0AdaptativeEpsilonGreedy-0.3-0.1-30000-0SimpleNetworkAdam0.00100.001ExperienceReplay128512110.0500.01.01.01.0467.0
361733367DQN1.064.00.99AdaptativeEpsilonGreedy-0.3-0.1-50000-0AdaptativeEpsilonGreedy-0.8-0.2-10000-0SimpleNetworkAdam0.00100.001ExperienceReplay128332.01.01.01.0465.02048110.0500.0
361834514DQN1.064.00.99AdaptativeEpsilonGreedy-0.3-0.1-50000-01.00AdaptativeEpsilonGreedy-0.3-0.1-30000-0SimpleNetworkAdam0.00100.001ExperienceReplay128498.01.01.0512110.0500.0
35599DQN1.0450.064.01.00EpsilonGreedy-0.6SimpleNetworkAdam0.001ExperienceReplay2048110.0500.0
5087DuelingDQN27075DQN1.032.00.99AdaptativeEpsilonGreedy-0.8-0.2-50000-0SimpleDuelingNetwork0.95AdaptativeEpsilonGreedy-0.3-0.1-30000-0SimpleNetworkAdam0.00010.001ExperienceReplay128332.01.01.01.0421.0512120.0500.0
567CategoricalDQN32.028563DQN1.032.00.99AdaptativeEpsilonGreedy-0.3-0.1-50000-0C51NetworkAdaptativeEpsilonGreedy-0.3-0.1-30000-0SimpleNetworkAdam0.00100.001ExperienceReplay32332.01.01.01.0405.0512120.0500.0
355930020DQN1.032.00.99EpsilonGreedy-0.61.00AdaptativeEpsilonGreedy-0.3-0.1-30000-0SimpleNetworkAdam0.00010.001ExperienceReplay1282048120.0500.01.01.01.0398.0
\n", - "
" - ], - "text/plain": [ - " algo step_train batch_size gamma \\\n", - "3436 DQN 1.0 32.0 0.99 \n", - "3711 DQN 1.0 64.0 0.99 \n", - "4176 DQN 32.0 64.0 0.99 \n", - "2117 DoubleDQN 1.0 64.0 0.99 \n", - "3442 DQN 1.0 32.0 0.99 \n", - "3537 DQN 1.0 32.0 0.99 \n", - "3612 DQN 1.0 64.0 0.99 \n", - "3717 DQN 1.0 64.0 0.99 \n", - "4797 DQN 4.0 64.0 0.99 \n", - "1953 DoubleDQN 1.0 32.0 0.99 \n", - "2118 DoubleDQN 1.0 64.0 0.99 \n", - "3268 DQN 1.0 1.0 0.99 \n", - "3538 DQN 1.0 32.0 0.99 \n", - "3573 DQN 1.0 32.0 0.99 \n", - "3718 DQN 1.0 64.0 0.99 \n", - "3753 DQN 1.0 64.0 0.99 \n", - "4523 DQN 4.0 32.0 0.99 \n", - "4668 DQN 4.0 32.0 0.99 \n", - "4703 DQN 4.0 64.0 0.99 \n", - "4753 DQN 4.0 64.0 0.99 \n", - "4793 DQN 4.0 64.0 0.99 \n", - "4798 DQN 4.0 64.0 0.99 \n", - "6173 DuelingDQN 4.0 32.0 0.99 \n", - "6448 DuelingDQN 4.0 64.0 0.99 \n", - "3264 DQN 1.0 1.0 0.99 \n", - "3399 DQN 1.0 1.0 0.99 \n", - "3429 DQN 1.0 32.0 0.99 \n", - "3539 DQN 1.0 32.0 0.99 \n", - "3574 DQN 1.0 32.0 0.99 \n", - "3624 DQN 1.0 64.0 0.99 \n", - "3669 DQN 1.0 64.0 0.99 \n", - "3719 DQN 1.0 64.0 0.99 \n", - "3754 DQN 1.0 64.0 0.99 \n", - "4524 DQN 4.0 32.0 0.99 \n", - "4619 DQN 4.0 32.0 0.99 \n", - "4659 DQN 4.0 32.0 0.99 \n", - "4669 DQN 4.0 32.0 0.99 \n", - "4704 DQN 4.0 64.0 0.99 \n", - "4794 DQN 4.0 64.0 0.99 \n", - "4799 DQN 4.0 64.0 0.99 \n", - "4844 DQN 4.0 64.0 0.99 \n", - "3758 DQN 1.0 64.0 0.99 \n", - "4564 DQN 4.0 32.0 0.99 \n", - "3484 DQN 1.0 32.0 0.99 \n", - "3619 DQN 1.0 64.0 0.99 \n", - "3617 DQN 1.0 64.0 0.99 \n", - "3618 DQN 1.0 64.0 0.99 \n", - "5087 DuelingDQN 1.0 32.0 0.99 \n", - "567 CategoricalDQN 32.0 1.0 0.99 \n", - "3559 DQN 1.0 32.0 0.99 \n", - "\n", - " greedy_exploration network \\\n", - "3436 AdaptativeEpsilonGreedy-0.3-0.1-50000-0 SimpleNetwork \n", - "3711 EpsilonGreedy-0.1 SimpleNetwork \n", - "4176 AdaptativeEpsilonGreedy-0.3-0.1-50000-0 SimpleNetwork \n", - "2117 EpsilonGreedy-0.6 SimpleNetwork \n", - "3442 AdaptativeEpsilonGreedy-0.3-0.1-50000-0 SimpleNetwork \n", - "3537 EpsilonGreedy-0.1 SimpleNetwork \n", - "3612 AdaptativeEpsilonGreedy-0.3-0.1-50000-0 SimpleNetwork \n", - "3717 EpsilonGreedy-0.1 SimpleNetwork \n", - "4797 EpsilonGreedy-0.1 SimpleNetwork \n", - "1953 EpsilonGreedy-0.6 SimpleNetwork \n", - "2118 EpsilonGreedy-0.6 SimpleNetwork \n", - "3268 AdaptativeEpsilonGreedy-0.3-0.1-50000-0 SimpleNetwork \n", - "3538 EpsilonGreedy-0.1 SimpleNetwork \n", - "3573 EpsilonGreedy-0.6 SimpleNetwork \n", - "3718 EpsilonGreedy-0.1 SimpleNetwork \n", - "3753 EpsilonGreedy-0.6 SimpleNetwork \n", - "4523 AdaptativeEpsilonGreedy-0.3-0.1-50000-0 SimpleNetwork \n", - "4668 EpsilonGreedy-0.6 SimpleNetwork \n", - "4703 AdaptativeEpsilonGreedy-0.3-0.1-50000-0 SimpleNetwork \n", - "4753 AdaptativeEpsilonGreedy-0.8-0.2-50000-0 SimpleNetwork \n", - "4793 EpsilonGreedy-0.1 SimpleNetwork \n", - "4798 EpsilonGreedy-0.1 SimpleNetwork \n", - "6173 AdaptativeEpsilonGreedy-0.8-0.2-50000-0 SimpleDuelingNetwork \n", - "6448 EpsilonGreedy-0.6 SimpleDuelingNetwork \n", - "3264 AdaptativeEpsilonGreedy-0.3-0.1-50000-0 SimpleNetwork \n", - "3399 EpsilonGreedy-0.6 SimpleNetwork \n", - "3429 AdaptativeEpsilonGreedy-0.3-0.1-50000-0 SimpleNetwork \n", - "3539 EpsilonGreedy-0.1 SimpleNetwork \n", - "3574 EpsilonGreedy-0.6 SimpleNetwork \n", - "3624 AdaptativeEpsilonGreedy-0.3-0.1-50000-0 SimpleNetwork \n", - "3669 AdaptativeEpsilonGreedy-0.8-0.2-50000-0 SimpleNetwork \n", - "3719 EpsilonGreedy-0.1 SimpleNetwork \n", - "3754 EpsilonGreedy-0.6 SimpleNetwork \n", - "4524 AdaptativeEpsilonGreedy-0.3-0.1-50000-0 SimpleNetwork \n", - "4619 EpsilonGreedy-0.1 SimpleNetwork \n", - "4659 EpsilonGreedy-0.6 SimpleNetwork \n", - "4669 EpsilonGreedy-0.6 SimpleNetwork \n", - "4704 AdaptativeEpsilonGreedy-0.3-0.1-50000-0 SimpleNetwork \n", - "4794 EpsilonGreedy-0.1 SimpleNetwork \n", - "4799 EpsilonGreedy-0.1 SimpleNetwork \n", - "4844 EpsilonGreedy-0.6 SimpleNetwork \n", - "3758 EpsilonGreedy-0.6 SimpleNetwork \n", - "4564 AdaptativeEpsilonGreedy-0.8-0.2-50000-0 SimpleNetwork \n", - "3484 AdaptativeEpsilonGreedy-0.8-0.2-50000-0 SimpleNetwork \n", - "3619 AdaptativeEpsilonGreedy-0.3-0.1-50000-0 SimpleNetwork \n", - "3617 AdaptativeEpsilonGreedy-0.3-0.1-50000-0 SimpleNetwork \n", - "3618 AdaptativeEpsilonGreedy-0.3-0.1-50000-0 SimpleNetwork \n", - "5087 AdaptativeEpsilonGreedy-0.8-0.2-50000-0 SimpleDuelingNetwork \n", - "567 AdaptativeEpsilonGreedy-0.3-0.1-50000-0 C51Network \n", - "3559 EpsilonGreedy-0.6 SimpleNetwork \n", - "\n", - " optimizer lr memories max_size step max min avg sum \n", - "3436 Adam 0.0010 ExperienceReplay 128 166.0 1.0 1.0 1.0 500.0 \n", - "3711 Adam 0.0010 ExperienceReplay 16 166.0 1.0 1.0 1.0 500.0 \n", - "4176 Adam 0.1000 ExperienceReplay 16 166.0 1.0 1.0 1.0 500.0 \n", - "2117 Adam 0.0001 ExperienceReplay 128 332.0 1.0 1.0 1.0 500.0 \n", - "3442 Adam 0.0010 ExperienceReplay 16 332.0 1.0 1.0 1.0 500.0 \n", - "3537 Adam 0.0010 ExperienceReplay 32 332.0 1.0 1.0 1.0 500.0 \n", - "3612 Adam 0.0001 ExperienceReplay 32 332.0 1.0 1.0 1.0 500.0 \n", - "3717 Adam 0.0010 ExperienceReplay 32 332.0 1.0 1.0 1.0 500.0 \n", - "4797 Adam 0.0010 ExperienceReplay 32 332.0 1.0 1.0 1.0 500.0 \n", - "1953 Adam 0.0010 ExperienceReplay 128 498.0 1.0 1.0 1.0 500.0 \n", - "2118 Adam 0.0001 ExperienceReplay 128 498.0 1.0 1.0 1.0 500.0 \n", - "3268 Adam 0.0010 ExperienceReplay 32 498.0 1.0 1.0 1.0 500.0 \n", - "3538 Adam 0.0010 ExperienceReplay 32 498.0 1.0 1.0 1.0 500.0 \n", - "3573 Adam 0.0010 ExperienceReplay 128 498.0 1.0 1.0 1.0 500.0 \n", - "3718 Adam 0.0010 ExperienceReplay 32 498.0 1.0 1.0 1.0 500.0 \n", - "3753 Adam 0.0010 ExperienceReplay 128 498.0 1.0 1.0 1.0 500.0 \n", - "4523 Adam 0.0010 ExperienceReplay 16 498.0 1.0 1.0 1.0 500.0 \n", - "4668 Adam 0.1000 ExperienceReplay 128 498.0 1.0 1.0 1.0 500.0 \n", - "4703 Adam 0.0010 ExperienceReplay 16 498.0 1.0 1.0 1.0 500.0 \n", - "4753 Adam 0.0010 ExperienceReplay 32 498.0 1.0 1.0 1.0 500.0 \n", - "4793 Adam 0.0010 ExperienceReplay 16 498.0 1.0 1.0 1.0 500.0 \n", - "4798 Adam 0.0010 ExperienceReplay 32 498.0 1.0 1.0 1.0 500.0 \n", - "6173 Adam 0.0001 ExperienceReplay 16 498.0 1.0 1.0 1.0 500.0 \n", - "6448 Adam 0.0001 ExperienceReplay 32 498.0 1.0 1.0 1.0 500.0 \n", - "3264 Adam 0.0010 ExperienceReplay 16 500.0 1.0 1.0 1.0 500.0 \n", - "3399 Adam 0.0010 ExperienceReplay 16 500.0 1.0 1.0 1.0 500.0 \n", - "3429 Adam 0.0001 ExperienceReplay 16 500.0 1.0 1.0 1.0 500.0 \n", - "3539 Adam 0.0010 ExperienceReplay 32 500.0 1.0 1.0 1.0 500.0 \n", - "3574 Adam 0.0010 ExperienceReplay 128 500.0 1.0 1.0 1.0 500.0 \n", - "3624 Adam 0.0010 ExperienceReplay 16 500.0 1.0 1.0 1.0 500.0 \n", - "3669 Adam 0.0010 ExperienceReplay 16 500.0 1.0 1.0 1.0 500.0 \n", - "3719 Adam 0.0010 ExperienceReplay 32 500.0 1.0 1.0 1.0 500.0 \n", - "3754 Adam 0.0010 ExperienceReplay 128 500.0 1.0 1.0 1.0 500.0 \n", - "4524 Adam 0.0010 ExperienceReplay 16 500.0 1.0 1.0 1.0 500.0 \n", - "4619 Adam 0.0010 ExperienceReplay 32 500.0 1.0 1.0 1.0 500.0 \n", - "4659 Adam 0.0010 ExperienceReplay 16 500.0 1.0 1.0 1.0 500.0 \n", - "4669 Adam 0.1000 ExperienceReplay 128 500.0 1.0 1.0 1.0 500.0 \n", - "4704 Adam 0.0010 ExperienceReplay 16 500.0 1.0 1.0 1.0 500.0 \n", - "4794 Adam 0.0010 ExperienceReplay 16 500.0 1.0 1.0 1.0 500.0 \n", - "4799 Adam 0.0010 ExperienceReplay 32 500.0 1.0 1.0 1.0 500.0 \n", - "4844 Adam 0.0010 ExperienceReplay 32 500.0 1.0 1.0 1.0 500.0 \n", - "3758 Adam 0.0010 ExperienceReplay 16 498.0 1.0 1.0 1.0 486.0 \n", - "4564 Adam 0.0010 ExperienceReplay 128 500.0 1.0 1.0 1.0 485.0 \n", - "3484 Adam 0.0010 ExperienceReplay 128 500.0 1.0 1.0 1.0 483.0 \n", - "3619 Adam 0.0010 ExperienceReplay 128 500.0 1.0 1.0 1.0 467.0 \n", - "3617 Adam 0.0010 ExperienceReplay 128 332.0 1.0 1.0 1.0 465.0 \n", - "3618 Adam 0.0010 ExperienceReplay 128 498.0 1.0 1.0 1.0 450.0 \n", - "5087 Adam 0.0001 ExperienceReplay 128 332.0 1.0 1.0 1.0 421.0 \n", - "567 Adam 0.0010 ExperienceReplay 32 332.0 1.0 1.0 1.0 405.0 \n", - "3559 Adam 0.0001 ExperienceReplay 128 500.0 1.0 1.0 1.0 398.0 " - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "df.sort_values(by =[\"sum\",\"step\"], ascending = [False, True]).head(50)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Correlation matrix" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [], - "source": [ - "df_corr = df.copy()" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [], - "source": [ - "for c in df_corr.columns:\n", - " try:\n", - " df_corr[c] = df_corr[c].cat.codes\n", - " except:\n", - " pass" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "algo 7.373685e-02\n", - "step_train -1.697983e-01\n", - "batch_size 1.182673e-01\n", - "gamma -5.161482e-15\n", - "greedy_exploration 2.540062e-02\n", - "network 1.525560e-01\n", - " NaN\n", - "optimizer NaN\n", - "lr -1.308714e-01\n", - "memories NaN\n", - "max_size -1.339293e-02\n", - "step 1.712250e-01\n", - "max NaN\n", - "min NaN\n", - "avg NaN\n", - "sum 1.000000e+00\n", - "Name: sum, dtype: float64" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "df_corr.corr()[\"sum\"]" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
algostep_trainbatch_sizegammagreedy_explorationnetworkoptimizerlrmemoriesmax_sizestepmaxminavgsum
algo1.000000e+00-8.992421e-180.000000e+000.000000e+000.000000e+004.045199e-01NaNNaN0.000000e+00NaN0.000000e+008.101552e-20NaNNaNNaN7.373685e-02
step_train-8.992421e-181.000000e+001.563538e-18-3.133980e-178.992421e-181.524091e-15NaNNaN2.026146e-17NaN-4.809914e-201.297656e-20NaNNaNNaN-1.697983e-01
batch_size0.000000e+001.563538e-181.000000e+000.000000e+000.000000e+001.533537e-15NaNNaN-1.642472e-16NaN-5.221288e-200.000000e+00NaNNaNNaN1.182673e-01
gamma0.000000e+00-3.133980e-170.000000e+001.000000e+000.000000e+000.000000e+00NaNNaN-1.533798e-14NaN0.000000e+009.148140e-17NaNNaNNaN-5.161482e-15
greedy_exploration0.000000e+008.992421e-180.000000e+000.000000e+001.000000e+000.000000e+00NaNNaN0.000000e+00NaN0.000000e+008.101552e-20NaNNaNNaN2.540062e-02
network4.045199e-011.524091e-151.533537e-150.000000e+000.000000e+001.000000e+00NaNNaN-1.978668e-16NaN0.000000e+002.660026e-17NaNNaNNaN1.525560e-01
NaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
optimizerNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
lr0.000000e+002.026146e-17-1.642472e-16-1.533798e-140.000000e+00-1.978668e-16NaNNaN1.000000e+00NaN1.118946e-193.477634e-17NaNNaNNaN-1.308714e-01
memoriesNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
max_size0.000000e+00-4.809914e-20-5.221288e-200.000000e+000.000000e+000.000000e+00NaNNaN1.118946e-19NaN1.000000e+000.000000e+00NaNNaNNaN-1.339293e-02
step8.101552e-201.297656e-200.000000e+009.148140e-178.101552e-202.660026e-17NaNNaN3.477634e-17NaN0.000000e+001.000000e+00NaNNaNNaN1.712250e-01
maxNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
minNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
avgNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
sum7.373685e-02-1.697983e-011.182673e-01-5.161482e-152.540062e-021.525560e-01NaNNaN-1.308714e-01NaN-1.339293e-021.712250e-01NaNNaNNaN1.000000e+00
\n", - "
" - ], - "text/plain": [ - " algo step_train batch_size gamma \\\n", - "algo 1.000000e+00 -8.992421e-18 0.000000e+00 0.000000e+00 \n", - "step_train -8.992421e-18 1.000000e+00 1.563538e-18 -3.133980e-17 \n", - "batch_size 0.000000e+00 1.563538e-18 1.000000e+00 0.000000e+00 \n", - "gamma 0.000000e+00 -3.133980e-17 0.000000e+00 1.000000e+00 \n", - "greedy_exploration 0.000000e+00 8.992421e-18 0.000000e+00 0.000000e+00 \n", - "network 4.045199e-01 1.524091e-15 1.533537e-15 0.000000e+00 \n", - " NaN NaN NaN NaN \n", - "optimizer NaN NaN NaN NaN \n", - "lr 0.000000e+00 2.026146e-17 -1.642472e-16 -1.533798e-14 \n", - "memories NaN NaN NaN NaN \n", - "max_size 0.000000e+00 -4.809914e-20 -5.221288e-20 0.000000e+00 \n", - "step 8.101552e-20 1.297656e-20 0.000000e+00 9.148140e-17 \n", - "max NaN NaN NaN NaN \n", - "min NaN NaN NaN NaN \n", - "avg NaN NaN NaN NaN \n", - "sum 7.373685e-02 -1.697983e-01 1.182673e-01 -5.161482e-15 \n", - "\n", - " greedy_exploration network optimizer \\\n", - "algo 0.000000e+00 4.045199e-01 NaN NaN \n", - "step_train 8.992421e-18 1.524091e-15 NaN NaN \n", - "batch_size 0.000000e+00 1.533537e-15 NaN NaN \n", - "gamma 0.000000e+00 0.000000e+00 NaN NaN \n", - "greedy_exploration 1.000000e+00 0.000000e+00 NaN NaN \n", - "network 0.000000e+00 1.000000e+00 NaN NaN \n", - " NaN NaN NaN NaN \n", - "optimizer NaN NaN NaN NaN \n", - "lr 0.000000e+00 -1.978668e-16 NaN NaN \n", - "memories NaN NaN NaN NaN \n", - "max_size 0.000000e+00 0.000000e+00 NaN NaN \n", - "step 8.101552e-20 2.660026e-17 NaN NaN \n", - "max NaN NaN NaN NaN \n", - "min NaN NaN NaN NaN \n", - "avg NaN NaN NaN NaN \n", - "sum 2.540062e-02 1.525560e-01 NaN NaN \n", - "\n", - " lr memories max_size step max \\\n", - "algo 0.000000e+00 NaN 0.000000e+00 8.101552e-20 NaN \n", - "step_train 2.026146e-17 NaN -4.809914e-20 1.297656e-20 NaN \n", - "batch_size -1.642472e-16 NaN -5.221288e-20 0.000000e+00 NaN \n", - "gamma -1.533798e-14 NaN 0.000000e+00 9.148140e-17 NaN \n", - "greedy_exploration 0.000000e+00 NaN 0.000000e+00 8.101552e-20 NaN \n", - "network -1.978668e-16 NaN 0.000000e+00 2.660026e-17 NaN \n", - " NaN NaN NaN NaN NaN \n", - "optimizer NaN NaN NaN NaN NaN \n", - "lr 1.000000e+00 NaN 1.118946e-19 3.477634e-17 NaN \n", - "memories NaN NaN NaN NaN NaN \n", - "max_size 1.118946e-19 NaN 1.000000e+00 0.000000e+00 NaN \n", - "step 3.477634e-17 NaN 0.000000e+00 1.000000e+00 NaN \n", - "max NaN NaN NaN NaN NaN \n", - "min NaN NaN NaN NaN NaN \n", - "avg NaN NaN NaN NaN NaN \n", - "sum -1.308714e-01 NaN -1.339293e-02 1.712250e-01 NaN \n", - "\n", - " min avg sum \n", - "algo NaN NaN 7.373685e-02 \n", - "step_train NaN NaN -1.697983e-01 \n", - "batch_size NaN NaN 1.182673e-01 \n", - "gamma NaN NaN -5.161482e-15 \n", - "greedy_exploration NaN NaN 2.540062e-02 \n", - "network NaN NaN 1.525560e-01 \n", - " NaN NaN NaN \n", - "optimizer NaN NaN NaN \n", - "lr NaN NaN -1.308714e-01 \n", - "memories NaN NaN NaN \n", - "max_size NaN NaN -1.339293e-02 \n", - "step NaN NaN 1.712250e-01 \n", - "max NaN NaN NaN \n", - "min NaN NaN NaN \n", - "avg NaN NaN NaN \n", - "sum NaN NaN 1.000000e+00 " - ] - }, - "execution_count": 13, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "df_corr.corr()" - ] - }, - { - "cell_type": "code", - "execution_count": 66, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 66, - "metadata": {}, - "output_type": "execute_result" - }, - { - "data": { - "image/png": "\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], - "source": [ - "fig, ax = plt.subplots(figsize=(10,10)) \n", - "sns.heatmap(df.corr()[df.corr() > 0.05], annot = True, fmt='.2g',cmap= 'coolwarm', ax=ax)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Correlation matrix for best result" - ] - }, - { - "cell_type": "code", - "execution_count": 56, - "metadata": {}, - "outputs": [], - "source": [ - "df_corr_best = df_corr[df_corr[\"sum\"] >= 300]" - ] - }, - { - "cell_type": "code", - "execution_count": 57, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", "
algostep_trainbatch_sizegammagreedy_explorationnetworkoptimizerlrmemoriesmax_sizestepmaxminavgsum
algo1.000000e+00-2.024974e-01-4.931761e-023.482840e-161.613751e-01-6.653163e-01NaNNaN-1.387198e-01NaN3.924623e-022.674115e-02NaNNaNNaN-3.890492e-01
step_train-2.024974e-011.000000e+004.067992e-024.391100e-17-2.457767e-01-2.033679e-01NaNNaN5.822535e-01NaN2.226417e-01-2.898155e-01NaNNaNNaN-2.090821e-02
batch_size-4.931761e-024.067992e-021.000000e+002.820643e-164.627404e-022.620456e-01NaNNaN3.432420e-02NaN4.325075e-027.685944e-02NaNNaNNaN1.187327e-01
gamma3.482840e-164.391100e-172.820643e-161.000000e+006.223740e-17-3.238646e-16NaNNaN1.441029e-16NaN1.306218e-16-1.823356e-16NaNNaNNaN-3.410748e-16
greedy_exploration1.613751e-01-2.457767e-014.627404e-026.223740e-171.000000e+00-7.484315e-03NaNNaN-2.102527e-02NaN-1.114625e-011.432139e-01NaNNaNNaN-2.831818e-02
network-6.653163e-01-2.033679e-012.620456e-01-3.238646e-16-7.484315e-031.000000e+00NaNNaN1.019551e-01NaN-1.256836e-019.331015e-02NaNNaNNaN2.632569e-01
NaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
optimizerNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaN30423DQN1.032.01.00AdaptativeEpsilonGreedy-0.8-0.2-10000-0SimpleNetworkAdam0.001ExperienceReplay512120.0500.0
lr-1.387198e-015.822535e-013.432420e-021.441029e-16-2.102527e-021.019551e-01NaNNaN1.000000e+00NaN-4.741960e-02-1.683165e-01NaNNaNNaN7.717727e-0231136DQN1.032.01.00EpsilonGreedy-0.6SimpleNetworkAdam0.001ExperienceReplay2048120.0500.0
memoriesNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaN31167DQN1.032.01.00EpsilonGreedy-0.6SimpleNetworkAdam0.001ExperienceReplay512120.0500.0
max_size3.924623e-022.226417e-014.325075e-021.306218e-16-1.114625e-01-1.256836e-01NaNNaN-4.741960e-02NaN1.000000e+00-8.632279e-02NaNNaNNaN5.969496e-0231911DQN1.064.00.95AdaptativeEpsilonGreedy-0.8-0.2-10000-0SimpleNetworkAdam0.001ExperienceReplay512120.0500.0
step2.674115e-02-2.898155e-017.685944e-02-1.823356e-161.432139e-019.331015e-02NaNNaN-1.683165e-01NaN-8.632279e-021.000000e+00NaNNaNNaN5.831934e-0234112DQN1.064.00.99EpsilonGreedy-0.6SimpleNetworkAdam0.001ExperienceReplay2048120.0500.0
maxNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaN34856DQN1.064.01.00AdaptativeEpsilonGreedy-0.8-0.2-10000-0SimpleNetworkAdam0.001ExperienceReplay2048120.0500.0
minNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaN35259DQN1.064.01.00EpsilonGreedy-0.1SimpleNetworkAdam0.001ExperienceReplay512120.0500.0
avgNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaN35693DQN1.064.01.00EpsilonGreedy-0.6SimpleNetworkAdam0.100ExperienceReplay512120.0500.0
sum-3.890492e-01-2.090821e-021.187327e-01-3.410748e-16-2.831818e-022.632569e-01NaNNaN7.717727e-02NaN5.969496e-025.831934e-02NaNNaNNaN1.000000e+0028192DQN1.032.00.95EpsilonGreedy-0.6SimpleNetworkAdam0.001ExperienceReplay512130.0500.0
\n", "
" ], "text/plain": [ - " algo step_train batch_size gamma \\\n", - "algo 1.000000e+00 -2.024974e-01 -4.931761e-02 3.482840e-16 \n", - "step_train -2.024974e-01 1.000000e+00 4.067992e-02 4.391100e-17 \n", - "batch_size -4.931761e-02 4.067992e-02 1.000000e+00 2.820643e-16 \n", - "gamma 3.482840e-16 4.391100e-17 2.820643e-16 1.000000e+00 \n", - "greedy_exploration 1.613751e-01 -2.457767e-01 4.627404e-02 6.223740e-17 \n", - "network -6.653163e-01 -2.033679e-01 2.620456e-01 -3.238646e-16 \n", - " NaN NaN NaN NaN \n", - "optimizer NaN NaN NaN NaN \n", - "lr -1.387198e-01 5.822535e-01 3.432420e-02 1.441029e-16 \n", - "memories NaN NaN NaN NaN \n", - "max_size 3.924623e-02 2.226417e-01 4.325075e-02 1.306218e-16 \n", - "step 2.674115e-02 -2.898155e-01 7.685944e-02 -1.823356e-16 \n", - "max NaN NaN NaN NaN \n", - "min NaN NaN NaN NaN \n", - "avg NaN NaN NaN NaN \n", - "sum -3.890492e-01 -2.090821e-02 1.187327e-01 -3.410748e-16 \n", - "\n", - " greedy_exploration network optimizer \\\n", - "algo 1.613751e-01 -6.653163e-01 NaN NaN \n", - "step_train -2.457767e-01 -2.033679e-01 NaN NaN \n", - "batch_size 4.627404e-02 2.620456e-01 NaN NaN \n", - "gamma 6.223740e-17 -3.238646e-16 NaN NaN \n", - "greedy_exploration 1.000000e+00 -7.484315e-03 NaN NaN \n", - "network -7.484315e-03 1.000000e+00 NaN NaN \n", - " NaN NaN NaN NaN \n", - "optimizer NaN NaN NaN NaN \n", - "lr -2.102527e-02 1.019551e-01 NaN NaN \n", - "memories NaN NaN NaN NaN \n", - "max_size -1.114625e-01 -1.256836e-01 NaN NaN \n", - "step 1.432139e-01 9.331015e-02 NaN NaN \n", - "max NaN NaN NaN NaN \n", - "min NaN NaN NaN NaN \n", - "avg NaN NaN NaN NaN \n", - "sum -2.831818e-02 2.632569e-01 NaN NaN \n", + " algo step_train batch_size gamma \\\n", + "31127 DQN 1.0 32.0 1.00 \n", + "34103 DQN 1.0 64.0 0.99 \n", + "33050 DQN 1.0 64.0 0.99 \n", + "32648 DQN 1.0 64.0 0.95 \n", + "34508 DQN 1.0 64.0 1.00 \n", + "28898 DQN 1.0 32.0 0.99 \n", + "33052 DQN 1.0 64.0 0.99 \n", + "34943 DQN 1.0 64.0 1.00 \n", + "35687 DQN 1.0 64.0 1.00 \n", + "31906 DQN 1.0 64.0 0.95 \n", + "32991 DQN 1.0 64.0 0.99 \n", + "34479 DQN 1.0 64.0 1.00 \n", + "28559 DQN 1.0 32.0 0.99 \n", + "28993 DQN 1.0 32.0 0.99 \n", + "31163 DQN 1.0 32.0 1.00 \n", + "31225 DQN 1.0 32.0 1.00 \n", + "33798 DQN 1.0 64.0 0.99 \n", + "35255 DQN 1.0 64.0 1.00 \n", + "27072 DQN 1.0 32.0 0.95 \n", + "27785 DQN 1.0 32.0 0.95 \n", + "28188 DQN 1.0 32.0 0.95 \n", + "31164 DQN 1.0 32.0 1.00 \n", + "32714 DQN 1.0 64.0 0.95 \n", + "32993 DQN 1.0 64.0 0.99 \n", + "33055 DQN 1.0 64.0 0.99 \n", + "33768 DQN 1.0 64.0 0.99 \n", + "28561 DQN 1.0 32.0 0.99 \n", + "30390 DQN 1.0 32.0 1.00 \n", + "31134 DQN 1.0 32.0 1.00 \n", + "33025 DQN 1.0 64.0 0.99 \n", + "33397 DQN 1.0 64.0 0.99 \n", + "33428 DQN 1.0 64.0 0.99 \n", + "27074 DQN 1.0 32.0 0.95 \n", + "28562 DQN 1.0 32.0 0.99 \n", + "33026 DQN 1.0 64.0 0.99 \n", + "33367 DQN 1.0 64.0 0.99 \n", + "34514 DQN 1.0 64.0 1.00 \n", + "35599 DQN 1.0 64.0 1.00 \n", + "27075 DQN 1.0 32.0 0.95 \n", + "28563 DQN 1.0 32.0 0.99 \n", + "30020 DQN 1.0 32.0 1.00 \n", + "30423 DQN 1.0 32.0 1.00 \n", + "31136 DQN 1.0 32.0 1.00 \n", + "31167 DQN 1.0 32.0 1.00 \n", + "31911 DQN 1.0 64.0 0.95 \n", + "34112 DQN 1.0 64.0 0.99 \n", + "34856 DQN 1.0 64.0 1.00 \n", + "35259 DQN 1.0 64.0 1.00 \n", + "35693 DQN 1.0 64.0 1.00 \n", + "28192 DQN 1.0 32.0 0.95 \n", "\n", - " lr memories max_size step max \\\n", - "algo -1.387198e-01 NaN 3.924623e-02 2.674115e-02 NaN \n", - "step_train 5.822535e-01 NaN 2.226417e-01 -2.898155e-01 NaN \n", - "batch_size 3.432420e-02 NaN 4.325075e-02 7.685944e-02 NaN \n", - "gamma 1.441029e-16 NaN 1.306218e-16 -1.823356e-16 NaN \n", - "greedy_exploration -2.102527e-02 NaN -1.114625e-01 1.432139e-01 NaN \n", - "network 1.019551e-01 NaN -1.256836e-01 9.331015e-02 NaN \n", - " NaN NaN NaN NaN NaN \n", - "optimizer NaN NaN NaN NaN NaN \n", - "lr 1.000000e+00 NaN -4.741960e-02 -1.683165e-01 NaN \n", - "memories NaN NaN NaN NaN NaN \n", - "max_size -4.741960e-02 NaN 1.000000e+00 -8.632279e-02 NaN \n", - "step -1.683165e-01 NaN -8.632279e-02 1.000000e+00 NaN \n", - "max NaN NaN NaN NaN NaN \n", - "min NaN NaN NaN NaN NaN \n", - "avg NaN NaN NaN NaN NaN \n", - "sum 7.717727e-02 NaN 5.969496e-02 5.831934e-02 NaN \n", + " greedy_exploration network optimizer \\\n", + "31127 EpsilonGreedy-0.6 SimpleNetwork Adam \n", + "34103 EpsilonGreedy-0.6 SimpleNetwork Adam \n", + "33050 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleNetwork Adam \n", + "32648 EpsilonGreedy-0.6 SimpleNetwork Adam \n", + "34508 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleNetwork Adam \n", + "28898 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleNetwork Adam \n", + "33052 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleNetwork Adam \n", + "34943 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleNetwork Adam \n", + "35687 EpsilonGreedy-0.6 SimpleNetwork Adam \n", + "31906 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleNetwork Adam \n", + "32991 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleNetwork Adam \n", + "34479 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleNetwork Adam \n", + "28559 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleNetwork Adam \n", + "28993 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleNetwork Adam \n", + "31163 EpsilonGreedy-0.6 SimpleNetwork Adam \n", + "31225 EpsilonGreedy-0.6 SimpleNetwork Adam \n", + "33798 EpsilonGreedy-0.1 SimpleNetwork Adam \n", + "35255 EpsilonGreedy-0.1 SimpleNetwork Adam \n", + "27072 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleNetwork Adam \n", + "27785 EpsilonGreedy-0.1 SimpleNetwork Adam \n", + "28188 EpsilonGreedy-0.6 SimpleNetwork Adam \n", + "31164 EpsilonGreedy-0.6 SimpleNetwork Adam \n", + "32714 EpsilonGreedy-0.6 SimpleNetwork Adam \n", + "32993 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleNetwork Adam \n", + "33055 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleNetwork Adam \n", + "33768 EpsilonGreedy-0.1 SimpleNetwork Adam \n", + "28561 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleNetwork Adam \n", + "30390 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleNetwork Adam \n", + "31134 EpsilonGreedy-0.6 SimpleNetwork Adam \n", + "33025 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleNetwork Adam \n", + "33397 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleNetwork Adam \n", + "33428 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleNetwork Adam \n", + "27074 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleNetwork Adam \n", + "28562 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleNetwork Adam \n", + "33026 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleNetwork Adam \n", + "33367 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleNetwork Adam \n", + "34514 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleNetwork Adam \n", + "35599 EpsilonGreedy-0.6 SimpleNetwork Adam \n", + "27075 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleNetwork Adam \n", + "28563 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleNetwork Adam \n", + "30020 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleNetwork Adam \n", + "30423 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleNetwork Adam \n", + "31136 EpsilonGreedy-0.6 SimpleNetwork Adam \n", + "31167 EpsilonGreedy-0.6 SimpleNetwork Adam \n", + "31911 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleNetwork Adam \n", + "34112 EpsilonGreedy-0.6 SimpleNetwork Adam \n", + "34856 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleNetwork Adam \n", + "35259 EpsilonGreedy-0.1 SimpleNetwork Adam \n", + "35693 EpsilonGreedy-0.6 SimpleNetwork Adam \n", + "28192 EpsilonGreedy-0.6 SimpleNetwork Adam \n", "\n", - " min avg sum \n", - "algo NaN NaN -3.890492e-01 \n", - "step_train NaN NaN -2.090821e-02 \n", - "batch_size NaN NaN 1.187327e-01 \n", - "gamma NaN NaN -3.410748e-16 \n", - "greedy_exploration NaN NaN -2.831818e-02 \n", - "network NaN NaN 2.632569e-01 \n", - " NaN NaN NaN \n", - "optimizer NaN NaN NaN \n", - "lr NaN NaN 7.717727e-02 \n", - "memories NaN NaN NaN \n", - "max_size NaN NaN 5.969496e-02 \n", - "step NaN NaN 5.831934e-02 \n", - "max NaN NaN NaN \n", - "min NaN NaN NaN \n", - "avg NaN NaN NaN \n", - "sum NaN NaN 1.000000e+00 " + " lr memories max_size step sum \n", + "31127 0.001 ExperienceReplay 2048 30.0 500.0 \n", + "34103 0.001 ExperienceReplay 2048 30.0 500.0 \n", + "33050 0.100 ExperienceReplay 2048 40.0 500.0 \n", + "32648 0.001 ExperienceReplay 512 50.0 500.0 \n", + "34508 0.001 ExperienceReplay 512 50.0 500.0 \n", + "28898 0.001 ExperienceReplay 2048 60.0 500.0 \n", + "33052 0.100 ExperienceReplay 2048 60.0 500.0 \n", + "34943 0.100 ExperienceReplay 512 60.0 500.0 \n", + "35687 0.100 ExperienceReplay 512 60.0 500.0 \n", + "31906 0.001 ExperienceReplay 512 70.0 500.0 \n", + "32991 0.001 ExperienceReplay 2048 70.0 500.0 \n", + "34479 0.001 ExperienceReplay 2048 70.0 500.0 \n", + "28559 0.001 ExperienceReplay 512 80.0 500.0 \n", + "28993 0.100 ExperienceReplay 512 80.0 500.0 \n", + "31163 0.001 ExperienceReplay 512 80.0 500.0 \n", + "31225 0.100 ExperienceReplay 512 80.0 500.0 \n", + "33798 0.100 ExperienceReplay 2048 80.0 500.0 \n", + "35255 0.001 ExperienceReplay 512 80.0 500.0 \n", + "27072 0.001 ExperienceReplay 512 90.0 500.0 \n", + "27785 0.001 ExperienceReplay 2048 90.0 500.0 \n", + "28188 0.001 ExperienceReplay 512 90.0 500.0 \n", + "31164 0.001 ExperienceReplay 512 90.0 500.0 \n", + "32714 0.100 ExperienceReplay 512 90.0 500.0 \n", + "32993 0.001 ExperienceReplay 2048 90.0 500.0 \n", + "33055 0.100 ExperienceReplay 2048 90.0 500.0 \n", + "33768 0.001 ExperienceReplay 512 90.0 500.0 \n", + "28561 0.001 ExperienceReplay 512 100.0 500.0 \n", + "30390 0.001 ExperienceReplay 2048 100.0 500.0 \n", + "31134 0.001 ExperienceReplay 2048 100.0 500.0 \n", + "33025 0.001 ExperienceReplay 512 100.0 500.0 \n", + "33397 0.001 ExperienceReplay 512 100.0 500.0 \n", + "33428 0.100 ExperienceReplay 2048 100.0 500.0 \n", + "27074 0.001 ExperienceReplay 512 110.0 500.0 \n", + "28562 0.001 ExperienceReplay 512 110.0 500.0 \n", + "33026 0.001 ExperienceReplay 512 110.0 500.0 \n", + "33367 0.001 ExperienceReplay 2048 110.0 500.0 \n", + "34514 0.001 ExperienceReplay 512 110.0 500.0 \n", + "35599 0.001 ExperienceReplay 2048 110.0 500.0 \n", + "27075 0.001 ExperienceReplay 512 120.0 500.0 \n", + "28563 0.001 ExperienceReplay 512 120.0 500.0 \n", + "30020 0.001 ExperienceReplay 2048 120.0 500.0 \n", + "30423 0.001 ExperienceReplay 512 120.0 500.0 \n", + "31136 0.001 ExperienceReplay 2048 120.0 500.0 \n", + "31167 0.001 ExperienceReplay 512 120.0 500.0 \n", + "31911 0.001 ExperienceReplay 512 120.0 500.0 \n", + "34112 0.001 ExperienceReplay 2048 120.0 500.0 \n", + "34856 0.001 ExperienceReplay 2048 120.0 500.0 \n", + "35259 0.001 ExperienceReplay 512 120.0 500.0 \n", + "35693 0.100 ExperienceReplay 512 120.0 500.0 \n", + "28192 0.001 ExperienceReplay 512 130.0 500.0 " ] }, - "execution_count": 57, + "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "df_corr_best.corr()" + "df_DQN.sort_values(by =[\"sum\",\"step\"], ascending = [False, True]).head(50)" ] }, { "cell_type": "code", - "execution_count": 65, + "execution_count": 19, "metadata": {}, "outputs": [ { @@ -2551,13 +1816,13 @@ "" ] }, - "execution_count": 65, + "execution_count": 19, "metadata": {}, "output_type": "execute_result" }, { "data": { - "image/png": "\n", + "image/png": "\n", "text/plain": [ "
" ] @@ -2570,35 +1835,12 @@ ], "source": [ "fig, ax = plt.subplots(figsize=(10,10)) \n", - "sns.heatmap(df_corr_best.corr()[df_corr_best.corr() > 0.05], annot = True, fmt='.2g',cmap= 'coolwarm', ax=ax)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Result by algo" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### DQN" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": {}, - "outputs": [], - "source": [ - "df_DQN = df[df[\"algo\"] == \"DQN\"].copy()" + "sns.heatmap(df_DQN.corr()[abs(df_DQN.corr()) > 0.05], annot = True, fmt='.2g',cmap= 'coolwarm', ax=ax)" ] }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 20, "metadata": {}, "outputs": [ { @@ -2622,186 +1864,783 @@ " \n", " \n", " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " step\n", + " sum\n", + " \n", + " \n", + " algo\n", " step_train\n", " batch_size\n", " gamma\n", + " greedy_exploration\n", + " network\n", + " optimizer\n", " lr\n", - " step\n", - " max\n", - " min\n", - " avg\n", - " sum\n", + " memories\n", + " max_size\n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", - " step_train\n", - " 1.000000e+00\n", - " 6.254152e-18\n", - " -3.338836e-15\n", - " -1.947613e-17\n", - " 5.190625e-20\n", - " NaN\n", - " NaN\n", - " NaN\n", - " -2.081287e-01\n", + " DQN\n", + " 1.0\n", + " 32.0\n", + " 0.99\n", + " AdaptativeEpsilonGreedy-0.3-0.1-30000-0\n", + " SimpleNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 512\n", + " 17\n", + " 17\n", + " 17\n", + " \n", + " \n", + " 64.0\n", + " 1.00\n", + " EpsilonGreedy-0.1\n", + " SimpleNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 512\n", + " 14\n", + " 14\n", + " 14\n", + " \n", + " \n", + " 32.0\n", + " 1.00\n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0\n", + " SimpleNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 2048\n", + " 10\n", + " 10\n", + " 10\n", + " \n", + " \n", + " 64.0\n", + " 1.00\n", + " EpsilonGreedy-0.6\n", + " SimpleNetwork\n", + " Adam\n", + " 0.1000\n", + " ExperienceReplay\n", + " 512\n", + " 9\n", + " 9\n", + " 9\n", + " \n", + " \n", + " 0.99\n", + " EpsilonGreedy-0.6\n", + " SimpleNetwork\n", + " Adam\n", + " 0.1000\n", + " ExperienceReplay\n", + " 512\n", + " 9\n", + " 9\n", + " 9\n", + " \n", + " \n", + " 32.0\n", + " 0.99\n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0\n", + " SimpleNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 2048\n", + " 8\n", + " 8\n", + " 8\n", + " \n", + " \n", + " 1.00\n", + " EpsilonGreedy-0.6\n", + " SimpleNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 2048\n", + " 8\n", + " 8\n", + " 8\n", + " \n", + " \n", + " 64.0\n", + " 0.99\n", + " AdaptativeEpsilonGreedy-0.3-0.1-30000-0\n", + " SimpleNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 2048\n", + " 8\n", + " 8\n", + " 8\n", + " \n", + " \n", + " 1.00\n", + " EpsilonGreedy-0.6\n", + " SimpleNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 2048\n", + " 7\n", + " 7\n", + " 7\n", + " \n", + " \n", + " 0.99\n", + " EpsilonGreedy-0.6\n", + " SimpleNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 512\n", + " 6\n", + " 6\n", + " 6\n", + " \n", + " \n", + " 1.00\n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0\n", + " SimpleNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 512\n", + " 5\n", + " 5\n", + " 5\n", + " \n", + " \n", + " 32.0\n", + " 0.95\n", + " AdaptativeEpsilonGreedy-0.3-0.1-30000-0\n", + " SimpleNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 512\n", + " 5\n", + " 5\n", + " 5\n", + " \n", + " \n", + " 1.00\n", + " AdaptativeEpsilonGreedy-0.3-0.1-30000-0\n", + " SimpleNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 2048\n", + " 5\n", + " 5\n", + " 5\n", + " \n", + " \n", + " 0.95\n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0\n", + " SimpleNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 2048\n", + " 5\n", + " 5\n", + " 5\n", + " \n", + " \n", + " 64.0\n", + " 0.99\n", + " AdaptativeEpsilonGreedy-0.3-0.1-30000-0\n", + " SimpleNetwork\n", + " Adam\n", + " 0.1000\n", + " ExperienceReplay\n", + " 2048\n", + " 5\n", + " 5\n", + " 5\n", + " \n", + " \n", + " 1.00\n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0\n", + " SimpleNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 2048\n", + " 4\n", + " 4\n", + " 4\n", + " \n", + " \n", + " AdaptativeEpsilonGreedy-0.3-0.1-30000-0\n", + " SimpleNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 512\n", + " 4\n", + " 4\n", + " 4\n", + " \n", + " \n", + " 0.99\n", + " EpsilonGreedy-0.6\n", + " SimpleNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 2048\n", + " 4\n", + " 4\n", + " 4\n", + " \n", + " \n", + " 0.95\n", + " EpsilonGreedy-0.6\n", + " SimpleNetwork\n", + " Adam\n", + " 0.1000\n", + " ExperienceReplay\n", + " 512\n", + " 4\n", + " 4\n", + " 4\n", + " \n", + " \n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0\n", + " SimpleNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 512\n", + " 4\n", + " 4\n", + " 4\n", + " \n", + " \n", + " 2048\n", + " 4\n", + " 4\n", + " 4\n", + " \n", + " \n", + " 32.0\n", + " 0.95\n", + " EpsilonGreedy-0.6\n", + " SimpleNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 512\n", + " 4\n", + " 4\n", + " 4\n", + " \n", + " \n", + " 1.00\n", + " EpsilonGreedy-0.6\n", + " SimpleNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 512\n", + " 4\n", + " 4\n", + " 4\n", + " \n", + " \n", + " 0.95\n", + " AdaptativeEpsilonGreedy-0.3-0.1-30000-0\n", + " SimpleNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 2048\n", + " 3\n", + " 3\n", + " 3\n", + " \n", + " \n", + " 64.0\n", + " 0.99\n", + " EpsilonGreedy-0.1\n", + " SimpleNetwork\n", + " Adam\n", + " 0.1000\n", + " ExperienceReplay\n", + " 2048\n", + " 3\n", + " 3\n", + " 3\n", + " \n", + " \n", + " 32.0\n", + " 0.95\n", + " EpsilonGreedy-0.1\n", + " SimpleNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 2048\n", + " 3\n", + " 3\n", + " 3\n", + " \n", + " \n", + " 1.00\n", + " EpsilonGreedy-0.6\n", + " SimpleNetwork\n", + " Adam\n", + " 0.1000\n", + " ExperienceReplay\n", + " 512\n", + " 3\n", + " 3\n", + " 3\n", + " \n", + " \n", + " 64.0\n", + " 1.00\n", + " AdaptativeEpsilonGreedy-0.3-0.1-30000-0\n", + " SimpleNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 2048\n", + " 3\n", + " 3\n", + " 3\n", + " \n", + " \n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0\n", + " SimpleNetwork\n", + " Adam\n", + " 0.1000\n", + " ExperienceReplay\n", + " 512\n", + " 2\n", + " 2\n", + " 2\n", + " \n", + " \n", + " 32.0\n", + " 0.99\n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0\n", + " SimpleNetwork\n", + " Adam\n", + " 0.1000\n", + " ExperienceReplay\n", + " 512\n", + " 2\n", + " 2\n", + " 2\n", + " \n", + " \n", + " EpsilonGreedy-0.6\n", + " SimpleNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 512\n", + " 2\n", + " 2\n", + " 2\n", + " \n", + " \n", + " 64.0\n", + " 0.95\n", + " AdaptativeEpsilonGreedy-0.3-0.1-30000-0\n", + " SimpleNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 512\n", + " 2\n", + " 2\n", + " 2\n", + " \n", + " \n", + " 0.99\n", + " AdaptativeEpsilonGreedy-0.3-0.1-30000-0\n", + " SimpleNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 512\n", + " 2\n", + " 2\n", + " 2\n", + " \n", + " \n", + " 0.95\n", + " AdaptativeEpsilonGreedy-0.3-0.1-30000-0\n", + " SimpleNetwork\n", + " Adam\n", + " 0.0001\n", + " ExperienceReplay\n", + " 2048\n", + " 2\n", + " 2\n", + " 2\n", + " \n", + " \n", + " 0.99\n", + " EpsilonGreedy-0.1\n", + " SimpleNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 512\n", + " 2\n", + " 2\n", + " 2\n", + " \n", + " \n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0\n", + " SimpleNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 512\n", + " 1\n", + " 1\n", + " 1\n", " \n", " \n", - " batch_size\n", - " 6.254152e-18\n", - " 1.000000e+00\n", - " -5.096548e-15\n", - " 1.309431e-16\n", - " 0.000000e+00\n", - " NaN\n", - " NaN\n", - " NaN\n", - " 1.477390e-01\n", + " 0.1000\n", + " ExperienceReplay\n", + " 2048\n", + " 1\n", + " 1\n", + " 1\n", " \n", " \n", - " gamma\n", - " -3.338836e-15\n", - " -5.096548e-15\n", - " 1.000000e+00\n", - " 1.070395e-14\n", - " -1.030754e-16\n", - " NaN\n", - " NaN\n", - " NaN\n", - " 3.327095e-16\n", + " 0.95\n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0\n", + " SimpleNetwork\n", + " Adam\n", + " 0.0001\n", + " ExperienceReplay\n", + " 2048\n", + " 1\n", + " 1\n", + " 1\n", " \n", " \n", - " lr\n", - " -1.947613e-17\n", - " 1.309431e-16\n", - " 1.070395e-14\n", - " 1.000000e+00\n", - " -3.622535e-19\n", - " NaN\n", - " NaN\n", - " NaN\n", - " -1.253560e-01\n", + " 0.99\n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0\n", + " SimpleNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 2048\n", + " 1\n", + " 1\n", + " 1\n", " \n", " \n", - " step\n", - " 5.190625e-20\n", - " 0.000000e+00\n", - " -1.030754e-16\n", - " -3.622535e-19\n", - " 1.000000e+00\n", - " NaN\n", - " NaN\n", - " NaN\n", - " 2.112197e-01\n", + " 32.0\n", + " 1.00\n", + " EpsilonGreedy-0.1\n", + " SimpleNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 2048\n", + " 1\n", + " 1\n", + " 1\n", " \n", " \n", - " max\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0\n", + " SimpleNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 512\n", + " 1\n", + " 1\n", + " 1\n", " \n", " \n", - " min\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", + " EpsilonGreedy-0.1\n", + " SimpleNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 512\n", + " 1\n", + " 1\n", + " 1\n", " \n", " \n", - " avg\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", + " 0.99\n", + " EpsilonGreedy-0.1\n", + " SimpleNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 512\n", + " 1\n", + " 1\n", + " 1\n", " \n", " \n", - " sum\n", - " -2.081287e-01\n", - " 1.477390e-01\n", - " 3.327095e-16\n", - " -1.253560e-01\n", - " 2.112197e-01\n", - " NaN\n", - " NaN\n", - " NaN\n", - " 1.000000e+00\n", + " 64.0\n", + " 0.95\n", + " EpsilonGreedy-0.6\n", + " SimpleNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 512\n", + " 1\n", + " 1\n", + " 1\n", + " \n", + " \n", + " 0.0001\n", + " ExperienceReplay\n", + " 2048\n", + " 1\n", + " 1\n", + " 1\n", + " \n", + " \n", + " AdaptativeEpsilonGreedy-0.3-0.1-30000-0\n", + " SimpleNetwork\n", + " Adam\n", + " 0.0001\n", + " ExperienceReplay\n", + " 512\n", + " 1\n", + " 1\n", + " 1\n", + " \n", + " \n", + " 32.0\n", + " 0.95\n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0\n", + " SimpleNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 512\n", + " 1\n", + " 1\n", + " 1\n", " \n", " \n", "\n", "" ], "text/plain": [ - " step_train batch_size gamma lr \\\n", - "step_train 1.000000e+00 6.254152e-18 -3.338836e-15 -1.947613e-17 \n", - "batch_size 6.254152e-18 1.000000e+00 -5.096548e-15 1.309431e-16 \n", - "gamma -3.338836e-15 -5.096548e-15 1.000000e+00 1.070395e-14 \n", - "lr -1.947613e-17 1.309431e-16 1.070395e-14 1.000000e+00 \n", - "step 5.190625e-20 0.000000e+00 -1.030754e-16 -3.622535e-19 \n", - "max NaN NaN NaN NaN \n", - "min NaN NaN NaN NaN \n", - "avg NaN NaN NaN NaN \n", - "sum -2.081287e-01 1.477390e-01 3.327095e-16 -1.253560e-01 \n", + " \\\n", + "algo step_train batch_size gamma greedy_exploration network optimizer lr memories max_size \n", + "DQN 1.0 32.0 0.99 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleNetwork Adam 0.0010 ExperienceReplay 512 17 \n", + " 64.0 1.00 EpsilonGreedy-0.1 SimpleNetwork Adam 0.0010 ExperienceReplay 512 14 \n", + " 32.0 1.00 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleNetwork Adam 0.0010 ExperienceReplay 2048 10 \n", + " 64.0 1.00 EpsilonGreedy-0.6 SimpleNetwork Adam 0.1000 ExperienceReplay 512 9 \n", + " 0.99 EpsilonGreedy-0.6 SimpleNetwork Adam 0.1000 ExperienceReplay 512 9 \n", + " 32.0 0.99 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleNetwork Adam 0.0010 ExperienceReplay 2048 8 \n", + " 1.00 EpsilonGreedy-0.6 SimpleNetwork Adam 0.0010 ExperienceReplay 2048 8 \n", + " 64.0 0.99 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleNetwork Adam 0.0010 ExperienceReplay 2048 8 \n", + " 1.00 EpsilonGreedy-0.6 SimpleNetwork Adam 0.0010 ExperienceReplay 2048 7 \n", + " 0.99 EpsilonGreedy-0.6 SimpleNetwork Adam 0.0010 ExperienceReplay 512 6 \n", + " 1.00 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleNetwork Adam 0.0010 ExperienceReplay 512 5 \n", + " 32.0 0.95 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleNetwork Adam 0.0010 ExperienceReplay 512 5 \n", + " 1.00 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleNetwork Adam 0.0010 ExperienceReplay 2048 5 \n", + " 0.95 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleNetwork Adam 0.0010 ExperienceReplay 2048 5 \n", + " 64.0 0.99 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleNetwork Adam 0.1000 ExperienceReplay 2048 5 \n", + " 1.00 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleNetwork Adam 0.0010 ExperienceReplay 2048 4 \n", + " AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleNetwork Adam 0.0010 ExperienceReplay 512 4 \n", + " 0.99 EpsilonGreedy-0.6 SimpleNetwork Adam 0.0010 ExperienceReplay 2048 4 \n", + " 0.95 EpsilonGreedy-0.6 SimpleNetwork Adam 0.1000 ExperienceReplay 512 4 \n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleNetwork Adam 0.0010 ExperienceReplay 512 4 \n", + " 2048 4 \n", + " 32.0 0.95 EpsilonGreedy-0.6 SimpleNetwork Adam 0.0010 ExperienceReplay 512 4 \n", + " 1.00 EpsilonGreedy-0.6 SimpleNetwork Adam 0.0010 ExperienceReplay 512 4 \n", + " 0.95 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleNetwork Adam 0.0010 ExperienceReplay 2048 3 \n", + " 64.0 0.99 EpsilonGreedy-0.1 SimpleNetwork Adam 0.1000 ExperienceReplay 2048 3 \n", + " 32.0 0.95 EpsilonGreedy-0.1 SimpleNetwork Adam 0.0010 ExperienceReplay 2048 3 \n", + " 1.00 EpsilonGreedy-0.6 SimpleNetwork Adam 0.1000 ExperienceReplay 512 3 \n", + " 64.0 1.00 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleNetwork Adam 0.0010 ExperienceReplay 2048 3 \n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleNetwork Adam 0.1000 ExperienceReplay 512 2 \n", + " 32.0 0.99 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleNetwork Adam 0.1000 ExperienceReplay 512 2 \n", + " EpsilonGreedy-0.6 SimpleNetwork Adam 0.0010 ExperienceReplay 512 2 \n", + " 64.0 0.95 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleNetwork Adam 0.0010 ExperienceReplay 512 2 \n", + " 0.99 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleNetwork Adam 0.0010 ExperienceReplay 512 2 \n", + " 0.95 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleNetwork Adam 0.0001 ExperienceReplay 2048 2 \n", + " 0.99 EpsilonGreedy-0.1 SimpleNetwork Adam 0.0010 ExperienceReplay 512 2 \n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleNetwork Adam 0.0010 ExperienceReplay 512 1 \n", + " 0.1000 ExperienceReplay 2048 1 \n", + " 0.95 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleNetwork Adam 0.0001 ExperienceReplay 2048 1 \n", + " 0.99 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleNetwork Adam 0.0010 ExperienceReplay 2048 1 \n", + " 32.0 1.00 EpsilonGreedy-0.1 SimpleNetwork Adam 0.0010 ExperienceReplay 2048 1 \n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleNetwork Adam 0.0010 ExperienceReplay 512 1 \n", + " EpsilonGreedy-0.1 SimpleNetwork Adam 0.0010 ExperienceReplay 512 1 \n", + " 0.99 EpsilonGreedy-0.1 SimpleNetwork Adam 0.0010 ExperienceReplay 512 1 \n", + " 64.0 0.95 EpsilonGreedy-0.6 SimpleNetwork Adam 0.0010 ExperienceReplay 512 1 \n", + " 0.0001 ExperienceReplay 2048 1 \n", + " AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleNetwork Adam 0.0001 ExperienceReplay 512 1 \n", + " 32.0 0.95 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleNetwork Adam 0.0010 ExperienceReplay 512 1 \n", + "\n", + " step \\\n", + "algo step_train batch_size gamma greedy_exploration network optimizer lr memories max_size \n", + "DQN 1.0 32.0 0.99 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleNetwork Adam 0.0010 ExperienceReplay 512 17 \n", + " 64.0 1.00 EpsilonGreedy-0.1 SimpleNetwork Adam 0.0010 ExperienceReplay 512 14 \n", + " 32.0 1.00 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleNetwork Adam 0.0010 ExperienceReplay 2048 10 \n", + " 64.0 1.00 EpsilonGreedy-0.6 SimpleNetwork Adam 0.1000 ExperienceReplay 512 9 \n", + " 0.99 EpsilonGreedy-0.6 SimpleNetwork Adam 0.1000 ExperienceReplay 512 9 \n", + " 32.0 0.99 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleNetwork Adam 0.0010 ExperienceReplay 2048 8 \n", + " 1.00 EpsilonGreedy-0.6 SimpleNetwork Adam 0.0010 ExperienceReplay 2048 8 \n", + " 64.0 0.99 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleNetwork Adam 0.0010 ExperienceReplay 2048 8 \n", + " 1.00 EpsilonGreedy-0.6 SimpleNetwork Adam 0.0010 ExperienceReplay 2048 7 \n", + " 0.99 EpsilonGreedy-0.6 SimpleNetwork Adam 0.0010 ExperienceReplay 512 6 \n", + " 1.00 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleNetwork Adam 0.0010 ExperienceReplay 512 5 \n", + " 32.0 0.95 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleNetwork Adam 0.0010 ExperienceReplay 512 5 \n", + " 1.00 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleNetwork Adam 0.0010 ExperienceReplay 2048 5 \n", + " 0.95 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleNetwork Adam 0.0010 ExperienceReplay 2048 5 \n", + " 64.0 0.99 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleNetwork Adam 0.1000 ExperienceReplay 2048 5 \n", + " 1.00 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleNetwork Adam 0.0010 ExperienceReplay 2048 4 \n", + " AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleNetwork Adam 0.0010 ExperienceReplay 512 4 \n", + " 0.99 EpsilonGreedy-0.6 SimpleNetwork Adam 0.0010 ExperienceReplay 2048 4 \n", + " 0.95 EpsilonGreedy-0.6 SimpleNetwork Adam 0.1000 ExperienceReplay 512 4 \n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleNetwork Adam 0.0010 ExperienceReplay 512 4 \n", + " 2048 4 \n", + " 32.0 0.95 EpsilonGreedy-0.6 SimpleNetwork Adam 0.0010 ExperienceReplay 512 4 \n", + " 1.00 EpsilonGreedy-0.6 SimpleNetwork Adam 0.0010 ExperienceReplay 512 4 \n", + " 0.95 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleNetwork Adam 0.0010 ExperienceReplay 2048 3 \n", + " 64.0 0.99 EpsilonGreedy-0.1 SimpleNetwork Adam 0.1000 ExperienceReplay 2048 3 \n", + " 32.0 0.95 EpsilonGreedy-0.1 SimpleNetwork Adam 0.0010 ExperienceReplay 2048 3 \n", + " 1.00 EpsilonGreedy-0.6 SimpleNetwork Adam 0.1000 ExperienceReplay 512 3 \n", + " 64.0 1.00 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleNetwork Adam 0.0010 ExperienceReplay 2048 3 \n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleNetwork Adam 0.1000 ExperienceReplay 512 2 \n", + " 32.0 0.99 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleNetwork Adam 0.1000 ExperienceReplay 512 2 \n", + " EpsilonGreedy-0.6 SimpleNetwork Adam 0.0010 ExperienceReplay 512 2 \n", + " 64.0 0.95 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleNetwork Adam 0.0010 ExperienceReplay 512 2 \n", + " 0.99 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleNetwork Adam 0.0010 ExperienceReplay 512 2 \n", + " 0.95 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleNetwork Adam 0.0001 ExperienceReplay 2048 2 \n", + " 0.99 EpsilonGreedy-0.1 SimpleNetwork Adam 0.0010 ExperienceReplay 512 2 \n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleNetwork Adam 0.0010 ExperienceReplay 512 1 \n", + " 0.1000 ExperienceReplay 2048 1 \n", + " 0.95 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleNetwork Adam 0.0001 ExperienceReplay 2048 1 \n", + " 0.99 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleNetwork Adam 0.0010 ExperienceReplay 2048 1 \n", + " 32.0 1.00 EpsilonGreedy-0.1 SimpleNetwork Adam 0.0010 ExperienceReplay 2048 1 \n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleNetwork Adam 0.0010 ExperienceReplay 512 1 \n", + " EpsilonGreedy-0.1 SimpleNetwork Adam 0.0010 ExperienceReplay 512 1 \n", + " 0.99 EpsilonGreedy-0.1 SimpleNetwork Adam 0.0010 ExperienceReplay 512 1 \n", + " 64.0 0.95 EpsilonGreedy-0.6 SimpleNetwork Adam 0.0010 ExperienceReplay 512 1 \n", + " 0.0001 ExperienceReplay 2048 1 \n", + " AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleNetwork Adam 0.0001 ExperienceReplay 512 1 \n", + " 32.0 0.95 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleNetwork Adam 0.0010 ExperienceReplay 512 1 \n", "\n", - " step max min avg sum \n", - "step_train 5.190625e-20 NaN NaN NaN -2.081287e-01 \n", - "batch_size 0.000000e+00 NaN NaN NaN 1.477390e-01 \n", - "gamma -1.030754e-16 NaN NaN NaN 3.327095e-16 \n", - "lr -3.622535e-19 NaN NaN NaN -1.253560e-01 \n", - "step 1.000000e+00 NaN NaN NaN 2.112197e-01 \n", - "max NaN NaN NaN NaN NaN \n", - "min NaN NaN NaN NaN NaN \n", - "avg NaN NaN NaN NaN NaN \n", - "sum 2.112197e-01 NaN NaN NaN 1.000000e+00 " + " sum \n", + "algo step_train batch_size gamma greedy_exploration network optimizer lr memories max_size \n", + "DQN 1.0 32.0 0.99 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleNetwork Adam 0.0010 ExperienceReplay 512 17 \n", + " 64.0 1.00 EpsilonGreedy-0.1 SimpleNetwork Adam 0.0010 ExperienceReplay 512 14 \n", + " 32.0 1.00 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleNetwork Adam 0.0010 ExperienceReplay 2048 10 \n", + " 64.0 1.00 EpsilonGreedy-0.6 SimpleNetwork Adam 0.1000 ExperienceReplay 512 9 \n", + " 0.99 EpsilonGreedy-0.6 SimpleNetwork Adam 0.1000 ExperienceReplay 512 9 \n", + " 32.0 0.99 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleNetwork Adam 0.0010 ExperienceReplay 2048 8 \n", + " 1.00 EpsilonGreedy-0.6 SimpleNetwork Adam 0.0010 ExperienceReplay 2048 8 \n", + " 64.0 0.99 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleNetwork Adam 0.0010 ExperienceReplay 2048 8 \n", + " 1.00 EpsilonGreedy-0.6 SimpleNetwork Adam 0.0010 ExperienceReplay 2048 7 \n", + " 0.99 EpsilonGreedy-0.6 SimpleNetwork Adam 0.0010 ExperienceReplay 512 6 \n", + " 1.00 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleNetwork Adam 0.0010 ExperienceReplay 512 5 \n", + " 32.0 0.95 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleNetwork Adam 0.0010 ExperienceReplay 512 5 \n", + " 1.00 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleNetwork Adam 0.0010 ExperienceReplay 2048 5 \n", + " 0.95 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleNetwork Adam 0.0010 ExperienceReplay 2048 5 \n", + " 64.0 0.99 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleNetwork Adam 0.1000 ExperienceReplay 2048 5 \n", + " 1.00 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleNetwork Adam 0.0010 ExperienceReplay 2048 4 \n", + " AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleNetwork Adam 0.0010 ExperienceReplay 512 4 \n", + " 0.99 EpsilonGreedy-0.6 SimpleNetwork Adam 0.0010 ExperienceReplay 2048 4 \n", + " 0.95 EpsilonGreedy-0.6 SimpleNetwork Adam 0.1000 ExperienceReplay 512 4 \n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleNetwork Adam 0.0010 ExperienceReplay 512 4 \n", + " 2048 4 \n", + " 32.0 0.95 EpsilonGreedy-0.6 SimpleNetwork Adam 0.0010 ExperienceReplay 512 4 \n", + " 1.00 EpsilonGreedy-0.6 SimpleNetwork Adam 0.0010 ExperienceReplay 512 4 \n", + " 0.95 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleNetwork Adam 0.0010 ExperienceReplay 2048 3 \n", + " 64.0 0.99 EpsilonGreedy-0.1 SimpleNetwork Adam 0.1000 ExperienceReplay 2048 3 \n", + " 32.0 0.95 EpsilonGreedy-0.1 SimpleNetwork Adam 0.0010 ExperienceReplay 2048 3 \n", + " 1.00 EpsilonGreedy-0.6 SimpleNetwork Adam 0.1000 ExperienceReplay 512 3 \n", + " 64.0 1.00 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleNetwork Adam 0.0010 ExperienceReplay 2048 3 \n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleNetwork Adam 0.1000 ExperienceReplay 512 2 \n", + " 32.0 0.99 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleNetwork Adam 0.1000 ExperienceReplay 512 2 \n", + " EpsilonGreedy-0.6 SimpleNetwork Adam 0.0010 ExperienceReplay 512 2 \n", + " 64.0 0.95 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleNetwork Adam 0.0010 ExperienceReplay 512 2 \n", + " 0.99 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleNetwork Adam 0.0010 ExperienceReplay 512 2 \n", + " 0.95 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleNetwork Adam 0.0001 ExperienceReplay 2048 2 \n", + " 0.99 EpsilonGreedy-0.1 SimpleNetwork Adam 0.0010 ExperienceReplay 512 2 \n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleNetwork Adam 0.0010 ExperienceReplay 512 1 \n", + " 0.1000 ExperienceReplay 2048 1 \n", + " 0.95 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleNetwork Adam 0.0001 ExperienceReplay 2048 1 \n", + " 0.99 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleNetwork Adam 0.0010 ExperienceReplay 2048 1 \n", + " 32.0 1.00 EpsilonGreedy-0.1 SimpleNetwork Adam 0.0010 ExperienceReplay 2048 1 \n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleNetwork Adam 0.0010 ExperienceReplay 512 1 \n", + " EpsilonGreedy-0.1 SimpleNetwork Adam 0.0010 ExperienceReplay 512 1 \n", + " 0.99 EpsilonGreedy-0.1 SimpleNetwork Adam 0.0010 ExperienceReplay 512 1 \n", + " 64.0 0.95 EpsilonGreedy-0.6 SimpleNetwork Adam 0.0010 ExperienceReplay 512 1 \n", + " 0.0001 ExperienceReplay 2048 1 \n", + " AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleNetwork Adam 0.0001 ExperienceReplay 512 1 \n", + " 32.0 0.95 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleNetwork Adam 0.0010 ExperienceReplay 512 1 " ] }, - "execution_count": 17, + "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "df_DQN.corr()" + "columns = [\"algo\",\"step_train\",\"batch_size\",\"gamma\",\"greedy_exploration\",\"network\",\"optimizer\",\"lr\",\"memories\",\"max_size\"]\n", + "df_DQN[df_DQN[\"sum\"] >= 500].groupby(by=columns, observed=True).count().sort_values(by=['sum'], ascending=False)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### DuelingNetwork" ] }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 21, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(500.0, 8.0)" - ] - }, - "execution_count": 18, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "max(df_DQN[\"sum\"]), min(df_DQN[\"sum\"])" + "df_DQN = df[df[\"algo\"] == \"DQN\"].copy()\n", + "df_DQN = df_DQN[df_DQN[\"network\"] == \"SimpleDuelingNetwork\"]" ] }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 22, "metadata": {}, "outputs": [ { @@ -2837,391 +2676,328 @@ " memories\n", " max_size\n", " step\n", - " max\n", - " min\n", - " avg\n", " sum\n", " \n", " \n", " \n", " \n", - " 3436\n", - " DQN\n", - " 1.0\n", - " 32.0\n", - " 0.99\n", - " AdaptativeEpsilonGreedy-0.3-0.1-50000-0\n", - " SimpleNetwork\n", - " \n", - " Adam\n", - " 0.0010\n", - " ExperienceReplay\n", - " 128\n", - " 166.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 500.0\n", - " \n", - " \n", - " 3711\n", + " 31716\n", " DQN\n", " 1.0\n", " 64.0\n", - " 0.99\n", - " EpsilonGreedy-0.1\n", - " SimpleNetwork\n", + " 0.95\n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0\n", + " SimpleDuelingNetwork\n", " \n", " Adam\n", - " 0.0010\n", + " 0.001\n", " ExperienceReplay\n", - " 16\n", - " 166.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", + " 512\n", + " 30.0\n", " 500.0\n", " \n", " \n", - " 4176\n", + " 27997\n", " DQN\n", + " 1.0\n", " 32.0\n", - " 64.0\n", - " 0.99\n", - " AdaptativeEpsilonGreedy-0.3-0.1-50000-0\n", - " SimpleNetwork\n", + " 0.95\n", + " EpsilonGreedy-0.6\n", + " SimpleDuelingNetwork\n", " \n", " Adam\n", - " 0.1000\n", + " 0.001\n", " ExperienceReplay\n", - " 16\n", - " 166.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", + " 512\n", + " 40.0\n", " 500.0\n", " \n", " \n", - " 3442\n", + " 29082\n", " DQN\n", " 1.0\n", " 32.0\n", " 0.99\n", - " AdaptativeEpsilonGreedy-0.3-0.1-50000-0\n", - " SimpleNetwork\n", + " EpsilonGreedy-0.1\n", + " SimpleDuelingNetwork\n", " \n", " Adam\n", - " 0.0010\n", + " 0.001\n", " ExperienceReplay\n", - " 16\n", - " 332.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", + " 2048\n", + " 40.0\n", " 500.0\n", " \n", " \n", - " 3537\n", + " 30601\n", " DQN\n", " 1.0\n", " 32.0\n", - " 0.99\n", + " 1.00\n", " EpsilonGreedy-0.1\n", - " SimpleNetwork\n", + " SimpleDuelingNetwork\n", " \n", " Adam\n", - " 0.0010\n", + " 0.001\n", " ExperienceReplay\n", - " 32\n", - " 332.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", + " 512\n", + " 40.0\n", " 500.0\n", " \n", " \n", - " 3612\n", + " 32461\n", " DQN\n", " 1.0\n", " 64.0\n", - " 0.99\n", - " AdaptativeEpsilonGreedy-0.3-0.1-50000-0\n", - " SimpleNetwork\n", + " 0.95\n", + " EpsilonGreedy-0.6\n", + " SimpleDuelingNetwork\n", " \n", " Adam\n", - " 0.0001\n", + " 0.001\n", " ExperienceReplay\n", - " 32\n", - " 332.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", + " 512\n", + " 40.0\n", " 500.0\n", " \n", " \n", - " 3717\n", + " 33918\n", " DQN\n", " 1.0\n", " 64.0\n", " 0.99\n", - " EpsilonGreedy-0.1\n", - " SimpleNetwork\n", + " EpsilonGreedy-0.6\n", + " SimpleDuelingNetwork\n", " \n", " Adam\n", - " 0.0010\n", + " 0.001\n", " ExperienceReplay\n", - " 32\n", - " 332.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", + " 2048\n", + " 40.0\n", " 500.0\n", " \n", " \n", - " 4797\n", + " 27626\n", " DQN\n", - " 4.0\n", - " 64.0\n", - " 0.99\n", + " 1.0\n", + " 32.0\n", + " 0.95\n", " EpsilonGreedy-0.1\n", - " SimpleNetwork\n", + " SimpleDuelingNetwork\n", " \n", " Adam\n", - " 0.0010\n", + " 0.001\n", " ExperienceReplay\n", - " 32\n", - " 332.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", + " 512\n", + " 50.0\n", " 500.0\n", " \n", " \n", - " 3268\n", + " 29827\n", " DQN\n", " 1.0\n", - " 1.0\n", - " 0.99\n", - " AdaptativeEpsilonGreedy-0.3-0.1-50000-0\n", - " SimpleNetwork\n", + " 32.0\n", + " 1.00\n", + " AdaptativeEpsilonGreedy-0.3-0.1-30000-0\n", + " SimpleDuelingNetwork\n", " \n", " Adam\n", - " 0.0010\n", + " 0.001\n", " ExperienceReplay\n", - " 32\n", - " 498.0\n", - " 1.0\n", - " 1.0\n", + " 2048\n", + " 50.0\n", + " 500.0\n", + " \n", + " \n", + " 31346\n", + " DQN\n", " 1.0\n", + " 64.0\n", + " 0.95\n", + " AdaptativeEpsilonGreedy-0.3-0.1-30000-0\n", + " SimpleDuelingNetwork\n", + " \n", + " Adam\n", + " 0.001\n", + " ExperienceReplay\n", + " 512\n", + " 50.0\n", " 500.0\n", " \n", " \n", - " 3538\n", + " 32803\n", " DQN\n", " 1.0\n", - " 32.0\n", + " 64.0\n", " 0.99\n", - " EpsilonGreedy-0.1\n", - " SimpleNetwork\n", + " AdaptativeEpsilonGreedy-0.3-0.1-30000-0\n", + " SimpleDuelingNetwork\n", " \n", " Adam\n", - " 0.0010\n", + " 0.001\n", " ExperienceReplay\n", - " 32\n", - " 498.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", + " 2048\n", + " 50.0\n", " 500.0\n", " \n", " \n", - " 3573\n", + " 28340\n", " DQN\n", " 1.0\n", " 32.0\n", " 0.99\n", - " EpsilonGreedy-0.6\n", - " SimpleNetwork\n", + " AdaptativeEpsilonGreedy-0.3-0.1-30000-0\n", + " SimpleDuelingNetwork\n", " \n", " Adam\n", - " 0.0010\n", + " 0.001\n", " ExperienceReplay\n", - " 128\n", - " 498.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", + " 2048\n", + " 60.0\n", " 500.0\n", " \n", " \n", - " 3718\n", + " 29859\n", " DQN\n", " 1.0\n", - " 64.0\n", - " 0.99\n", - " EpsilonGreedy-0.1\n", - " SimpleNetwork\n", + " 32.0\n", + " 1.00\n", + " AdaptativeEpsilonGreedy-0.3-0.1-30000-0\n", + " SimpleDuelingNetwork\n", " \n", " Adam\n", - " 0.0010\n", + " 0.001\n", " ExperienceReplay\n", - " 32\n", - " 498.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", + " 512\n", + " 60.0\n", " 500.0\n", " \n", " \n", - " 3753\n", + " 32804\n", " DQN\n", " 1.0\n", " 64.0\n", " 0.99\n", - " EpsilonGreedy-0.6\n", - " SimpleNetwork\n", + " AdaptativeEpsilonGreedy-0.3-0.1-30000-0\n", + " SimpleDuelingNetwork\n", " \n", " Adam\n", - " 0.0010\n", + " 0.001\n", " ExperienceReplay\n", - " 128\n", - " 498.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", + " 2048\n", + " 60.0\n", " 500.0\n", " \n", " \n", - " 4523\n", + " 30201\n", " DQN\n", - " 4.0\n", + " 1.0\n", " 32.0\n", - " 0.99\n", - " AdaptativeEpsilonGreedy-0.3-0.1-50000-0\n", - " SimpleNetwork\n", + " 1.00\n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0\n", + " SimpleDuelingNetwork\n", " \n", " Adam\n", - " 0.0010\n", + " 0.001\n", " ExperienceReplay\n", - " 16\n", - " 498.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", + " 2048\n", + " 70.0\n", " 500.0\n", " \n", " \n", - " 4668\n", + " 32805\n", " DQN\n", - " 4.0\n", - " 32.0\n", + " 1.0\n", + " 64.0\n", " 0.99\n", - " EpsilonGreedy-0.6\n", - " SimpleNetwork\n", + " AdaptativeEpsilonGreedy-0.3-0.1-30000-0\n", + " SimpleDuelingNetwork\n", " \n", " Adam\n", - " 0.1000\n", + " 0.001\n", " ExperienceReplay\n", - " 128\n", - " 498.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", + " 2048\n", + " 70.0\n", " 500.0\n", " \n", " \n", - " 4703\n", + " 33921\n", " DQN\n", - " 4.0\n", + " 1.0\n", " 64.0\n", " 0.99\n", - " AdaptativeEpsilonGreedy-0.3-0.1-50000-0\n", - " SimpleNetwork\n", + " EpsilonGreedy-0.6\n", + " SimpleDuelingNetwork\n", " \n", " Adam\n", - " 0.0010\n", + " 0.001\n", " ExperienceReplay\n", - " 16\n", - " 498.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", + " 2048\n", + " 70.0\n", " 500.0\n", " \n", " \n", - " 4753\n", + " 35409\n", " DQN\n", - " 4.0\n", + " 1.0\n", " 64.0\n", - " 0.99\n", - " AdaptativeEpsilonGreedy-0.8-0.2-50000-0\n", - " SimpleNetwork\n", + " 1.00\n", + " EpsilonGreedy-0.6\n", + " SimpleDuelingNetwork\n", " \n", " Adam\n", - " 0.0010\n", + " 0.001\n", " ExperienceReplay\n", - " 32\n", - " 498.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", + " 2048\n", + " 70.0\n", " 500.0\n", " \n", " \n", - " 4793\n", + " 28745\n", " DQN\n", - " 4.0\n", - " 64.0\n", + " 1.0\n", + " 32.0\n", " 0.99\n", - " EpsilonGreedy-0.1\n", - " SimpleNetwork\n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0\n", + " SimpleDuelingNetwork\n", " \n", " Adam\n", - " 0.0010\n", + " 0.001\n", " ExperienceReplay\n", - " 16\n", - " 498.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", + " 512\n", + " 80.0\n", " 500.0\n", " \n", " \n", - " 4798\n", + " 31349\n", " DQN\n", - " 4.0\n", + " 1.0\n", " 64.0\n", - " 0.99\n", - " EpsilonGreedy-0.1\n", - " SimpleNetwork\n", + " 0.95\n", + " AdaptativeEpsilonGreedy-0.3-0.1-30000-0\n", + " SimpleDuelingNetwork\n", " \n", " Adam\n", - " 0.0010\n", + " 0.001\n", " ExperienceReplay\n", - " 32\n", - " 498.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", + " 512\n", + " 80.0\n", " 500.0\n", " \n", " \n", - " 3264\n", + " 28343\n", " DQN\n", " 1.0\n", - " 1.0\n", + " 32.0\n", " 0.99\n", - " AdaptativeEpsilonGreedy-0.3-0.1-50000-0\n", - " SimpleNetwork\n", + " AdaptativeEpsilonGreedy-0.3-0.1-30000-0\n", + " SimpleDuelingNetwork\n", " \n", " Adam\n", - " 0.0010\n", + " 0.001\n", " ExperienceReplay\n", - " 16\n", - " 500.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", + " 2048\n", + " 90.0\n", " 500.0\n", " \n", " \n", @@ -3229,74 +3005,74 @@ "" ], "text/plain": [ - " algo step_train batch_size gamma \\\n", - "3436 DQN 1.0 32.0 0.99 \n", - "3711 DQN 1.0 64.0 0.99 \n", - "4176 DQN 32.0 64.0 0.99 \n", - "3442 DQN 1.0 32.0 0.99 \n", - "3537 DQN 1.0 32.0 0.99 \n", - "3612 DQN 1.0 64.0 0.99 \n", - "3717 DQN 1.0 64.0 0.99 \n", - "4797 DQN 4.0 64.0 0.99 \n", - "3268 DQN 1.0 1.0 0.99 \n", - "3538 DQN 1.0 32.0 0.99 \n", - "3573 DQN 1.0 32.0 0.99 \n", - "3718 DQN 1.0 64.0 0.99 \n", - "3753 DQN 1.0 64.0 0.99 \n", - "4523 DQN 4.0 32.0 0.99 \n", - "4668 DQN 4.0 32.0 0.99 \n", - "4703 DQN 4.0 64.0 0.99 \n", - "4753 DQN 4.0 64.0 0.99 \n", - "4793 DQN 4.0 64.0 0.99 \n", - "4798 DQN 4.0 64.0 0.99 \n", - "3264 DQN 1.0 1.0 0.99 \n", + " algo step_train batch_size gamma \\\n", + "31716 DQN 1.0 64.0 0.95 \n", + "27997 DQN 1.0 32.0 0.95 \n", + "29082 DQN 1.0 32.0 0.99 \n", + "30601 DQN 1.0 32.0 1.00 \n", + "32461 DQN 1.0 64.0 0.95 \n", + "33918 DQN 1.0 64.0 0.99 \n", + "27626 DQN 1.0 32.0 0.95 \n", + "29827 DQN 1.0 32.0 1.00 \n", + "31346 DQN 1.0 64.0 0.95 \n", + "32803 DQN 1.0 64.0 0.99 \n", + "28340 DQN 1.0 32.0 0.99 \n", + "29859 DQN 1.0 32.0 1.00 \n", + "32804 DQN 1.0 64.0 0.99 \n", + "30201 DQN 1.0 32.0 1.00 \n", + "32805 DQN 1.0 64.0 0.99 \n", + "33921 DQN 1.0 64.0 0.99 \n", + "35409 DQN 1.0 64.0 1.00 \n", + "28745 DQN 1.0 32.0 0.99 \n", + "31349 DQN 1.0 64.0 0.95 \n", + "28343 DQN 1.0 32.0 0.99 \n", "\n", - " greedy_exploration network optimizer \\\n", - "3436 AdaptativeEpsilonGreedy-0.3-0.1-50000-0 SimpleNetwork Adam \n", - "3711 EpsilonGreedy-0.1 SimpleNetwork Adam \n", - "4176 AdaptativeEpsilonGreedy-0.3-0.1-50000-0 SimpleNetwork Adam \n", - "3442 AdaptativeEpsilonGreedy-0.3-0.1-50000-0 SimpleNetwork Adam \n", - "3537 EpsilonGreedy-0.1 SimpleNetwork Adam \n", - "3612 AdaptativeEpsilonGreedy-0.3-0.1-50000-0 SimpleNetwork Adam \n", - "3717 EpsilonGreedy-0.1 SimpleNetwork Adam \n", - "4797 EpsilonGreedy-0.1 SimpleNetwork Adam \n", - "3268 AdaptativeEpsilonGreedy-0.3-0.1-50000-0 SimpleNetwork Adam \n", - "3538 EpsilonGreedy-0.1 SimpleNetwork Adam \n", - "3573 EpsilonGreedy-0.6 SimpleNetwork Adam \n", - "3718 EpsilonGreedy-0.1 SimpleNetwork Adam \n", - "3753 EpsilonGreedy-0.6 SimpleNetwork Adam \n", - "4523 AdaptativeEpsilonGreedy-0.3-0.1-50000-0 SimpleNetwork Adam \n", - "4668 EpsilonGreedy-0.6 SimpleNetwork Adam \n", - "4703 AdaptativeEpsilonGreedy-0.3-0.1-50000-0 SimpleNetwork Adam \n", - "4753 AdaptativeEpsilonGreedy-0.8-0.2-50000-0 SimpleNetwork Adam \n", - "4793 EpsilonGreedy-0.1 SimpleNetwork Adam \n", - "4798 EpsilonGreedy-0.1 SimpleNetwork Adam \n", - "3264 AdaptativeEpsilonGreedy-0.3-0.1-50000-0 SimpleNetwork Adam \n", + " greedy_exploration network \\\n", + "31716 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork \n", + "27997 EpsilonGreedy-0.6 SimpleDuelingNetwork \n", + "29082 EpsilonGreedy-0.1 SimpleDuelingNetwork \n", + "30601 EpsilonGreedy-0.1 SimpleDuelingNetwork \n", + "32461 EpsilonGreedy-0.6 SimpleDuelingNetwork \n", + "33918 EpsilonGreedy-0.6 SimpleDuelingNetwork \n", + "27626 EpsilonGreedy-0.1 SimpleDuelingNetwork \n", + "29827 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleDuelingNetwork \n", + "31346 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleDuelingNetwork \n", + "32803 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleDuelingNetwork \n", + "28340 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleDuelingNetwork \n", + "29859 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleDuelingNetwork \n", + "32804 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleDuelingNetwork \n", + "30201 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork \n", + "32805 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleDuelingNetwork \n", + "33921 EpsilonGreedy-0.6 SimpleDuelingNetwork \n", + "35409 EpsilonGreedy-0.6 SimpleDuelingNetwork \n", + "28745 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork \n", + "31349 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleDuelingNetwork \n", + "28343 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleDuelingNetwork \n", "\n", - " lr memories max_size step max min avg sum \n", - "3436 0.0010 ExperienceReplay 128 166.0 1.0 1.0 1.0 500.0 \n", - "3711 0.0010 ExperienceReplay 16 166.0 1.0 1.0 1.0 500.0 \n", - "4176 0.1000 ExperienceReplay 16 166.0 1.0 1.0 1.0 500.0 \n", - "3442 0.0010 ExperienceReplay 16 332.0 1.0 1.0 1.0 500.0 \n", - "3537 0.0010 ExperienceReplay 32 332.0 1.0 1.0 1.0 500.0 \n", - "3612 0.0001 ExperienceReplay 32 332.0 1.0 1.0 1.0 500.0 \n", - "3717 0.0010 ExperienceReplay 32 332.0 1.0 1.0 1.0 500.0 \n", - "4797 0.0010 ExperienceReplay 32 332.0 1.0 1.0 1.0 500.0 \n", - "3268 0.0010 ExperienceReplay 32 498.0 1.0 1.0 1.0 500.0 \n", - "3538 0.0010 ExperienceReplay 32 498.0 1.0 1.0 1.0 500.0 \n", - "3573 0.0010 ExperienceReplay 128 498.0 1.0 1.0 1.0 500.0 \n", - "3718 0.0010 ExperienceReplay 32 498.0 1.0 1.0 1.0 500.0 \n", - "3753 0.0010 ExperienceReplay 128 498.0 1.0 1.0 1.0 500.0 \n", - "4523 0.0010 ExperienceReplay 16 498.0 1.0 1.0 1.0 500.0 \n", - "4668 0.1000 ExperienceReplay 128 498.0 1.0 1.0 1.0 500.0 \n", - "4703 0.0010 ExperienceReplay 16 498.0 1.0 1.0 1.0 500.0 \n", - "4753 0.0010 ExperienceReplay 32 498.0 1.0 1.0 1.0 500.0 \n", - "4793 0.0010 ExperienceReplay 16 498.0 1.0 1.0 1.0 500.0 \n", - "4798 0.0010 ExperienceReplay 32 498.0 1.0 1.0 1.0 500.0 \n", - "3264 0.0010 ExperienceReplay 16 500.0 1.0 1.0 1.0 500.0 " + " optimizer lr memories max_size step sum \n", + "31716 Adam 0.001 ExperienceReplay 512 30.0 500.0 \n", + "27997 Adam 0.001 ExperienceReplay 512 40.0 500.0 \n", + "29082 Adam 0.001 ExperienceReplay 2048 40.0 500.0 \n", + "30601 Adam 0.001 ExperienceReplay 512 40.0 500.0 \n", + "32461 Adam 0.001 ExperienceReplay 512 40.0 500.0 \n", + "33918 Adam 0.001 ExperienceReplay 2048 40.0 500.0 \n", + "27626 Adam 0.001 ExperienceReplay 512 50.0 500.0 \n", + "29827 Adam 0.001 ExperienceReplay 2048 50.0 500.0 \n", + "31346 Adam 0.001 ExperienceReplay 512 50.0 500.0 \n", + "32803 Adam 0.001 ExperienceReplay 2048 50.0 500.0 \n", + "28340 Adam 0.001 ExperienceReplay 2048 60.0 500.0 \n", + "29859 Adam 0.001 ExperienceReplay 512 60.0 500.0 \n", + "32804 Adam 0.001 ExperienceReplay 2048 60.0 500.0 \n", + "30201 Adam 0.001 ExperienceReplay 2048 70.0 500.0 \n", + "32805 Adam 0.001 ExperienceReplay 2048 70.0 500.0 \n", + "33921 Adam 0.001 ExperienceReplay 2048 70.0 500.0 \n", + "35409 Adam 0.001 ExperienceReplay 2048 70.0 500.0 \n", + "28745 Adam 0.001 ExperienceReplay 512 80.0 500.0 \n", + "31349 Adam 0.001 ExperienceReplay 512 80.0 500.0 \n", + "28343 Adam 0.001 ExperienceReplay 2048 90.0 500.0 " ] }, - "execution_count": 19, + "execution_count": 22, "metadata": {}, "output_type": "execute_result" } @@ -3307,7 +3083,7 @@ }, { "cell_type": "code", - "execution_count": 67, + "execution_count": 23, "metadata": {}, "outputs": [ { @@ -3316,13 +3092,13 @@ "" ] }, - "execution_count": 67, + "execution_count": 23, "metadata": {}, "output_type": "execute_result" }, { "data": { - "image/png": "\n", + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjUAAAJDCAYAAADzbuVEAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy86wFpkAAAACXBIWXMAAAsTAAALEwEAmpwYAAA4wklEQVR4nO3debhdZXn38e+dEEgIATLIFMYiBgEhDCIIFooKlEEULGJ9nYqNtlKhDojVOlALitgWUKTRMjgrmPcVhIpKFQRRSWQeIpFQEkAwhCEkIdO53z/WTjgk55C9cva0Vr6f69rX2cPae917Jfvs+/yeZ60VmYkkSVLVDet2AZIkSa1gUyNJkmrBpkaSJNWCTY0kSaoFmxpJklQLNjWSJKkWbGokSVJLRcTFEfF4RNw1yOMREedHxKyIuCMi9mnFem1qJElSq10KHPkij/8lsEvjMgX4SitWalMjSZJaKjNvAOa/yCLHAV/Pwq+BzSNi66Gud4OhvkAzrh4xycMWN+noZTO7XYIkDergY6/vdgmVceNVh0Sn1tXp79ljlv/+vRQJy0pTM3NqiZeYCMzpd3tu475Hh1JXR5oaSZJUH40GpkwTs7qBGr4hN2YOP0mSpE6bC2zX7/a2wCNDfVGTGkmSKi5GdGykq1WuBE6JiO8CrwKezswhDT2BTY0kSWqxiPgOcCgwISLmAp8CRgBk5kXANcBRwCxgEfDuVqzXpkaSpIobtkFvJTWZ+da1PJ7A+1u9XufUSJKkWjCpkSSp4mKEGQWY1EiSpJowqZEkqeJ6bU5Nt5jUSJKkWjCpkSSp4ip4nJq2MKmRJEm1YFMjSZJqweEnSZIqzonCBZMaSZJUCyY1kiRVnBOFCyY1kiSpFkxqJEmqOOfUFExqJElSLZjUSJJUcTHcpAZMaiRJUk2Y1EiSVHHDTGoAkxpJklQTJjWSJFVcDDOpAZMaSZJUEyY1kiRVXAw3owCTGkmSVBMmNZIkVZx7PxVMaiRJUi3Y1EiSpFpw+EmSpIpzl+6CSY0kSaoFkxpJkirOicIFkxpJklQLJjWSJFVcmNQAJjWSJKkmTGokSaq4GGZGASY1kiSpJkxqJEmqOI9TUzCpkSRJtWBSI0lSxXmcmsJ619Ts+dWz2OKoQ1n6+BPcsPex3S5HkrQWp07ZmQP3Hc9zS1Zw1nkz+f0fnl1jmTP+4WXsussYAOY8spiz/uM+Fj/Xx5jRG/CxUyexzVYjWbqsj7PPm8nshxZ1+i2oQ9a74ae5l03jt8e8p9tlSJKacMC+49hum4056b2/5Qtf/j0f/rtdBlzu/K/9gXd9YAbv+sAMHvvTEk44ZiIAbz9xe+5/4Fne9YEZfPbf7+PUKS/tZPkdE8Oio5detd41NfNvnM6y+U93uwxJUhNec8B4fvw/fwTg7pkL2GT0Bowfu+Eayy1avGLV9Y02HEZmcX3H7TZmxh1PAvDQ3MVsvcVIxm4+ov2Fqyuabmoi4viIuD8ino6IZyJiQUQ8087iJEnrtwnjN+LxeUtW3X78iSVMGL9mUwPwsVMnceXXD2SHbTfmih89DMCs2Qv58wNfAsDLdxnDlluMZIvxG7W/8A6LYcM6eulVZSo7B3hDZm6WmZtm5pjM3HSwhSNiSkRMj4jpP+57asiFSpLWPwMOdOTAy5593kze+K6b+d+5C3ntwUUj880rHmLMJhtwyXn7csKxE7n/gQWsWDHIC6jyykwUfiwz72124cycCkwFuHrEJP8HSZKacvxR23DsEVsDcO/9C9hiwvPJyhbjN2Le/KWDPrevD6775Z946/Hbcc11j7Fo8QrOPm/mqscv/9qreOSx59pXvLqqTFMzPSK+B/w/YFUWmJnTWl2UJGn9Ne2aR5h2zSMAHLjfOE44ZiI/u+FP7D5pDM8uWs4TT67Z1EzceiQPP1o0KwftP56H5hZ7OG0yejjPLelj+fLk2MO34va7n3rB/Ju66OXJu51UpqnZFFgEHN7vvgQq1dRM/sYXGX/I/mw4YSyHzb6e+8+8gDmXXNHtsiRJA7h5+nwO3G8c35u6/6pdulf6wqf24HMX/J75Ty7l46ftyuiNhxMRzJr9LOdeeD8AO2w7mk98cBJ9ffDgQwv53Pm/79ZbUQdEZvtHhhx+at7Ry2aufSFJ6pKDj72+2yVUxo1XHdKx+OTu4w7r6Pfs7j/8n56Mhtaa1ETE6Zl5TkRcwADTszLzA22pTJIkqYRmhp9WTg6e3s5CJEnSunFOTWGtTU1mXtX4eVn7y5EkSVo3TU8UjoiXAB8FdgNGrrw/Mw9rQ12SJKlJvXxAvE4qsxW+RTEUtRPwGeBB4JY21CRJklRamV26x2fmf0XEqZl5PXB9RDgNXpKkLnNOTaFMU7Os8fPRiDgaeATYtvUlSZIklVemqflsRGwGfAi4gOJgfP/YlqokSVLTTGoKTTU1ETEc2CUzfwQ8DfxFW6uSJEkqqamJwpm5AnhDm2uRJEnrIIZFRy+9qszw068i4kvA94CFK+/MzN+1vCpJkqSSyjQ1r278PLPffQl4nBpJkrrI49QUyjQ1J2fmA/3viIg/a3E9kiRJ66RMU3MFsM9q910O7Nu6ciRJUlnDhvfuPJdOauYs3bsCuwObRcTx/R7alH6nS5AkSeqmZpKaScAxwObAsf3uXwD8bRtqkiRJKq2Zs3T/EPhhRByYmTcPtlxEfCwzz25pdZIkaa16eTfrTmp6uvSLNTQNfzXEWiRJktZZmYnCa2ObKElSF7hLd6GVWyFb+FqSJEmlmNRIklRxzqkptDKpubyFryVJklRK001NRPxZRFwVEfMi4vGI+GH/Iwpn5lntKVGSJL0YT2hZKJPUfBv4PrAVsA1FMvOddhQlSZJUVpk5NZGZ3+h3+5sRcUqrC5IkSeW491OhzFb4eUScERE7RsQOEXE6cHVEjIuIce0qUJIkVUtEHBkRMyNiVkScMcDjmzWmtNweEXdHxLtbsd4ySc1bGj/fu9r9f0OxO7dn7JYkqQt6aZ5LRAwHvgy8HpgL3BIRV2bmPf0Wez9wT2YeGxEvAWZGxLcyc+lQ1t10U5OZOw1lRZIkab2wPzArMx8AiIjvAscB/ZuaBMZERACbAPOB5UNdcdNNTURsDHwQ2D4zp0TELsCkzPzRUIuQJEnrrtNzaiJiCjCl311TM3Nq4/pEYE6/x+YCr1rtJb4EXAk8AowB3pKZfUOtq8zw0yXADODV/Yq8HLCpkSRpPdJoYKYO8vBAY2Grn3XgCOA24DBgZ+CnEfHLzHxmKHWVae12zsxzgGUAmbkYjyIsSVL3RXT28uLmAtv1u70tRSLT37uBaVmYBcwGdh3qZijT1CyNiFE0uq2I2BlYMtQCJElSrdwC7BIRO0XEhsBJFENN/T0EvBYgIrYEJgEPDHXFZYafPg38GNguIr4FHETRaUmSJAGQmcsbx7G7FhgOXJyZd0fE+xqPXwT8C3BpRNxJMerz0cycN9R1l9n76ScRMQM4oFHAqa0oQJIkDU0v7dINkJnXANesdt9F/a4/Ahze6vWWOffTdZn5RGZenZk/ysx5EXFdqwuSJElaF2tNaiJiJLAxMCEixvL85OBNKc4BJUmSusjTJBSaGX56L3AaRQMzg6KpSWABxX7mkiRJXbfW1i4zz2scTfhfgcmN65dQzFK+uc31SZKktYhh0dFLryqTV705M5+JiIMpzudwKfCVtlQlSZJUUpldulc0fh4NXJSZP4yIT7e+pPXb1SMmdbuESjh62cxulyBJPcM5NYUyTc3DEfGfwOuAz0fERjSZ9PgF1BwbGkm97sarDul2CdKgyjQ1JwJHAudm5lMRsTXwkfaUJUmSmtXL81w6qczB9xYB0/rdfhR4tB1FSZIklVUmqZEkST3IpKbgzCJJklQLJjWSJFWdez8BJjWSJKkmTGokSaq4COfUgEmNJEmqCZsaSZJUCw4/SZJUcZ4moeBWkCRJtWBSI0lSxXnwvYJJjSRJqgWTGkmSqs45NYBJjSRJqgmTGkmSKs45NQWTGkmSVAsmNZIkVVyEGQWY1EiSpJowqZEkqeqcUwOY1EiSpJowqZEkqeI891PBrSBJkmrBpEaSpIrzODUFkxpJklQLNjWSJKkWHH6SJKnqPPgeYFIjSZJqwqRGkqSKc6JwwaRGkiTVgkmNJElV58H3AJMaSZJUEyY1kiRVXIRzasCkRpIk1YRJjSRJVeecGsCkRpIk1YRJjSRJFedxagomNZIkqRZMaiRJqjrP/QSY1EiSpJowqdGg9vzqWWxx1KEsffwJbtj72G6XI0kajHNqAJMavYi5l03jt8e8p9tlSJLUFJsaDWr+jdNZNv/pbpchSVJTHH6SJKniwonCQImkJiJeFhHXRcRdjdt7RsQnXmT5KRExPSKmT506tRW1SpIkDapMUvNV4CPAfwJk5h0R8W3gswMtnJlTgZXdTA6lSEmS9CKcKAyUm1OzcWb+drX7lreyGEmSpHVVpqmZFxE700hdIuLNwKNtqUo9YfI3vsirf/ldRk/aicNmX892735zt0uSJA0ghg3r6KVXlRl+ej/FcNKuEfEwMBv4P22pSj3htrd/qNslSJLUtKabmsx8AHhdRIwGhmXmgvaVJUmSmhbOqYFyez+tiIjPAYtWNjQR8bu2VSZJklRCmeGnuymaoJ9ExFsycz5gayhJUrf18DyXTiqzFZZn5ukUu3b/MiL2xV21JUlSjyiT1ARAZn4/Iu4GvgNs35aqJElS85xTA5Rralad2TAz746Ig4E3trwiSZKkdbDWpiYiDsvM/wF2iIgdVnv42faUJUmSmtXLx47ppGaSmkOA/wGOHeCxBKa1tCJJkqR1sNamJjM/1fj57vaXI0mSSvMs3UC549ScGhGbRuFrEfG7iDi8ncVJkiQ1q0xr9zeZ+QxwOLAF8G7gc22pSpIkqaTSu3QDRwGXZObtEe5DJklS1w3z6xjKJTUzIuInFE3NtRExBuhrT1mSJEnllGlqTgbOAF6ZmYuADSmGoACIiN1bXJskSWpCxLCOXtZeTxwZETMjYlZEnDHIModGxG0RcXdEXN+K7VDmLN19wO/63X4CeKLfIt8A9mlFUZIkqZoiYjjwZeD1wFzgloi4MjPv6bfM5sCFwJGZ+VBEbNGKdZeZU7M2DuhJktQNvTWnZn9gVmY+ABAR3wWOA+7pt8xfA9My8yGAzHy8FStu5Y7tntxSkqT1QERMiYjp/S5T+j08EZjT7/bcxn39vQwYGxG/iIgZEfGOVtTVyqRGkiR1Q4cPvpeZU4Gpgzw8UGy0evCxAbAv8FpgFHBzRPw6M38/lLpa2dQsbeFrSZKkapoLbNfv9rbAIwMsMy8zFwILI+IGYC+gc01NREwEduj/vMy8ofHzgKEUIkmS1lFvHTbuFmCXiNgJeBg4iWIOTX8/BL4UERtQ7E39KuDfh7rippuaiPg88BaKiT4rGncncMNQi5AkSfWQmcsj4hTgWmA4cHFm3h0R72s8flFm3hsRPwbuoDjm3dcy866hrrtMUvNGYFJmLhnqSiVJUgsN660TWmbmNcA1q9130Wq3vwB8oZXrLbMVHgBGtHLlkiRJrbLWpCYiLqAYZloE3BYR1wGr0prM/ED7ypMkSWvV4b2felUzw0/TGz9nAFe2sRZJkqR1ttamJjMvA4iI0cBzmbmicXs4sFF7y5MkSWvVW0cU7poyedV1FAfIWWkU8LPWliNJkrRuyjQ1IzPz2ZU3Gtc3bn1JkiRJ5ZXZpXthROyTmb8DiIh9gcXtKUuSJDXNicJAuabmNODyiFh5qOOtKY4SKEmS1HVlmpo7gF2BSRQnq7qP1p7lW5IkrYveOk1C15RpSm7OzGWZeVdm3pmZy4Cb21WYJElSGc0cfG8rYCIwKiL25vlTim+KE4UlSeq+HjtNQrc0M/x0BPAuilOH/1u/+xcA/9SGmiRJkkpr9uB7l0XECZn5gw7UJEmSynBODVBionBm/iAijgZ2B0b2u//MdhQmSZJURtNNTURcRDGH5i+ArwFvBn7bprokSVKzPE4NUG7vp1dn5juAJzPzM8CBwHbtKUuSJKmcMsepWXn04EURsQ3wBLBT60uSJEmluPcTUK6p+VFEbA6cA8xo3Pe1llckSZK0Dso0NecCfwe8huKge78EvtKOoqS1uXrEpG6XUBlHL5vZ7RIktZt7PwHlmprLKI5Nc37j9luBrwMntrqo9ZVfPs2xoZEkDaRMUzMpM/fqd/vnEXF7qwuSJEklufcTUG7vp1sj4oCVNyLiVcBNrS9JkiSpvGbO/XQnkMAI4B0R8VDj9g7APe0tT5IkqTnNDD8d0/YqJEnSunOiMNDcuZ/+txOFSJIkDUWZicKSJKkXefA9oNxEYUmSpJ5lUiNJUsWlc2oAkxpJklQTJjWSJFWdB98DTGokSVJNmNRIklR1JjWASY0kSaoJkxpJkirOvZ8KJjWSJKkWTGokSao659QAJjWSJKkmTGokSao659QAJjWSJKkmbGokSVItOPwkSVLVDTOjAJMaSZJUEyY1kiRVnAffK5jUSJKkWjCpkSSp6jz4HmBSI0mSasKkRpKkikuTGsCkRpIk1YRJjSRJVefeT4BJjSRJqgmTGkmSKs45NQW3giRJqgWTGkmSqs45NYBJjSRJqgmTGkmSqs45NYBJjSRJqgmbGkmSVAsOP0mSVHHpRGHApEaSJNWESY0kSVXnRGHApEYasj2/ehave/hX/PmtV3W7FElar9nUSEM097Jp/PaY93S7DEnrsSQ6eulVNjXSEM2/cTrL5j/d7TIkab3nnBpJkirOE1oWmt4KEXFARNwSEc9GxNKIWBERz7zI8lMiYnpETJ86dWprqpUkSRpEmaTmS8BJwOXAfsA7gJcOtnBmTgVWdjO5rgVKkqS1MKkBSg4/ZeasiBiemSuASyLiV22qS5IkqZQyTc2iiNgQuC0izgEeBUa3pyypOiZ/44uMP2R/NpwwlsNmX8/9Z17AnEuu6HZZktYjHlG4UKapeTswHDgF+EdgO+CEdhQlVcltb/9Qt0uQJFGiqcnM/21cXQx8pj3lSJKksnpt76eIOBI4jyIM+Vpmfm6Q5V4J/Bp4S2YOOeIus/fTMRFxa0TMj4hnImLBi+39JEmS1j8RMRz4MvCXwG7AWyNit0GW+zxwbavWXWb46T+A44E7M9O9mSRJ6hW9Nadmf2BWZj4AEBHfBY4D7lltuX8AfgC8slUrLpNXzQHusqGRJGn91v9YdI3LlH4PT6ToGVaa27iv//MnAm8CLmplXWWSmtOBayLiemDJyjsz899aWZAkSSqn03NqVjsW3eoGio1WD0T+A/hoZq6IFqZMZZqafwWeBUYCG7asAkmSVCdzKfaQXmlb4JHVltkP+G6joZkAHBURyzPz/w1lxWWamnGZefhQViZJkmrvFmCXiNgJeJjibAR/3X+BzNxp5fWIuBT40VAbGijX1PwsIg7PzJ8MdaWSJKl1csARn+7IzOURcQrFXk3DgYsz8+6IeF/j8ZbOo+kvmp33GxELKI4gvARYRjFmlpm5aRNPd3KxWubqEZO6XUKlHL1sZrdLkNZXHes05t11c0e/ZyfscWDvdFH9lDn43ph2FiJJktZNrx18r1tKndAyIvYEduz/vMyc1uKaJEmSSmu6qYmIi4E9gbuBvsbdCdjUSJLUTb118L2uKZPUHJCZaxzmWJIkqReUaWpujojdMnP1wxxLkqQuylInCKivMk3NZRSNzR8p9oBauffTnm2pTJIkqYQyTc3FwNuBO3l+To0kSeqydE4NUK6peSgzr2xbJZIkSUNQpqm5LyK+DVzFC09o6d5PkiR1kcepKZRpakZRNDP9z//kLt2SJKknlDmi8LvbWYgkSVo3vXTup24qc/C9kcDJwO7AyJX3Z+bftKEuSZKkUsoMwn0D2Ao4Arge2BZY0I6iJElS8zKGdfTSq8pU9tLM/GdgYWZeBhwNvKI9ZUmSJJVTpqlZ1vj5VETsAWxGcXJLSZKkriuz99PUiBgLfAK4EtgE+Oe2VCVJkprmwfcKZZqazYCVe0B9ufFzeURMzszbWlqVJElSSWWamn2B/SgOvgfFnJpbgPdFxOWZeU6ri5MkSWvnLt2FMk3NeGCfzHwWICI+BVwB/DkwA7CpkSRJXVOmqdkeWNrv9jJgh8xcHBFLBnmOJElqs17ezbqTyjQ13wZ+HRE/bNw+FvhORIwG7ml5ZZIkSSWUOU3Cv0TENcDBQADvy8zpjYff1o7iJEnS2jmnplAmqSEzZ1DMn5EkSeoppZoaSZLUe5xTU3ArSJKkWjCpkSSp4pxTUzCpkSRJtWBSI0lSxTmnpuBWkCRJtWBSI0lSxTmnpmBSI0mSasGkRqq5q0dM6nYJlXD0spndLqESTvzQg90uoTK+/8Udu13CesemRpXjl0/zbGik9UOGw0/g8JMkSaoJkxpJkiou06QGTGokSVJNmNRIklRxaUYBmNRIkqSaMKmRJKniPPhewaRGkiTVgkmNJEkVZ1JTMKmRJEm1YFIjSVLFmdQUTGokSVItmNRIklRxJjUFkxpJklQLJjWSJFWc534qmNRIkqRasKmRJEm14PCTJEkV50ThgkmNJEmqBZMaSZIqzqSmYFIjSZJqwaRGkqSKM6kpmNRIkqRaMKmRJKniPPhewaRGkiTVgkmNJEkV1+ecGsCkRpIk1YRJjSRJFefeTwWTGkmSVAsmNZIkVZx7PxVMaiRJUi2Y1EiSVHHOqSmY1EiSpFqwqZEkSbXg8JMkSRXnROGCSY0kSaoFkxpJkirOicIFkxpJktRSEXFkRMyMiFkRccYAj78tIu5oXH4VEXu1Yr0mNZIkVVwvzamJiOHAl4HXA3OBWyLiysy8p99is4FDMvPJiPhLYCrwqqGu26RGkiS10v7ArMx8IDOXAt8Fjuu/QGb+KjOfbNz8NbBtK1ZsUiOpI/b86llscdShLH38CW7Y+9hul6MKefcbx7H3y0exZGly4XfnMfvhpWss8w9vm8DO227E8hXJH+YsYerlT7CiD449dFNes88mAAwbBttuOYKTPzmHhYv7Ov022qrH3s1EYE6/23N58RTmZOC/W7FikxpJHTH3smn89pj3dLsMVczeu45iqwkb8IGzH2bq5U/wnhPGD7jcjTMWctrnH+bD5z7ChiOCw141BoCrfvEMp//bI5z+b4/wnWue5J4/PFe7hqYbImJKREzvd5nS/+EBnpKDvM5fUDQ1H21FXSY1kjpi/o3TGbXDxG6XoYrZb4+NuWHGQgDuf2gJo0cNY/Mxw3lqwYoXLHfrfYtXXZ/10FLGbz58jdc6aO/R3HTrwvYW3CWdnlOTmVMp5sEMZC6wXb/b2wKPrL5QROwJfA34y8x8ohV1NZXURMSwiLirFSuUJKlZ4zYbzrynlq+6/cTTyxm32ZoNy0rDh8Fr9h3Nbf2aHIANRwSTdx3Fr+9Y1LZatcotwC4RsVNEbAicBFzZf4GI2B6YBrw9M3/fqhU31dRkZh9we6OIpvSPpqZOHayZkyRpcE2PYzS854Tx3PvAEu6bveQF9++7+yhmzl5S26GnJDp6edFaMpcDpwDXAvcC38/MuyPifRHxvsZinwTGAxdGxG0RMb0V26HM8NPWwN0R8VtgVX6XmW8YaOHVoqkX+z8oSdIqRxw0htc25sT8Yc4SJmy+ATMpmpTxm23Ak0+vGPB5bz58MzbdZDhTL318jccOmjyaG2s69NSLMvMa4JrV7ruo3/X3AC2fZFemqflMq1cuSdLqrr1pAdfetACAvV8+iiMPGsNNty5kl+03YtFzfWvMpwE47FWbsNekUZz5lcfI1f6MHjUy2G3nkVzw7XmdKL8reuk4Nd3UdFOTmde3sxBJ9Tb5G19k/CH7s+GEsRw2+3ruP/MC5lxyRbfLUo+79d7F7PPyUZz/sYksXVbs0r3SGe/Zgv/8/hM8+cwK/vaE8fzpyeX86we2BuA3dy7kBz99GoD9XzGa22c+x5KlDhrUXeTqLe3qC0QsYODhowAyMzdtYj3+T5K64OoRk7pdQmUcvWxmt0uohBM/9GC3S6iM739xx47FJzfes7Cj37MH7za6J6OhtSY1mTmmE4VIkiQNhQffkyRJteDB9yRJqrg+J3kAJjWSJKkmTGokSaq4tR0Qb31hUiNJkmrBpEaSpIrz4HsFkxpJklQLJjWSJFXcWo6ju94wqZEkSbVgUiNJUsX1ufcTYFIjSZJqwqRGkqSKc++ngkmNJEmqBZMaSZIqzr2fCiY1kiSpFkxqJEmqOM/9VDCpkSRJtWBTI0mSasHhJ0mSKq7PicKASY0kSaoJkxpJkirOg+8VTGokSVItmNRIklRxHnyvYFIjSZJqwaRGkqSK6/Pge4BJjSRJqgmTGkmSKs45NQWTGkmSVAsmNZIkVZzHqSmY1EiSpFowqZEkqeI891PBpEaSJNWCSY0kSRXn3k8FkxpJklQLNjWSJKkWHH6SJKni0tMkACY1kiSpJkxqJEmqOHfpLpjUSJKkWjCpkSTg6hGTul1CJfz9rbd1uwQNwF26CzY1Uo0dvWxmt0uoBBsaqR5saiRJqjiTmoJzaiRJUi2Y1EiSVHF96XFqwKRGkiTVhEmNJEkV55yagkmNJEmqBZMaSZIqzqSmYFIjSZJqwaRGkqSK89xPBZMaSZJUCzY1kiSpFhx+kiSp4tKD7wEmNZIkqSZMaiRJqjh36S6Y1EiSpFowqZEkqeLcpbtgUiNJkmrBpEaSpIpzTk3BpEaSJNWCSY0kSRVnUlMwqZEkSbVgUiNJUsW591PBpEaSJNWCSY0kSRXnnJqCSY0kSaoFmxpJkiqur6+zl7WJiCMjYmZEzIqIMwZ4PCLi/Mbjd0TEPq3YDjY1kiSpZSJiOPBl4C+B3YC3RsRuqy32l8AujcsU4CutWLdNjSRJaqX9gVmZ+UBmLgW+Cxy32jLHAV/Pwq+BzSNi66Gu2KZGkqSKy+zsJSKmRMT0fpcp/cqZCMzpd3tu4z5KLlOaez9JkqRSMnMqMHWQh2Ogp6zDMqXZ1EiSVHE9tkv3XGC7fre3BR5Zh2VKc/hJkiS10i3ALhGxU0RsCJwEXLnaMlcC72jsBXUA8HRmPjrUFZvUSJJUcb10moTMXB4RpwDXAsOBizPz7oh4X+Pxi4BrgKOAWcAi4N2tWLdNjSRJaqnMvIaicel/30X9rifw/lav16ZGkqSKy45Pqhlonm/3OadGkiTVgkmNJEkV12N7P3WNSY0kSaoFkxpJkiqumZNMrg9MaiRJUi2Y1EhSj9nzq2exxVGHsvTxJ7hh72O7XU5X3XXrTXz/4nPo6+vj4Ne+iSOP/5sXPP7HubO59MufYs4D93LcX5/C4ce9c9VjixY+wzcuPJOHH5pFRPCO93+anSft1em30BHOqSnY1EhSj5l72TQevPCbTL74890upav6VqzgO189m9M+eRFjx2/J2R99G3u+8hC22W7nVctsPGYzTjr5dG77zc/XeP73Lj6H3fd+Ne/9yLksX7aMpUsXd7J8dYHDT5LUY+bfOJ1l85/udhldN3vWXWyx1Xa8ZKtt2WDECPY7+Ahuv+UXL1hm083GseNL92D4Bi/8G33xome5/57fcdBr3wTABiNGsPHoTTtVesf1ZWcvvarppiYi/iwiroqIeRHxeET8MCL+rJ3FSZLWX0/Nf5yxE7ZadXvsuC156onHm3ruvMfmMmbTsVz2pU/y2Q+/ha9f+BmWPGdSU3dlkppvA98HtgK2AS4HvjPYwhExJSKmR8T0qVMHOzu5JEmDGGiiSDR3JNsVK1bw0AP3ccgRJ/KJc7/HRhuN5Mf/9+IWF9g7Mjt76VVl5tREZn6j3+1vNk5YNaDMnAqs7GZ6eBNIknrR5uO35Ml5f1x1+8n5j7H5uJc09dyx47dk7Pgt2OllrwBgnwNfX+umRoUySc3PI+KMiNgxInaIiNOBqyNiXESMa1eBkqT1044v3Z3HH32IeY89zPJly5h+47Xstd8hTT13s7ETGDthK/748IMA3Hfnb9h6W2dM1F00exKsiJj9Ig9nZr7Y/xaTGkk96+oRk7pdwgtM/sYXGX/I/mw4YSxLHnuC+8+8gDmXXNHtsgAYfettHV3fnTN+yfcv+QJ9fX0cdNhxHPXmv+X6ay8H4JAj/oqnn5zHWaf/Nc8tXkhEsNHIjfn0edMYtfEmzJl9H1//ypmsWLaMCVtO5J2nnMnoTTo3WfjQPUZ17KyP507r7PTdDx8/rCfPaNl0UzNENjWSelavNTW9rNNNTZXZ1HRe03NqImJj4IPA9pk5JSJ2ASZl5o/aVp0kSVqrXt7NupPKzKm5BFgKvLpxey7w2ZZXJEmStA7KNDU7Z+Y5wDKAzFwM9GT8JEnS+sRdugtlmpqlETGKxvyYiNgZWNKWqiRJkkoqc5yaTwM/BraLiG8BBwHvbkdRkiSpeX1OqgFKNDWZ+ZOImAEcQDHsdGpmzmtbZZIkSSWU2fvpusx8LXD1APdJkqQu6eV5Lp201qYmIkYCGwMTImIsz08O3pTiHFCSJEld10xS817gNIoGZgZFU5PAAuBLbatMkiQ1xaSmsNa9nzLzvMzcCfhXYHLj+iXAA8DNba5PkiSpKWV26X5zZj4TEQcDrwcuBb7SlqokSVLT+jI7eulVZZqaFY2fRwMXZeYPgQ1bX5IkSVJ5ZY5T83BE/CfwOuDzEbER5ZoiSZLUBtnX7Qp6Q5mm5ETgWuDIzHwKGAd8pB1FSZIklVXm4HuLgGn9bj8KPNqOoiRJksoqM/wkSZJ6UPbw5N1Ock6MJEmqBZMaSZIqrs+JwoBJjSRJqgmTGkmSKs45NQWTGkmSVAsmNZIkVVyfQQ1gUiNJkmrCpEaSpIpLoxrApEaSJNWESY0kSRXnzk8FkxpJklQLJjWSJFVcn3NqAJMaSZJUEyY1kiRVnEcULpjUSJKkWrCpkSRJteDwkyRJFZd93a6gN5jUSJKkWjCpkSSp4vqcKAyY1EiSpJowqZEkqeLcpbtgUiNJkmrBpEaSpIrzNAkFkxpJklQLHUlqDj72+k6sphZuvOqQbpegGjnxQw92u4RK+Ptbb+t2CZWxcO/J3S6hOpbN7NiqnFJTMKmRJEm14JwaSZIqLp1TA5jUSJKkmjCpkSSp4jyicMGkRpIk1YJJjSRJFeecmoJJjSRJqgWbGkmSVAsOP0mSVHEOPxVMaiRJUi2Y1EiSVHEGNQWTGkmSVAsmNZIkVZxzagomNZIkqWMiYlxE/DQi7m/8HDvAMttFxM8j4t6IuDsiTm3mtW1qJEmquMzs6GWIzgCuy8xdgOsat1e3HPhQZr4cOAB4f0TstrYXtqmRJEmddBxwWeP6ZcAbV18gMx/NzN81ri8A7gUmru2FnVMjSVLF9XV4Tk1ETAGm9LtramZObfLpW2bmo1A0LxGxxVrWtSOwN/Cbtb2wTY0kSSql0cAM2sRExM+ArQZ46ONl1hMRmwA/AE7LzGfWtrxNjSRJFdeCeS4tlZmvG+yxiHgsIrZupDRbA48PstwIiobmW5k5rZn1OqdGkiR10pXAOxvX3wn8cPUFIiKA/wLuzcx/a/aFbWokSaq47MuOXoboc8DrI+J+4PWN20TENhFxTWOZg4C3A4dFxG2Ny1Fre2GHnyRJUsdk5hPAawe4/xHgqMb1G4Eo+9o2NZIkVZxHFC44/CRJkmrBpkaSJNWCw0+SJFVcX4/t0t0tJjWSJKkWTGokSao4JwoXTGokSVItmNRIklRxvXaahG4xqZEkSbVgUiNJUsX1OacGMKmRJEk1YVIjSVLFufdTwaRGkiTVgkmNJEkV595PBZMaSZJUCyY1kiRVXPb1dbuEnmBSI0mSasGkRpKkivM4NYXaNjWnTtmZA/cdz3NLVnDWeTP5/R+eXWOZM/7hZey6yxgA5jyymLP+4z4WP9fHmNEb8LFTJ7HNViNZuqyPs8+byeyHFnX6LUiV9O43jmPvl49iydLkwu/OY/bDS9dY5h/eNoGdt92I5SuSP8xZwtTLn2BFHxx76Ka8Zp9NABg2DLbdcgQnf3IOCxfXK1q/69ab+P7F59DX18fBr30TRx7/Ny94/I9zZ3Pplz/FnAfu5bi/PoXDj3vnqscWLXyGb1x4Jg8/NIuI4B3v/zQ7T9qr02+hJ+z51bPY4qhDWfr4E9yw97HdLkc9oJZNzQH7jmO7bTbmpPf+lt0njeHDf7cLUz586xrLnf+1P7Bo8QoATjl5Z044ZiLfvGIObz9xe+5/4Fn+6ay72X7bUXzwfbtw2ifu6PTbkCpn711HsdWEDfjA2Q+zy/Yb8Z4TxvPx8x9dY7kbZyzkgm/NA+DU/zOBw141hp/evICrfvEMV/3iGQD23W0UR//5prVraPpWrOA7Xz2b0z55EWPHb8nZH30be77yELbZbudVy2w8ZjNOOvl0bvvNz9d4/vcuPofd93417/3IuSxftoylSxd3svyeMveyaTx44TeZfPHnu12KekQt59S85oDx/Ph//gjA3TMXsMnoDRg/dsM1llvZ0ABstOEwVu4Rt+N2GzPjjicBeGjuYrbeYiRjNx/R/sKlittvj425YcZCAO5/aAmjRw1j8zHD11ju1vue/yKe9dBSxm++5jIH7T2am25d2L5iu2T2rLvYYqvteMlW27LBiBHsd/AR3H7LL16wzKabjWPHl+7B8A1e+Hfn4kXPcv89v+Og174JgA1GjGDj0Zt2qvSeM//G6Syb/3S3y+gJmdnRS6+qZVMzYfxGPD5vyarbjz+xhAnj12xqAD526iSu/PqB7LDtxlzxo4cBmDV7IX9+4EsAePkuY9hyi5FsMX6j9hcuVdy4zYYz76nlq24/8fRyxm22ZsOy0vBh8Jp9R3PbfS9MGzYcEUzedRS/vqN+w75PzX+csRO2WnV77LgteeqJx5t67rzH5jJm07Fc9qVP8tkPv4WvX/gZljy3/iY10uqabmoiYnhEvCEiPhARH1x5aWdx6yoGunOQxvLs82byxnfdzP/OXchrDy4amW9e8RBjNtmAS87blxOOncj9DyxgxYre7UylXjHQZ+/FPjnvOWE89z6whPtmL3nB/fvuPoqZs5fUbugJgIH+yo0Bf2utYcWKFTz0wH0ccsSJfOLc77HRRiP58f+9uMUFqoqyLzt66VVl5tRcBTwH3Ams9TdNREwBpgDs/IoPsdUO7Z3EdfxR23DsEVsDcO/9C9hiwvPJyhbjN2Le/DUnK67U1wfX/fJPvPX47bjmusdYtHgFZ583c9Xjl3/tVTzy2HPtK16qsCMOGsNrX1VMuP/DnCVM2HwDZlI0KeM324Ann14x4PPefPhmbLrJcKZeumZKcdDk0dxYw6EngM3Hb8mT8/646vaT8x9j83Evaeq5Y8dvydjxW7DTy14BwD4Hvt6mRuqnTFOzbWbu2ezCmTkVmApw8LHXt72tm3bNI0y75hEADtxvHCccM5Gf3fAndp80hmcXLeeJJ9dsaiZuPZKHHy2alYP2H89Dc4uoe5PRw3luSR/LlyfHHr4Vt9/91Avm30h63rU3LeDamxYAsPfLR3HkQWO46daF7LL9Rix6ro+nFqz52TnsVZuw16RRnPmVx9YILkaNDHbbeSQXfHteJ8rvuB1fujuPP/oQ8x57mM3HbcH0G6/l5NPOauq5m42dwNgJW/HHhx9kq4k7ct+dv2Hrbf+szRWrCno5PemkMk3Nf0fE4Zn5k7ZV0yI3T5/PgfuN43tT91+1S/dKX/jUHnzugt8z/8mlfPy0XRm98XAiglmzn+XcC+8HYIdtR/OJD06irw8efGghnzv/9916K1Kl3HrvYvZ5+SjO/9hEli4rdule6Yz3bMF/fv8JnnxmBX97wnj+9ORy/vUDRbr6mzsX8oOfFhM+93/FaG6f+RxLltbzl/Tw4Rtw0nvO4Lx/+Tv6+vo46LDj2Gb7l3L9tZcDcMgRf8XTT87jrNP/mucWLyQiuO5H3+LT501j1MabcNLJH+W/zvsnVixbxoQtJ/LOU87s8jvqnsnf+CLjD9mfDSeM5bDZ13P/mRcw55Irul2WuiiancUcEW8CvkkxD2cZxfB5ZuZap953IqmpixuvOqTbJahGTvzQg90uoRL+/t1bdruEyli49+Rul1AZRy+b2dxkqRY4/gOzOvo9O+38l3bsvZVRJqn5InAgcGf28v5ckiRpvVSmqbkfuMuGRpKk3uKcmkKZpuZR4BcR8d/Aqv0vM/PfWl6VJElSSWWamtmNy4aNiyRJ6gEmNYWmm5rM/Ew7C5EkSRqKppuaiPg5AxwcNDMPa2lFkiSpFKe7FsoMP3243/WRwAnA8kGWlSRJ6qgyw08zVrvrpoi4vsX1SJKkkvr6anietHVQZvhpXL+bw4D9gK0GWVySJKmjygw/zaCYUxMURxR+EDi5DTVJkiSVVqap+Sjw48x8JiL+GdgHWNSesiRJUrPcpbswrMSyn2g0NAcDrwcuBb7SlqokSZJKKtPUrGj8PBq4KDN/iAfhkySp6zL7OnrpVWWamocj4j+BE4FrImKjks+XJElqmzJzak4EjgTOzcynImJr4CPtKUuSJDXLOTWFMsepWQRM63f7UYqTXEqSJHVdmaRGkiT1IJOagnNiJElSLZjUSJJUcX09vEdSJ5nUSJKkWjCpkSSp4pxTUzCpkSRJtWBSI0lSxWWfc2rApEaSJNWESY0kSRXnnJqCSY0kSaoFmxpJklQLDj9JklRx6cH3AJMaSZJUEyY1kiRVXJ8ThQGTGkmSVBMmNZIkVZwH3yuY1EiSpFowqZEkqeI8+F7BpEaSJNWCSY0kSRXncWoKJjWSJKkWTGokSao459QUTGokSVItmNRIklRxHqemYFIjSZJqITLXz3G4iJiSmVO7XUcVuK2a43ZqntuqOW6n5ridtNL6nNRM6XYBFeK2ao7bqXluq+a4nZrjdhKwfjc1kiSpRmxqJElSLazPTY3jr81zWzXH7dQ8t1Vz3E7NcTsJWI8nCkuSpHpZn5MaSZJUIzY1kiSpFmxqJHVERDzb7Rp6WUScFhEbd7sOqcoq1dS040MfEW+MiN3W4XlviIgzWlnLOtSwY0TcVWL5d0XENk0s86Uh1nVmRLxuKK+h9UNEDO92DT3kNMCmRhqCSjU1tOdD/0ZgwKYmIgY9N1ZmXpmZn2txLe32LuBFm5pWyMxPZubP2r2eVoqIf46I+yLipxHxnYj4cET8bUTcEhG3R8QPVjbUEXFpRHwlIn4eEQ9ExCERcXFE3BsRl/Z7zWcj4vMRMSMifhYR+0fELxrPeUNjmR0j4pcR8bvG5dVd2gQdExGHNrbdt4E7u11PN0TE6Ii4uvF/666I+BTFZ/PnEfHzxjKHR8TNjf8Xl0fEJo37H2z8v/pt4/LSbr6XdhpgO72l8f4nNB7fLyJ+0bj+6Yi4LCJ+0ljm+Ig4JyLujIgfR8SIrr4ZdUTPNjWd+NA3vkDeAHwhIm6LiJ0bXzpnRcT1wKkRcWxE/CYibm18MW3ZeO6qRKPxJXd+RPyq8YX15g5sopU2aHyQ74iIKyJi44j4ZOPL+K6ImBqFNwP7Ad9qvNdREfHKRs23N7bTmMZrbtP4JXB/RJwz2IojYnjjvd/V+MXxj437L42INzd+4dzWuNwZEdl4fOfG689ofKHv2vat9CIiYj/gBGBv4HiK7QQwLTNfmZl7AfcCJ/d72ljgMOAfgauAfwd2B14REZMby4wGfpGZ+wILgM8CrwfeBJzZWOZx4PWZuQ/wFuD8drzHHrQ/8PHMLJ2S1sSRwCOZuVdm7gH8B/AI8BeZ+ReNL+1PAK9r/N+YDnyw3/Ofycz9gS81nltXq2+nH69l+Z2Bo4HjgG8CP8/MVwCLG/er5nq2qaEDH/rM/BVwJfCRzJycmX9oPLR5Zh6SmV8EbgQOyMy9ge8Cpw9S79bAwcAxQCcTnEnA1MzcE3gG+HvgS40v4z2AUcAxmXkFxTZ6W2ZOBlYA3wNObXxpv47igw8wmeIL9hXAWyJiu0HWPRmYmJl7NH5xXNL/wcyc3tiukyl+GZ3beGgq8A+NL/sPAxcObRMM2cHADzNzcWYuoGhSAPZoNF13Am+jaFpWuiqL4yHcCTyWmXdmZh9wN7BjY5mlPP9L+E7g+sxc1ri+cpkRwFcb67icQVLDGvptZs7udhFddCfwusYfX6/JzKdXe/wAiv8LN0XEbcA7gR36Pf6dfj8PbHexXbS27bS6/+73GRvOCz9/O7avTPWKQYdXesCdwLkR8XngR5n5y4jo/3j/Dz3AhsDN/R7v/6H/95Lr/l6/69sC34uIrRvrGOwX8f9rfKndszLN6ZA5mXlT4/o3gQ8AsyPidIqhunEUX7RXrfa8ScCjmXkLQGY+A9DYltet/OUREfdQ/DKdM8C6HwD+LCIuAK4GfjJQgRFxIrAPcHgjTXs1cHm/f8+NSr7nVotB7r8UeGNm3h4R7wIO7ffYksbPvn7XV95e+blals8fCGrVcpnZF88Pbf4j8BiwF8UfGc+t87uoloXdLqCbMvP3EbEvcBRwdkSs/tkJ4KeZ+dbBXmKQ67UyyHZazvN/kI9c7Sn9P2Orf/56+ftOLdKzSU1m/h7Yl6K5OTsiPrnaIis/9JMbl90ys//wwFA+9P1/4V5AkXy8Angva36IVur/xTbYl2Q7rP7ekiL5eHOj5q8ycM0xwHNX6v9eVjDIL4PMfJLiy/gXwPuBr62xkojdgc8AJ2XmCor/c0/1+3ebnJkvH6SOTrkRODYiRjaarpUx9Rjg0cZY/NvatO7NKJrLPuDtFH9dquaimLC/KDO/SZFg7kMxRLlyCPjXwEErh84bw8ov6/cSb+n3s/8fc7UyyHZ6kOK7AYphY2mVnm1qOvih7/+aA9kMeLhx/Z2l3kRnbB8RK+Pnt1J8QQPMa3xB95/f0/+93kcxd+aVABExJl5kYvRAGkOAwzLzB8A/U/wb9X98M4ohu3dk5p9gVSI0OyL+qrFMRMReZdbbao206krgdmAaxTDd0xTv6TfATym2VztcCLwzIn4NvIz1PMFYj7wC+G1jaOnjFPOtpgL/HRE/b3xe3gV8JyLuoPh913/u2UYR8RvgVIq0r64G2k6fAc6LiF9S/NElrdKzp0mIiCOAL1DEhsuAv6MYO34/xV+2fxERhwGf5/nhi09k5pUR8SDF/I6jKBq3t2bmrEHWcxBFmrGEogH4L+DDmTm98fhxFMNXD1P8YnllZh7aGI7YLzNPiWKPlx815q0QEc9m5iat3B6D1L4jcA1wA8WQzv0Uf+3/E3ASxV80c4D/zcxPR8QJwFkUc2cOBPagSKJGNe57HcU22C8zT2ms40fAuZn5iwHWvxfFdl7ZHH8sM/975fagmCh7AcUwFQCZOTkidgK+QjEPaQTw3cw8ky6KiE0y89ko9nC6AZiSmb/rZk3SQBq/3/bLzHndrkXqNT3b1AyFH3qVFcXuxbtRDNVdlplnd7kkaUD+fpMGZ1MjSZJqoZZNzUAi4uPAX6129+WZ+a/dqKdqGuP3q++l9PbMXC8PniZJ6j3rTVMjSZLqrWf3fpIkSSrDpkaSJNWCTY0kSaoFmxpJklQL/x9GY3OfJz7iYAAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] @@ -3335,28 +3111,12 @@ ], "source": [ "fig, ax = plt.subplots(figsize=(10,10)) \n", - "sns.heatmap(df_DQN.corr()[df_DQN.corr() > 0.05], annot = True, fmt='.2g',cmap= 'coolwarm', ax=ax)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### DoubleDQN" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "metadata": {}, - "outputs": [], - "source": [ - "df_DoubleDQN = df[df[\"algo\"] == \"DoubleDQN\"].copy()" + "sns.heatmap(df_DQN.corr()[abs(df_DQN.corr()) > 0.05], annot = True, fmt='.2g',cmap= 'coolwarm', ax=ax)" ] }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 24, "metadata": {}, "outputs": [ { @@ -3380,186 +3140,910 @@ " \n", " \n", " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " step\n", + " sum\n", + " \n", + " \n", + " algo\n", " step_train\n", " batch_size\n", " gamma\n", + " greedy_exploration\n", + " network\n", + " optimizer\n", " lr\n", - " step\n", - " max\n", - " min\n", - " avg\n", - " sum\n", + " memories\n", + " max_size\n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", - " step_train\n", - " 1.000000e+00\n", - " 6.254152e-18\n", - " -3.338836e-15\n", - " -1.947613e-17\n", - " 5.190625e-20\n", - " NaN\n", - " NaN\n", - " NaN\n", - " -2.016153e-01\n", + " DQN\n", + " 1.0\n", + " 32.0\n", + " 1.00\n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 2048\n", + " 21\n", + " 21\n", + " 21\n", + " \n", + " \n", + " 64.0\n", + " 1.00\n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 2048\n", + " 20\n", + " 20\n", + " 20\n", + " \n", + " \n", + " 0.99\n", + " AdaptativeEpsilonGreedy-0.3-0.1-30000-0\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 2048\n", + " 17\n", + " 17\n", + " 17\n", + " \n", + " \n", + " EpsilonGreedy-0.6\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 2048\n", + " 15\n", + " 15\n", + " 15\n", + " \n", + " \n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0001\n", + " ExperienceReplay\n", + " 2048\n", + " 14\n", + " 14\n", + " 14\n", + " \n", + " \n", + " EpsilonGreedy-0.6\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 512\n", + " 14\n", + " 14\n", + " 14\n", + " \n", + " \n", + " 32.0\n", + " 1.00\n", + " EpsilonGreedy-0.6\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 2048\n", + " 14\n", + " 14\n", + " 14\n", + " \n", + " \n", + " 64.0\n", + " 1.00\n", + " EpsilonGreedy-0.6\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 2048\n", + " 13\n", + " 13\n", + " 13\n", + " \n", + " \n", + " 0.99\n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 2048\n", + " 11\n", + " 11\n", + " 11\n", + " \n", + " \n", + " 32.0\n", + " 1.00\n", + " AdaptativeEpsilonGreedy-0.3-0.1-30000-0\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 2048\n", + " 11\n", + " 11\n", + " 11\n", + " \n", + " \n", + " EpsilonGreedy-0.6\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 512\n", + " 11\n", + " 11\n", + " 11\n", + " \n", + " \n", + " 64.0\n", + " 1.00\n", + " EpsilonGreedy-0.6\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 512\n", + " 10\n", + " 10\n", + " 10\n", + " \n", + " \n", + " 0.95\n", + " AdaptativeEpsilonGreedy-0.3-0.1-30000-0\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0001\n", + " ExperienceReplay\n", + " 2048\n", + " 8\n", + " 8\n", + " 8\n", + " \n", + " \n", + " EpsilonGreedy-0.6\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0001\n", + " ExperienceReplay\n", + " 2048\n", + " 7\n", + " 7\n", + " 7\n", + " \n", + " \n", + " 32.0\n", + " 0.95\n", + " EpsilonGreedy-0.6\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0001\n", + " ExperienceReplay\n", + " 2048\n", + " 6\n", + " 6\n", + " 6\n", + " \n", + " \n", + " 64.0\n", + " 0.95\n", + " EpsilonGreedy-0.6\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 512\n", + " 6\n", + " 6\n", + " 6\n", + " \n", + " \n", + " 0.0001\n", + " ExperienceReplay\n", + " 512\n", + " 5\n", + " 5\n", + " 5\n", + " \n", + " \n", + " EpsilonGreedy-0.1\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 2048\n", + " 5\n", + " 5\n", + " 5\n", + " \n", + " \n", + " 1.00\n", + " EpsilonGreedy-0.1\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 2048\n", + " 5\n", + " 5\n", + " 5\n", + " \n", + " \n", + " 0.99\n", + " EpsilonGreedy-0.6\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0001\n", + " ExperienceReplay\n", + " 2048\n", + " 4\n", + " 4\n", + " 4\n", + " \n", + " \n", + " 1.00\n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0001\n", + " ExperienceReplay\n", + " 2048\n", + " 4\n", + " 4\n", + " 4\n", + " \n", + " \n", + " 0.95\n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 512\n", + " 4\n", + " 4\n", + " 4\n", + " \n", + " \n", + " AdaptativeEpsilonGreedy-0.3-0.1-30000-0\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 512\n", + " 4\n", + " 4\n", + " 4\n", + " \n", + " \n", + " 32.0\n", + " 0.95\n", + " EpsilonGreedy-0.1\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 2048\n", + " 4\n", + " 4\n", + " 4\n", + " \n", + " \n", + " 0.99\n", + " EpsilonGreedy-0.6\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 2048\n", + " 4\n", + " 4\n", + " 4\n", + " \n", + " \n", + " 64.0\n", + " 0.95\n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 2048\n", + " 3\n", + " 3\n", + " 3\n", + " \n", + " \n", + " 0.99\n", + " EpsilonGreedy-0.6\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0001\n", + " ExperienceReplay\n", + " 512\n", + " 3\n", + " 3\n", + " 3\n", + " \n", + " \n", + " 32.0\n", + " 0.95\n", + " EpsilonGreedy-0.6\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 512\n", + " 3\n", + " 3\n", + " 3\n", + " \n", + " \n", + " EpsilonGreedy-0.1\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 512\n", + " 3\n", + " 3\n", + " 3\n", + " \n", + " \n", + " 1.00\n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0001\n", + " ExperienceReplay\n", + " 512\n", + " 3\n", + " 3\n", + " 3\n", + " \n", + " \n", + " 0.99\n", + " EpsilonGreedy-0.6\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0001\n", + " ExperienceReplay\n", + " 2048\n", + " 3\n", + " 3\n", + " 3\n", + " \n", + " \n", + " 0.95\n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 2048\n", + " 3\n", + " 3\n", + " 3\n", + " \n", + " \n", + " 64.0\n", + " 0.95\n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0001\n", + " ExperienceReplay\n", + " 512\n", + " 2\n", + " 2\n", + " 2\n", + " \n", + " \n", + " 32.0\n", + " 0.95\n", + " EpsilonGreedy-0.6\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0001\n", + " ExperienceReplay\n", + " 512\n", + " 2\n", + " 2\n", + " 2\n", + " \n", + " \n", + " 0.99\n", + " EpsilonGreedy-0.6\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0001\n", + " ExperienceReplay\n", + " 512\n", + " 2\n", + " 2\n", + " 2\n", + " \n", + " \n", + " 64.0\n", + " 0.95\n", + " AdaptativeEpsilonGreedy-0.3-0.1-30000-0\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0001\n", + " ExperienceReplay\n", + " 512\n", + " 2\n", + " 2\n", + " 2\n", + " \n", + " \n", + " 0.99\n", + " AdaptativeEpsilonGreedy-0.3-0.1-30000-0\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0001\n", + " ExperienceReplay\n", + " 2048\n", + " 2\n", + " 2\n", + " 2\n", + " \n", + " \n", + " 32.0\n", + " 0.99\n", + " AdaptativeEpsilonGreedy-0.3-0.1-30000-0\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 2048\n", + " 2\n", + " 2\n", + " 2\n", + " \n", + " \n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 512\n", + " 1\n", + " 1\n", + " 1\n", + " \n", + " \n", + " 64.0\n", + " 1.00\n", + " EpsilonGreedy-0.6\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0001\n", + " ExperienceReplay\n", + " 2048\n", + " 1\n", + " 1\n", + " 1\n", + " \n", + " \n", + " 32.0\n", + " 0.95\n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0001\n", + " ExperienceReplay\n", + " 2048\n", + " 1\n", + " 1\n", + " 1\n", + " \n", + " \n", + " 0.99\n", + " AdaptativeEpsilonGreedy-0.3-0.1-30000-0\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0001\n", + " ExperienceReplay\n", + " 512\n", + " 1\n", + " 1\n", + " 1\n", + " \n", + " \n", + " 1.00\n", + " EpsilonGreedy-0.6\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0001\n", + " ExperienceReplay\n", + " 2048\n", + " 1\n", + " 1\n", + " 1\n", + " \n", + " \n", + " 0.99\n", + " EpsilonGreedy-0.1\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 2048\n", + " 1\n", + " 1\n", + " 1\n", + " \n", + " \n", + " 64.0\n", + " 0.99\n", + " AdaptativeEpsilonGreedy-0.3-0.1-30000-0\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0001\n", + " ExperienceReplay\n", + " 512\n", + " 1\n", + " 1\n", + " 1\n", + " \n", + " \n", + " 32.0\n", + " 1.00\n", + " EpsilonGreedy-0.1\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 512\n", + " 1\n", + " 1\n", + " 1\n", " \n", " \n", - " batch_size\n", - " 6.254152e-18\n", - " 1.000000e+00\n", - " -5.096548e-15\n", - " 1.309431e-16\n", - " 0.000000e+00\n", - " NaN\n", - " NaN\n", - " NaN\n", - " 1.556979e-01\n", + " 64.0\n", + " 0.95\n", + " EpsilonGreedy-0.6\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 2048\n", + " 1\n", + " 1\n", + " 1\n", " \n", " \n", - " gamma\n", - " -3.338836e-15\n", - " -5.096548e-15\n", - " 1.000000e+00\n", - " 1.070395e-14\n", - " -1.030754e-16\n", - " NaN\n", - " NaN\n", - " NaN\n", - " 2.291126e-16\n", + " EpsilonGreedy-0.1\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 512\n", + " 1\n", + " 1\n", + " 1\n", " \n", " \n", - " lr\n", - " -1.947613e-17\n", - " 1.309431e-16\n", - " 1.070395e-14\n", - " 1.000000e+00\n", - " -3.622535e-19\n", - " NaN\n", - " NaN\n", - " NaN\n", - " -1.461935e-01\n", + " 32.0\n", + " 1.00\n", + " AdaptativeEpsilonGreedy-0.3-0.1-30000-0\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 512\n", + " 1\n", + " 1\n", + " 1\n", " \n", " \n", - " step\n", - " 5.190625e-20\n", - " 0.000000e+00\n", - " -1.030754e-16\n", - " -3.622535e-19\n", - " 1.000000e+00\n", - " NaN\n", - " NaN\n", - " NaN\n", - " 1.829836e-01\n", + " 64.0\n", + " 0.95\n", + " EpsilonGreedy-0.1\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0001\n", + " ExperienceReplay\n", + " 512\n", + " 1\n", + " 1\n", + " 1\n", " \n", " \n", - " max\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0001\n", + " ExperienceReplay\n", + " 2048\n", + " 1\n", + " 1\n", + " 1\n", " \n", " \n", - " min\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", + " 32.0\n", + " 0.95\n", + " AdaptativeEpsilonGreedy-0.3-0.1-30000-0\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0001\n", + " ExperienceReplay\n", + " 512\n", + " 1\n", + " 1\n", + " 1\n", " \n", " \n", - " avg\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", + " 64.0\n", + " 0.95\n", + " AdaptativeEpsilonGreedy-0.3-0.1-30000-0\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 2048\n", + " 1\n", + " 1\n", + " 1\n", " \n", " \n", - " sum\n", - " -2.016153e-01\n", - " 1.556979e-01\n", - " 2.291126e-16\n", - " -1.461935e-01\n", - " 1.829836e-01\n", - " NaN\n", - " NaN\n", - " NaN\n", - " 1.000000e+00\n", + " 32.0\n", + " 0.95\n", + " AdaptativeEpsilonGreedy-0.3-0.1-30000-0\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0001\n", + " ExperienceReplay\n", + " 2048\n", + " 1\n", + " 1\n", + " 1\n", " \n", " \n", "\n", "" ], "text/plain": [ - " step_train batch_size gamma lr \\\n", - "step_train 1.000000e+00 6.254152e-18 -3.338836e-15 -1.947613e-17 \n", - "batch_size 6.254152e-18 1.000000e+00 -5.096548e-15 1.309431e-16 \n", - "gamma -3.338836e-15 -5.096548e-15 1.000000e+00 1.070395e-14 \n", - "lr -1.947613e-17 1.309431e-16 1.070395e-14 1.000000e+00 \n", - "step 5.190625e-20 0.000000e+00 -1.030754e-16 -3.622535e-19 \n", - "max NaN NaN NaN NaN \n", - "min NaN NaN NaN NaN \n", - "avg NaN NaN NaN NaN \n", - "sum -2.016153e-01 1.556979e-01 2.291126e-16 -1.461935e-01 \n", + " \\\n", + "algo step_train batch_size gamma greedy_exploration network optimizer lr memories max_size \n", + "DQN 1.0 32.0 1.00 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 2048 21 \n", + " 64.0 1.00 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 2048 20 \n", + " 0.99 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 2048 17 \n", + " EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 2048 15 \n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 2048 14 \n", + " EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 512 14 \n", + " 32.0 1.00 EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 2048 14 \n", + " 64.0 1.00 EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 2048 13 \n", + " 0.99 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 2048 11 \n", + " 32.0 1.00 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 2048 11 \n", + " EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 512 11 \n", + " 64.0 1.00 EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 512 10 \n", + " 0.95 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 2048 8 \n", + " EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 2048 7 \n", + " 32.0 0.95 EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 2048 6 \n", + " 64.0 0.95 EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 512 6 \n", + " 0.0001 ExperienceReplay 512 5 \n", + " EpsilonGreedy-0.1 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 2048 5 \n", + " 1.00 EpsilonGreedy-0.1 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 2048 5 \n", + " 0.99 EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 2048 4 \n", + " 1.00 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 2048 4 \n", + " 0.95 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 512 4 \n", + " AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 512 4 \n", + " 32.0 0.95 EpsilonGreedy-0.1 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 2048 4 \n", + " 0.99 EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 2048 4 \n", + " 64.0 0.95 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 2048 3 \n", + " 0.99 EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 512 3 \n", + " 32.0 0.95 EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 512 3 \n", + " EpsilonGreedy-0.1 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 512 3 \n", + " 1.00 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 512 3 \n", + " 0.99 EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 2048 3 \n", + " 0.95 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 2048 3 \n", + " 64.0 0.95 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 512 2 \n", + " 32.0 0.95 EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 512 2 \n", + " 0.99 EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 512 2 \n", + " 64.0 0.95 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 512 2 \n", + " 0.99 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 2048 2 \n", + " 32.0 0.99 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 2048 2 \n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 512 1 \n", + " 64.0 1.00 EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 2048 1 \n", + " 32.0 0.95 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 2048 1 \n", + " 0.99 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 512 1 \n", + " 1.00 EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 2048 1 \n", + " 0.99 EpsilonGreedy-0.1 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 2048 1 \n", + " 64.0 0.99 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 512 1 \n", + " 32.0 1.00 EpsilonGreedy-0.1 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 512 1 \n", + " 64.0 0.95 EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 2048 1 \n", + " EpsilonGreedy-0.1 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 512 1 \n", + " 32.0 1.00 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 512 1 \n", + " 64.0 0.95 EpsilonGreedy-0.1 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 512 1 \n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 2048 1 \n", + " 32.0 0.95 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 512 1 \n", + " 64.0 0.95 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 2048 1 \n", + " 32.0 0.95 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 2048 1 \n", "\n", - " step max min avg sum \n", - "step_train 5.190625e-20 NaN NaN NaN -2.016153e-01 \n", - "batch_size 0.000000e+00 NaN NaN NaN 1.556979e-01 \n", - "gamma -1.030754e-16 NaN NaN NaN 2.291126e-16 \n", - "lr -3.622535e-19 NaN NaN NaN -1.461935e-01 \n", - "step 1.000000e+00 NaN NaN NaN 1.829836e-01 \n", - "max NaN NaN NaN NaN NaN \n", - "min NaN NaN NaN NaN NaN \n", - "avg NaN NaN NaN NaN NaN \n", - "sum 1.829836e-01 NaN NaN NaN 1.000000e+00 " + " step \\\n", + "algo step_train batch_size gamma greedy_exploration network optimizer lr memories max_size \n", + "DQN 1.0 32.0 1.00 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 2048 21 \n", + " 64.0 1.00 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 2048 20 \n", + " 0.99 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 2048 17 \n", + " EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 2048 15 \n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 2048 14 \n", + " EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 512 14 \n", + " 32.0 1.00 EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 2048 14 \n", + " 64.0 1.00 EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 2048 13 \n", + " 0.99 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 2048 11 \n", + " 32.0 1.00 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 2048 11 \n", + " EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 512 11 \n", + " 64.0 1.00 EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 512 10 \n", + " 0.95 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 2048 8 \n", + " EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 2048 7 \n", + " 32.0 0.95 EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 2048 6 \n", + " 64.0 0.95 EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 512 6 \n", + " 0.0001 ExperienceReplay 512 5 \n", + " EpsilonGreedy-0.1 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 2048 5 \n", + " 1.00 EpsilonGreedy-0.1 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 2048 5 \n", + " 0.99 EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 2048 4 \n", + " 1.00 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 2048 4 \n", + " 0.95 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 512 4 \n", + " AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 512 4 \n", + " 32.0 0.95 EpsilonGreedy-0.1 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 2048 4 \n", + " 0.99 EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 2048 4 \n", + " 64.0 0.95 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 2048 3 \n", + " 0.99 EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 512 3 \n", + " 32.0 0.95 EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 512 3 \n", + " EpsilonGreedy-0.1 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 512 3 \n", + " 1.00 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 512 3 \n", + " 0.99 EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 2048 3 \n", + " 0.95 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 2048 3 \n", + " 64.0 0.95 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 512 2 \n", + " 32.0 0.95 EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 512 2 \n", + " 0.99 EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 512 2 \n", + " 64.0 0.95 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 512 2 \n", + " 0.99 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 2048 2 \n", + " 32.0 0.99 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 2048 2 \n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 512 1 \n", + " 64.0 1.00 EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 2048 1 \n", + " 32.0 0.95 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 2048 1 \n", + " 0.99 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 512 1 \n", + " 1.00 EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 2048 1 \n", + " 0.99 EpsilonGreedy-0.1 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 2048 1 \n", + " 64.0 0.99 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 512 1 \n", + " 32.0 1.00 EpsilonGreedy-0.1 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 512 1 \n", + " 64.0 0.95 EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 2048 1 \n", + " EpsilonGreedy-0.1 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 512 1 \n", + " 32.0 1.00 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 512 1 \n", + " 64.0 0.95 EpsilonGreedy-0.1 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 512 1 \n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 2048 1 \n", + " 32.0 0.95 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 512 1 \n", + " 64.0 0.95 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 2048 1 \n", + " 32.0 0.95 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 2048 1 \n", + "\n", + " sum \n", + "algo step_train batch_size gamma greedy_exploration network optimizer lr memories max_size \n", + "DQN 1.0 32.0 1.00 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 2048 21 \n", + " 64.0 1.00 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 2048 20 \n", + " 0.99 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 2048 17 \n", + " EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 2048 15 \n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 2048 14 \n", + " EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 512 14 \n", + " 32.0 1.00 EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 2048 14 \n", + " 64.0 1.00 EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 2048 13 \n", + " 0.99 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 2048 11 \n", + " 32.0 1.00 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 2048 11 \n", + " EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 512 11 \n", + " 64.0 1.00 EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 512 10 \n", + " 0.95 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 2048 8 \n", + " EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 2048 7 \n", + " 32.0 0.95 EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 2048 6 \n", + " 64.0 0.95 EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 512 6 \n", + " 0.0001 ExperienceReplay 512 5 \n", + " EpsilonGreedy-0.1 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 2048 5 \n", + " 1.00 EpsilonGreedy-0.1 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 2048 5 \n", + " 0.99 EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 2048 4 \n", + " 1.00 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 2048 4 \n", + " 0.95 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 512 4 \n", + " AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 512 4 \n", + " 32.0 0.95 EpsilonGreedy-0.1 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 2048 4 \n", + " 0.99 EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 2048 4 \n", + " 64.0 0.95 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 2048 3 \n", + " 0.99 EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 512 3 \n", + " 32.0 0.95 EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 512 3 \n", + " EpsilonGreedy-0.1 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 512 3 \n", + " 1.00 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 512 3 \n", + " 0.99 EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 2048 3 \n", + " 0.95 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 2048 3 \n", + " 64.0 0.95 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 512 2 \n", + " 32.0 0.95 EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 512 2 \n", + " 0.99 EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 512 2 \n", + " 64.0 0.95 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 512 2 \n", + " 0.99 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 2048 2 \n", + " 32.0 0.99 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 2048 2 \n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 512 1 \n", + " 64.0 1.00 EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 2048 1 \n", + " 32.0 0.95 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 2048 1 \n", + " 0.99 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 512 1 \n", + " 1.00 EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 2048 1 \n", + " 0.99 EpsilonGreedy-0.1 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 2048 1 \n", + " 64.0 0.99 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 512 1 \n", + " 32.0 1.00 EpsilonGreedy-0.1 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 512 1 \n", + " 64.0 0.95 EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 2048 1 \n", + " EpsilonGreedy-0.1 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 512 1 \n", + " 32.0 1.00 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 512 1 \n", + " 64.0 0.95 EpsilonGreedy-0.1 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 512 1 \n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 2048 1 \n", + " 32.0 0.95 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 512 1 \n", + " 64.0 0.95 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 2048 1 \n", + " 32.0 0.95 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 2048 1 " ] }, - "execution_count": 21, + "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "df_DoubleDQN.corr()" + "columns = [\"algo\",\"step_train\",\"batch_size\",\"gamma\",\"greedy_exploration\",\"network\",\"optimizer\",\"lr\",\"memories\",\"max_size\"]\n", + "df_DQN[df_DQN[\"sum\"] >= 500].groupby(by=columns, observed=True).count().sort_values(by=['sum'], ascending=False)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### DoubleDQN" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### SimpleNetwork" ] }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 25, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(500.0, 8.0)" - ] - }, - "execution_count": 22, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "max(df_DoubleDQN[\"sum\"]), min(df_DoubleDQN[\"sum\"])" + "df_DoubleDQN = df[df[\"algo\"] == \"DoubleDQN\"].copy()\n", + "df_DoubleDQN = df_DoubleDQN[df_DoubleDQN[\"network\"] == \"SimpleNetwork\"]" ] }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 26, "metadata": {}, "outputs": [ { @@ -3589,483 +4073,230 @@ " gamma\n", " greedy_exploration\n", " network\n", - " \n", - " optimizer\n", - " lr\n", - " memories\n", - " max_size\n", - " step\n", - " max\n", - " min\n", - " avg\n", - " sum\n", - " \n", - " \n", - " \n", - " \n", - " 2117\n", - " DoubleDQN\n", - " 1.0\n", - " 64.0\n", - " 0.99\n", - " EpsilonGreedy-0.6\n", - " SimpleNetwork\n", - " \n", - " Adam\n", - " 0.0001\n", - " ExperienceReplay\n", - " 128\n", - " 332.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 500.0\n", - " \n", - " \n", - " 1953\n", - " DoubleDQN\n", - " 1.0\n", - " 32.0\n", - " 0.99\n", - " EpsilonGreedy-0.6\n", - " SimpleNetwork\n", - " \n", - " Adam\n", - " 0.0010\n", - " ExperienceReplay\n", - " 128\n", - " 498.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 500.0\n", - " \n", - " \n", - " 2118\n", - " DoubleDQN\n", - " 1.0\n", - " 64.0\n", - " 0.99\n", - " EpsilonGreedy-0.6\n", - " SimpleNetwork\n", - " \n", - " Adam\n", - " 0.0001\n", - " ExperienceReplay\n", - " 128\n", - " 498.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 500.0\n", - " \n", - " \n", - " 2089\n", - " DoubleDQN\n", - " 1.0\n", - " 64.0\n", - " 0.99\n", - " EpsilonGreedy-0.1\n", - " SimpleNetwork\n", - " \n", - " Adam\n", - " 0.0010\n", - " ExperienceReplay\n", - " 128\n", - " 500.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 386.0\n", - " \n", - " \n", - " 1938\n", - " DoubleDQN\n", - " 1.0\n", - " 32.0\n", - " 0.99\n", - " EpsilonGreedy-0.6\n", - " SimpleNetwork\n", - " \n", - " Adam\n", - " 0.0001\n", - " ExperienceReplay\n", - " 128\n", - " 498.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 380.0\n", - " \n", - " \n", - " 2029\n", - " DoubleDQN\n", - " 1.0\n", - " 64.0\n", - " 0.99\n", - " AdaptativeEpsilonGreedy-0.8-0.2-50000-0\n", - " SimpleNetwork\n", - " \n", - " Adam\n", - " 0.0001\n", - " ExperienceReplay\n", - " 128\n", - " 500.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 379.0\n", - " \n", - " \n", - " 2008\n", - " DoubleDQN\n", - " 1.0\n", - " 64.0\n", - " 0.99\n", - " AdaptativeEpsilonGreedy-0.3-0.1-50000-0\n", - " SimpleNetwork\n", - " \n", - " Adam\n", - " 0.0010\n", - " ExperienceReplay\n", - " 32\n", - " 498.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 369.0\n", - " \n", - " \n", - " 1962\n", - " DoubleDQN\n", - " 1.0\n", - " 32.0\n", - " 0.99\n", - " EpsilonGreedy-0.6\n", - " SimpleNetwork\n", - " \n", - " Adam\n", - " 0.0010\n", - " ExperienceReplay\n", - " 32\n", - " 332.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 347.0\n", - " \n", - " \n", - " 3132\n", - " DoubleDQN\n", - " 4.0\n", - " 64.0\n", - " 0.99\n", - " AdaptativeEpsilonGreedy-0.8-0.2-50000-0\n", - " SimpleNetwork\n", - " \n", - " Adam\n", - " 0.0010\n", - " ExperienceReplay\n", - " 32\n", - " 332.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 337.0\n", + " \n", + " optimizer\n", + " lr\n", + " memories\n", + " max_size\n", + " step\n", + " sum\n", " \n", + " \n", + " \n", " \n", - " 3134\n", + " 12930\n", " DoubleDQN\n", - " 4.0\n", - " 64.0\n", - " 0.99\n", - " AdaptativeEpsilonGreedy-0.8-0.2-50000-0\n", - " SimpleNetwork\n", - " \n", - " Adam\n", - " 0.0010\n", - " ExperienceReplay\n", - " 32\n", - " 500.0\n", - " 1.0\n", - " 1.0\n", " 1.0\n", - " 326.0\n", - " \n", - " \n", - " 3044\n", - " DoubleDQN\n", - " 4.0\n", " 32.0\n", - " 0.99\n", - " EpsilonGreedy-0.6\n", + " 1.00\n", + " EpsilonGreedy-0.1\n", " SimpleNetwork\n", " \n", " Adam\n", " 0.0010\n", " ExperienceReplay\n", - " 32\n", + " 512\n", + " 30.0\n", " 500.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 316.0\n", " \n", " \n", - " 1818\n", + " 11412\n", " DoubleDQN\n", " 1.0\n", " 32.0\n", " 0.99\n", - " AdaptativeEpsilonGreedy-0.3-0.1-50000-0\n", + " EpsilonGreedy-0.1\n", " SimpleNetwork\n", " \n", " Adam\n", " 0.0010\n", " ExperienceReplay\n", - " 128\n", - " 498.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 314.0\n", + " 2048\n", + " 40.0\n", + " 500.0\n", " \n", " \n", - " 3133\n", + " 24122\n", " DoubleDQN\n", - " 4.0\n", + " 32.0\n", " 64.0\n", " 0.99\n", - " AdaptativeEpsilonGreedy-0.8-0.2-50000-0\n", + " AdaptativeEpsilonGreedy-0.3-0.1-30000-0\n", " SimpleNetwork\n", " \n", " Adam\n", - " 0.0010\n", + " 0.1000\n", " ExperienceReplay\n", - " 32\n", - " 498.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 310.0\n", + " 2048\n", + " 40.0\n", + " 500.0\n", " \n", " \n", - " 2138\n", + " 14796\n", " DoubleDQN\n", " 1.0\n", " 64.0\n", - " 0.99\n", + " 0.95\n", " EpsilonGreedy-0.6\n", " SimpleNetwork\n", " \n", " Adam\n", " 0.0010\n", " ExperienceReplay\n", - " 16\n", - " 498.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 309.0\n", + " 512\n", + " 90.0\n", + " 500.0\n", " \n", " \n", - " 1939\n", + " 17742\n", " DoubleDQN\n", " 1.0\n", - " 32.0\n", - " 0.99\n", + " 64.0\n", + " 1.00\n", " EpsilonGreedy-0.6\n", " SimpleNetwork\n", " \n", " Adam\n", - " 0.0001\n", + " 0.0010\n", " ExperienceReplay\n", - " 128\n", + " 2048\n", + " 100.0\n", " 500.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 294.0\n", " \n", " \n", - " 2028\n", + " 14802\n", " DoubleDQN\n", " 1.0\n", " 64.0\n", - " 0.99\n", - " AdaptativeEpsilonGreedy-0.8-0.2-50000-0\n", + " 0.95\n", + " EpsilonGreedy-0.6\n", " SimpleNetwork\n", " \n", " Adam\n", - " 0.0001\n", + " 0.0010\n", " ExperienceReplay\n", - " 128\n", - " 498.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 293.0\n", + " 512\n", + " 150.0\n", + " 500.0\n", " \n", " \n", - " 3223\n", + " 15580\n", " DoubleDQN\n", - " 4.0\n", + " 1.0\n", " 64.0\n", " 0.99\n", - " EpsilonGreedy-0.6\n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0\n", " SimpleNetwork\n", " \n", " Adam\n", - " 0.0010\n", + " 0.1000\n", " ExperienceReplay\n", - " 32\n", - " 498.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 293.0\n", + " 2048\n", + " 180.0\n", + " 500.0\n", " \n", " \n", - " 1951\n", + " 17719\n", " DoubleDQN\n", " 1.0\n", - " 32.0\n", - " 0.99\n", + " 64.0\n", + " 1.00\n", " EpsilonGreedy-0.6\n", " SimpleNetwork\n", " \n", " Adam\n", - " 0.0010\n", + " 0.0001\n", " ExperienceReplay\n", - " 128\n", - " 166.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 287.0\n", + " 512\n", + " 180.0\n", + " 500.0\n", " \n", " \n", - " 1952\n", + " 15922\n", " DoubleDQN\n", " 1.0\n", - " 32.0\n", + " 64.0\n", " 0.99\n", - " EpsilonGreedy-0.6\n", + " EpsilonGreedy-0.1\n", " SimpleNetwork\n", " \n", " Adam\n", " 0.0010\n", " ExperienceReplay\n", - " 128\n", - " 332.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 283.0\n", + " 512\n", + " 190.0\n", + " 500.0\n", " \n", " \n", - " 1957\n", + " 17720\n", " DoubleDQN\n", " 1.0\n", - " 32.0\n", - " 0.99\n", + " 64.0\n", + " 1.00\n", " EpsilonGreedy-0.6\n", " SimpleNetwork\n", " \n", " Adam\n", - " 0.0010\n", + " 0.0001\n", " ExperienceReplay\n", - " 16\n", - " 332.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 283.0\n", + " 512\n", + " 190.0\n", + " 500.0\n", " \n", " \n", "\n", "" ], "text/plain": [ - " algo step_train batch_size gamma \\\n", - "2117 DoubleDQN 1.0 64.0 0.99 \n", - "1953 DoubleDQN 1.0 32.0 0.99 \n", - "2118 DoubleDQN 1.0 64.0 0.99 \n", - "2089 DoubleDQN 1.0 64.0 0.99 \n", - "1938 DoubleDQN 1.0 32.0 0.99 \n", - "2029 DoubleDQN 1.0 64.0 0.99 \n", - "2008 DoubleDQN 1.0 64.0 0.99 \n", - "1962 DoubleDQN 1.0 32.0 0.99 \n", - "3132 DoubleDQN 4.0 64.0 0.99 \n", - "3134 DoubleDQN 4.0 64.0 0.99 \n", - "3044 DoubleDQN 4.0 32.0 0.99 \n", - "1818 DoubleDQN 1.0 32.0 0.99 \n", - "3133 DoubleDQN 4.0 64.0 0.99 \n", - "2138 DoubleDQN 1.0 64.0 0.99 \n", - "1939 DoubleDQN 1.0 32.0 0.99 \n", - "2028 DoubleDQN 1.0 64.0 0.99 \n", - "3223 DoubleDQN 4.0 64.0 0.99 \n", - "1951 DoubleDQN 1.0 32.0 0.99 \n", - "1952 DoubleDQN 1.0 32.0 0.99 \n", - "1957 DoubleDQN 1.0 32.0 0.99 \n", + " algo step_train batch_size gamma \\\n", + "12930 DoubleDQN 1.0 32.0 1.00 \n", + "11412 DoubleDQN 1.0 32.0 0.99 \n", + "24122 DoubleDQN 32.0 64.0 0.99 \n", + "14796 DoubleDQN 1.0 64.0 0.95 \n", + "17742 DoubleDQN 1.0 64.0 1.00 \n", + "14802 DoubleDQN 1.0 64.0 0.95 \n", + "15580 DoubleDQN 1.0 64.0 0.99 \n", + "17719 DoubleDQN 1.0 64.0 1.00 \n", + "15922 DoubleDQN 1.0 64.0 0.99 \n", + "17720 DoubleDQN 1.0 64.0 1.00 \n", "\n", - " greedy_exploration network optimizer \\\n", - "2117 EpsilonGreedy-0.6 SimpleNetwork Adam \n", - "1953 EpsilonGreedy-0.6 SimpleNetwork Adam \n", - "2118 EpsilonGreedy-0.6 SimpleNetwork Adam \n", - "2089 EpsilonGreedy-0.1 SimpleNetwork Adam \n", - "1938 EpsilonGreedy-0.6 SimpleNetwork Adam \n", - "2029 AdaptativeEpsilonGreedy-0.8-0.2-50000-0 SimpleNetwork Adam \n", - "2008 AdaptativeEpsilonGreedy-0.3-0.1-50000-0 SimpleNetwork Adam \n", - "1962 EpsilonGreedy-0.6 SimpleNetwork Adam \n", - "3132 AdaptativeEpsilonGreedy-0.8-0.2-50000-0 SimpleNetwork Adam \n", - "3134 AdaptativeEpsilonGreedy-0.8-0.2-50000-0 SimpleNetwork Adam \n", - "3044 EpsilonGreedy-0.6 SimpleNetwork Adam \n", - "1818 AdaptativeEpsilonGreedy-0.3-0.1-50000-0 SimpleNetwork Adam \n", - "3133 AdaptativeEpsilonGreedy-0.8-0.2-50000-0 SimpleNetwork Adam \n", - "2138 EpsilonGreedy-0.6 SimpleNetwork Adam \n", - "1939 EpsilonGreedy-0.6 SimpleNetwork Adam \n", - "2028 AdaptativeEpsilonGreedy-0.8-0.2-50000-0 SimpleNetwork Adam \n", - "3223 EpsilonGreedy-0.6 SimpleNetwork Adam \n", - "1951 EpsilonGreedy-0.6 SimpleNetwork Adam \n", - "1952 EpsilonGreedy-0.6 SimpleNetwork Adam \n", - "1957 EpsilonGreedy-0.6 SimpleNetwork Adam \n", + " greedy_exploration network optimizer \\\n", + "12930 EpsilonGreedy-0.1 SimpleNetwork Adam \n", + "11412 EpsilonGreedy-0.1 SimpleNetwork Adam \n", + "24122 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleNetwork Adam \n", + "14796 EpsilonGreedy-0.6 SimpleNetwork Adam \n", + "17742 EpsilonGreedy-0.6 SimpleNetwork Adam \n", + "14802 EpsilonGreedy-0.6 SimpleNetwork Adam \n", + "15580 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleNetwork Adam \n", + "17719 EpsilonGreedy-0.6 SimpleNetwork Adam \n", + "15922 EpsilonGreedy-0.1 SimpleNetwork Adam \n", + "17720 EpsilonGreedy-0.6 SimpleNetwork Adam \n", "\n", - " lr memories max_size step max min avg sum \n", - "2117 0.0001 ExperienceReplay 128 332.0 1.0 1.0 1.0 500.0 \n", - "1953 0.0010 ExperienceReplay 128 498.0 1.0 1.0 1.0 500.0 \n", - "2118 0.0001 ExperienceReplay 128 498.0 1.0 1.0 1.0 500.0 \n", - "2089 0.0010 ExperienceReplay 128 500.0 1.0 1.0 1.0 386.0 \n", - "1938 0.0001 ExperienceReplay 128 498.0 1.0 1.0 1.0 380.0 \n", - "2029 0.0001 ExperienceReplay 128 500.0 1.0 1.0 1.0 379.0 \n", - "2008 0.0010 ExperienceReplay 32 498.0 1.0 1.0 1.0 369.0 \n", - "1962 0.0010 ExperienceReplay 32 332.0 1.0 1.0 1.0 347.0 \n", - "3132 0.0010 ExperienceReplay 32 332.0 1.0 1.0 1.0 337.0 \n", - "3134 0.0010 ExperienceReplay 32 500.0 1.0 1.0 1.0 326.0 \n", - "3044 0.0010 ExperienceReplay 32 500.0 1.0 1.0 1.0 316.0 \n", - "1818 0.0010 ExperienceReplay 128 498.0 1.0 1.0 1.0 314.0 \n", - "3133 0.0010 ExperienceReplay 32 498.0 1.0 1.0 1.0 310.0 \n", - "2138 0.0010 ExperienceReplay 16 498.0 1.0 1.0 1.0 309.0 \n", - "1939 0.0001 ExperienceReplay 128 500.0 1.0 1.0 1.0 294.0 \n", - "2028 0.0001 ExperienceReplay 128 498.0 1.0 1.0 1.0 293.0 \n", - "3223 0.0010 ExperienceReplay 32 498.0 1.0 1.0 1.0 293.0 \n", - "1951 0.0010 ExperienceReplay 128 166.0 1.0 1.0 1.0 287.0 \n", - "1952 0.0010 ExperienceReplay 128 332.0 1.0 1.0 1.0 283.0 \n", - "1957 0.0010 ExperienceReplay 16 332.0 1.0 1.0 1.0 283.0 " + " lr memories max_size step sum \n", + "12930 0.0010 ExperienceReplay 512 30.0 500.0 \n", + "11412 0.0010 ExperienceReplay 2048 40.0 500.0 \n", + "24122 0.1000 ExperienceReplay 2048 40.0 500.0 \n", + "14796 0.0010 ExperienceReplay 512 90.0 500.0 \n", + "17742 0.0010 ExperienceReplay 2048 100.0 500.0 \n", + "14802 0.0010 ExperienceReplay 512 150.0 500.0 \n", + "15580 0.1000 ExperienceReplay 2048 180.0 500.0 \n", + "17719 0.0001 ExperienceReplay 512 180.0 500.0 \n", + "15922 0.0010 ExperienceReplay 512 190.0 500.0 \n", + "17720 0.0001 ExperienceReplay 512 190.0 500.0 " ] }, - "execution_count": 23, + "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "df_DoubleDQN.sort_values(by =[\"sum\",\"step\"], ascending = [False, True]).head(20)" + "df_DoubleDQN.sort_values(by =[\"sum\",\"step\"], ascending = [False, True]).head(10)" ] }, { "cell_type": "code", - "execution_count": 68, + "execution_count": 27, "metadata": {}, "outputs": [ { @@ -4074,13 +4305,13 @@ "" ] }, - "execution_count": 68, + "execution_count": 27, "metadata": {}, "output_type": "execute_result" }, { "data": { - "image/png": "\n", + "image/png": "\n", "text/plain": [ "
" ] @@ -4093,28 +4324,12 @@ ], "source": [ "fig, ax = plt.subplots(figsize=(10,10)) \n", - "sns.heatmap(df_DoubleDQN.corr()[df_DoubleDQN.corr() > 0.05], annot = True, fmt='.2g',cmap= 'coolwarm', ax=ax)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### DuelingDQN" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "metadata": {}, - "outputs": [], - "source": [ - "df_DuelingDQN = df[df[\"algo\"] == \"DuelingDQN\"].copy()" + "sns.heatmap(df_DQN.corr()[abs(df_DQN.corr()) > 0.05], annot = True, fmt='.2g',cmap= 'coolwarm', ax=ax)" ] }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 28, "metadata": {}, "outputs": [ { @@ -4138,186 +4353,373 @@ " \n", " \n", " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " step\n", + " sum\n", + " \n", + " \n", + " algo\n", " step_train\n", " batch_size\n", " gamma\n", + " greedy_exploration\n", + " network\n", + " optimizer\n", " lr\n", - " step\n", - " max\n", - " min\n", - " avg\n", - " sum\n", + " memories\n", + " max_size\n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", - " step_train\n", - " 1.000000e+00\n", - " 6.254152e-18\n", - " -3.338836e-15\n", - " -1.947613e-17\n", - " 5.190625e-20\n", - " NaN\n", - " NaN\n", - " NaN\n", - " -2.254267e-01\n", + " DoubleDQN\n", + " 1.0\n", + " 64.0\n", + " 0.99\n", + " EpsilonGreedy-0.1\n", + " SimpleNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 512\n", + " 4\n", + " 4\n", + " 4\n", + " \n", + " \n", + " 32.0\n", + " 1.00\n", + " AdaptativeEpsilonGreedy-0.3-0.1-30000-0\n", + " SimpleNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 512\n", + " 4\n", + " 4\n", + " 4\n", + " \n", + " \n", + " 64.0\n", + " 1.00\n", + " EpsilonGreedy-0.6\n", + " SimpleNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 2048\n", + " 3\n", + " 3\n", + " 3\n", + " \n", + " \n", + " 0.0001\n", + " ExperienceReplay\n", + " 512\n", + " 2\n", + " 2\n", + " 2\n", + " \n", + " \n", + " AdaptativeEpsilonGreedy-0.3-0.1-30000-0\n", + " SimpleNetwork\n", + " Adam\n", + " 0.1000\n", + " ExperienceReplay\n", + " 2048\n", + " 2\n", + " 2\n", + " 2\n", + " \n", + " \n", + " 0.95\n", + " AdaptativeEpsilonGreedy-0.3-0.1-30000-0\n", + " SimpleNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 512\n", + " 2\n", + " 2\n", + " 2\n", + " \n", + " \n", + " EpsilonGreedy-0.6\n", + " SimpleNetwork\n", + " Adam\n", + " 0.0001\n", + " ExperienceReplay\n", + " 2048\n", + " 2\n", + " 2\n", + " 2\n", + " \n", + " \n", + " 0.0010\n", + " ExperienceReplay\n", + " 512\n", + " 2\n", + " 2\n", + " 2\n", + " \n", + " \n", + " 1.00\n", + " AdaptativeEpsilonGreedy-0.3-0.1-30000-0\n", + " SimpleNetwork\n", + " Adam\n", + " 0.0001\n", + " ExperienceReplay\n", + " 2048\n", + " 1\n", + " 1\n", + " 1\n", " \n", " \n", - " batch_size\n", - " 6.254152e-18\n", - " 1.000000e+00\n", - " -5.096548e-15\n", - " 1.309431e-16\n", - " 0.000000e+00\n", - " NaN\n", - " NaN\n", - " NaN\n", - " 1.375869e-01\n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0\n", + " SimpleNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 512\n", + " 1\n", + " 1\n", + " 1\n", " \n", " \n", - " gamma\n", - " -3.338836e-15\n", - " -5.096548e-15\n", - " 1.000000e+00\n", - " 1.070395e-14\n", - " -1.030754e-16\n", - " NaN\n", - " NaN\n", - " NaN\n", - " -4.766649e-16\n", + " AdaptativeEpsilonGreedy-0.3-0.1-30000-0\n", + " SimpleNetwork\n", + " Adam\n", + " 0.1000\n", + " ExperienceReplay\n", + " 512\n", + " 1\n", + " 1\n", + " 1\n", " \n", " \n", - " lr\n", - " -1.947613e-17\n", - " 1.309431e-16\n", - " 1.070395e-14\n", - " 1.000000e+00\n", - " -3.622535e-19\n", - " NaN\n", - " NaN\n", - " NaN\n", - " -2.261230e-01\n", + " 0.0010\n", + " ExperienceReplay\n", + " 2048\n", + " 1\n", + " 1\n", + " 1\n", " \n", " \n", - " step\n", - " 5.190625e-20\n", - " 0.000000e+00\n", - " -1.030754e-16\n", - " -3.622535e-19\n", - " 1.000000e+00\n", - " NaN\n", - " NaN\n", - " NaN\n", - " 2.421773e-01\n", + " 32.0\n", + " 0.95\n", + " AdaptativeEpsilonGreedy-0.3-0.1-30000-0\n", + " SimpleNetwork\n", + " Adam\n", + " 0.0001\n", + " ExperienceReplay\n", + " 512\n", + " 1\n", + " 1\n", + " 1\n", " \n", " \n", - " max\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", + " 64.0\n", + " 0.99\n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0\n", + " SimpleNetwork\n", + " Adam\n", + " 0.1000\n", + " ExperienceReplay\n", + " 2048\n", + " 1\n", + " 1\n", + " 1\n", " \n", " \n", - " min\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", + " 32.0\n", + " 0.95\n", + " EpsilonGreedy-0.6\n", + " SimpleNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 2048\n", + " 1\n", + " 1\n", + " 1\n", " \n", " \n", - " avg\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", + " 1.00\n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0\n", + " SimpleNetwork\n", + " Adam\n", + " 0.1000\n", + " ExperienceReplay\n", + " 512\n", + " 1\n", + " 1\n", + " 1\n", " \n", " \n", - " sum\n", - " -2.254267e-01\n", - " 1.375869e-01\n", - " -4.766649e-16\n", - " -2.261230e-01\n", - " 2.421773e-01\n", - " NaN\n", - " NaN\n", - " NaN\n", - " 1.000000e+00\n", + " EpsilonGreedy-0.1\n", + " SimpleNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 512\n", + " 1\n", + " 1\n", + " 1\n", + " \n", + " \n", + " EpsilonGreedy-0.6\n", + " SimpleNetwork\n", + " Adam\n", + " 0.0001\n", + " ExperienceReplay\n", + " 2048\n", + " 1\n", + " 1\n", + " 1\n", + " \n", + " \n", + " 0.99\n", + " EpsilonGreedy-0.1\n", + " SimpleNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 2048\n", + " 1\n", + " 1\n", + " 1\n", + " \n", + " \n", + " 32.0\n", + " 64.0\n", + " 0.99\n", + " AdaptativeEpsilonGreedy-0.3-0.1-30000-0\n", + " SimpleNetwork\n", + " Adam\n", + " 0.1000\n", + " ExperienceReplay\n", + " 2048\n", + " 1\n", + " 1\n", + " 1\n", " \n", " \n", "\n", "" ], "text/plain": [ - " step_train batch_size gamma lr \\\n", - "step_train 1.000000e+00 6.254152e-18 -3.338836e-15 -1.947613e-17 \n", - "batch_size 6.254152e-18 1.000000e+00 -5.096548e-15 1.309431e-16 \n", - "gamma -3.338836e-15 -5.096548e-15 1.000000e+00 1.070395e-14 \n", - "lr -1.947613e-17 1.309431e-16 1.070395e-14 1.000000e+00 \n", - "step 5.190625e-20 0.000000e+00 -1.030754e-16 -3.622535e-19 \n", - "max NaN NaN NaN NaN \n", - "min NaN NaN NaN NaN \n", - "avg NaN NaN NaN NaN \n", - "sum -2.254267e-01 1.375869e-01 -4.766649e-16 -2.261230e-01 \n", + " \\\n", + "algo step_train batch_size gamma greedy_exploration network optimizer lr memories max_size \n", + "DoubleDQN 1.0 64.0 0.99 EpsilonGreedy-0.1 SimpleNetwork Adam 0.0010 ExperienceReplay 512 4 \n", + " 32.0 1.00 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleNetwork Adam 0.0010 ExperienceReplay 512 4 \n", + " 64.0 1.00 EpsilonGreedy-0.6 SimpleNetwork Adam 0.0010 ExperienceReplay 2048 3 \n", + " 0.0001 ExperienceReplay 512 2 \n", + " AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleNetwork Adam 0.1000 ExperienceReplay 2048 2 \n", + " 0.95 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleNetwork Adam 0.0010 ExperienceReplay 512 2 \n", + " EpsilonGreedy-0.6 SimpleNetwork Adam 0.0001 ExperienceReplay 2048 2 \n", + " 0.0010 ExperienceReplay 512 2 \n", + " 1.00 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleNetwork Adam 0.0001 ExperienceReplay 2048 1 \n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleNetwork Adam 0.0010 ExperienceReplay 512 1 \n", + " AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleNetwork Adam 0.1000 ExperienceReplay 512 1 \n", + " 0.0010 ExperienceReplay 2048 1 \n", + " 32.0 0.95 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleNetwork Adam 0.0001 ExperienceReplay 512 1 \n", + " 64.0 0.99 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleNetwork Adam 0.1000 ExperienceReplay 2048 1 \n", + " 32.0 0.95 EpsilonGreedy-0.6 SimpleNetwork Adam 0.0010 ExperienceReplay 2048 1 \n", + " 1.00 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleNetwork Adam 0.1000 ExperienceReplay 512 1 \n", + " EpsilonGreedy-0.1 SimpleNetwork Adam 0.0010 ExperienceReplay 512 1 \n", + " EpsilonGreedy-0.6 SimpleNetwork Adam 0.0001 ExperienceReplay 2048 1 \n", + " 0.99 EpsilonGreedy-0.1 SimpleNetwork Adam 0.0010 ExperienceReplay 2048 1 \n", + " 32.0 64.0 0.99 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleNetwork Adam 0.1000 ExperienceReplay 2048 1 \n", + "\n", + " step \\\n", + "algo step_train batch_size gamma greedy_exploration network optimizer lr memories max_size \n", + "DoubleDQN 1.0 64.0 0.99 EpsilonGreedy-0.1 SimpleNetwork Adam 0.0010 ExperienceReplay 512 4 \n", + " 32.0 1.00 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleNetwork Adam 0.0010 ExperienceReplay 512 4 \n", + " 64.0 1.00 EpsilonGreedy-0.6 SimpleNetwork Adam 0.0010 ExperienceReplay 2048 3 \n", + " 0.0001 ExperienceReplay 512 2 \n", + " AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleNetwork Adam 0.1000 ExperienceReplay 2048 2 \n", + " 0.95 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleNetwork Adam 0.0010 ExperienceReplay 512 2 \n", + " EpsilonGreedy-0.6 SimpleNetwork Adam 0.0001 ExperienceReplay 2048 2 \n", + " 0.0010 ExperienceReplay 512 2 \n", + " 1.00 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleNetwork Adam 0.0001 ExperienceReplay 2048 1 \n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleNetwork Adam 0.0010 ExperienceReplay 512 1 \n", + " AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleNetwork Adam 0.1000 ExperienceReplay 512 1 \n", + " 0.0010 ExperienceReplay 2048 1 \n", + " 32.0 0.95 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleNetwork Adam 0.0001 ExperienceReplay 512 1 \n", + " 64.0 0.99 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleNetwork Adam 0.1000 ExperienceReplay 2048 1 \n", + " 32.0 0.95 EpsilonGreedy-0.6 SimpleNetwork Adam 0.0010 ExperienceReplay 2048 1 \n", + " 1.00 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleNetwork Adam 0.1000 ExperienceReplay 512 1 \n", + " EpsilonGreedy-0.1 SimpleNetwork Adam 0.0010 ExperienceReplay 512 1 \n", + " EpsilonGreedy-0.6 SimpleNetwork Adam 0.0001 ExperienceReplay 2048 1 \n", + " 0.99 EpsilonGreedy-0.1 SimpleNetwork Adam 0.0010 ExperienceReplay 2048 1 \n", + " 32.0 64.0 0.99 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleNetwork Adam 0.1000 ExperienceReplay 2048 1 \n", "\n", - " step max min avg sum \n", - "step_train 5.190625e-20 NaN NaN NaN -2.254267e-01 \n", - "batch_size 0.000000e+00 NaN NaN NaN 1.375869e-01 \n", - "gamma -1.030754e-16 NaN NaN NaN -4.766649e-16 \n", - "lr -3.622535e-19 NaN NaN NaN -2.261230e-01 \n", - "step 1.000000e+00 NaN NaN NaN 2.421773e-01 \n", - "max NaN NaN NaN NaN NaN \n", - "min NaN NaN NaN NaN NaN \n", - "avg NaN NaN NaN NaN NaN \n", - "sum 2.421773e-01 NaN NaN NaN 1.000000e+00 " + " sum \n", + "algo step_train batch_size gamma greedy_exploration network optimizer lr memories max_size \n", + "DoubleDQN 1.0 64.0 0.99 EpsilonGreedy-0.1 SimpleNetwork Adam 0.0010 ExperienceReplay 512 4 \n", + " 32.0 1.00 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleNetwork Adam 0.0010 ExperienceReplay 512 4 \n", + " 64.0 1.00 EpsilonGreedy-0.6 SimpleNetwork Adam 0.0010 ExperienceReplay 2048 3 \n", + " 0.0001 ExperienceReplay 512 2 \n", + " AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleNetwork Adam 0.1000 ExperienceReplay 2048 2 \n", + " 0.95 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleNetwork Adam 0.0010 ExperienceReplay 512 2 \n", + " EpsilonGreedy-0.6 SimpleNetwork Adam 0.0001 ExperienceReplay 2048 2 \n", + " 0.0010 ExperienceReplay 512 2 \n", + " 1.00 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleNetwork Adam 0.0001 ExperienceReplay 2048 1 \n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleNetwork Adam 0.0010 ExperienceReplay 512 1 \n", + " AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleNetwork Adam 0.1000 ExperienceReplay 512 1 \n", + " 0.0010 ExperienceReplay 2048 1 \n", + " 32.0 0.95 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleNetwork Adam 0.0001 ExperienceReplay 512 1 \n", + " 64.0 0.99 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleNetwork Adam 0.1000 ExperienceReplay 2048 1 \n", + " 32.0 0.95 EpsilonGreedy-0.6 SimpleNetwork Adam 0.0010 ExperienceReplay 2048 1 \n", + " 1.00 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleNetwork Adam 0.1000 ExperienceReplay 512 1 \n", + " EpsilonGreedy-0.1 SimpleNetwork Adam 0.0010 ExperienceReplay 512 1 \n", + " EpsilonGreedy-0.6 SimpleNetwork Adam 0.0001 ExperienceReplay 2048 1 \n", + " 0.99 EpsilonGreedy-0.1 SimpleNetwork Adam 0.0010 ExperienceReplay 2048 1 \n", + " 32.0 64.0 0.99 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleNetwork Adam 0.1000 ExperienceReplay 2048 1 " ] }, - "execution_count": 25, + "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "df_DuelingDQN.corr()" + "columns = [\"algo\",\"step_train\",\"batch_size\",\"gamma\",\"greedy_exploration\",\"network\",\"optimizer\",\"lr\",\"memories\",\"max_size\"]\n", + "df_DoubleDQN[df_DoubleDQN[\"sum\"] >= 500].groupby(by=columns, observed=True).count().sort_values(by=['sum'], ascending=False)" ] }, { - "cell_type": "code", - "execution_count": 26, + "cell_type": "markdown", "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(500.0, 8.0)" - ] - }, - "execution_count": 26, - "metadata": {}, - "output_type": "execute_result" - } - ], "source": [ - "max(df_DuelingDQN[\"sum\"]), min(df_DuelingDQN[\"sum\"])" + "#### DuelingNetwork" ] }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 29, + "metadata": {}, + "outputs": [], + "source": [ + "df_DoubleDQN = df[df[\"algo\"] == \"DoubleDQN\"].copy()\n", + "df_DoubleDQN = df_DoubleDQN[df_DoubleDQN[\"network\"] == \"SimpleDuelingNetwork\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 30, "metadata": {}, "outputs": [ { @@ -4353,392 +4755,169 @@ " memories\n", " max_size\n", " step\n", - " max\n", - " min\n", - " avg\n", " sum\n", " \n", " \n", " \n", " \n", - " 6173\n", - " DuelingDQN\n", - " 4.0\n", - " 32.0\n", - " 0.99\n", - " AdaptativeEpsilonGreedy-0.8-0.2-50000-0\n", - " SimpleDuelingNetwork\n", - " \n", - " Adam\n", - " 0.0001\n", - " ExperienceReplay\n", - " 16\n", - " 498.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 500.0\n", - " \n", - " \n", - " 6448\n", - " DuelingDQN\n", - " 4.0\n", - " 64.0\n", - " 0.99\n", - " EpsilonGreedy-0.6\n", - " SimpleDuelingNetwork\n", - " \n", - " Adam\n", - " 0.0001\n", - " ExperienceReplay\n", - " 32\n", - " 498.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 500.0\n", - " \n", - " \n", - " 5087\n", - " DuelingDQN\n", + " 12340\n", + " DoubleDQN\n", " 1.0\n", " 32.0\n", - " 0.99\n", - " AdaptativeEpsilonGreedy-0.8-0.2-50000-0\n", - " SimpleDuelingNetwork\n", - " \n", - " Adam\n", - " 0.0001\n", - " ExperienceReplay\n", - " 128\n", - " 332.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 421.0\n", - " \n", - " \n", - " 6354\n", - " DuelingDQN\n", - " 4.0\n", - " 64.0\n", - " 0.99\n", - " AdaptativeEpsilonGreedy-0.8-0.2-50000-0\n", - " SimpleDuelingNetwork\n", - " \n", - " Adam\n", - " 0.0001\n", - " ExperienceReplay\n", - " 16\n", - " 500.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 386.0\n", - " \n", - " \n", - " 4932\n", - " DuelingDQN\n", - " 1.0\n", - " 1.0\n", - " 0.99\n", - " AdaptativeEpsilonGreedy-0.8-0.2-50000-0\n", + " 1.00\n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0\n", " SimpleDuelingNetwork\n", " \n", " Adam\n", " 0.0010\n", " ExperienceReplay\n", - " 32\n", - " 332.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 379.0\n", - " \n", - " \n", - " 4918\n", - " DuelingDQN\n", - " 1.0\n", - " 1.0\n", - " 0.99\n", - " AdaptativeEpsilonGreedy-0.8-0.2-50000-0\n", - " SimpleDuelingNetwork\n", - " \n", - " Adam\n", - " 0.0001\n", - " ExperienceReplay\n", - " 32\n", - " 498.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 376.0\n", - " \n", - " \n", - " 5274\n", - " DuelingDQN\n", - " 1.0\n", - " 64.0\n", - " 0.99\n", - " AdaptativeEpsilonGreedy-0.8-0.2-50000-0\n", - " SimpleDuelingNetwork\n", - " \n", - " Adam\n", - " 0.0001\n", - " ExperienceReplay\n", - " 16\n", + " 2048\n", + " 20.0\n", " 500.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 372.0\n", " \n", " \n", - " 6449\n", - " DuelingDQN\n", - " 4.0\n", + " 17645\n", + " DoubleDQN\n", + " 1.0\n", " 64.0\n", - " 0.99\n", + " 1.00\n", " EpsilonGreedy-0.6\n", " SimpleDuelingNetwork\n", " \n", " Adam\n", - " 0.0001\n", + " 0.1000\n", " ExperienceReplay\n", - " 32\n", + " 512\n", + " 60.0\n", " 500.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 337.0\n", - " \n", - " \n", - " 5017\n", - " DuelingDQN\n", - " 1.0\n", - " 1.0\n", - " 0.99\n", - " EpsilonGreedy-0.6\n", - " SimpleDuelingNetwork\n", - " \n", - " Adam\n", - " 0.0010\n", - " ExperienceReplay\n", - " 16\n", - " 332.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 335.0\n", " \n", " \n", - " 5358\n", - " DuelingDQN\n", + " 13775\n", + " DoubleDQN\n", " 1.0\n", " 64.0\n", - " 0.99\n", - " EpsilonGreedy-0.6\n", - " SimpleDuelingNetwork\n", - " \n", - " Adam\n", - " 0.0001\n", - " ExperienceReplay\n", - " 128\n", - " 498.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 330.0\n", - " \n", - " \n", - " 5187\n", - " DuelingDQN\n", - " 1.0\n", - " 32.0\n", - " 0.99\n", - " EpsilonGreedy-0.6\n", + " 0.95\n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0\n", " SimpleDuelingNetwork\n", " \n", " Adam\n", " 0.0001\n", " ExperienceReplay\n", - " 32\n", - " 332.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 329.0\n", + " 2048\n", + " 110.0\n", + " 500.0\n", " \n", " \n", - " 5088\n", - " DuelingDQN\n", + " 9747\n", + " DoubleDQN\n", " 1.0\n", " 32.0\n", - " 0.99\n", - " AdaptativeEpsilonGreedy-0.8-0.2-50000-0\n", + " 0.95\n", + " EpsilonGreedy-0.1\n", " SimpleDuelingNetwork\n", " \n", " Adam\n", - " 0.0001\n", + " 0.0010\n", " ExperienceReplay\n", - " 128\n", - " 498.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 328.0\n", + " 2048\n", + " 130.0\n", + " 500.0\n", " \n", " \n", - " 5094\n", - " DuelingDQN\n", + " 12320\n", + " DoubleDQN\n", " 1.0\n", " 32.0\n", - " 0.99\n", - " AdaptativeEpsilonGreedy-0.8-0.2-50000-0\n", + " 1.00\n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0\n", " SimpleDuelingNetwork\n", " \n", " Adam\n", " 0.0001\n", " ExperienceReplay\n", - " 16\n", + " 512\n", + " 130.0\n", " 500.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 317.0\n", " \n", " \n", - " 5177\n", - " DuelingDQN\n", + " 13777\n", + " DoubleDQN\n", " 1.0\n", - " 32.0\n", - " 0.99\n", - " EpsilonGreedy-0.6\n", + " 64.0\n", + " 0.95\n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0\n", " SimpleDuelingNetwork\n", " \n", " Adam\n", " 0.0001\n", " ExperienceReplay\n", - " 128\n", - " 332.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 307.0\n", + " 2048\n", + " 130.0\n", + " 500.0\n", " \n", " \n", - " 5179\n", - " DuelingDQN\n", + " 12321\n", + " DoubleDQN\n", " 1.0\n", " 32.0\n", - " 0.99\n", - " EpsilonGreedy-0.6\n", + " 1.00\n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0\n", " SimpleDuelingNetwork\n", " \n", " Adam\n", " 0.0001\n", " ExperienceReplay\n", - " 128\n", + " 512\n", + " 140.0\n", " 500.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 299.0\n", - " \n", - " \n", - " 6138\n", - " DuelingDQN\n", - " 4.0\n", - " 32.0\n", - " 0.99\n", - " AdaptativeEpsilonGreedy-0.3-0.1-50000-0\n", - " SimpleDuelingNetwork\n", - " \n", - " Adam\n", - " 0.0010\n", - " ExperienceReplay\n", - " 128\n", - " 498.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 291.0\n", " \n", " \n", - " 6188\n", - " DuelingDQN\n", - " 4.0\n", - " 32.0\n", - " 0.99\n", - " AdaptativeEpsilonGreedy-0.8-0.2-50000-0\n", - " SimpleDuelingNetwork\n", - " \n", - " Adam\n", - " 0.0010\n", - " ExperienceReplay\n", - " 16\n", - " 498.0\n", - " 1.0\n", - " 1.0\n", + " 14522\n", + " DoubleDQN\n", " 1.0\n", - " 278.0\n", - " \n", - " \n", - " 6374\n", - " DuelingDQN\n", - " 4.0\n", " 64.0\n", - " 0.99\n", - " AdaptativeEpsilonGreedy-0.8-0.2-50000-0\n", + " 0.95\n", + " EpsilonGreedy-0.6\n", " SimpleDuelingNetwork\n", " \n", " Adam\n", - " 0.0010\n", + " 0.0001\n", " ExperienceReplay\n", - " 32\n", + " 2048\n", + " 140.0\n", " 500.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 273.0\n", " \n", " \n", - " 5364\n", - " DuelingDQN\n", + " 16785\n", + " DoubleDQN\n", " 1.0\n", " 64.0\n", - " 0.99\n", - " EpsilonGreedy-0.6\n", + " 1.00\n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0\n", " SimpleDuelingNetwork\n", " \n", " Adam\n", " 0.0001\n", " ExperienceReplay\n", - " 16\n", + " 512\n", + " 140.0\n", " 500.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 261.0\n", " \n", " \n", - " 5367\n", - " DuelingDQN\n", + " 10803\n", + " DoubleDQN\n", " 1.0\n", - " 64.0\n", + " 32.0\n", " 0.99\n", - " EpsilonGreedy-0.6\n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0\n", " SimpleDuelingNetwork\n", " \n", " Adam\n", " 0.0001\n", " ExperienceReplay\n", - " 32\n", - " 332.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 260.0\n", + " 2048\n", + " 150.0\n", + " 500.0\n", " \n", " \n", "\n", @@ -4746,84 +4925,54 @@ ], "text/plain": [ " algo step_train batch_size gamma \\\n", - "6173 DuelingDQN 4.0 32.0 0.99 \n", - "6448 DuelingDQN 4.0 64.0 0.99 \n", - "5087 DuelingDQN 1.0 32.0 0.99 \n", - "6354 DuelingDQN 4.0 64.0 0.99 \n", - "4932 DuelingDQN 1.0 1.0 0.99 \n", - "4918 DuelingDQN 1.0 1.0 0.99 \n", - "5274 DuelingDQN 1.0 64.0 0.99 \n", - "6449 DuelingDQN 4.0 64.0 0.99 \n", - "5017 DuelingDQN 1.0 1.0 0.99 \n", - "5358 DuelingDQN 1.0 64.0 0.99 \n", - "5187 DuelingDQN 1.0 32.0 0.99 \n", - "5088 DuelingDQN 1.0 32.0 0.99 \n", - "5094 DuelingDQN 1.0 32.0 0.99 \n", - "5177 DuelingDQN 1.0 32.0 0.99 \n", - "5179 DuelingDQN 1.0 32.0 0.99 \n", - "6138 DuelingDQN 4.0 32.0 0.99 \n", - "6188 DuelingDQN 4.0 32.0 0.99 \n", - "6374 DuelingDQN 4.0 64.0 0.99 \n", - "5364 DuelingDQN 1.0 64.0 0.99 \n", - "5367 DuelingDQN 1.0 64.0 0.99 \n", + "12340 DoubleDQN 1.0 32.0 1.00 \n", + "17645 DoubleDQN 1.0 64.0 1.00 \n", + "13775 DoubleDQN 1.0 64.0 0.95 \n", + "9747 DoubleDQN 1.0 32.0 0.95 \n", + "12320 DoubleDQN 1.0 32.0 1.00 \n", + "13777 DoubleDQN 1.0 64.0 0.95 \n", + "12321 DoubleDQN 1.0 32.0 1.00 \n", + "14522 DoubleDQN 1.0 64.0 0.95 \n", + "16785 DoubleDQN 1.0 64.0 1.00 \n", + "10803 DoubleDQN 1.0 32.0 0.99 \n", "\n", - " greedy_exploration network \\\n", - "6173 AdaptativeEpsilonGreedy-0.8-0.2-50000-0 SimpleDuelingNetwork \n", - "6448 EpsilonGreedy-0.6 SimpleDuelingNetwork \n", - "5087 AdaptativeEpsilonGreedy-0.8-0.2-50000-0 SimpleDuelingNetwork \n", - "6354 AdaptativeEpsilonGreedy-0.8-0.2-50000-0 SimpleDuelingNetwork \n", - "4932 AdaptativeEpsilonGreedy-0.8-0.2-50000-0 SimpleDuelingNetwork \n", - "4918 AdaptativeEpsilonGreedy-0.8-0.2-50000-0 SimpleDuelingNetwork \n", - "5274 AdaptativeEpsilonGreedy-0.8-0.2-50000-0 SimpleDuelingNetwork \n", - "6449 EpsilonGreedy-0.6 SimpleDuelingNetwork \n", - "5017 EpsilonGreedy-0.6 SimpleDuelingNetwork \n", - "5358 EpsilonGreedy-0.6 SimpleDuelingNetwork \n", - "5187 EpsilonGreedy-0.6 SimpleDuelingNetwork \n", - "5088 AdaptativeEpsilonGreedy-0.8-0.2-50000-0 SimpleDuelingNetwork \n", - "5094 AdaptativeEpsilonGreedy-0.8-0.2-50000-0 SimpleDuelingNetwork \n", - "5177 EpsilonGreedy-0.6 SimpleDuelingNetwork \n", - "5179 EpsilonGreedy-0.6 SimpleDuelingNetwork \n", - "6138 AdaptativeEpsilonGreedy-0.3-0.1-50000-0 SimpleDuelingNetwork \n", - "6188 AdaptativeEpsilonGreedy-0.8-0.2-50000-0 SimpleDuelingNetwork \n", - "6374 AdaptativeEpsilonGreedy-0.8-0.2-50000-0 SimpleDuelingNetwork \n", - "5364 EpsilonGreedy-0.6 SimpleDuelingNetwork \n", - "5367 EpsilonGreedy-0.6 SimpleDuelingNetwork \n", + " greedy_exploration network \\\n", + "12340 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork \n", + "17645 EpsilonGreedy-0.6 SimpleDuelingNetwork \n", + "13775 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork \n", + "9747 EpsilonGreedy-0.1 SimpleDuelingNetwork \n", + "12320 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork \n", + "13777 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork \n", + "12321 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork \n", + "14522 EpsilonGreedy-0.6 SimpleDuelingNetwork \n", + "16785 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork \n", + "10803 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork \n", "\n", - " optimizer lr memories max_size step max min avg sum \n", - "6173 Adam 0.0001 ExperienceReplay 16 498.0 1.0 1.0 1.0 500.0 \n", - "6448 Adam 0.0001 ExperienceReplay 32 498.0 1.0 1.0 1.0 500.0 \n", - "5087 Adam 0.0001 ExperienceReplay 128 332.0 1.0 1.0 1.0 421.0 \n", - "6354 Adam 0.0001 ExperienceReplay 16 500.0 1.0 1.0 1.0 386.0 \n", - "4932 Adam 0.0010 ExperienceReplay 32 332.0 1.0 1.0 1.0 379.0 \n", - "4918 Adam 0.0001 ExperienceReplay 32 498.0 1.0 1.0 1.0 376.0 \n", - "5274 Adam 0.0001 ExperienceReplay 16 500.0 1.0 1.0 1.0 372.0 \n", - "6449 Adam 0.0001 ExperienceReplay 32 500.0 1.0 1.0 1.0 337.0 \n", - "5017 Adam 0.0010 ExperienceReplay 16 332.0 1.0 1.0 1.0 335.0 \n", - "5358 Adam 0.0001 ExperienceReplay 128 498.0 1.0 1.0 1.0 330.0 \n", - "5187 Adam 0.0001 ExperienceReplay 32 332.0 1.0 1.0 1.0 329.0 \n", - "5088 Adam 0.0001 ExperienceReplay 128 498.0 1.0 1.0 1.0 328.0 \n", - "5094 Adam 0.0001 ExperienceReplay 16 500.0 1.0 1.0 1.0 317.0 \n", - "5177 Adam 0.0001 ExperienceReplay 128 332.0 1.0 1.0 1.0 307.0 \n", - "5179 Adam 0.0001 ExperienceReplay 128 500.0 1.0 1.0 1.0 299.0 \n", - "6138 Adam 0.0010 ExperienceReplay 128 498.0 1.0 1.0 1.0 291.0 \n", - "6188 Adam 0.0010 ExperienceReplay 16 498.0 1.0 1.0 1.0 278.0 \n", - "6374 Adam 0.0010 ExperienceReplay 32 500.0 1.0 1.0 1.0 273.0 \n", - "5364 Adam 0.0001 ExperienceReplay 16 500.0 1.0 1.0 1.0 261.0 \n", - "5367 Adam 0.0001 ExperienceReplay 32 332.0 1.0 1.0 1.0 260.0 " + " optimizer lr memories max_size step sum \n", + "12340 Adam 0.0010 ExperienceReplay 2048 20.0 500.0 \n", + "17645 Adam 0.1000 ExperienceReplay 512 60.0 500.0 \n", + "13775 Adam 0.0001 ExperienceReplay 2048 110.0 500.0 \n", + "9747 Adam 0.0010 ExperienceReplay 2048 130.0 500.0 \n", + "12320 Adam 0.0001 ExperienceReplay 512 130.0 500.0 \n", + "13777 Adam 0.0001 ExperienceReplay 2048 130.0 500.0 \n", + "12321 Adam 0.0001 ExperienceReplay 512 140.0 500.0 \n", + "14522 Adam 0.0001 ExperienceReplay 2048 140.0 500.0 \n", + "16785 Adam 0.0001 ExperienceReplay 512 140.0 500.0 \n", + "10803 Adam 0.0001 ExperienceReplay 2048 150.0 500.0 " ] }, - "execution_count": 27, + "execution_count": 30, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "df_DuelingDQN.sort_values(by =[\"sum\",\"step\"], ascending = [False, True]).head(20)" + "df_DoubleDQN.sort_values(by =[\"sum\",\"step\"], ascending = [False, True]).head(10)" ] }, { "cell_type": "code", - "execution_count": 69, + "execution_count": 31, "metadata": {}, "outputs": [ { @@ -4832,13 +4981,13 @@ "" ] }, - "execution_count": 69, + "execution_count": 31, "metadata": {}, "output_type": "execute_result" }, { "data": { - "image/png": "\n", + "image/png": "\n", "text/plain": [ "
" ] @@ -4851,28 +5000,12 @@ ], "source": [ "fig, ax = plt.subplots(figsize=(10,10)) \n", - "sns.heatmap(df_DuelingDQN.corr()[df_DuelingDQN.corr() > 0.05], annot = True, fmt='.2g',cmap= 'coolwarm', ax=ax)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### CategoricalDQN" - ] - }, - { - "cell_type": "code", - "execution_count": 28, - "metadata": {}, - "outputs": [], - "source": [ - "df_CategoricalDQN = df[df[\"algo\"] == \"CategoricalDQN\"].copy()" + "sns.heatmap(df_DoubleDQN.corr()[abs(df_DoubleDQN.corr()) > 0.05], annot = True, fmt='.2g',cmap= 'coolwarm', ax=ax)" ] }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 32, "metadata": {}, "outputs": [ { @@ -4896,186 +5029,700 @@ " \n", " \n", " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " step\n", + " sum\n", + " \n", + " \n", + " algo\n", " step_train\n", " batch_size\n", " gamma\n", + " greedy_exploration\n", + " network\n", + " optimizer\n", " lr\n", - " step\n", - " max\n", - " min\n", - " avg\n", - " sum\n", + " memories\n", + " max_size\n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", - " step_train\n", - " 1.000000e+00\n", - " 6.254152e-18\n", - " -3.338836e-15\n", - " -1.947613e-17\n", - " 5.190625e-20\n", - " NaN\n", - " NaN\n", - " NaN\n", - " 6.569585e-02\n", + " DoubleDQN\n", + " 1.0\n", + " 64.0\n", + " 0.99\n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 512\n", + " 6\n", + " 6\n", + " 6\n", + " \n", + " \n", + " 32.0\n", + " 0.99\n", + " EpsilonGreedy-0.6\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0001\n", + " ExperienceReplay\n", + " 2048\n", + " 6\n", + " 6\n", + " 6\n", + " \n", + " \n", + " 0.95\n", + " EpsilonGreedy-0.6\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0001\n", + " ExperienceReplay\n", + " 2048\n", + " 5\n", + " 5\n", + " 5\n", + " \n", + " \n", + " 64.0\n", + " 1.00\n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 2048\n", + " 4\n", + " 4\n", + " 4\n", + " \n", + " \n", + " 0.99\n", + " EpsilonGreedy-0.6\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0001\n", + " ExperienceReplay\n", + " 512\n", + " 4\n", + " 4\n", + " 4\n", + " \n", + " \n", + " 0.95\n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0001\n", + " ExperienceReplay\n", + " 2048\n", + " 4\n", + " 4\n", + " 4\n", + " \n", + " \n", + " AdaptativeEpsilonGreedy-0.3-0.1-30000-0\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 2048\n", + " 4\n", + " 4\n", + " 4\n", + " \n", + " \n", + " EpsilonGreedy-0.6\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0001\n", + " ExperienceReplay\n", + " 2048\n", + " 3\n", + " 3\n", + " 3\n", + " \n", + " \n", + " EpsilonGreedy-0.1\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0001\n", + " ExperienceReplay\n", + " 2048\n", + " 3\n", + " 3\n", + " 3\n", + " \n", + " \n", + " 32.0\n", + " 0.95\n", + " AdaptativeEpsilonGreedy-0.3-0.1-30000-0\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0001\n", + " ExperienceReplay\n", + " 512\n", + " 2\n", + " 2\n", + " 2\n", + " \n", + " \n", + " 1.00\n", + " AdaptativeEpsilonGreedy-0.3-0.1-30000-0\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 512\n", + " 2\n", + " 2\n", + " 2\n", + " \n", + " \n", + " 64.0\n", + " 1.00\n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0001\n", + " ExperienceReplay\n", + " 512\n", + " 2\n", + " 2\n", + " 2\n", + " \n", + " \n", + " 0.99\n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0001\n", + " ExperienceReplay\n", + " 2048\n", + " 2\n", + " 2\n", + " 2\n", + " \n", + " \n", + " 32.0\n", + " 0.95\n", + " AdaptativeEpsilonGreedy-0.3-0.1-30000-0\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 2048\n", + " 2\n", + " 2\n", + " 2\n", + " \n", + " \n", + " 1.00\n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0001\n", + " ExperienceReplay\n", + " 512\n", + " 2\n", + " 2\n", + " 2\n", + " \n", + " \n", + " 64.0\n", + " 1.00\n", + " EpsilonGreedy-0.6\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.1000\n", + " ExperienceReplay\n", + " 512\n", + " 2\n", + " 2\n", + " 2\n", + " \n", + " \n", + " 32.0\n", + " 0.99\n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0001\n", + " ExperienceReplay\n", + " 512\n", + " 2\n", + " 2\n", + " 2\n", + " \n", + " \n", + " 2048\n", + " 2\n", + " 2\n", + " 2\n", + " \n", + " \n", + " 64.0\n", + " 0.95\n", + " EpsilonGreedy-0.6\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 512\n", + " 1\n", + " 1\n", + " 1\n", " \n", " \n", - " batch_size\n", - " 6.254152e-18\n", - " 1.000000e+00\n", - " -5.096548e-15\n", - " 1.309431e-16\n", - " 0.000000e+00\n", - " NaN\n", - " NaN\n", - " NaN\n", - " -4.414052e-02\n", + " 1.00\n", + " EpsilonGreedy-0.6\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0001\n", + " ExperienceReplay\n", + " 512\n", + " 1\n", + " 1\n", + " 1\n", " \n", " \n", - " gamma\n", - " -3.338836e-15\n", - " -5.096548e-15\n", - " 1.000000e+00\n", - " 1.070395e-14\n", - " -1.030754e-16\n", - " NaN\n", - " NaN\n", - " NaN\n", - " 4.203546e-16\n", + " EpsilonGreedy-0.1\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0001\n", + " ExperienceReplay\n", + " 512\n", + " 1\n", + " 1\n", + " 1\n", " \n", " \n", - " lr\n", - " -1.947613e-17\n", - " 1.309431e-16\n", - " 1.070395e-14\n", - " 1.000000e+00\n", - " -3.622535e-19\n", - " NaN\n", - " NaN\n", - " NaN\n", - " -2.343891e-02\n", + " 32.0\n", + " 0.95\n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0001\n", + " ExperienceReplay\n", + " 2048\n", + " 1\n", + " 1\n", + " 1\n", " \n", " \n", - " step\n", - " 5.190625e-20\n", - " 0.000000e+00\n", - " -1.030754e-16\n", - " -3.622535e-19\n", - " 1.000000e+00\n", - " NaN\n", - " NaN\n", - " NaN\n", - " -5.528580e-02\n", + " 64.0\n", + " 1.00\n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 512\n", + " 1\n", + " 1\n", + " 1\n", " \n", " \n", - " max\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", + " 32.0\n", + " 0.95\n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 512\n", + " 1\n", + " 1\n", + " 1\n", " \n", " \n", - " min\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", + " 64.0\n", + " 1.00\n", + " AdaptativeEpsilonGreedy-0.3-0.1-30000-0\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0001\n", + " ExperienceReplay\n", + " 512\n", + " 1\n", + " 1\n", + " 1\n", " \n", " \n", - " avg\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", - " NaN\n", + " 0.99\n", + " EpsilonGreedy-0.6\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0001\n", + " ExperienceReplay\n", + " 2048\n", + " 1\n", + " 1\n", + " 1\n", " \n", " \n", - " sum\n", - " 6.569585e-02\n", - " -4.414052e-02\n", - " 4.203546e-16\n", - " -2.343891e-02\n", - " -5.528580e-02\n", - " NaN\n", - " NaN\n", - " NaN\n", - " 1.000000e+00\n", + " 32.0\n", + " 0.95\n", + " EpsilonGreedy-0.1\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 2048\n", + " 1\n", + " 1\n", + " 1\n", + " \n", + " \n", + " EpsilonGreedy-0.6\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 512\n", + " 1\n", + " 1\n", + " 1\n", + " \n", + " \n", + " 0.99\n", + " AdaptativeEpsilonGreedy-0.3-0.1-30000-0\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 512\n", + " 1\n", + " 1\n", + " 1\n", + " \n", + " \n", + " 1.00\n", + " AdaptativeEpsilonGreedy-0.3-0.1-30000-0\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 2048\n", + " 1\n", + " 1\n", + " 1\n", + " \n", + " \n", + " 0.0001\n", + " ExperienceReplay\n", + " 2048\n", + " 1\n", + " 1\n", + " 1\n", + " \n", + " \n", + " 64.0\n", + " 0.95\n", + " EpsilonGreedy-0.6\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0001\n", + " ExperienceReplay\n", + " 512\n", + " 1\n", + " 1\n", + " 1\n", + " \n", + " \n", + " EpsilonGreedy-0.1\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 2048\n", + " 1\n", + " 1\n", + " 1\n", + " \n", + " \n", + " 0.0001\n", + " ExperienceReplay\n", + " 512\n", + " 1\n", + " 1\n", + " 1\n", + " \n", + " \n", + " 32.0\n", + " 0.99\n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 2048\n", + " 1\n", + " 1\n", + " 1\n", + " \n", + " \n", + " EpsilonGreedy-0.1\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 2048\n", + " 1\n", + " 1\n", + " 1\n", + " \n", + " \n", + " 64.0\n", + " 0.95\n", + " AdaptativeEpsilonGreedy-0.3-0.1-30000-0\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0001\n", + " ExperienceReplay\n", + " 512\n", + " 1\n", + " 1\n", + " 1\n", + " \n", + " \n", + " 32.0\n", + " 1.00\n", + " AdaptativeEpsilonGreedy-0.3-0.1-30000-0\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0001\n", + " ExperienceReplay\n", + " 512\n", + " 1\n", + " 1\n", + " 1\n", + " \n", + " \n", + " EpsilonGreedy-0.6\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0001\n", + " ExperienceReplay\n", + " 512\n", + " 1\n", + " 1\n", + " 1\n", + " \n", + " \n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 2048\n", + " 1\n", + " 1\n", + " 1\n", + " \n", + " \n", + " 512\n", + " 1\n", + " 1\n", + " 1\n", + " \n", + " \n", + " EpsilonGreedy-0.6\n", + " SimpleDuelingNetwork\n", + " Adam\n", + " 0.0001\n", + " ExperienceReplay\n", + " 2048\n", + " 1\n", + " 1\n", + " 1\n", " \n", " \n", "\n", "" ], "text/plain": [ - " step_train batch_size gamma lr \\\n", - "step_train 1.000000e+00 6.254152e-18 -3.338836e-15 -1.947613e-17 \n", - "batch_size 6.254152e-18 1.000000e+00 -5.096548e-15 1.309431e-16 \n", - "gamma -3.338836e-15 -5.096548e-15 1.000000e+00 1.070395e-14 \n", - "lr -1.947613e-17 1.309431e-16 1.070395e-14 1.000000e+00 \n", - "step 5.190625e-20 0.000000e+00 -1.030754e-16 -3.622535e-19 \n", - "max NaN NaN NaN NaN \n", - "min NaN NaN NaN NaN \n", - "avg NaN NaN NaN NaN \n", - "sum 6.569585e-02 -4.414052e-02 4.203546e-16 -2.343891e-02 \n", + " \\\n", + "algo step_train batch_size gamma greedy_exploration network optimizer lr memories max_size \n", + "DoubleDQN 1.0 64.0 0.99 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 512 6 \n", + " 32.0 0.99 EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 2048 6 \n", + " 0.95 EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 2048 5 \n", + " 64.0 1.00 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 2048 4 \n", + " 0.99 EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 512 4 \n", + " 0.95 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 2048 4 \n", + " AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 2048 4 \n", + " EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 2048 3 \n", + " EpsilonGreedy-0.1 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 2048 3 \n", + " 32.0 0.95 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 512 2 \n", + " 1.00 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 512 2 \n", + " 64.0 1.00 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 512 2 \n", + " 0.99 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 2048 2 \n", + " 32.0 0.95 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 2048 2 \n", + " 1.00 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 512 2 \n", + " 64.0 1.00 EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.1000 ExperienceReplay 512 2 \n", + " 32.0 0.99 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 512 2 \n", + " 2048 2 \n", + " 64.0 0.95 EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 512 1 \n", + " 1.00 EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 512 1 \n", + " EpsilonGreedy-0.1 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 512 1 \n", + " 32.0 0.95 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 2048 1 \n", + " 64.0 1.00 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 512 1 \n", + " 32.0 0.95 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 512 1 \n", + " 64.0 1.00 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 512 1 \n", + " 0.99 EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 2048 1 \n", + " 32.0 0.95 EpsilonGreedy-0.1 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 2048 1 \n", + " EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 512 1 \n", + " 0.99 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 512 1 \n", + " 1.00 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 2048 1 \n", + " 0.0001 ExperienceReplay 2048 1 \n", + " 64.0 0.95 EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 512 1 \n", + " EpsilonGreedy-0.1 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 2048 1 \n", + " 0.0001 ExperienceReplay 512 1 \n", + " 32.0 0.99 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 2048 1 \n", + " EpsilonGreedy-0.1 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 2048 1 \n", + " 64.0 0.95 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 512 1 \n", + " 32.0 1.00 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 512 1 \n", + " EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 512 1 \n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 2048 1 \n", + " 512 1 \n", + " EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 2048 1 \n", + "\n", + " step \\\n", + "algo step_train batch_size gamma greedy_exploration network optimizer lr memories max_size \n", + "DoubleDQN 1.0 64.0 0.99 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 512 6 \n", + " 32.0 0.99 EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 2048 6 \n", + " 0.95 EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 2048 5 \n", + " 64.0 1.00 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 2048 4 \n", + " 0.99 EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 512 4 \n", + " 0.95 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 2048 4 \n", + " AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 2048 4 \n", + " EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 2048 3 \n", + " EpsilonGreedy-0.1 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 2048 3 \n", + " 32.0 0.95 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 512 2 \n", + " 1.00 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 512 2 \n", + " 64.0 1.00 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 512 2 \n", + " 0.99 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 2048 2 \n", + " 32.0 0.95 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 2048 2 \n", + " 1.00 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 512 2 \n", + " 64.0 1.00 EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.1000 ExperienceReplay 512 2 \n", + " 32.0 0.99 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 512 2 \n", + " 2048 2 \n", + " 64.0 0.95 EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 512 1 \n", + " 1.00 EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 512 1 \n", + " EpsilonGreedy-0.1 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 512 1 \n", + " 32.0 0.95 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 2048 1 \n", + " 64.0 1.00 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 512 1 \n", + " 32.0 0.95 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 512 1 \n", + " 64.0 1.00 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 512 1 \n", + " 0.99 EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 2048 1 \n", + " 32.0 0.95 EpsilonGreedy-0.1 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 2048 1 \n", + " EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 512 1 \n", + " 0.99 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 512 1 \n", + " 1.00 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 2048 1 \n", + " 0.0001 ExperienceReplay 2048 1 \n", + " 64.0 0.95 EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 512 1 \n", + " EpsilonGreedy-0.1 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 2048 1 \n", + " 0.0001 ExperienceReplay 512 1 \n", + " 32.0 0.99 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 2048 1 \n", + " EpsilonGreedy-0.1 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 2048 1 \n", + " 64.0 0.95 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 512 1 \n", + " 32.0 1.00 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 512 1 \n", + " EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 512 1 \n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 2048 1 \n", + " 512 1 \n", + " EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 2048 1 \n", "\n", - " step max min avg sum \n", - "step_train 5.190625e-20 NaN NaN NaN 6.569585e-02 \n", - "batch_size 0.000000e+00 NaN NaN NaN -4.414052e-02 \n", - "gamma -1.030754e-16 NaN NaN NaN 4.203546e-16 \n", - "lr -3.622535e-19 NaN NaN NaN -2.343891e-02 \n", - "step 1.000000e+00 NaN NaN NaN -5.528580e-02 \n", - "max NaN NaN NaN NaN NaN \n", - "min NaN NaN NaN NaN NaN \n", - "avg NaN NaN NaN NaN NaN \n", - "sum -5.528580e-02 NaN NaN NaN 1.000000e+00 " + " sum \n", + "algo step_train batch_size gamma greedy_exploration network optimizer lr memories max_size \n", + "DoubleDQN 1.0 64.0 0.99 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 512 6 \n", + " 32.0 0.99 EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 2048 6 \n", + " 0.95 EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 2048 5 \n", + " 64.0 1.00 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 2048 4 \n", + " 0.99 EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 512 4 \n", + " 0.95 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 2048 4 \n", + " AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 2048 4 \n", + " EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 2048 3 \n", + " EpsilonGreedy-0.1 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 2048 3 \n", + " 32.0 0.95 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 512 2 \n", + " 1.00 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 512 2 \n", + " 64.0 1.00 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 512 2 \n", + " 0.99 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 2048 2 \n", + " 32.0 0.95 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 2048 2 \n", + " 1.00 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 512 2 \n", + " 64.0 1.00 EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.1000 ExperienceReplay 512 2 \n", + " 32.0 0.99 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 512 2 \n", + " 2048 2 \n", + " 64.0 0.95 EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 512 1 \n", + " 1.00 EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 512 1 \n", + " EpsilonGreedy-0.1 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 512 1 \n", + " 32.0 0.95 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 2048 1 \n", + " 64.0 1.00 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 512 1 \n", + " 32.0 0.95 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 512 1 \n", + " 64.0 1.00 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 512 1 \n", + " 0.99 EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 2048 1 \n", + " 32.0 0.95 EpsilonGreedy-0.1 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 2048 1 \n", + " EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 512 1 \n", + " 0.99 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 512 1 \n", + " 1.00 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 2048 1 \n", + " 0.0001 ExperienceReplay 2048 1 \n", + " 64.0 0.95 EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 512 1 \n", + " EpsilonGreedy-0.1 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 2048 1 \n", + " 0.0001 ExperienceReplay 512 1 \n", + " 32.0 0.99 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 2048 1 \n", + " EpsilonGreedy-0.1 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 2048 1 \n", + " 64.0 0.95 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 512 1 \n", + " 32.0 1.00 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 512 1 \n", + " EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 512 1 \n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 2048 1 \n", + " 512 1 \n", + " EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0001 ExperienceReplay 2048 1 " ] }, - "execution_count": 29, + "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "df_CategoricalDQN.corr()" + "columns = [\"algo\",\"step_train\",\"batch_size\",\"gamma\",\"greedy_exploration\",\"network\",\"optimizer\",\"lr\",\"memories\",\"max_size\"]\n", + "df_DoubleDQN[df_DoubleDQN[\"sum\"] >= 500].groupby(by=columns, observed=True).count().sort_values(by=['sum'], ascending=False)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### CategoricalDQN" ] }, { "cell_type": "code", - "execution_count": 30, + "execution_count": 33, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(405.0, 8.0)" - ] - }, - "execution_count": 30, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "max(df_CategoricalDQN[\"sum\"]), min(df_CategoricalDQN[\"sum\"])" + "df_CategoricalDQN = df[df[\"algo\"] == \"CategoricalDQN\"].copy()" ] }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 34, "metadata": {}, "outputs": [ { @@ -5111,392 +5758,329 @@ " memories\n", " max_size\n", " step\n", - " max\n", - " min\n", - " avg\n", " sum\n", " \n", " \n", " \n", " \n", - " 567\n", + " 2672\n", " CategoricalDQN\n", - " 32.0\n", " 1.0\n", - " 0.99\n", - " AdaptativeEpsilonGreedy-0.3-0.1-50000-0\n", + " 64.0\n", + " 0.95\n", + " EpsilonGreedy-0.1\n", " C51Network\n", " \n", " Adam\n", " 0.0010\n", " ExperienceReplay\n", - " 32\n", - " 332.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 405.0\n", - " \n", - " \n", - " 1110\n", - " CategoricalDQN\n", - " 4.0\n", - " 1.0\n", - " 0.99\n", - " AdaptativeEpsilonGreedy-0.3-0.1-50000-0\n", - " C51Network\n", - " \n", - " Adam\n", - " 0.1000\n", - " ExperienceReplay\n", - " 128\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 108.0\n", + " 2048\n", + " 60.0\n", + " 500.0\n", " \n", " \n", - " 300\n", + " 8419\n", " CategoricalDQN\n", - " 1.0\n", " 32.0\n", - " 0.99\n", - " EpsilonGreedy-0.1\n", + " 64.0\n", + " 1.00\n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0\n", " C51Network\n", " \n", " Adam\n", - " 0.1000\n", + " 0.0001\n", " ExperienceReplay\n", - " 128\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 105.0\n", + " 512\n", + " 180.0\n", + " 500.0\n", " \n", " \n", - " 1505\n", + " 3979\n", " CategoricalDQN\n", - " 4.0\n", + " 1.0\n", " 64.0\n", - " 0.99\n", - " AdaptativeEpsilonGreedy-0.8-0.2-50000-0\n", + " 1.00\n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0\n", " C51Network\n", " \n", " Adam\n", " 0.0010\n", " ExperienceReplay\n", - " 16\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 74.0\n", + " 2048\n", + " 110.0\n", + " 403.0\n", " \n", " \n", - " 685\n", + " 5525\n", " CategoricalDQN\n", " 32.0\n", - " 1.0\n", + " 32.0\n", " 0.99\n", - " EpsilonGreedy-0.6\n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0\n", " C51Network\n", " \n", " Adam\n", - " 0.0001\n", + " 0.1000\n", " ExperienceReplay\n", - " 32\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 71.0\n", + " 2048\n", + " 70.0\n", + " 388.0\n", " \n", " \n", - " 689\n", + " 463\n", " CategoricalDQN\n", - " 32.0\n", " 1.0\n", - " 0.99\n", - " EpsilonGreedy-0.6\n", + " 32.0\n", + " 0.95\n", + " EpsilonGreedy-0.1\n", " C51Network\n", " \n", " Adam\n", - " 0.0001\n", + " 0.0010\n", " ExperienceReplay\n", - " 32\n", - " 500.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 67.0\n", + " 2048\n", + " 290.0\n", + " 355.0\n", " \n", " \n", - " 543\n", + " 2163\n", " CategoricalDQN\n", - " 32.0\n", " 1.0\n", - " 0.99\n", - " AdaptativeEpsilonGreedy-0.3-0.1-50000-0\n", + " 32.0\n", + " 1.00\n", + " EpsilonGreedy-0.6\n", " C51Network\n", " \n", " Adam\n", - " 0.0001\n", + " 0.0010\n", " ExperienceReplay\n", - " 128\n", - " 498.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 59.0\n", + " 512\n", + " 240.0\n", + " 328.0\n", " \n", " \n", - " 435\n", + " 2659\n", " CategoricalDQN\n", " 1.0\n", " 64.0\n", - " 0.99\n", - " AdaptativeEpsilonGreedy-0.8-0.2-50000-0\n", + " 0.95\n", + " EpsilonGreedy-0.1\n", " C51Network\n", " \n", " Adam\n", - " 0.1000\n", + " 0.0001\n", " ExperienceReplay\n", - " 128\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 46.0\n", + " 512\n", + " 240.0\n", + " 327.0\n", " \n", " \n", - " 1465\n", + " 2488\n", " CategoricalDQN\n", - " 4.0\n", + " 1.0\n", " 64.0\n", - " 0.99\n", - " AdaptativeEpsilonGreedy-0.3-0.1-50000-0\n", + " 0.95\n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0\n", " C51Network\n", " \n", " Adam\n", " 0.0010\n", " ExperienceReplay\n", - " 32\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 44.0\n", + " 2048\n", + " 80.0\n", + " 326.0\n", " \n", " \n", - " 544\n", + " 1975\n", " CategoricalDQN\n", - " 32.0\n", " 1.0\n", - " 0.99\n", - " AdaptativeEpsilonGreedy-0.3-0.1-50000-0\n", + " 32.0\n", + " 1.00\n", + " EpsilonGreedy-0.1\n", " C51Network\n", " \n", " Adam\n", - " 0.0001\n", + " 0.0010\n", " ExperienceReplay\n", - " 128\n", - " 500.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 44.0\n", + " 512\n", + " 220.0\n", + " 316.0\n", " \n", " \n", - " 761\n", + " 2493\n", " CategoricalDQN\n", - " 32.0\n", - " 32.0\n", - " 0.99\n", - " AdaptativeEpsilonGreedy-0.3-0.1-50000-0\n", + " 1.0\n", + " 64.0\n", + " 0.95\n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0\n", " C51Network\n", " \n", " Adam\n", - " 0.1000\n", + " 0.0010\n", " ExperienceReplay\n", - " 32\n", - " 166.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 42.0\n", + " 2048\n", + " 130.0\n", + " 312.0\n", " \n", " \n", - " 861\n", + " 4366\n", " CategoricalDQN\n", - " 32.0\n", - " 32.0\n", - " 0.99\n", + " 1.0\n", + " 64.0\n", + " 1.00\n", " EpsilonGreedy-0.6\n", " C51Network\n", " \n", " Adam\n", - " 0.0001\n", + " 0.0010\n", " ExperienceReplay\n", - " 16\n", - " 166.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 41.0\n", + " 2048\n", + " 260.0\n", + " 305.0\n", " \n", " \n", - " 55\n", + " 461\n", " CategoricalDQN\n", " 1.0\n", - " 1.0\n", - " 0.99\n", - " AdaptativeEpsilonGreedy-0.8-0.2-50000-0\n", + " 32.0\n", + " 0.95\n", + " EpsilonGreedy-0.1\n", " C51Network\n", " \n", " Adam\n", - " 0.0001\n", + " 0.0010\n", " ExperienceReplay\n", - " 32\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 38.0\n", + " 2048\n", + " 270.0\n", + " 294.0\n", " \n", " \n", - " 1136\n", + " 3974\n", " CategoricalDQN\n", - " 4.0\n", " 1.0\n", - " 0.99\n", - " AdaptativeEpsilonGreedy-0.8-0.2-50000-0\n", + " 64.0\n", + " 1.00\n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0\n", " C51Network\n", " \n", " Adam\n", - " 0.0001\n", + " 0.0010\n", " ExperienceReplay\n", - " 32\n", - " 166.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 38.0\n", + " 2048\n", + " 60.0\n", + " 292.0\n", " \n", " \n", - " 1036\n", + " 8906\n", " CategoricalDQN\n", " 32.0\n", " 64.0\n", - " 0.99\n", + " 1.00\n", " EpsilonGreedy-0.6\n", " C51Network\n", " \n", " Adam\n", - " 0.0001\n", + " 0.1000\n", " ExperienceReplay\n", - " 128\n", - " 166.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 37.0\n", + " 512\n", + " 90.0\n", + " 292.0\n", " \n", " \n", - " 702\n", + " 2158\n", " CategoricalDQN\n", - " 32.0\n", " 1.0\n", - " 0.99\n", + " 32.0\n", + " 1.00\n", " EpsilonGreedy-0.6\n", " C51Network\n", " \n", " Adam\n", " 0.0010\n", " ExperienceReplay\n", - " 32\n", - " 332.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 37.0\n", + " 512\n", + " 190.0\n", + " 290.0\n", " \n", " \n", - " 688\n", + " 2159\n", " CategoricalDQN\n", - " 32.0\n", " 1.0\n", - " 0.99\n", + " 32.0\n", + " 1.00\n", " EpsilonGreedy-0.6\n", " C51Network\n", " \n", " Adam\n", - " 0.0001\n", + " 0.0010\n", " ExperienceReplay\n", - " 32\n", - " 498.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 36.0\n", + " 512\n", + " 200.0\n", + " 282.0\n", " \n", " \n", - " 485\n", + " 4023\n", " CategoricalDQN\n", " 1.0\n", " 64.0\n", - " 0.99\n", - " EpsilonGreedy-0.1\n", + " 1.00\n", + " AdaptativeEpsilonGreedy-0.8-0.2-10000-0\n", " C51Network\n", " \n", " Adam\n", - " 0.1000\n", + " 0.0010\n", " ExperienceReplay\n", - " 16\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 34.0\n", + " 512\n", + " 240.0\n", + " 278.0\n", " \n", " \n", - " 862\n", + " 2133\n", " CategoricalDQN\n", + " 1.0\n", " 32.0\n", - " 32.0\n", - " 0.99\n", + " 1.00\n", " EpsilonGreedy-0.6\n", " C51Network\n", " \n", " Adam\n", - " 0.0001\n", + " 0.0010\n", " ExperienceReplay\n", - " 16\n", - " 332.0\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", - " 34.0\n", + " 2048\n", + " 250.0\n", + " 276.0\n", " \n", " \n", - " 1475\n", + " 2726\n", " CategoricalDQN\n", - " 4.0\n", + " 1.0\n", " 64.0\n", - " 0.99\n", - " AdaptativeEpsilonGreedy-0.3-0.1-50000-0\n", + " 0.95\n", + " EpsilonGreedy-0.1\n", " C51Network\n", " \n", " Adam\n", - " 0.1000\n", + " 0.0010\n", " ExperienceReplay\n", - " 16\n", - " 1.0\n", - " 1.0\n", - " 1.0\n", + " 512\n", + " 290.0\n", + " 266.0\n", + " \n", + " \n", + " 2156\n", + " CategoricalDQN\n", " 1.0\n", - " 33.0\n", + " 32.0\n", + " 1.00\n", + " EpsilonGreedy-0.6\n", + " C51Network\n", + " \n", + " Adam\n", + " 0.0010\n", + " ExperienceReplay\n", + " 512\n", + " 170.0\n", + " 261.0\n", " \n", " \n", "\n", @@ -5504,73 +6088,73 @@ ], "text/plain": [ " algo step_train batch_size gamma \\\n", - "567 CategoricalDQN 32.0 1.0 0.99 \n", - "1110 CategoricalDQN 4.0 1.0 0.99 \n", - "300 CategoricalDQN 1.0 32.0 0.99 \n", - "1505 CategoricalDQN 4.0 64.0 0.99 \n", - "685 CategoricalDQN 32.0 1.0 0.99 \n", - "689 CategoricalDQN 32.0 1.0 0.99 \n", - "543 CategoricalDQN 32.0 1.0 0.99 \n", - "435 CategoricalDQN 1.0 64.0 0.99 \n", - "1465 CategoricalDQN 4.0 64.0 0.99 \n", - "544 CategoricalDQN 32.0 1.0 0.99 \n", - "761 CategoricalDQN 32.0 32.0 0.99 \n", - "861 CategoricalDQN 32.0 32.0 0.99 \n", - "55 CategoricalDQN 1.0 1.0 0.99 \n", - "1136 CategoricalDQN 4.0 1.0 0.99 \n", - "1036 CategoricalDQN 32.0 64.0 0.99 \n", - "702 CategoricalDQN 32.0 1.0 0.99 \n", - "688 CategoricalDQN 32.0 1.0 0.99 \n", - "485 CategoricalDQN 1.0 64.0 0.99 \n", - "862 CategoricalDQN 32.0 32.0 0.99 \n", - "1475 CategoricalDQN 4.0 64.0 0.99 \n", + "2672 CategoricalDQN 1.0 64.0 0.95 \n", + "8419 CategoricalDQN 32.0 64.0 1.00 \n", + "3979 CategoricalDQN 1.0 64.0 1.00 \n", + "5525 CategoricalDQN 32.0 32.0 0.99 \n", + "463 CategoricalDQN 1.0 32.0 0.95 \n", + "2163 CategoricalDQN 1.0 32.0 1.00 \n", + "2659 CategoricalDQN 1.0 64.0 0.95 \n", + "2488 CategoricalDQN 1.0 64.0 0.95 \n", + "1975 CategoricalDQN 1.0 32.0 1.00 \n", + "2493 CategoricalDQN 1.0 64.0 0.95 \n", + "4366 CategoricalDQN 1.0 64.0 1.00 \n", + "461 CategoricalDQN 1.0 32.0 0.95 \n", + "3974 CategoricalDQN 1.0 64.0 1.00 \n", + "8906 CategoricalDQN 32.0 64.0 1.00 \n", + "2158 CategoricalDQN 1.0 32.0 1.00 \n", + "2159 CategoricalDQN 1.0 32.0 1.00 \n", + "4023 CategoricalDQN 1.0 64.0 1.00 \n", + "2133 CategoricalDQN 1.0 32.0 1.00 \n", + "2726 CategoricalDQN 1.0 64.0 0.95 \n", + "2156 CategoricalDQN 1.0 32.0 1.00 \n", "\n", " greedy_exploration network optimizer lr \\\n", - "567 AdaptativeEpsilonGreedy-0.3-0.1-50000-0 C51Network Adam 0.0010 \n", - "1110 AdaptativeEpsilonGreedy-0.3-0.1-50000-0 C51Network Adam 0.1000 \n", - "300 EpsilonGreedy-0.1 C51Network Adam 0.1000 \n", - "1505 AdaptativeEpsilonGreedy-0.8-0.2-50000-0 C51Network Adam 0.0010 \n", - "685 EpsilonGreedy-0.6 C51Network Adam 0.0001 \n", - "689 EpsilonGreedy-0.6 C51Network Adam 0.0001 \n", - "543 AdaptativeEpsilonGreedy-0.3-0.1-50000-0 C51Network Adam 0.0001 \n", - "435 AdaptativeEpsilonGreedy-0.8-0.2-50000-0 C51Network Adam 0.1000 \n", - "1465 AdaptativeEpsilonGreedy-0.3-0.1-50000-0 C51Network Adam 0.0010 \n", - "544 AdaptativeEpsilonGreedy-0.3-0.1-50000-0 C51Network Adam 0.0001 \n", - "761 AdaptativeEpsilonGreedy-0.3-0.1-50000-0 C51Network Adam 0.1000 \n", - "861 EpsilonGreedy-0.6 C51Network Adam 0.0001 \n", - "55 AdaptativeEpsilonGreedy-0.8-0.2-50000-0 C51Network Adam 0.0001 \n", - "1136 AdaptativeEpsilonGreedy-0.8-0.2-50000-0 C51Network Adam 0.0001 \n", - "1036 EpsilonGreedy-0.6 C51Network Adam 0.0001 \n", - "702 EpsilonGreedy-0.6 C51Network Adam 0.0010 \n", - "688 EpsilonGreedy-0.6 C51Network Adam 0.0001 \n", - "485 EpsilonGreedy-0.1 C51Network Adam 0.1000 \n", - "862 EpsilonGreedy-0.6 C51Network Adam 0.0001 \n", - "1475 AdaptativeEpsilonGreedy-0.3-0.1-50000-0 C51Network Adam 0.1000 \n", + "2672 EpsilonGreedy-0.1 C51Network Adam 0.0010 \n", + "8419 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 C51Network Adam 0.0001 \n", + "3979 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 C51Network Adam 0.0010 \n", + "5525 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 C51Network Adam 0.1000 \n", + "463 EpsilonGreedy-0.1 C51Network Adam 0.0010 \n", + "2163 EpsilonGreedy-0.6 C51Network Adam 0.0010 \n", + "2659 EpsilonGreedy-0.1 C51Network Adam 0.0001 \n", + "2488 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 C51Network Adam 0.0010 \n", + "1975 EpsilonGreedy-0.1 C51Network Adam 0.0010 \n", + "2493 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 C51Network Adam 0.0010 \n", + "4366 EpsilonGreedy-0.6 C51Network Adam 0.0010 \n", + "461 EpsilonGreedy-0.1 C51Network Adam 0.0010 \n", + "3974 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 C51Network Adam 0.0010 \n", + "8906 EpsilonGreedy-0.6 C51Network Adam 0.1000 \n", + "2158 EpsilonGreedy-0.6 C51Network Adam 0.0010 \n", + "2159 EpsilonGreedy-0.6 C51Network Adam 0.0010 \n", + "4023 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 C51Network Adam 0.0010 \n", + "2133 EpsilonGreedy-0.6 C51Network Adam 0.0010 \n", + "2726 EpsilonGreedy-0.1 C51Network Adam 0.0010 \n", + "2156 EpsilonGreedy-0.6 C51Network Adam 0.0010 \n", "\n", - " memories max_size step max min avg sum \n", - "567 ExperienceReplay 32 332.0 1.0 1.0 1.0 405.0 \n", - "1110 ExperienceReplay 128 1.0 1.0 1.0 1.0 108.0 \n", - "300 ExperienceReplay 128 1.0 1.0 1.0 1.0 105.0 \n", - "1505 ExperienceReplay 16 1.0 1.0 1.0 1.0 74.0 \n", - "685 ExperienceReplay 32 1.0 1.0 1.0 1.0 71.0 \n", - "689 ExperienceReplay 32 500.0 1.0 1.0 1.0 67.0 \n", - "543 ExperienceReplay 128 498.0 1.0 1.0 1.0 59.0 \n", - "435 ExperienceReplay 128 1.0 1.0 1.0 1.0 46.0 \n", - "1465 ExperienceReplay 32 1.0 1.0 1.0 1.0 44.0 \n", - "544 ExperienceReplay 128 500.0 1.0 1.0 1.0 44.0 \n", - "761 ExperienceReplay 32 166.0 1.0 1.0 1.0 42.0 \n", - "861 ExperienceReplay 16 166.0 1.0 1.0 1.0 41.0 \n", - "55 ExperienceReplay 32 1.0 1.0 1.0 1.0 38.0 \n", - "1136 ExperienceReplay 32 166.0 1.0 1.0 1.0 38.0 \n", - "1036 ExperienceReplay 128 166.0 1.0 1.0 1.0 37.0 \n", - "702 ExperienceReplay 32 332.0 1.0 1.0 1.0 37.0 \n", - "688 ExperienceReplay 32 498.0 1.0 1.0 1.0 36.0 \n", - "485 ExperienceReplay 16 1.0 1.0 1.0 1.0 34.0 \n", - "862 ExperienceReplay 16 332.0 1.0 1.0 1.0 34.0 \n", - "1475 ExperienceReplay 16 1.0 1.0 1.0 1.0 33.0 " + " memories max_size step sum \n", + "2672 ExperienceReplay 2048 60.0 500.0 \n", + "8419 ExperienceReplay 512 180.0 500.0 \n", + "3979 ExperienceReplay 2048 110.0 403.0 \n", + "5525 ExperienceReplay 2048 70.0 388.0 \n", + "463 ExperienceReplay 2048 290.0 355.0 \n", + "2163 ExperienceReplay 512 240.0 328.0 \n", + "2659 ExperienceReplay 512 240.0 327.0 \n", + "2488 ExperienceReplay 2048 80.0 326.0 \n", + "1975 ExperienceReplay 512 220.0 316.0 \n", + "2493 ExperienceReplay 2048 130.0 312.0 \n", + "4366 ExperienceReplay 2048 260.0 305.0 \n", + "461 ExperienceReplay 2048 270.0 294.0 \n", + "3974 ExperienceReplay 2048 60.0 292.0 \n", + "8906 ExperienceReplay 512 90.0 292.0 \n", + "2158 ExperienceReplay 512 190.0 290.0 \n", + "2159 ExperienceReplay 512 200.0 282.0 \n", + "4023 ExperienceReplay 512 240.0 278.0 \n", + "2133 ExperienceReplay 2048 250.0 276.0 \n", + "2726 ExperienceReplay 512 290.0 266.0 \n", + "2156 ExperienceReplay 512 170.0 261.0 " ] }, - "execution_count": 31, + "execution_count": 34, "metadata": {}, "output_type": "execute_result" } @@ -5581,7 +6165,7 @@ }, { "cell_type": "code", - "execution_count": 70, + "execution_count": 35, "metadata": {}, "outputs": [ { @@ -5590,13 +6174,13 @@ "" ] }, - "execution_count": 70, + "execution_count": 35, "metadata": {}, "output_type": "execute_result" }, { "data": { - "image/png": "\n", + "image/png": "\n", "text/plain": [ "
" ] @@ -5609,7 +6193,241 @@ ], "source": [ "fig, ax = plt.subplots(figsize=(10,10)) \n", - "sns.heatmap(df_CategoricalDQN.corr()[df_CategoricalDQN.corr() > 0.05], annot = True, fmt='.2g',cmap= 'coolwarm', ax=ax)" + "sns.heatmap(df_CategoricalDQN.corr()[abs(df_CategoricalDQN.corr()) > 0.05], annot = True, fmt='.2g',cmap= 'coolwarm', ax=ax)" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
stepsum
algostep_trainbatch_sizegammagreedy_explorationnetworkoptimizerlrmemoriesmax_size
CategoricalDQN1.064.00.95EpsilonGreedy-0.1C51NetworkAdam0.0010ExperienceReplay20481.01.01.0
32.064.01.00AdaptativeEpsilonGreedy-0.8-0.2-10000-0C51NetworkAdam0.0001ExperienceReplay5121.01.01.0
1.064.00.95AdaptativeEpsilonGreedy-0.3-0.1-30000-0C51NetworkAdam0.0001ExperienceReplay2048NaNNaNNaN
512NaNNaNNaN
0.0010ExperienceReplay2048NaNNaNNaN
.......................................
DoubleDQN32.064.01.00EpsilonGreedy-0.6SimpleDuelingNetworkAdam0.0010ExperienceReplay512NaNNaNNaN
SimpleNetworkAdam0.0001ExperienceReplay2048NaNNaNNaN
512NaNNaNNaN
0.0010ExperienceReplay2048NaNNaNNaN
512NaNNaNNaN
\n", + "

576 rows × 3 columns

\n", + "
" + ], + "text/plain": [ + " \\\n", + "algo step_train batch_size gamma greedy_exploration network optimizer lr memories max_size \n", + "CategoricalDQN 1.0 64.0 0.95 EpsilonGreedy-0.1 C51Network Adam 0.0010 ExperienceReplay 2048 1.0 \n", + " 32.0 64.0 1.00 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 C51Network Adam 0.0001 ExperienceReplay 512 1.0 \n", + " 1.0 64.0 0.95 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 C51Network Adam 0.0001 ExperienceReplay 2048 NaN \n", + " 512 NaN \n", + " 0.0010 ExperienceReplay 2048 NaN \n", + "... ... \n", + "DoubleDQN 32.0 64.0 1.00 EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 512 NaN \n", + " SimpleNetwork Adam 0.0001 ExperienceReplay 2048 NaN \n", + " 512 NaN \n", + " 0.0010 ExperienceReplay 2048 NaN \n", + " 512 NaN \n", + "\n", + " step \\\n", + "algo step_train batch_size gamma greedy_exploration network optimizer lr memories max_size \n", + "CategoricalDQN 1.0 64.0 0.95 EpsilonGreedy-0.1 C51Network Adam 0.0010 ExperienceReplay 2048 1.0 \n", + " 32.0 64.0 1.00 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 C51Network Adam 0.0001 ExperienceReplay 512 1.0 \n", + " 1.0 64.0 0.95 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 C51Network Adam 0.0001 ExperienceReplay 2048 NaN \n", + " 512 NaN \n", + " 0.0010 ExperienceReplay 2048 NaN \n", + "... ... \n", + "DoubleDQN 32.0 64.0 1.00 EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 512 NaN \n", + " SimpleNetwork Adam 0.0001 ExperienceReplay 2048 NaN \n", + " 512 NaN \n", + " 0.0010 ExperienceReplay 2048 NaN \n", + " 512 NaN \n", + "\n", + " sum \n", + "algo step_train batch_size gamma greedy_exploration network optimizer lr memories max_size \n", + "CategoricalDQN 1.0 64.0 0.95 EpsilonGreedy-0.1 C51Network Adam 0.0010 ExperienceReplay 2048 1.0 \n", + " 32.0 64.0 1.00 AdaptativeEpsilonGreedy-0.8-0.2-10000-0 C51Network Adam 0.0001 ExperienceReplay 512 1.0 \n", + " 1.0 64.0 0.95 AdaptativeEpsilonGreedy-0.3-0.1-30000-0 C51Network Adam 0.0001 ExperienceReplay 2048 NaN \n", + " 512 NaN \n", + " 0.0010 ExperienceReplay 2048 NaN \n", + "... ... \n", + "DoubleDQN 32.0 64.0 1.00 EpsilonGreedy-0.6 SimpleDuelingNetwork Adam 0.0010 ExperienceReplay 512 NaN \n", + " SimpleNetwork Adam 0.0001 ExperienceReplay 2048 NaN \n", + " 512 NaN \n", + " 0.0010 ExperienceReplay 2048 NaN \n", + " 512 NaN \n", + "\n", + "[576 rows x 3 columns]" + ] + }, + "execution_count": 36, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "columns = [\"algo\",\"step_train\",\"batch_size\",\"gamma\",\"greedy_exploration\",\"network\",\"optimizer\",\"lr\",\"memories\",\"max_size\"]\n", + "df_CategoricalDQN[df_CategoricalDQN[\"sum\"] >= 500].groupby(by=columns).count().sort_values(by=['sum'], ascending=False)" ] }, { diff --git a/results/README.md b/results/README.md index b711526..997a44c 100644 --- a/results/README.md +++ b/results/README.md @@ -6,7 +6,7 @@ There are one subdirectory by environment used. ## CartPole -Example with agent: DQN, network: SimpleNetwork, Algo: Adam, Memorie: ExperienceReplay(max_size=32), Step train: 1, Batch size: 64, Gamma: 0.99, Exploration: EpsilonGreedy(0.1), Learning rate: 0.001 +Example with agent: DQN, network: SimpleNetwork, Algo: Adam, Memorie: ExperienceReplay(max_size=2048), Step train: 1, Batch size: 32, Gamma: 0.95, Exploration: AdaptativeEpsilonGreedy(0.8, 0.2, 10000, 0), Learning rate: 0.0001 ![CartPoleExemple.gif](./ressources/cartpole.gif) @@ -23,85 +23,125 @@ We test to train all this agent with this parameters. We train agent with 300 max_step. * Agent - * Algo : [DQN, DoubleDQN, DuelingDQN, CategoricalDQN] + * Algo : [DQN, DoubleDQN, CategoricalDQN] - * Step train : [1, 4, 32] + * Step train : [1, 32] - * Batch size : [1, 32, 64] + * Batch size : [32, 64] - * Gamma : [0.99] + * Gamma : [1.0, 0.99, 0.95] - * Exploration : [EpsilonGreedy(0.1), - EpsilonGreedy(0.6), - AdaptativeEpsilonGreedy(0.3, 0.1, 50000, 0), - AdaptativeEpsilonGreedy(0.8, 0.2, 50000,² 0)] + * Exploration : [EpsilonGreedy(0.1), EpsilonGreedy(0.6), AdaptativeEpsilonGreedy(0.3, 0.1, 30000, 0), + AdaptativeEpsilonGreedy(0.8, 0.2, 10000, 0)] + * Network -For _DQN, DoubleDQN_ : SimpleNetwork - -For _DuelingDQN_ : SimpleDuelingNetwork +For _DQN, DoubleDQN_ : SimpleNetwork, SimpleDuelingNetwork For _CategoricalDQN_ : C51Network + * Optimizer * Algo : Adam - * Learning rate : [0.1, 0.001, 0.001] + * Learning rate : [0.1, 0.001, 0.0001] * Memories * Algo : [ExperienceReplay] + * max_size: [512, 2048] ### Result analysis -We stop training at 500 step. We can find differant result if we add more training step. +We stop training at 300 step. We can find different result if we add more training step. #### Agent performance -DQN, DoubleDQN and DuelingDQN can reach in 500(max) steps. CategoricalDQN reach ~400 steps. +DQN, DoubleDQN, CategoricalDQN can reach 500 rewards in 300 (max) steps. * DQN -500 step is read with after 166 episode with this parameters. +500 step is reach with after 30 episode with this parameters. + +with SimpleNetwork: + +| algo | step_train | batch_size | gamma | greedy_exploration | network | optimizer | lr | memories | max_size | step | max | min | avg | sum | +| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | +|DQN|1.0|32.0|1.00|EpsilonGreedy-0.6|SimpleNetwork|Adam|0.001|ExperienceReplay|2048|30.0|1.0|1.0|1.0|500.0| +|DQN|1.0|64.0|0.99|EpsilonGreedy-0.6|SimpleNetwork|Adam|0.001|ExperienceReplay|2048|30.0|1.0|1.0|1.0|500.0| +|DQN|1.0|64.0|0.99|AdaptativeEpsilonGreedy-0.3-0.1-30000-0|SimpleNetwork|Adam|0.100|ExperienceReplay|2048|40.0|1.0|1.0|1.0|500.0| +|DQN|1.0|64.0|0.95|EpsilonGreedy-0.6|SimpleNetwork|Adam|0.001|ExperienceReplay|512|50.0|1.0|1.0|1.0|500.0| +|DQN|1.0|64.0|1.00|AdaptativeEpsilonGreedy-0.3-0.1-30000-0|SimpleNetwork|Adam|0.001|ExperienceReplay|512|50.0|1.0|1.0|1.0|500.0| +|DQN|1.0|32.0|0.99|AdaptativeEpsilonGreedy-0.8-0.2-10000-0|SimpleNetwork|Adam|0.001|ExperienceReplay|2048|60.0|1.0|1.0|1.0|500.0| +|DQN|1.0|64.0|0.99|AdaptativeEpsilonGreedy-0.3-0.1-30000-0|SimpleNetwork|Adam|0.100|ExperienceReplay|2048|60.0|1.0|1.0|1.0|500.0| +|DQN|1.0|64.0|1.00|AdaptativeEpsilonGreedy-0.8-0.2-10000-0|SimpleNetwork|Adam|0.100|ExperienceReplay|512|60.0|1.0|1.0|1.0|500.0| +|DQN|1.0|64.0|1.00|EpsilonGreedy-0.6|SimpleNetwork|Adam|0.100|ExperienceReplay|512|60.0|1.0|1.0|1.0|500.0| +|DQN|1.0|64.0|0.95|AdaptativeEpsilonGreedy-0.8-0.2-10000-0|SimpleNetwork|Adam|0.001|ExperienceReplay|512|70.0|1.0|1.0|1.0|500.0| +|DQN|1.0|64.0|0.99|AdaptativeEpsilonGreedy-0.3-0.1-30000-0|SimpleNetwork|Adam|0.001|ExperienceReplay|2048|70.0|1.0|1.0|1.0|500.0| +|DQN|1.0|64.0|1.00|AdaptativeEpsilonGreedy-0.3-0.1-30000-0|SimpleNetwork|Adam|0.001|ExperienceReplay|2048|70.0|1.0|1.0|1.0|500.0| +|DQN|1.0|32.0|0.99|AdaptativeEpsilonGreedy-0.3-0.1-30000-0|SimpleNetwork|Adam|0.001|ExperienceReplay|512|80.0|1.0|1.0|1.0|500.0| +|DQN|1.0|32.0|0.99|AdaptativeEpsilonGreedy-0.8-0.2-10000-0|SimpleNetwork|Adam|0.100|ExperienceReplay|512|80.0|1.0|1.0|1.0|500.0| +|DQN|1.0|32.0|1.00|EpsilonGreedy-0.6|SimpleNetwork| Adam|0.001|ExperienceReplay|512|80.0|1.0|1.0|1.0|500.0 + +with SimpleDuelingNetwork: | algo | step_train | batch_size | gamma | greedy_exploration | network | optimizer | lr | memories | max_size | step | max | min | avg | sum | | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | -| DQN | 1.0 | 32.0 | 0.99 | AdaptativeEpsilonGreedy-0.3-0.1-50000-0 | SimpleNetwork | Adam | 0.0010 | ExperienceReplay | 128 | 166.0 | 1.0 | 1.0 | 1.0 | 500.0 | -| DQN | 1.0 | 64.0 | 0.99 | EpsilonGreedy-0.1 | SimpleNetwork | Adam | 0.0010 | ExperienceReplay | 16 | 166.0 | 1.0 | 1.0 | 1.0 | 500.0 | -| DQN | 32.0 | 64.0 | 0.99 | AdaptativeEpsilonGreedy-0.3-0.1-50000-0 | SimpleNetwork | Adam | 0.1000 | ExperienceReplay | 16 | 166.0 | 1.0 | 1.0 | 1.0 | 500.0 | +|DQN|1.0|64.0|0.95|AdaptativeEpsilonGreedy-0.8-0.2-10000-0|SimpleDuelingNetwork|Adam|0.001|ExperienceReplay|512|30.0|1.0|1.0|1.0|500.0| +|DQN|1.0|32.0|0.95|EpsilonGreedy-0.6|SimpleDuelingNetwork|Adam|0.001|ExperienceReplay|512|40.0|1.0|1.0|1.0|500.0| +|DQN|1.0|32.0|0.99|EpsilonGreedy-0.1|SimpleDuelingNetwork|Adam|0.001|ExperienceReplay|2048|40.0|1.0|1.0|1.0|500.0| +|DQN|1.0|32.0|1.00|EpsilonGreedy-0.1|SimpleDuelingNetwork|Adam|0.001|ExperienceReplay|512|40.0|1.0|1.0|1.0|500.0| +|DQN|1.0|64.0|0.95|EpsilonGreedy-0.6|SimpleDuelingNetwork|Adam|0.001|ExperienceReplay|512|40.0|1.0|1.0|1.0|500.0| +|DQN|1.0|64.0|0.99|EpsilonGreedy-0.6|SimpleDuelingNetwork|Adam|0.001|ExperienceReplay|2048|40.0|1.0|1.0|1.0|500.0| +|DQN|1.0|32.0|0.95|EpsilonGreedy-0.1|SimpleDuelingNetwork|Adam|0.001|ExperienceReplay|512|50.0|1.0|1.0|1.0|500.0| +|DQN|1.0|32.0|1.00|AdaptativeEpsilonGreedy-0.3-0.1-30000-0|SimpleDuelingNetwork|Adam|0.001|ExperienceReplay|2048|50.0|1.0|1.0|1.0|500.0| +|DQN|1.0|64.0|0.95|AdaptativeEpsilonGreedy-0.3-0.1-30000-0|SimpleDuelingNetwork|Adam|0.001|ExperienceReplay|512|50.0|1.0|1.0|1.0|500.0| +|DQN|1.0|64.0|0.99|AdaptativeEpsilonGreedy-0.3-0.1-30000-0|SimpleDuelingNetwork|Adam|0.001|ExperienceReplay|2048|50.0|1.0|1.0|1.0|500.0| +|DQN|1.0|32.0|0.99|AdaptativeEpsilonGreedy-0.3-0.1-30000-0|SimpleDuelingNetwork|Adam|0.001|ExperienceReplay|2048|60.0|1.0|1.0|1.0|500.0| +|DQN|1.0|32.0|1.00|AdaptativeEpsilonGreedy-0.3-0.1-30000-0|SimpleDuelingNetwork|Adam|0.001|ExperienceReplay|512|60.0|1.0|1.0|1.0|500.0| +|DQN|1.0|64.0|0.99|AdaptativeEpsilonGreedy-0.3-0.1-30000-0|SimpleDuelingNetwork|Adam|0.001|ExperienceReplay|2048|60.0|1.0|1.0|1.0|500.0| * DoubleDQN -Same for doubleDQN but less training reach 500 step and need more episode to reach that. +500 step is reach with after 20 episode with this parameters with SimpleDuelingNetwork. + +with SimpleNetwork: | algo | step_train | batch_size | gamma | greedy_exploration | network | optimizer | lr | memories | max_size | step | max | min | avg | sum | | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | -|DoubleDQN|1.0|64.0|0.99|EpsilonGreedy-0.6|SimpleNetwork|Adam|0.0001|ExperienceReplay|128|332.0|1.0|1.0|1.0|500.0| - -* DuelingDQN +|DoubleDQN|1.0|32.0|1.00|EpsilonGreedy-0.1|SimpleNetwork|Adam|0.0010|ExperienceReplay|512|30.0|1.0|1.0|1.0|500.0| +|DoubleDQN|1.0|32.0|0.99|EpsilonGreedy-0.1|SimpleNetwork|Adam|0.0010|ExperienceReplay|2048|40.0|1.0|1.0|1.0|500.0| +|DoubleDQN|32.0|64.0|0.99|AdaptativeEpsilonGreedy-0.3-0.1-30000-0|SimpleNetwork|Adam|0.1000|ExperienceReplay|2048|40.0|1.0|1.0|1.0|500.0| +|DoubleDQN|1.0|64.0|0.95|EpsilonGreedy-0.6|SimpleNetwork|Adam|0.0010|ExperienceReplay|512|90.0|1.0|1.0|1.0|500.0| +|DoubleDQN|1.0|64.0|1.00|EpsilonGreedy-0.6|SimpleNetwork|Adam|0.0010|ExperienceReplay|2048|100.0|1.0|1.0|1.0|500.0| +|DoubleDQN|1.0|64.0|0.95|EpsilonGreedy-0.6|SimpleNetwork|Adam|0.0010|ExperienceReplay|512|150.0|1.0|1.0|1.0|500.0| -Only two train reach 500 for duelingDQN +with SimpleDuelingNetwork: | algo | step_train | batch_size | gamma | greedy_exploration | network | optimizer | lr | memories | max_size | step | max | min | avg | sum | | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | -|DuelingDQN|4.0|32.0|0.99|AdaptativeEpsilonGreedy-0.8-0.2-50000-0|SimpleDuelingNetwork|Adam|0.0001|ExperienceReplay|16 |498.0 |1.0|1.0|1.0|500.0| -|DuelingDQN|4.0|64.0|0.99|EpsilonGreedy-0.6|SimpleDuelingNetwork|Adam|0.0001|ExperienceReplay|32|498.0|1.0|1.0|1.0|500.0| +|DoubleDQN|1.0|32.0|1.00|AdaptativeEpsilonGreedy-0.8-0.2-10000-0|SimpleDuelingNetwork|Adam|0.0010|ExperienceReplay|2048|20.0|1.0|1.0|1.0|500.0| +|DoubleDQN|1.0|64.0|1.00|EpsilonGreedy-0.6|SimpleDuelingNetwork|Adam|0.1000|ExperienceReplay|512|60.0|1.0|1.0|1.0|500.0| +|DoubleDQN|1.0|64.0|0.95|AdaptativeEpsilonGreedy-0.8-0.2-10000-0|SimpleDuelingNetwork|Adam|0.0001|ExperienceReplay|2048|110.0|1.0|1.0|1.0|500.0| +|DoubleDQN|1.0|32.0|0.95|EpsilonGreedy-0.1|SimpleDuelingNetwork|Adam|0.0010|ExperienceReplay|2048|130.0|1.0|1.0|1.0|500.0| +|DoubleDQN|1.0|32.0|1.00|AdaptativeEpsilonGreedy-0.8-0.2-10000-0|SimpleDuelingNetwork|Adam|0.0001|ExperienceReplay|512|130.0|1.0|1.0|1.0|500.0| +|DoubleDQN|1.0|64.0|0.95|AdaptativeEpsilonGreedy-0.8-0.2-10000-0|SimpleDuelingNetwork|Adam|0.0001|ExperienceReplay|2048|130.0|1.0|1.0|1.0|500.0| * CategoricalDQN -The best training reach 405 steps in one episode. +500 step is reach with after 60 episode with this parameters for only one training. It's less stable then other agent. | algo | step_train | batch_size | gamma | greedy_exploration | network | optimizer | lr | memories | max_size | step | max | min | avg | sum | | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | -|CategoricalDQN|32.0|1.0|0.99|AdaptativeEpsilonGreedy-0.3-0.1-50000-0|C51Network|Adam|0.0010|ExperienceReplay|32|332.0|1.0|1.0|1.0|405.0| +|CategoricalDQN|1.0|64.0|0.95|EpsilonGreedy-0.1|C51Network|Adam|0.0010|ExperienceReplay|2048|60.0|1.0|1.0|1.0|500.0| +|CategoricalDQN|32.0|64.0|1.00|AdaptativeEpsilonGreedy-0.8-0.2-10000-0|C51Network|Adam|0.0001|ExperienceReplay|512|180.0|1.0|1.0|1.0|500.0| -#### Parameters importance -There are two principal parameters *batch size* and number of *step* trained. When is bigger, much we have best result. +#### Parameters importance -LR and the number of times the model is trained is canceled. When one is large the other is smaller for equivalent results. +There are three principal parameters *step* and number of *step* trained and learning rate. ### Reproduce this result ```batch -python result.py --env "CartPole-v1" --max_episode 500 +python result.py --env "CartPole-v1" --max_episode 300 ``` ## Env2 diff --git a/results/ressources/CartPoleEvaluation.png b/results/ressources/CartPoleEvaluation.png index d4482a3..91e7ee1 100644 Binary files a/results/ressources/CartPoleEvaluation.png and b/results/ressources/CartPoleEvaluation.png differ diff --git a/results/ressources/CartPoleTrainning.png b/results/ressources/CartPoleTrainning.png index 69be2c1..dccc516 100644 Binary files a/results/ressources/CartPoleTrainning.png and b/results/ressources/CartPoleTrainning.png differ diff --git a/results/result.py b/results/result.py index f6ea8e9..2887969 100644 --- a/results/result.py +++ b/results/result.py @@ -1,28 +1,29 @@ from argparse import ArgumentParser +from os import path +import shutil import gym import torch -from gym.spaces import flatdim from torch import optim from blobrl import Trainer -from blobrl.agents import DQN, DoubleDQN, DuelingDQN, CategoricalDQN +from blobrl.agents import DQN, DoubleDQN, CategoricalDQN from blobrl.explorations import EpsilonGreedy, AdaptativeEpsilonGreedy from blobrl.memories import ExperienceReplay from blobrl.networks import SimpleNetwork, SimpleDuelingNetwork, C51Network memory = [ExperienceReplay] -step_train = [1, 4, 32] -batch_size = [1, 32, 64] -gamma = [0.99] +step_train = [1, 32] +batch_size = [32, 64] +gamma = [1.0, 0.99, 0.95] loss = [torch.nn.MSELoss()] optimizer = [optim.Adam] lr = [0.1, 0.001, 0.0001] greedy_exploration = [EpsilonGreedy(0.1), EpsilonGreedy(0.6), - AdaptativeEpsilonGreedy(0.3, 0.1, 50000, 0), - AdaptativeEpsilonGreedy(0.8, 0.2, 50000, 0)] + AdaptativeEpsilonGreedy(0.3, 0.1, 30000, 0), + AdaptativeEpsilonGreedy(0.8, 0.2, 10000, 0)] arg_all = [{"agent": {"class": [DQN, DoubleDQN], "param": {"step_train": step_train, @@ -31,26 +32,13 @@ "greedy_exploration": greedy_exploration } }, - "neural_network": {"class": [SimpleNetwork], - "param": {}}, + "network": {"class": [SimpleNetwork], + "param": {}}, "optimizer": {"class": optimizer, "param": {"lr": lr}}, "memory": {"class": memory, - "param": {"max_size": [16, 32, 128]}}, - }, - {"agent": {"class": [DuelingDQN], - "param": {"step_train": step_train, - "batch_size": batch_size, - "gamma": gamma, - "greedy_exploration": greedy_exploration - } - }, - "neural_network": {"class": [SimpleDuelingNetwork], - "param": {}}, - "optimizer": {"class": optimizer, - "param": {"lr": lr}}, - "memory": {"class": memory, - "param": {"max_size": [16, 32, 128]}}, + "param": {"max_size": [512, 2048]}}, + "dueling": True }, {"agent": {"class": [CategoricalDQN], "param": {"step_train": step_train, @@ -59,12 +47,13 @@ "greedy_exploration": greedy_exploration } }, - "neural_network": {"class": [C51Network], - "param": {}}, + "network": {"class": [C51Network], + "param": {}}, "optimizer": {"class": optimizer, "param": {"lr": lr}}, "memory": {"class": memory, - "param": {"max_size": [16, 32, 128]}}, + "param": {"max_size": [512, 2048]}}, + "dueling": False }] @@ -93,7 +82,7 @@ def dict_mzip(x): if __name__ == '__main__': parser = ArgumentParser() parser.add_argument('--env', type=str, help='name of environment', nargs='?', const=1, default="CartPole-v1") - parser.add_argument('--max_episode', type=int, help='number of episode for train', nargs='?', const=1, default=500) + parser.add_argument('--max_episode', type=int, help='number of episode for train', nargs='?', const=1, default=300) parser.add_argument('--render', type=bool, help='if show render on each step or not', nargs='?', default=False) args = parser.parse_args() @@ -107,14 +96,14 @@ def dict_mzip(x): for arg in arg_all: arg_agent = arg["agent"] - arg_neural_network = arg["neural_network"] + arg_network = arg["network"] arg_optimizer = arg["optimizer"] arg_memory = arg["memory"] for class_agent in arg_agent["class"]: print("###" + class_agent.__name__) - for class_neural_network in arg_neural_network["class"]: - print(" ##" + class_neural_network.__name__) + for class_network in arg_network["class"]: + print(" ##" + class_network.__name__) for class_optimizer in arg_optimizer["class"]: print(" #" + class_optimizer.__name__) @@ -122,36 +111,85 @@ def dict_mzip(x): for param_agent in dict_mzip(arg_agent["param"]): - for param_network in dict_mzip(arg_neural_network["param"]): + for param_network in dict_mzip(arg_network["param"]): for param_optimizer in dict_mzip(arg_optimizer["param"]): for param_memory in dict_mzip(arg_memory["param"]): - neural_network = class_neural_network( - observation_shape=flatdim(env.observation_space), - action_shape=flatdim(env.action_space), - **param_network) - optimizer = class_optimizer(neural_network.parameters(), **param_optimizer) - - memory = class_memory(**param_memory) - - agent = class_agent(observation_space=env.observation_space, - action_space=env.action_space, - neural_network=neural_network, - optimizer=optimizer, memory=memory, device=device, - **param_agent) log_dir = args.env + "/" + class_agent.__name__ + "/" + "_".join( [str(x) for x in list( - param_agent.values())]) + "_" + class_neural_network.__name__ + "_" + "_".join( + param_agent.values())]) + "_" + class_network.__name__ + "_" + "_".join( [str(x) for x in list( param_network.values())]) + "_" + class_optimizer.__name__ + "_" + "_".join( [str(x) for x in list( param_optimizer.values())]) + "_" + class_memory.__name__ + "_" + "_".join( [str(x) for x in list(param_memory.values())]) - trainer = Trainer(environment=env, agent=agent, - log_dir=log_dir) - trainer.train(max_episode=args.max_episode, render=args.render) - - agent.save(file_name="save.p", dire_name=log_dir) + try: + if not path.exists(log_dir): + network = class_network( + observation_space=env.observation_space, + action_space=env.action_space, + **param_network) + optimizer = class_optimizer(network.parameters(), **param_optimizer) + + memory = class_memory(**param_memory) + + agent = class_agent(observation_space=env.observation_space, + action_space=env.action_space, + network=network, + optimizer=optimizer, memory=memory, device=device, + **param_agent) + + trainer = Trainer(environment=env, agent=agent, + log_dir=log_dir) + trainer.train(max_episode=args.max_episode, render=False, + nb_evaluation=int(args.max_episode / 10)) + + agent.save(file_name="save.p", dire_name=log_dir) + except KeyboardInterrupt: + del trainer + shutil.rmtree(log_dir) + break + + if arg["dueling"]: + log_dir = args.env + "/" + class_agent.__name__ + "/" + "_".join( + [str(x) for x in list( + param_agent.values())]) + "_" + SimpleDuelingNetwork.__name__ + "_" + "_".join( + [str(x) for x in list( + param_network.values())]) + "_" + class_optimizer.__name__ + "_" + "_".join( + [str(x) for x in list( + param_optimizer.values())]) + "_" + class_memory.__name__ + "_" + "_".join( + [str(x) for x in list(param_memory.values())]) + + try: + if not path.exists(log_dir): + base_network = class_network( + observation_space=env.observation_space, + action_space=env.action_space, + **param_network) + + network = SimpleDuelingNetwork(base_network) + + optimizer = class_optimizer(network.parameters(), **param_optimizer) + + memory = class_memory(**param_memory) + + agent = class_agent(observation_space=env.observation_space, + action_space=env.action_space, + network=network, + optimizer=optimizer, memory=memory, + device=device, + **param_agent) + + trainer = Trainer(environment=env, agent=agent, + log_dir=log_dir) + trainer.train(max_episode=args.max_episode, render=False, + nb_evaluation=int(args.max_episode / 10)) + + agent.save(file_name="save.p", dire_name=log_dir) + except KeyboardInterrupt: + del trainer + shutil.rmtree(log_dir) + break diff --git a/setup.py b/setup.py index 49233db..ba43734 100644 --- a/setup.py +++ b/setup.py @@ -30,6 +30,7 @@ setuptools.setup( author="french ai team", name='blobrl', + version='0.1.0', license="Apache-2.0", description='Reinforcement learning with pytorch ', long_description=README, diff --git a/tests/agents/__init__.py b/tests/agents/__init__.py index e69de29..ebacbcf 100644 --- a/tests/agents/__init__.py +++ b/tests/agents/__init__.py @@ -0,0 +1,4 @@ +from .test_agent_interface import TestAgentInterface +from .test_dqn import TestDQN +from .test_categorical_dqn import TestCategorical_DQN +from .test_double_dqn import TestDouble_DQN diff --git a/tests/agents/test_agent_constant.py b/tests/agents/test_agent_constant.py index f0c14b1..afd727e 100644 --- a/tests/agents/test_agent_constant.py +++ b/tests/agents/test_agent_constant.py @@ -1,86 +1,68 @@ import os - -import numpy as np import pytest -from gym.spaces import Box, Dict, Discrete, MultiBinary, MultiDiscrete, Tuple from blobrl.agents import AgentConstant +from tests.agents import TestAgentInterface -def test_agent_don_t_work_with_no_space(): - test_list = [1, 100, "100", "somethink", [], dict(), 0.0, 1245.215, None] - for action_space in test_list: - with pytest.raises(TypeError): - AgentConstant(observation_space=Discrete(1), action_space=action_space) - - for observation_space in test_list: - with pytest.raises(TypeError): - AgentConstant(observation_space=observation_space, action_space=Discrete(1)) - - -base_list = {"box": Box(low=-1.0, high=2.0, shape=(3, 4), dtype=np.float32), "discrete": Discrete(3), - "multibinary": MultiBinary(10), "multidiscrete": MultiDiscrete(10)} -dict_list = Dict(base_list) -tuple_list = Tuple(list(base_list.values())) - -test_list = [*base_list.values(), dict_list, tuple_list] - - -def test_agent_work_with_space(): - for space in test_list: - agent = AgentConstant(observation_space=space, action_space=space) - assert agent.observation_space == space - assert agent.action_space == space - -def test_agent_get_action(): - for space in test_list: - agent = AgentConstant(observation_space=space, action_space=space) - agent.get_action(None) +class TestAgentConstant(TestAgentInterface): + __test__ = True + agent = AgentConstant -def test_agent_learn(): - for space in test_list: - agent = AgentConstant(observation_space=space, action_space=space) - agent.learn(None, None, None, None, None) + def test_init(self): + for o, a in self.list_work: + self.agent(o, a) + for o, a in self.list_fail: + with pytest.raises(TypeError): + self.agent(o, a) -def test_agent_episode_finished(): - for space in test_list: - agent = AgentConstant(observation_space=space, action_space=space) - agent.episode_finished() + def test_get_action(self): + for o, a in self.list_work: + agent = self.agent(observation_space=o, action_space=a) + agent.get_action(None) + def test_learn(self): + for o, a in self.list_work: + agent = self.agent(observation_space=o, action_space=a) + agent.learn(None, None, None, None, None) -def test_agent_save_load(): - for space in test_list: - agent = AgentConstant(observation_space=space, action_space=space) + def test_episode_finished(self): + for o, a in self.list_work: + agent = self.agent(observation_space=o, action_space=a) + agent.episode_finished() - agent.save(file_name="deed.pt") - agent_l = AgentConstant.load(file_name="deed.pt") + def test_agent_save_load(self): + for o, a in self.list_work: + agent = self.agent(observation_space=o, action_space=a) - assert agent.observation_space == agent_l.observation_space - assert agent.action_space == agent_l.action_space - os.remove("deed.pt") + agent.save(file_name="deed.pt") + agent_l = self.agent.load(file_name="deed.pt") - agent = AgentConstant(observation_space=space, action_space=space) - agent.save(file_name="deed.pt", dire_name="./remove/") + assert agent.observation_space == agent_l.observation_space + assert agent.action_space == agent_l.action_space + os.remove("deed.pt") - os.remove("./remove/deed.pt") - os.rmdir("./remove/") + agent = self.agent(observation_space=o, action_space=a) + agent.save(file_name="deed.pt", dire_name="./remove/") - with pytest.raises(TypeError): - agent.save(file_name=14548) - with pytest.raises(TypeError): - agent.save(file_name="deed.pt", dire_name=14484) + os.remove("./remove/deed.pt") + os.rmdir("./remove/") - with pytest.raises(FileNotFoundError): - AgentConstant.load(file_name="deed.pt") - with pytest.raises(FileNotFoundError): - AgentConstant.load(file_name="deed.pt", dire_name="/Dede/") + with pytest.raises(TypeError): + agent.save(file_name=14548) + with pytest.raises(TypeError): + agent.save(file_name="deed.pt", dire_name=14484) + with pytest.raises(FileNotFoundError): + self.agent.load(file_name="deed.pt") + with pytest.raises(FileNotFoundError): + self.agent.load(file_name="deed.pt", dire_name="/Dede/") -def test__str__(): - for space in test_list: - agent = AgentConstant(observation_space=space, action_space=space) + def test__str__(self): + for o, a in self.list_work: + agent = self.agent(observation_space=o, action_space=a) - assert 'AgentConstant-' + str(space) + "-" + str(space) + "-" + str(agent.action) == agent.__str__() + assert 'AgentConstant-' + str(o) + "-" + str(a) + "-" + str(agent.action) == agent.__str__() diff --git a/tests/agents/test_agent_interface.py b/tests/agents/test_agent_interface.py index ba5944c..f734fb6 100644 --- a/tests/agents/test_agent_interface.py +++ b/tests/agents/test_agent_interface.py @@ -1,17 +1,13 @@ import pytest import torch +from gym.spaces import Discrete, MultiDiscrete, MultiBinary, Dict, Tuple, Box from blobrl.agents import AgentInterface -def test_can_t_instantiate_agent_interface(): - with pytest.raises(TypeError): - AgentInterface() - - class MOCKAgentInterface(AgentInterface): - def __init__(self, device): - super().__init__(device) + def __init__(self, observation_space, action_space, device): + super().__init__(observation_space, action_space, device) def get_action(self, observation): pass @@ -39,14 +35,68 @@ def __str__(self): return "" -def test_device(): - device = torch.device("cpu") - assert device == MOCKAgentInterface(device).device +class TestAgentInterface: + __test__ = True + + agent = MOCKAgentInterface + + list_work = [ + [Discrete(3), Discrete(1)], + [Discrete(3), Discrete(3)], + [Discrete(10), Discrete(50)], + [MultiDiscrete([3]), MultiDiscrete([1])], + [MultiDiscrete([3, 3]), MultiDiscrete([3, 3])], + [MultiDiscrete([4, 4, 4]), MultiDiscrete([50, 4, 4])], + [MultiDiscrete([[100, 3], [3, 5]]), MultiDiscrete([[100, 3], [3, 5]])], + [MultiDiscrete([[[100, 3], [3, 5]], [[100, 3], [3, 5]]]), + MultiDiscrete([[[100, 3], [3, 5]], [[100, 3], [3, 5]]])], + [MultiBinary(1), MultiBinary(1)], + [MultiBinary(3), MultiBinary(3)], + # [MultiBinary([3, 2]), MultiBinary([3, 2])], # Don't work yet because gym don't implemented this + [Box(low=0, high=10, shape=[1]), Box(low=0, high=10, shape=[1])], + [Box(low=0, high=10, shape=[2, 2]), Box(low=0, high=10, shape=[2, 2])], + [Box(low=0, high=10, shape=[2, 2, 2]), Box(low=0, high=10, shape=[2, 2, 2])], + + [Tuple([Discrete(1), MultiDiscrete([1, 1])]), Tuple([Discrete(1), MultiDiscrete([1, 1])])], + [Dict({"first": Discrete(1), "second": MultiDiscrete([1, 1])}), + Dict({"first": Discrete(1), "second": MultiDiscrete([1, 1])})], + + ] + list_fail = [ + [None, None], + ["dedrfe", "qdzq"], + [1215.4154, 157.48], + ["zdzd", (Discrete(1))], + [Discrete(1), "zdzd"], + ["zdzd", (1, 4, 7)], + [(1, 4, 7), "zdzd"], + [152, 485] + ] + + def test_init(self): + for o, a in self.list_work: + with pytest.raises(TypeError): + self.agent(o, a, "cpu") + + for o, a in self.list_fail: + with pytest.raises(TypeError): + self.agent(o, a, "cpu") + + def test_device(self): + for o, a in self.list_work: + device = torch.device("cpu") + assert device == self.agent(o, a, device).device + + device = None + assert torch.device("cpu") == self.agent(o, a, device).device + + for device in ["dzeqdzqd", 1512, object(), 151.515]: + with pytest.raises(TypeError): + self.agent(o, a, device) + + if torch.cuda.is_available(): + self.agent(o, a, torch.device("cuda")) + + def test__str__(self): - device = None - assert torch.device("cpu") == MOCKAgentInterface(device).device - - devices = ["dzeqdzqd", 1512, object(), 151.515] - for device in devices: - with pytest.raises(TypeError): - MOCKAgentInterface(device) + pass diff --git a/tests/agents/test_agent_random.py b/tests/agents/test_agent_random.py index b8772b4..4aada4d 100644 --- a/tests/agents/test_agent_random.py +++ b/tests/agents/test_agent_random.py @@ -1,86 +1,68 @@ import os - -import numpy as np import pytest -from gym.spaces import Box, Dict, Discrete, MultiBinary, MultiDiscrete, Tuple from blobrl.agents import AgentRandom - -def test_agent_don_t_work_with_no_space(): - test_list = [1, 100, "100", "somethink", [], dict(), 0.0, 1245.215, None] - for action_space in test_list: - with pytest.raises(TypeError): - AgentRandom(observation_space=Discrete(1), action_space=action_space) - - for observation_space in test_list: - with pytest.raises(TypeError): - AgentRandom(observation_space=observation_space, action_space=Discrete(1)) - - -base_list = {"box": Box(low=-1.0, high=2.0, shape=(3, 4), dtype=np.float32), "discrete": Discrete(3), - "multibinary": MultiBinary(10), "multidiscrete": MultiDiscrete(10)} -dict_list = Dict(base_list) -tuple_list = Tuple(list(base_list.values())) - -test_list = [*base_list.values(), dict_list, tuple_list] - - -def test_agent_work_with_space(): - for space in test_list: - agent = AgentRandom(observation_space=space, action_space=space) - assert agent.observation_space == space - assert agent.action_space == space +from tests.agents import TestAgentInterface -def test_agent_get_action(): - for space in test_list: - agent = AgentRandom(observation_space=space, action_space=space) - agent.get_action(None) +class TestAgentRandom(TestAgentInterface): + __test__ = True + agent = AgentRandom -def test_agent_learn(): - for space in test_list: - agent = AgentRandom(observation_space=space, action_space=space) - agent.learn(None, None, None, None, None) + def test_init(self): + for o, a in self.list_work: + self.agent(o, a) + for o, a in self.list_fail: + with pytest.raises(TypeError): + self.agent(o, a) -def test_agent_episode_finished(): - for space in test_list: - agent = AgentRandom(observation_space=space, action_space=space) - agent.episode_finished() + def test_get_action(self): + for o, a in self.list_work: + agent = self.agent(observation_space=o, action_space=a) + agent.get_action(None) + def test_learn(self): + for o, a in self.list_work: + agent = self.agent(observation_space=o, action_space=a) + agent.learn(None, None, None, None, None) -def test_agent_save_load(): - for space in test_list: - agent = AgentRandom(observation_space=space, action_space=space) + def test_episode_finished(self): + for o, a in self.list_work: + agent = self.agent(observation_space=o, action_space=a) + agent.episode_finished() - agent.save(file_name="deed.pt") - agent_l = AgentRandom.load(file_name="deed.pt") + def test_agent_save_load(self): + for o, a in self.list_work: + agent = self.agent(observation_space=o, action_space=a) - assert agent.observation_space == agent_l.observation_space - assert agent.action_space == agent_l.action_space - os.remove("deed.pt") + agent.save(file_name="deed.pt") + agent_l = self.agent.load(file_name="deed.pt") - agent = AgentRandom(observation_space=space, action_space=space) - agent.save(file_name="deed.pt", dire_name="./remove/") + assert agent.observation_space == agent_l.observation_space + assert agent.action_space == agent_l.action_space + os.remove("deed.pt") - os.remove("./remove/deed.pt") - os.rmdir("./remove/") + agent = self.agent(observation_space=o, action_space=a) + agent.save(file_name="deed.pt", dire_name="./remove/") - with pytest.raises(TypeError): - agent.save(file_name=14548) - with pytest.raises(TypeError): - agent.save(file_name="deed.pt", dire_name=14484) + os.remove("./remove/deed.pt") + os.rmdir("./remove/") - with pytest.raises(FileNotFoundError): - AgentRandom.load(file_name="deed.pt") - with pytest.raises(FileNotFoundError): - AgentRandom.load(file_name="deed.pt", dire_name="/Dede/") + with pytest.raises(TypeError): + agent.save(file_name=14548) + with pytest.raises(TypeError): + agent.save(file_name="deed.pt", dire_name=14484) + with pytest.raises(FileNotFoundError): + self.agent.load(file_name="deed.pt") + with pytest.raises(FileNotFoundError): + self.agent.load(file_name="deed.pt", dire_name="/Dede/") -def test__str__(): - for space in test_list: - agent = AgentRandom(observation_space=space, action_space=space) + def test__str__(self): + for o, a in self.list_work: + agent = self.agent(observation_space=o, action_space=a) - assert 'AgentRandom-' + str(space) + "-" + str(space) == agent.__str__() + assert 'AgentRandom-' + str(o) + "-" + str(a) == agent.__str__() diff --git a/tests/agents/test_categorical_dqn.py b/tests/agents/test_categorical_dqn.py index 3c27367..27eda9d 100644 --- a/tests/agents/test_categorical_dqn.py +++ b/tests/agents/test_categorical_dqn.py @@ -1,135 +1,19 @@ -import pytest -import torch -import torch.optim as optim -from gym.spaces import Discrete, Box - from blobrl.agents import CategoricalDQN -from blobrl.explorations import Greedy, EpsilonGreedy -from blobrl.memories import ExperienceReplay from blobrl.networks import C51Network +from tests.agents import TestDQN -def test_categorical_dqn_agent_instantiation(): - CategoricalDQN(Discrete(4), Discrete(4)) - - -def test_categorical_dqn_agent_instantiation_error_action_space(): - with pytest.raises(TypeError): - CategoricalDQN(None, Discrete(1)) - - -def test_categorical_dqn_agent_instantiation_error_observation_space(): - with pytest.raises(TypeError): - CategoricalDQN(Discrete(1), None) - - -def test_categorical_dqn_agent_instantiation_error_neural_network(): - with pytest.raises(TypeError): - CategoricalDQN(Discrete(4), Discrete(4), neural_network=154) - - -def test_categorical_dqn_agent_instantiation_error_memory(): - with pytest.raises(TypeError): - CategoricalDQN(Discrete(4), Discrete(4), None) - - -def test_categorical_dqn_agent_instantiation_error_loss(): - with pytest.raises(TypeError): - CategoricalDQN(Discrete(4), Discrete(4), loss="LOSS_ERROR") - - -def test_categorical_dqn_agent_instantiation_error_optimizer(): - with pytest.raises(TypeError): - CategoricalDQN(Discrete(4), Discrete(4), optimizer="OPTIMIZER_ERROR") - - -def test_categorical_dqn_agent_instantiation_error_greedy_exploration(): - with pytest.raises(TypeError): - CategoricalDQN(Discrete(4), Discrete(4), greedy_exploration="GREEDY_EXPLORATION_ERROR") - - -def test_categorical_dqn_agent_instantiation_custom_optimizer(): - c51 = C51Network((1), (1)) - - CategoricalDQN(Discrete(4), Discrete(4), neural_network=c51, optimizer=optim.RMSprop(c51.parameters())) - - with pytest.raises(TypeError): - CategoricalDQN(Discrete(4), Discrete(4), neural_network=None, optimizer=optim.RMSprop(c51.parameters())) - - -def test_categorical_dqn_agent_getaction(): - agent = CategoricalDQN(Discrete(4), Box(0, 3, (3,)), greedy_exploration=Greedy()) - - observation = [0, 1, 2] - - agent.get_action(observation) - - -def test_categorical_dqn_agent_getaction_non_greedy(): - agent = CategoricalDQN(Discrete(4), Box(0, 3, (3,)), greedy_exploration=EpsilonGreedy(1.)) - - observation = [0, 1, 2] - - agent.get_action(observation) - - -def test_categorical_dqn_agent_learn(): - memory = ExperienceReplay(max_size=5) - - agent = CategoricalDQN(Discrete(4), Box(1, 10, (4,)), memory) - - obs = [1, 2, 5, 0] - action = 0 - reward = 0 - next_obs = [5, 9, 4, 0] - done = False - - obs_s = [obs, obs, obs] - actions = [1, 2, 3] - rewards = [-2.2, 5, 4] - next_obs_s = [next_obs, next_obs, next_obs] - dones = [False, True, False] - - memory.extend(obs_s, actions, rewards, next_obs_s, dones) - - agent.learn(obs, action, reward, next_obs, done) - agent.learn(obs, action, reward, next_obs, done) - - -def test_categorical_dqn_agent_episode_finished(): - agent = CategoricalDQN(Discrete(4), Discrete(4)) - agent.episode_finished() - - -def test__str__(): - agent = CategoricalDQN(Discrete(4), Box(1, 10, (4,))) - - assert 'CategoricalDQN-' + str(agent.observation_space) + "-" + str(agent.action_space) + "-" + str( - agent.neural_network) + "-" + str(agent.memory) + "-" + str(agent.step_train) + "-" + str( - agent.step) + "-" + str(agent.batch_size) + "-" + str(agent.gamma) + "-" + str(agent.loss) + "-" + str( - agent.optimizer) + "-" + str(agent.greedy_exploration) + "-" + str(agent.num_atoms) + "-" + str( - agent.r_min) + "-" + str(agent.r_max) + "-" + str(agent.delta_z) + "-" + str(agent.z) == agent.__str__() - - -def test_device_gpu(): - if torch.cuda.is_available(): - memory = ExperienceReplay(max_size=5) - - agent = CategoricalDQN(Discrete(4), Box(1, 10, (4,)), memory, device=torch.device("cuda")) - - obs = [1, 2, 5, 0] - action = 0 - reward = 0 - next_obs = [5, 9, 4, 0] - done = False - obs_s = [obs, obs, obs] - actions = [1, 2, 3] - rewards = [-2.2, 5, 4] - next_obs_s = [next_obs, next_obs, next_obs] - dones = [False, True, False] +class TestCategorical_DQN(TestDQN): + agent = CategoricalDQN + network = C51Network - memory.extend(obs_s, actions, rewards, next_obs_s, dones) + def test__str__(self): + for o, a in self.list_work: + agent = self.agent(o, a) - agent.learn(obs, action, reward, next_obs, done) - agent.learn(obs, action, reward, next_obs, done) + assert 'CategoricalDQN-' + str(agent.observation_space) + "-" + str(agent.action_space) + "-" + str( + agent.network) + "-" + str(agent.memory) + "-" + str(agent.step_train) + "-" + str( + agent.step) + "-" + str(agent.batch_size) + "-" + str(agent.gamma) + "-" + str(agent.loss) + "-" + str( + agent.optimizer) + "-" + str(agent.greedy_exploration) + "-" + str(agent.num_atoms) + "-" + str( + agent.r_min) + "-" + str(agent.r_max) + "-" + str(agent.delta_z) + "-" + str(agent.z) == agent.__str__() diff --git a/tests/agents/test_double_dqn.py b/tests/agents/test_double_dqn.py index 6865087..3c50f70 100644 --- a/tests/agents/test_double_dqn.py +++ b/tests/agents/test_double_dqn.py @@ -1,261 +1,27 @@ -import os - -import numpy as np -import pytest -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.optim as optim -from gym.spaces import Discrete, Box, MultiBinary, MultiDiscrete, Dict, Tuple - from blobrl.agents import DoubleDQN -from blobrl.explorations import Greedy, EpsilonGreedy +from tests.agents import TestDQN +from gym.spaces import flatten from blobrl.memories import ExperienceReplay -from blobrl.networks import BaseNetwork - - -class Network(BaseNetwork): - def __str__(self): - return 'Network' - - def __init__(self, observation_shape=None, action_shape=None): - super().__init__(observation_shape, action_shape) - self.dense = nn.Linear(3, 4) - - def forward(self, x): - x = self.dense(x) - x = F.relu(x) - return x - - -def test_double_dqn_agent_instantiation(): - network = Network() - memory = ExperienceReplay(max_size=5) - - DoubleDQN(Discrete(4), Discrete(3), neural_network=network, memory=memory) - DoubleDQN(Discrete(4), Discrete(3)) - - -def test_double_dqn_agent_instantiation_error_action_space(): - network = Network() - memory = ExperienceReplay(max_size=5) - - with pytest.raises(TypeError): - DoubleDQN("ACTION_SPACE_ERROR", Discrete(3), neural_network=network, memory=memory) - - -def test_double_dqn_agent_instantiation_error_observation_space(): - network = Network() - memory = ExperienceReplay(max_size=5) - - with pytest.raises(TypeError): - DoubleDQN(Discrete(1), "OBSERVATION_SPACE_ERROR", neural_network=network, memory=memory) - - -def test_double_dqn_agent_instantiation_error_neural_network(): - memory = ExperienceReplay(max_size=5) - - with pytest.raises(TypeError): - DoubleDQN(Discrete(4), Discrete(3), neural_network="NEURAL_NETWORK_ERROR", memory=memory) - - -def test_double_dqn_agent_instantiation_error_memory(): - network = Network() - - with pytest.raises(TypeError): - DoubleDQN(Discrete(4), Discrete(3), neural_network=network, memory="MEMORY_ERROR") - - -def test_double_dqn_agent_instantiation_error_loss(): - network = Network() - memory = ExperienceReplay(max_size=5) - - with pytest.raises(TypeError): - DoubleDQN(Discrete(4), Discrete(3), neural_network=network, memory=memory, loss="LOSS_ERROR") - - -def test_double_dqn_agent_instantiation_error_optimizer(): - network = Network() - memory = ExperienceReplay(max_size=5) - - with pytest.raises(TypeError): - DoubleDQN(Discrete(4), Discrete(3), neural_network=network, memory=memory, optimizer="OPTIMIZER_ERROR") - - -def test_double_dqn_agent_instantiation_error_greedy_exploration(): - network = Network() - memory = ExperienceReplay(max_size=5) - - with pytest.raises(TypeError): - DoubleDQN(Discrete(4), Discrete(3), neural_network=network, memory=memory, - greedy_exploration="GREEDY_EXPLORATION_ERROR") - - -def test_double_dqn_agent_instantiation_custom_loss(): - network = Network() - memory = ExperienceReplay(max_size=5) - - DoubleDQN(Discrete(4), Discrete(3), neural_network=network, memory=memory, loss=nn.MSELoss()) - - -def test_double_dqn_agent_instantiation_custom_optimizer(): - network = Network() - memory = ExperienceReplay(max_size=5) - - DoubleDQN(Discrete(4), Discrete(3), neural_network=network, memory=memory, - optimizer=optim.RMSprop(network.parameters())) - - -def test_double_dqn_agent_getaction(): - network = Network() - memory = ExperienceReplay(max_size=5) - - agent = DoubleDQN(Discrete(4), Discrete(3), neural_network=network, memory=memory, greedy_exploration=Greedy()) - - observation = [0, 1, 2] - - agent.get_action(observation) - - -def test_double_dqn_agent_getaction_non_greedy(): - network = Network() - memory = ExperienceReplay(max_size=5) - - agent = DoubleDQN(Discrete(4), Discrete(3), neural_network=network, memory=memory, - greedy_exploration=EpsilonGreedy(1.)) - - observation = [0, 1, 2] - - agent.get_action(observation) - - -def test_double_dqn_agent_learn(): - network = Network() - memory = ExperienceReplay(max_size=5) - - agent = DoubleDQN(Discrete(4), Discrete(3), neural_network=network, memory=memory, step_copy=2) - - obs = [1, 2, 5] - action = 0 - reward = 0 - next_obs = [5, 9, 4] - done = False - - obs_s = [obs, obs, obs] - actions = [1, 2, 3] - rewards = [-2.2, 5, 4] - next_obs_s = [next_obs, next_obs, next_obs] - dones = [False, True, False] - - memory.extend(obs_s, actions, rewards, next_obs_s, dones) - - agent.learn(obs, action, reward, next_obs, done) - agent.learn(obs, action, reward, next_obs, done) - - -def test_double_dqn_agent_episode_finished(): - network = Network() - memory = ExperienceReplay(max_size=5) - - agent = DoubleDQN(Discrete(4), Discrete(3), neural_network=network, memory=memory) - agent.episode_finished() - - -base_list = {"box": Box(low=-1.0, high=2.0, shape=(3, 4), dtype=np.float32), "discrete": Discrete(3), - "multibinary": MultiBinary(10), "multidiscrete": MultiDiscrete(10)} -dict_list = Dict(base_list) -tuple_list = Tuple(list(base_list.values())) - -test_list = [*base_list.values(), dict_list, tuple_list] - - -def test_agent_save_load(): - for space in test_list: - agent = DoubleDQN(observation_space=space, action_space=Discrete(2)) - - agent.save(file_name="deed.pt") - agent_l = DoubleDQN.load(file_name="deed.pt") - - assert agent.observation_space == agent_l.observation_space - assert Discrete(2) == agent_l.action_space - os.remove("deed.pt") - - network = Network() - - agent = DoubleDQN(observation_space=space, action_space=Discrete(2), memory=ExperienceReplay(), - neural_network=network, step_train=3, batch_size=12, gamma=0.50, loss=None, - optimizer=torch.optim.Adam(network.parameters()), step_copy=300, - greedy_exploration=EpsilonGreedy(0.2)) - - agent.save(file_name="deed.pt") - agent_l = DoubleDQN.load(file_name="deed.pt") - - os.remove("deed.pt") - - assert agent.observation_space == agent_l.observation_space - assert Discrete(2) == agent_l.action_space - assert isinstance(agent.neural_network, type(agent_l.neural_network)) - for a, b in zip(agent.neural_network.state_dict(), agent_l.neural_network.state_dict()): - assert a == b - assert agent.step_train == agent_l.step_train - assert agent.batch_size == agent_l.batch_size - assert agent.gamma == agent_l.gamma - assert isinstance(agent.loss, type(agent_l.loss)) - for a, b in zip(agent.loss.parameters(), agent_l.loss.parameters()): - assert a == b - assert isinstance(agent.optimizer, type(agent_l.optimizer)) - for a, b in zip(agent.optimizer.state_dict(), agent_l.optimizer.state_dict()): - assert a == b - assert isinstance(agent.greedy_exploration, type(agent_l.greedy_exploration)) - assert agent.step_copy == agent_l.step_copy - - agent = DoubleDQN(observation_space=space, action_space=Discrete(2)) - agent.save(file_name="deed.pt", dire_name="./remove/") - - os.remove("./remove/deed.pt") - os.rmdir("./remove/") - - with pytest.raises(TypeError): - agent.save(file_name=14548) - with pytest.raises(TypeError): - agent.save(file_name="deed.pt", dire_name=14484) - - with pytest.raises(FileNotFoundError): - DoubleDQN.load(file_name="deed.pt") - with pytest.raises(FileNotFoundError): - DoubleDQN.load(file_name="deed.pt", dire_name="/Dede/") - - -def test__str__(): - agent = DoubleDQN(Discrete(4), Box(1, 10, (4,))) - - assert 'DoubleDQN-' + str(agent.observation_space) + "-" + str(agent.action_space) + "-" + str( - agent.neural_network) + "-" + str(agent.memory) + "-" + str(agent.step_train) + "-" + str( - agent.step) + "-" + str(agent.batch_size) + "-" + str(agent.gamma) + "-" + str(agent.loss) + "-" + str( - agent.optimizer) + "-" + str(agent.greedy_exploration) + "-" + str(agent.step_copy) == agent.__str__() -def test_device_gpu(): - if torch.cuda.is_available(): - network = Network() - memory = ExperienceReplay(max_size=5) +class TestDouble_DQN(TestDQN): + agent = DoubleDQN - agent = DoubleDQN(Discrete(4), Discrete(3), neural_network=network, memory=memory, step_copy=2, - device=torch.device("cuda")) + def test_learn(self): + for o, a in self.list_work: + network = self.network(o, a) + memory = ExperienceReplay(max_size=5) - obs = [1, 2, 5] - action = 0 - reward = 0 - next_obs = [5, 9, 4] - done = False + agent = self.agent(observation_space=o, action_space=a, memory=memory, step_copy=10, network=network) - obs_s = [obs, obs, obs] - actions = [1, 2, 3] - rewards = [-2.2, 5, 4] - next_obs_s = [next_obs, next_obs, next_obs] - dones = [False, True, False] + for i in range(20): + agent.learn(o.sample(), a.sample(), 0, o.sample(), False) - memory.extend(obs_s, actions, rewards, next_obs_s, dones) + def test__str__(self): + for o, a in self.list_work: + agent = self.agent(o, a) - agent.learn(obs, action, reward, next_obs, done) - agent.learn(obs, action, reward, next_obs, done) + assert 'DoubleDQN-' + str(agent.observation_space) + "-" + str(agent.action_space) + "-" + str( + agent.network) + "-" + str(agent.memory) + "-" + str(agent.step_train) + "-" + str( + agent.step) + "-" + str(agent.batch_size) + "-" + str(agent.gamma) + "-" + str(agent.loss) + "-" + str( + agent.optimizer) + "-" + str(agent.greedy_exploration) + "-" + str(agent.step_copy) == agent.__str__() diff --git a/tests/agents/test_dqn.py b/tests/agents/test_dqn.py index 4bd457e..7443708 100644 --- a/tests/agents/test_dqn.py +++ b/tests/agents/test_dqn.py @@ -1,276 +1,207 @@ import os - -import numpy as np import pytest import torch -import torch.nn as nn -import torch.nn.functional as F +from gym.spaces import Discrete, Box, MultiBinary, MultiDiscrete, Dict, Tuple, flatten import torch.optim as optim -from gym.spaces import Discrete, Box, MultiBinary, MultiDiscrete, Dict, Tuple from blobrl.agents import DQN from blobrl.explorations import Greedy, EpsilonGreedy from blobrl.memories import ExperienceReplay -from blobrl.networks import BaseNetwork - - -class Network(BaseNetwork): - - def __str__(self): - return 'Network' - - def __init__(self, observation_shape=None, action_shape=None): - super().__init__(observation_shape, action_shape) - self.dense = nn.Linear(4, 4) - - def forward(self, x): - x = self.dense(x) - x = F.relu(x) - return x - - -def test_dqn_agent_instantiation(): - network = Network() - memory = ExperienceReplay(max_size=5) - - DQN(Discrete(4), Discrete(4), memory, neural_network=network) - - -def test_dqn_agent_instantiation_error_action_space(): - network = Network() - memory = ExperienceReplay(max_size=5) - - with pytest.raises(TypeError): - DQN(None, Discrete(1), memory, neural_network=network) - - -def test_dqn_agent_instantiation_error_observation_space(): - network = Network() - memory = ExperienceReplay(max_size=5) - - with pytest.raises(TypeError): - DQN(Discrete(1), None, memory, neural_network=network) - - -def test_dqn_agent_instantiation_error_neural_network(): - memory = ExperienceReplay(max_size=5) - - with pytest.raises(TypeError): - DQN(Discrete(4), Discrete(4), memory, neural_network=154) - - DQN(Discrete(4), Discrete(4), memory, neural_network=None) - - -def test_dqn_agent_instantiation_error_memory(): - network = Network() - - with pytest.raises(TypeError): - DQN(Discrete(4), Discrete(4), None, neural_network=network) - - -def test_dqn_agent_instantiation_error_loss(): - network = Network() - memory = ExperienceReplay(max_size=5) - - with pytest.raises(TypeError): - DQN(Discrete(4), Discrete(4), memory, neural_network=network, loss="LOSS_ERROR") - - -def test_dqn_agent_instantiation_error_optimizer(): - network = Network() - memory = ExperienceReplay(max_size=5) - - with pytest.raises(TypeError): - DQN(Discrete(4), Discrete(4), memory, neural_network=network, optimizer="OPTIMIZER_ERROR") - - -def test_dqn_agent_instantiation_error_greedy_exploration(): - network = Network() - memory = ExperienceReplay(max_size=5) - - with pytest.raises(TypeError): - DQN(Discrete(4), Discrete(4), memory, neural_network=network, greedy_exploration="GREEDY_EXPLORATION_ERROR") - - -def test_dqn_agent_instantiation_custom_loss(): - network = Network() - memory = ExperienceReplay(max_size=5) - - DQN(Discrete(4), Discrete(4), memory, neural_network=network, loss=nn.MSELoss()) - - -def test_dqn_agent_instantiation_custom_optimizer(): - network = Network() - memory = ExperienceReplay(max_size=5) - - DQN(Discrete(4), Discrete(4), memory, neural_network=network, optimizer=optim.RMSprop(network.parameters())) - - with pytest.raises(TypeError): - DQN(Discrete(4), Discrete(4), memory, neural_network=None, optimizer=optim.RMSprop(network.parameters())) - - -def test_dqn_agent_getaction(): - network = Network() - memory = ExperienceReplay(max_size=5) - - agent = DQN(Discrete(4), Discrete(4), memory, neural_network=network, greedy_exploration=Greedy()) - - observation = [0, 1, 2] - - agent.get_action(observation) - - -def test_dqn_agent_getaction_non_greedy(): - network = Network() - memory = ExperienceReplay(max_size=5) - - agent = DQN(Discrete(4), Discrete(4), memory, neural_network=network, greedy_exploration=EpsilonGreedy(1.)) - - observation = [0, 1, 2] - - agent.get_action(observation) - - -def test_dqn_agent_learn(): - network = Network() - memory = ExperienceReplay(max_size=5) - - agent = DQN(Discrete(4), Discrete(4), memory, neural_network=network) - - obs = [1, 2, 5, 0] - action = 0 - reward = 0 - next_obs = [5, 9, 4, 0] - done = False - - obs_s = [obs, obs, obs] - actions = [1, 2, 3] - rewards = [-2.2, 5, 4] - next_obs_s = [next_obs, next_obs, next_obs] - dones = [False, True, False] - - memory.extend(obs_s, actions, rewards, next_obs_s, dones) - - agent.learn(obs, action, reward, next_obs, done) - agent.learn(obs, action, reward, next_obs, done) - - -def test_dqn_agent_episode_finished(): - network = Network() - memory = ExperienceReplay(max_size=5) - - agent = DQN(Discrete(4), Discrete(4), memory, neural_network=network) - agent.episode_finished() - - -base_list = {"box": Box(low=-1.0, high=2.0, shape=(3, 4), dtype=np.float32), "discrete": Discrete(3), - "multibinary": MultiBinary(10), "multidiscrete": MultiDiscrete(10)} -dict_list = Dict(base_list) -tuple_list = Tuple(list(base_list.values())) - -test_list = [*base_list.values(), dict_list, tuple_list] - - -def test_agent_save_load(): - for space in test_list: - agent = DQN(observation_space=space, action_space=Discrete(2)) - - agent.save(file_name="deed.pt") - agent_l = DQN.load(file_name="deed.pt") - - assert agent.observation_space == agent_l.observation_space - assert Discrete(2) == agent_l.action_space - os.remove("deed.pt") - - network = Network() - - agent = DQN(observation_space=space, action_space=Discrete(2), memory=ExperienceReplay(), neural_network=network, - step_train=3, batch_size=12, gamma=0.50, loss=None, optimizer=torch.optim.Adam(network.parameters()), - greedy_exploration=EpsilonGreedy(0.2)) - - agent.save(file_name="deed.pt") - agent_l = DQN.load(file_name="deed.pt") - - os.remove("deed.pt") - - assert agent.observation_space == agent_l.observation_space - assert Discrete(2) == agent_l.action_space - assert isinstance(agent.neural_network, type(agent_l.neural_network)) - for a, b in zip(agent.neural_network.state_dict(), agent_l.neural_network.state_dict()): - assert a == b - assert agent.step_train == agent_l.step_train - assert agent.batch_size == agent_l.batch_size - assert agent.gamma == agent_l.gamma - assert isinstance(agent.loss, type(agent_l.loss)) - for a, b in zip(agent.loss.parameters(), agent_l.loss.parameters()): - assert a == b - assert isinstance(agent.optimizer, type(agent_l.optimizer)) - for a, b in zip(agent.optimizer.state_dict(), agent_l.optimizer.state_dict()): - assert a == b - assert isinstance(agent.greedy_exploration, type(agent_l.greedy_exploration)) - - agent = DQN(observation_space=space, action_space=Discrete(2)) - agent.save(file_name="deed.pt", dire_name="./remove/") - - os.remove("./remove/deed.pt") - os.rmdir("./remove/") - - with pytest.raises(TypeError): - agent.save(file_name=14548) - with pytest.raises(TypeError): - agent.save(file_name="deed.pt", dire_name=14484) - - with pytest.raises(FileNotFoundError): - DQN.load(file_name="deed.pt") - with pytest.raises(FileNotFoundError): - DQN.load(file_name="deed.pt", dire_name="/Dede/") - - -def test__str__(): - agent = DQN(Discrete(4), Box(1, 10, (4,))) - - assert 'DQN-' + str(agent.observation_space) + "-" + str(agent.action_space) + "-" + str( - agent.neural_network) + "-" + str(agent.memory) + "-" + str(agent.step_train) + "-" + str( - agent.step) + "-" + str(agent.batch_size) + "-" + str(agent.gamma) + "-" + str(agent.loss) + "-" + str( - agent.optimizer) + "-" + str(agent.greedy_exploration) == agent.__str__() - - -def test_enable_train(): - agent = DQN(Discrete(4), Box(1, 10, (4,))) - - agent.trainable = False - - agent.enable_exploration() - assert agent.trainable is True - - -def test_disable_train(): - agent = DQN(Discrete(4), Box(1, 10, (4,))) - - agent.disable_exploration() - assert agent.trainable is False - - -def test_device_gpu(): - if torch.cuda.is_available(): - network = Network() - memory = ExperienceReplay(max_size=5) - - agent = DQN(Discrete(4), Discrete(4), memory, neural_network=network, device=torch.device("cuda")) - - obs = [1, 2, 5, 0] - action = 0 - reward = 0 - next_obs = [5, 9, 4, 0] - done = False - - obs_s = [obs, obs, obs] - actions = [1, 2, 3] - rewards = [-2.2, 5, 4] - next_obs_s = [next_obs, next_obs, next_obs] - dones = [False, True, False] - - memory.extend(obs_s, actions, rewards, next_obs_s, dones) +from blobrl.networks import SimpleNetwork + +from tests.agents import TestAgentInterface + + +class TestDQN(TestAgentInterface): + __test__ = True + + agent = DQN + network = SimpleNetwork + + list_work = [ + [Discrete(3), Discrete(1)], + [Discrete(3), Discrete(3)], + [Discrete(5), Discrete(3)], + [Discrete(10), Discrete(50)], + [Discrete(15), Discrete(50)], + [MultiDiscrete([3]), MultiDiscrete([1])], + [MultiDiscrete([3]), Discrete(50)], + [MultiDiscrete([3, 3]), MultiDiscrete([3, 3])], + [MultiDiscrete([4, 4, 4]), MultiDiscrete([50, 4, 4])], + [MultiDiscrete([4, 4, 4]), Discrete(3)], + [MultiDiscrete([[100, 3], [3, 5]]), MultiDiscrete([[100, 3], [3, 5]])], + [MultiDiscrete([[[100, 3], [3, 5]], [[100, 3], [3, 5]]]), + MultiDiscrete([[[100, 3], [3, 5]], [[100, 3], [3, 5]]])] + ] + list_fail = [ + [None, None], + ["dedrfe", "qdzq"], + [1215.4154, 157.48], + ["zdzd", (Discrete(1))], + [Discrete(1), "zdzd"], + ["zdzd", (1, 4, 7)], + [(1, 4, 7), "zdzd"], + [152, 485], + [MultiBinary(1), MultiBinary(1)], + [MultiBinary(3), MultiBinary(3)], + # [MultiBinary([3, 2]), MultiBinary([3, 2])], # Don't work yet because gym don't implemented this + [Box(low=0, high=10, shape=[1]), Box(low=0, high=10, shape=[1])], + [Box(low=0, high=10, shape=[2, 2]), Box(low=0, high=10, shape=[2, 2])], + [Box(low=0, high=10, shape=[2, 2, 2]), Box(low=0, high=10, shape=[2, 2, 2])], + + [Tuple([Discrete(1), MultiDiscrete([1, 1])]), Tuple([Discrete(1), MultiDiscrete([1, 1])])], + [Dict({"first": Discrete(1), "second": MultiDiscrete([1, 1])}), + Dict({"first": Discrete(1), "second": MultiDiscrete([1, 1])})], + ] + + def test_init(self): + for o, a in self.list_work: + self.agent(o, a) + for n in [object(), "dada", 154, 12.1]: + with pytest.raises(TypeError): + self.agent(o, a, optimizer=n) + with pytest.raises(TypeError): + self.agent(o, a, network=n) + with pytest.raises(TypeError): + self.agent(o, a, memory=n) + with pytest.raises(TypeError): + self.agent(o, a, loss=n) + with pytest.raises(TypeError): + self.agent(o, a, greedy_exploration=n) + with pytest.raises(TypeError): + net = self.network(o, a) + self.agent(o, a, optimizer=optim.Adam(net.parameters()), network=None) + + for o, a in self.list_fail: + with pytest.raises(TypeError): + self.agent(o, a) + + def test_get_action(self): + for o, a in self.list_work: + assert 1 + for ge in [Greedy(), EpsilonGreedy(1.)]: + agent = self.agent(o, a, greedy_exploration=ge) + + for i in range(20): + act = agent.get_action(o.sample()) + if isinstance(a, Discrete): + assert act in range(a.n) + + def test_learn(self): + for o, a in self.list_work: + network = self.network(o, a) + memory = ExperienceReplay(max_size=5) + + agent = self.agent(observation_space=o, action_space=a, memory=memory, network=network) + + for i in range(20): + agent.learn(o.sample(), a.sample(), 0, o.sample(), False) + + def test_episode_finished(self): + for o, a in self.list_work: + agent = self.agent(observation_space=o, action_space=a) + agent.episode_finished() + + def test_agent_save_load(self): + for o, a in self.list_work: + agent = self.agent(observation_space=o, action_space=a) + + agent.save(file_name="deed.pt") + agent_l = self.agent.load(file_name="deed.pt") + + assert agent.observation_space == agent_l.observation_space + assert agent.action_space == agent_l.action_space + os.remove("deed.pt") + + agent = self.agent(observation_space=o, action_space=a) + agent.save(file_name="deed.pt", dire_name="./remove/") + + os.remove("./remove/deed.pt") + os.rmdir("./remove/") + + with pytest.raises(TypeError): + agent.save(file_name=14548) + with pytest.raises(TypeError): + agent.save(file_name="deed.pt", dire_name=14484) + + with pytest.raises(FileNotFoundError): + self.agent.load(file_name="deed.pt") + with pytest.raises(FileNotFoundError): + self.agent.load(file_name="deed.pt", dire_name="/Dede/") + + network = self.network(o, a) + + agent = self.agent(observation_space=o, action_space=a, memory=ExperienceReplay(), + network=network, + step_train=3, batch_size=12, gamma=0.50, + optimizer=torch.optim.Adam(network.parameters()), + greedy_exploration=EpsilonGreedy(0.2)) + + agent.save(file_name="deed.pt") + agent_l = self.agent.load(file_name="deed.pt") + + os.remove("deed.pt") + + assert agent.observation_space == agent_l.observation_space + assert a == agent_l.action_space + assert isinstance(agent.network, type(agent_l.network)) + for a, b in zip(agent.network.state_dict(), agent_l.network.state_dict()): + assert a == b + assert agent.step_train == agent_l.step_train + assert agent.batch_size == agent_l.batch_size + assert agent.gamma == agent_l.gamma + assert isinstance(agent.loss, type(agent_l.loss)) + for a, b in zip(agent.loss.parameters(), agent_l.loss.parameters()): + assert a == b + assert isinstance(agent.optimizer, type(agent_l.optimizer)) + for a, b in zip(agent.optimizer.state_dict(), agent_l.optimizer.state_dict()): + assert a == b + assert isinstance(agent.greedy_exploration, type(agent_l.greedy_exploration)) + + def test_device(self): + for o, a in self.list_work: + device = torch.device("cpu") + assert device == self.agent(o, a, device=device).device + + device = None + assert torch.device("cpu") == self.agent(o, a, device=device).device + + for device in ["dzeqdzqd", 1512, object(), 151.515]: + with pytest.raises(TypeError): + self.agent(o, a, device=device) + + if torch.cuda.is_available(): + self.agent(o, a, device=torch.device("cuda")) + + def test_dqn_agent_episode_finished(self): + for o, a in self.list_work: + network = self.network(o, a) + memory = ExperienceReplay(max_size=5) + + agent = self.agent(o, a, memory, network=network) + agent.episode_finished() + + def test_enable_train(self): + for o, a in self.list_work: + agent = self.agent(o, a) + + agent.with_exploration = False - agent.learn(obs, action, reward, next_obs, done) - agent.learn(obs, action, reward, next_obs, done) + agent.enable_exploration() + assert agent.with_exploration is True + + def test_disable_train(self): + for o, a in self.list_work: + agent = self.agent(o, a) + + agent.disable_exploration() + assert agent.with_exploration is False + + def test__str__(self): + for o, a in self.list_work: + agent = self.agent(o, a) + + assert 'DQN-' + str(agent.observation_space) + "-" + str(agent.action_space) + "-" + str( + agent.network) + "-" + str(agent.memory) + "-" + str(agent.step_train) + "-" + str( + agent.step) + "-" + str(agent.batch_size) + "-" + str(agent.gamma) + "-" + str(agent.loss) + "-" + str( + agent.optimizer) + "-" + str(agent.greedy_exploration) == agent.__str__() diff --git a/tests/agents/test_dueling_dqn.py b/tests/agents/test_dueling_dqn.py deleted file mode 100644 index dd7861f..0000000 --- a/tests/agents/test_dueling_dqn.py +++ /dev/null @@ -1,261 +0,0 @@ -import os - -import numpy as np -import pytest -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.optim as optim -from gym.spaces import Discrete, Box, MultiBinary, MultiDiscrete, Dict, Tuple - -from blobrl.agents import DuelingDQN -from blobrl.explorations import Greedy, EpsilonGreedy -from blobrl.memories import ExperienceReplay -from blobrl.networks import BaseDuelingNetwork - - -class Network(BaseDuelingNetwork): - def __str__(self): - return 'Network' - - def __init__(self, observation_shape=None, action_shape=None): - super().__init__(observation_shape, action_shape) - self.dense = nn.Linear(3, 4) - - def forward(self, x): - x = self.dense(x) - x = F.relu(x) - return x - - -def test_dueling_dqn_agent_instantiation(): - network = Network() - memory = ExperienceReplay(max_size=5) - - DuelingDQN(Discrete(4), Discrete(3), neural_network=network, memory=memory) - DuelingDQN(Discrete(4), Discrete(3)) - - -def test_dueling_dqn_agent_instantiation_error_action_space(): - network = Network() - memory = ExperienceReplay(max_size=5) - - with pytest.raises(TypeError): - DuelingDQN("ACTION_SPACE_ERROR", Discrete(3), neural_network=network, memory=memory) - - -def test_dueling_dqn_agent_instantiation_error_observation_space(): - network = Network() - memory = ExperienceReplay(max_size=5) - - with pytest.raises(TypeError): - DuelingDQN(Discrete(1), "OBSERVATION_SPACE_ERROR", neural_network=network, memory=memory) - - -def test_dueling_dqn_agent_instantiation_error_neural_network(): - memory = ExperienceReplay(max_size=5) - - with pytest.raises(TypeError): - DuelingDQN(Discrete(4), Discrete(3), neural_network="NEURAL_NETWORK_ERROR", memory=memory) - - -def test_dueling_dqn_agent_instantiation_error_memory(): - network = Network() - - with pytest.raises(TypeError): - DuelingDQN(Discrete(4), Discrete(3), neural_network=network, memory="MEMORY_ERROR") - - -def test_dueling_dqn_agent_instantiation_error_loss(): - network = Network() - memory = ExperienceReplay(max_size=5) - - with pytest.raises(TypeError): - DuelingDQN(Discrete(4), Discrete(3), neural_network=network, memory=memory, loss="LOSS_ERROR") - - -def test_dueling_dqn_agent_instantiation_error_optimizer(): - network = Network() - memory = ExperienceReplay(max_size=5) - - with pytest.raises(TypeError): - DuelingDQN(Discrete(4), Discrete(3), neural_network=network, memory=memory, optimizer="OPTIMIZER_ERROR") - - -def test_dueling_dqn_agent_instantiation_error_greedy_exploration(): - network = Network() - memory = ExperienceReplay(max_size=5) - - with pytest.raises(TypeError): - DuelingDQN(Discrete(4), Discrete(3), neural_network=network, memory=memory, - greedy_exploration="GREEDY_EXPLORATION_ERROR") - - -def test_dueling_dqn_agent_instantiation_custom_loss(): - network = Network() - memory = ExperienceReplay(max_size=5) - - DuelingDQN(Discrete(4), Discrete(3), neural_network=network, memory=memory, loss=nn.MSELoss()) - - -def test_dueling_dqn_agent_instantiation_custom_optimizer(): - network = Network() - memory = ExperienceReplay(max_size=5) - - DuelingDQN(Discrete(4), Discrete(3), neural_network=network, memory=memory, - optimizer=optim.RMSprop(network.parameters())) - - -def test_dueling_dqn_agent_getaction(): - network = Network() - memory = ExperienceReplay(max_size=5) - - agent = DuelingDQN(Discrete(4), Discrete(3), neural_network=network, memory=memory, greedy_exploration=Greedy()) - - observation = [0, 1, 2] - - agent.get_action(observation) - - -def test_dueling_dqn_agent_getaction_non_greedy(): - network = Network() - memory = ExperienceReplay(max_size=5) - - agent = DuelingDQN(Discrete(4), Discrete(3), neural_network=network, memory=memory, - greedy_exploration=EpsilonGreedy(1.)) - - observation = [0, 1, 2] - - agent.get_action(observation) - - -def test_dueling_dqn_agent_learn(): - network = Network() - memory = ExperienceReplay(max_size=5) - - agent = DuelingDQN(Discrete(4), Discrete(3), neural_network=network, memory=memory, step_copy=2) - - obs = [1, 2, 5] - action = 0 - reward = 0 - next_obs = [5, 9, 4] - done = False - - obs_s = [obs, obs, obs] - actions = [1, 2, 3] - rewards = [-2.2, 5, 4] - next_obs_s = [next_obs, next_obs, next_obs] - dones = [False, True, False] - - memory.extend(obs_s, actions, rewards, next_obs_s, dones) - - agent.learn(obs, action, reward, next_obs, done) - agent.learn(obs, action, reward, next_obs, done) - - -def test_dueling_dqn_agent_episode_finished(): - network = Network() - memory = ExperienceReplay(max_size=5) - - agent = DuelingDQN(Discrete(4), Discrete(3), neural_network=network, memory=memory) - agent.episode_finished() - - -base_list = {"box": Box(low=-1.0, high=2.0, shape=(3, 4), dtype=np.float32), "discrete": Discrete(3), - "multibinary": MultiBinary(10), "multidiscrete": MultiDiscrete(10)} -dict_list = Dict(base_list) -tuple_list = Tuple(list(base_list.values())) - -test_list = [*base_list.values(), dict_list, tuple_list] - - -def test_agent_save_load(): - for space in test_list: - agent = DuelingDQN(observation_space=space, action_space=Discrete(2)) - - agent.save(file_name="deed.pt") - agent_l = DuelingDQN.load(file_name="deed.pt") - - assert agent.observation_space == agent_l.observation_space - assert Discrete(2) == agent_l.action_space - os.remove("deed.pt") - - network = Network() - - agent = DuelingDQN(observation_space=space, action_space=Discrete(2), memory=ExperienceReplay(), - neural_network=network, step_train=3, batch_size=12, gamma=0.50, loss=None, - optimizer=torch.optim.Adam(network.parameters()), step_copy=300, - greedy_exploration=EpsilonGreedy(0.2)) - - agent.save(file_name="deed.pt") - agent_l = DuelingDQN.load(file_name="deed.pt") - - os.remove("deed.pt") - - assert agent.observation_space == agent_l.observation_space - assert Discrete(2) == agent_l.action_space - assert isinstance(agent.neural_network, type(agent_l.neural_network)) - for a, b in zip(agent.neural_network.state_dict(), agent_l.neural_network.state_dict()): - assert a == b - assert agent.step_train == agent_l.step_train - assert agent.batch_size == agent_l.batch_size - assert agent.gamma == agent_l.gamma - assert isinstance(agent.loss, type(agent_l.loss)) - for a, b in zip(agent.loss.parameters(), agent_l.loss.parameters()): - assert a == b - assert isinstance(agent.optimizer, type(agent_l.optimizer)) - for a, b in zip(agent.optimizer.state_dict(), agent_l.optimizer.state_dict()): - assert a == b - assert isinstance(agent.greedy_exploration, type(agent_l.greedy_exploration)) - assert agent.step_copy == agent_l.step_copy - - agent = DuelingDQN(observation_space=space, action_space=Discrete(2)) - agent.save(file_name="deed.pt", dire_name="./remove/") - - os.remove("./remove/deed.pt") - os.rmdir("./remove/") - - with pytest.raises(TypeError): - agent.save(file_name=14548) - with pytest.raises(TypeError): - agent.save(file_name="deed.pt", dire_name=14484) - - with pytest.raises(FileNotFoundError): - DuelingDQN.load(file_name="deed.pt") - with pytest.raises(FileNotFoundError): - DuelingDQN.load(file_name="deed.pt", dire_name="/Dede/") - - -def test__str__(): - agent = DuelingDQN(Discrete(4), Box(1, 10, (4,))) - - assert 'DuelingDQN-' + str(agent.observation_space) + "-" + str(agent.action_space) + "-" + str( - agent.neural_network) + "-" + str(agent.memory) + "-" + str(agent.step_train) + "-" + str( - agent.step) + "-" + str(agent.batch_size) + "-" + str(agent.gamma) + "-" + str(agent.loss) + "-" + str( - agent.optimizer) + "-" + str(agent.greedy_exploration) + "-" + str(agent.step_copy) == agent.__str__() - - -def test_device_gpu(): - if torch.cuda.is_available(): - network = Network() - memory = ExperienceReplay(max_size=5) - - agent = DuelingDQN(Discrete(4), Discrete(3), neural_network=network, memory=memory, step_copy=2, - device=torch.device("cuda")) - - obs = [1, 2, 5] - action = 0 - reward = 0 - next_obs = [5, 9, 4] - done = False - - obs_s = [obs, obs, obs] - actions = [1, 2, 3] - rewards = [-2.2, 5, 4] - next_obs_s = [next_obs, next_obs, next_obs] - dones = [False, True, False] - - memory.extend(obs_s, actions, rewards, next_obs_s, dones) - - agent.learn(obs, action, reward, next_obs, done) - agent.learn(obs, action, reward, next_obs, done) diff --git a/tests/memories/test_experience_replay.py b/tests/memories/test_experience_replay.py index c1c7e26..f148a66 100644 --- a/tests/memories/test_experience_replay.py +++ b/tests/memories/test_experience_replay.py @@ -8,20 +8,24 @@ def test_experience_replay(): mem = ExperienceReplay(max_size) - obs = [1, 2, 5] - action = 0 - reward = 0 - next_obs = [5, 9, 4] - done = False + for i in range(10): + obs = [1, 2, 5] + action = 0 + reward = 0 + next_obs = [5, 9, 4] + done = False - mem.append(obs, action, reward, next_obs, done) + mem.append(obs, action, reward, next_obs, done) - obs_s = [obs, obs, obs] - actions = [1, 2, 3] - rewards = [-2.2, 5, 4] - next_obs_s = [next_obs, next_obs, next_obs] - dones = [False, True, False] + mem.sample(2, device=torch.device("cpu")) + + for i in range(10): + obs_s = [obs, obs, obs] + actions = [1, 2, 3] + rewards = [-2.2, 5, 4] + next_obs_s = [next_obs, next_obs, next_obs] + dones = [False, True, False] - mem.extend(obs_s, actions, rewards, next_obs_s, dones) + mem.extend(obs_s, actions, rewards, next_obs_s, dones) mem.sample(2, device=torch.device("cpu")) diff --git a/tests/networks/__init__.py b/tests/networks/__init__.py index e69de29..5273a8a 100644 --- a/tests/networks/__init__.py +++ b/tests/networks/__init__.py @@ -0,0 +1,4 @@ +from .test_base_network import TestBaseNetwork +from .test_simple_network import TestSimpleNetwork +from .test_base_dueling_network import TestBaseDuelingNetwork +from .test_simple_dueling_network import TestSimpleDuelingNetwork diff --git a/tests/networks/test_base_dueling_network.py b/tests/networks/test_base_dueling_network.py index c086aa8..48ac2e3 100644 --- a/tests/networks/test_base_dueling_network.py +++ b/tests/networks/test_base_dueling_network.py @@ -1,8 +1,8 @@ -import pytest +from blobrl.networks import BaseDuelingNetwork +from tests.networks import TestBaseNetwork -from blobrl.networks import BaseNetwork +class TestBaseDuelingNetwork(TestBaseNetwork): + __test__ = True -def test_init_base_network_fail(): - with pytest.raises(TypeError): - BaseNetwork(None, None) + network = BaseDuelingNetwork diff --git a/tests/networks/test_base_network.py b/tests/networks/test_base_network.py index c086aa8..b7dd217 100644 --- a/tests/networks/test_base_network.py +++ b/tests/networks/test_base_network.py @@ -1,8 +1,51 @@ import pytest - from blobrl.networks import BaseNetwork +from gym.spaces import Discrete, MultiDiscrete, MultiBinary, Box, Tuple, Dict + + +class TestBaseNetwork: + __test__ = True + + network = BaseNetwork + + list_work = [ + [Discrete(3), Discrete(1)], + [Discrete(3), Discrete(3)], + [Discrete(10), Discrete(50)], + [MultiDiscrete([3]), MultiDiscrete([1])], + [MultiDiscrete([3, 3]), MultiDiscrete([3, 3])], + [MultiDiscrete([4, 4, 4]), MultiDiscrete([50, 4, 4])], + [MultiDiscrete([[100, 3], [3, 5]]), MultiDiscrete([[100, 3], [3, 5]])], + [MultiDiscrete([[[100, 3], [3, 5]], [[100, 3], [3, 5]]]), + MultiDiscrete([[[100, 3], [3, 5]], [[100, 3], [3, 5]]])], + [MultiBinary(1), MultiBinary(1)], + [MultiBinary(3), MultiBinary(3)], + # [MultiBinary([3, 2]), MultiBinary([3, 2])], # Don't work yet because gym don't implemented this + [Box(low=0, high=10, shape=[1]), Box(low=0, high=10, shape=[1])], + [Box(low=0, high=10, shape=[2, 2]), Box(low=0, high=10, shape=[2, 2])], + [Box(low=0, high=10, shape=[2, 2, 2]), Box(low=0, high=10, shape=[2, 2, 2])], + + [Tuple([Discrete(1), MultiDiscrete([1, 1])]), Tuple([Discrete(1), MultiDiscrete([1, 1])])], + [Dict({"first": Discrete(1), "second": MultiDiscrete([1, 1])}), + Dict({"first": Discrete(1), "second": MultiDiscrete([1, 1])})] + ] + + list_fail = [ + [None, None], + ["dedrfe", "qdzq"], + [1215.4154, 157.48], + ["zdzd", (Discrete(1))], + [Discrete(1), "zdzd"], + ["zdzd", (1, 4, 7)], + [(1, 4, 7), "zdzd"], + [152, 485] + ] + def test_init(self): + for ob, ac in self.list_fail: + with pytest.raises(TypeError): + self.network(observation_space=ob, action_space=ac) -def test_init_base_network_fail(): - with pytest.raises(TypeError): - BaseNetwork(None, None) + for ob, ac in self.list_work: + with pytest.raises(TypeError): + self.network(observation_space=ob, action_space=ac) diff --git a/tests/networks/test_c51_network.py b/tests/networks/test_c51_network.py index 8d8fd93..f63c7e6 100644 --- a/tests/networks/test_c51_network.py +++ b/tests/networks/test_c51_network.py @@ -1,34 +1,66 @@ import pytest import torch -from gym.spaces import Discrete +from gym.spaces import flatdim +from gym.spaces import Discrete, MultiDiscrete, MultiBinary, Box, Tuple, Dict from blobrl.networks import C51Network +from tests.networks import TestBaseNetwork -def test_c51_init(): - list_fail = [[None, None], - ["dedrfe", "qdzq"], - [1215.4154, 157.48], - ["zdzd", (Discrete(1))], - [Discrete(1), "zdzd"], - ["zdzd", (1, 4, 7)], - [(1, 4, 7), "zdzd"], - [Discrete(1), Discrete(1)]] - - for ob, ac in list_fail: - with pytest.raises(TypeError): - C51Network(observation_shape=ob, action_shape=ac) - - list_work = [[(454), (4)], - [(454, 54), (5, 2)], - [(454, 54, 45), (4, 5, 3)]] - for ob, ac in list_work: - C51Network(observation_shape=ob, action_shape=ac) - - -def test_c51_forward(): - list_work = [[(454, 54), (4, 2)], - [(454, 54, 45), (5, 1, 2)]] - for ob, ac in list_work: - simple_network = C51Network(observation_shape=ob, action_shape=ac) - simple_network.forward(torch.rand((2, *ob))) +class TestC51Network(TestBaseNetwork): + __test__ = True + + network = C51Network + + list_work = [ + [Discrete(3), Discrete(1)], + [Discrete(3), Discrete(3)], + [Discrete(10), Discrete(50)], + [MultiDiscrete([3]), MultiDiscrete([1])], + [MultiDiscrete([3, 3]), MultiDiscrete([3, 3])], + [MultiDiscrete([4, 4, 4]), MultiDiscrete([50, 4, 4])], + [MultiDiscrete([[100, 3], [3, 5]]), MultiDiscrete([[100, 3], [3, 5]])], + [MultiDiscrete([[[100, 3], [3, 5]], [[100, 3], [3, 5]]]), + MultiDiscrete([[[100, 3], [3, 5]], [[100, 3], [3, 5]]])] + + ] + + list_fail = [ + [None, None], + ["dedrfe", "qdzq"], + [1215.4154, 157.48], + ["zdzd", (Discrete(1))], + [Discrete(1), "zdzd"], + ["zdzd", (1, 4, 7)], + [(1, 4, 7), "zdzd"], + [152, 485], + [MultiBinary(1), MultiBinary(1)], + [MultiBinary(3), MultiBinary(3)], + # [MultiBinary([3, 2]), MultiBinary([3, 2])], # Don't work yet because gym don't implemented this + [Box(low=0, high=10, shape=[1]), Box(low=0, high=10, shape=[1])], + [Box(low=0, high=10, shape=[2, 2]), Box(low=0, high=10, shape=[2, 2])], + [Box(low=0, high=10, shape=[2, 2, 2]), Box(low=0, high=10, shape=[2, 2, 2])], + + [Tuple([Discrete(1), MultiDiscrete([1, 1])]), Tuple([Discrete(1), MultiDiscrete([1, 1])])], + [Dict({"first": Discrete(1), "second": MultiDiscrete([1, 1])}), + Dict({"first": Discrete(1), "second": MultiDiscrete([1, 1])})] + ] + + def test_init(self): + for ob, ac in self.list_fail: + with pytest.raises(TypeError): + self.network(observation_space=ob, action_space=ac) + + for ob, ac in self.list_work: + self.network(observation_space=ob, action_space=ac) + + def test_forward(self): + for ob, ac in self.list_work: + network = self.network(observation_space=ob, action_space=ac) + network.forward(torch.rand((1, flatdim(ob)))) + + def test_str_(self): + for ob, ac in self.list_work: + network = self.network(observation_space=ob, action_space=ac) + + assert 'C51Network-' + str(ob) + "-" + str(ac) == network.__str__() diff --git a/tests/networks/test_simple_dueling_network.py b/tests/networks/test_simple_dueling_network.py index 9e7f033..63e49b6 100644 --- a/tests/networks/test_simple_dueling_network.py +++ b/tests/networks/test_simple_dueling_network.py @@ -1,45 +1,45 @@ import pytest import torch -from gym.spaces import Discrete +from gym.spaces import flatdim, flatten -from blobrl.networks import SimpleDuelingNetwork +from blobrl.networks import SimpleDuelingNetwork, SimpleNetwork +from tests.networks import TestBaseDuelingNetwork -def test_simple_network_init(): - list_fail = [[None, None], - ["dedrfe", "qdzq"], - [1215.4154, 157.48], - ["zdzd", (Discrete(1))], - [Discrete(1), "zdzd"], - ["zdzd", (1, 4, 7)], - [(1, 4, 7), "zdzd"], - [Discrete(1), Discrete(1)]] +class TestSimpleDuelingNetwork(TestBaseDuelingNetwork): + __test__ = True - for ob, ac in list_fail: - with pytest.raises(TypeError): - SimpleDuelingNetwork(observation_shape=ob, action_shape=ac) + network = SimpleDuelingNetwork + net = SimpleNetwork - list_work = [[(454), (874)], - [(454, 54), (48, 44)], - [(454, 54, 45), (48, 44, 47)]] - for ob, ac in list_work: - SimpleDuelingNetwork(observation_shape=ob, action_shape=ac) + list_fail = [1, + 0.1, + "string", + object(), + network(net(TestBaseDuelingNetwork.list_work[0][0], TestBaseDuelingNetwork.list_work[0][1])) + ] + def test_init(self): + for net in self.list_fail: + with pytest.raises(TypeError): + self.network(net) -def test_forward(): - list_work = [[(454, 54), (48, 44)], - [(454, 54, 45), (48, 44, 47)]] - for ob, ac in list_work: - simple_network = SimpleDuelingNetwork(observation_shape=ob, action_shape=ac) - simple_network.forward(torch.rand((1, *ob))) + for ob, ac in self.list_work: + self.network(self.net(observation_space=ob, action_space=ac)) + def test_forward(self): + for ob, ac in self.list_work: + network = self.network(self.net(observation_space=ob, action_space=ac)) + network.forward(torch.rand((1, flatdim(ob)))) + def test_str_(self): + for ob, ac in self.list_work: + net = self.net(observation_space=ob, action_space=ac) + network = self.network(net) -def test__str__(): - list_work = [[(454, 54), (48, 44)], - [(454, 54, 45), (48, 44, 47)]] + assert 'SimpleDuelingNetwork-' + str(net) == network.__str__() - for ob, ac in list_work: - network = SimpleDuelingNetwork(observation_shape=ob, action_shape=ac) - - assert 'SimpleDuelingNetwork-' + str(ob) + "-" + str(ac) == network.__str__() \ No newline at end of file + def test_call_network(self): + for ob, ac in self.list_work: + self.network(SimpleNetwork(observation_space=ob, action_space=ac))( + torch.tensor([flatten(ob, ob.sample())]).float()) diff --git a/tests/networks/test_simple_network.py b/tests/networks/test_simple_network.py index 685e604..78bce4f 100644 --- a/tests/networks/test_simple_network.py +++ b/tests/networks/test_simple_network.py @@ -1,34 +1,36 @@ import pytest import torch -from gym.spaces import Discrete +from gym.spaces import flatdim, flatten from blobrl.networks import SimpleNetwork +from tests.networks import TestBaseNetwork -def test_simple_network_init(): - list_fail = [[None, None], - ["dedrfe", "qdzq"], - [1215.4154, 157.48], - ["zdzd", (Discrete(1))], - [Discrete(1), "zdzd"], - ["zdzd", (1, 4, 7)], - [(1, 4, 7), "zdzd"], - [Discrete(1), Discrete(1)]] - - for ob, ac in list_fail: - with pytest.raises(TypeError): - SimpleNetwork(observation_shape=ob, action_shape=ac) - - list_work = [[(454), (874)], - [(454, 54), (48, 44)], - [(454, 54, 45), (48, 44, 47)]] - for ob, ac in list_work: - SimpleNetwork(observation_shape=ob, action_shape=ac) - - -def test_forward(): - list_work = [[(454, 54), (48, 44)], - [(454, 54, 45), (48, 44, 47)]] - for ob, ac in list_work: - simple_network = SimpleNetwork(observation_shape=ob, action_shape=ac) - simple_network.forward(torch.rand((1, *ob))) +class TestSimpleNetwork(TestBaseNetwork): + __test__ = True + + network = SimpleNetwork + + def test_init(self): + for ob, ac in self.list_fail: + with pytest.raises(TypeError): + self.network(observation_space=ob, action_space=ac) + + for ob, ac in self.list_work: + self.network(observation_space=ob, action_space=ac) + + def test_forward(self): + for ob, ac in self.list_work: + network = self.network(observation_space=ob, action_space=ac) + network.forward(torch.rand((1, flatdim(ob)))) + + def test_str_(self): + for ob, ac in self.list_work: + network = self.network(observation_space=ob, action_space=ac) + + assert 'SimpleNetwork-' + str(ob) + "-" + str(ac) == network.__str__() + + def test_call_network(self): + for ob, ac in self.list_work: + self.network(observation_space=ob, action_space=ac)( + torch.tensor([flatten(ob, ob.sample())]).float()) diff --git a/tests/networks/test_utils.py b/tests/networks/test_utils.py new file mode 100644 index 0000000..36df394 --- /dev/null +++ b/tests/networks/test_utils.py @@ -0,0 +1,98 @@ +from blobrl.networks import get_last_layers +from gym.spaces import Box, Discrete, MultiDiscrete, MultiBinary, Tuple, Dict +import torch.nn as nn + + +def valid_dim(out_v, out_g): + if isinstance(out_v, list): + assert len(out_v) == len(out_g) + for o, g in zip(out_v, out_g): + valid_dim(o, g) + else: + assert len(out_v.state_dict()) == len(out_g.state_dict()) + + +def test_get_last_layers(): + in_values = [ + Discrete(10), + Discrete(1), + Discrete(100), + Discrete(5), + + MultiDiscrete([1]), + MultiDiscrete([10, 110, 3, 50]), + MultiDiscrete([1, 1, 1]), + MultiDiscrete([100, 3, 3, 5]), + + MultiDiscrete([[100, 3], [3, 5]]), + MultiDiscrete([[[100, 3], [3, 5]], [[100, 3], [3, 5]]]), + + MultiBinary(1), + MultiBinary(3), + MultiBinary([3, 2]), + + Box(low=0, high=10, shape=[1]), + Box(low=0, high=10, shape=[2, 2]), + Box(low=0, high=10, shape=[2, 2, 2]), + + Tuple([Discrete(1), MultiDiscrete([1, 1])]), + Dict({"first": Discrete(1), "second": MultiDiscrete([1, 1])}) + + ] + + out_values = [ + + nn.Linear(10, 10), + nn.Linear(10, 1), + nn.Linear(10, 100), + nn.Linear(10, 5), + + [nn.Sequential(*[nn.Linear(10, 10), nn.Softmax()])], + [nn.Sequential(*[nn.Linear(10, 10), nn.Softmax()]), nn.Sequential(*[nn.Linear(10, 110), nn.Softmax()]), + nn.Sequential(*[nn.Linear(10, 3), nn.Softmax()]), nn.Sequential(*[nn.Linear(10, 50), nn.Softmax()])], + [nn.Sequential(*[nn.Linear(10, 1), nn.Softmax()]), nn.Sequential(*[nn.Linear(10, 1), nn.Softmax()]), + nn.Sequential(*[nn.Linear(10, 1), nn.Softmax()])], + [nn.Sequential(*[nn.Linear(10, 100), nn.Softmax()]), nn.Sequential(*[nn.Linear(10, 3), nn.Softmax()]), + nn.Sequential(*[nn.Linear(10, 3), nn.Softmax()]), nn.Sequential(*[nn.Linear(10, 5), nn.Softmax()])], + + [[nn.Sequential(*[nn.Linear(10, 100), nn.Softmax()]), nn.Sequential(*[nn.Linear(10, 3), nn.Softmax()])], + [nn.Sequential(*[nn.Linear(10, 3), nn.Softmax()]), nn.Sequential(*[nn.Linear(10, 5), nn.Softmax()])]], + [ + [[nn.Sequential(*[nn.Linear(10, 100), nn.Softmax()]), nn.Sequential(*[nn.Linear(10, 3), nn.Softmax()])], + [nn.Sequential(*[nn.Linear(10, 3), nn.Softmax()]), nn.Sequential(*[nn.Linear(10, 5), nn.Softmax()])]], + [[nn.Sequential(*[nn.Linear(10, 100), nn.Softmax()]), nn.Sequential(*[nn.Linear(10, 3), nn.Softmax()])], + [nn.Sequential(*[nn.Linear(10, 3), nn.Softmax()]), nn.Sequential(*[nn.Linear(10, 5), nn.Softmax()])]] + ], + + [nn.Sequential(*[nn.Linear(10, 1), nn.Sigmoid()])], + [nn.Sequential(*[nn.Linear(10, 1), nn.Sigmoid()]), + nn.Sequential(*[nn.Linear(10, 1), nn.Sigmoid()]), + nn.Sequential(*[nn.Linear(10, 1), nn.Sigmoid()])], + + [ + [nn.Sequential(*[nn.Linear(10, 1), nn.Sigmoid()]), + nn.Sequential(*[nn.Linear(10, 1), nn.Sigmoid()])], + [nn.Sequential(*[nn.Linear(10, 1), nn.Sigmoid()]), + nn.Sequential(*[nn.Linear(10, 1), nn.Sigmoid()])], + [nn.Sequential(*[nn.Linear(10, 1), nn.Sigmoid()]), + nn.Sequential(*[nn.Linear(10, 1), nn.Sigmoid()])] + ], + + [nn.Linear(10, 1)] + , + [[nn.Linear(10, 1), nn.Linear(10, 1)], [nn.Linear(10, 1), nn.Linear(10, 1)]] + , + [[[nn.Linear(10, 1), nn.Linear(10, 1)], [nn.Linear(10, 1), nn.Linear(10, 1)]], + [[nn.Linear(10, 1), nn.Linear(10, 1)], [nn.Linear(10, 1), nn.Linear(10, 1)]]], + + [nn.Linear(10, 1), + [nn.Sequential(*[nn.Linear(10, 1), nn.Softmax()]), nn.Sequential(*[nn.Linear(10, 1), nn.Softmax()])]], + + [nn.Linear(10, 1), + [nn.Sequential(*[nn.Linear(10, 1), nn.Softmax()]), nn.Sequential(*[nn.Linear(10, 1), nn.Softmax()])]], + + ] + + for in_value, out_value in zip(in_values, out_values): + out_value_gen = get_last_layers(in_value, 10) + valid_dim(out_value, out_value_gen) diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 1aef187..9837c99 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -7,6 +7,8 @@ import pytest from ipykernel import iostream +from gym.spaces import Discrete + from blobrl import Trainer, Logger from blobrl.agents import AgentInterface from blobrl.trainer import arg_to_agent @@ -14,7 +16,7 @@ def test_arg_to_agent(): fail_list = ["dzdzqd", None, 123, 123.123, [], {}, object] - work_list = ["agent_random", "dqn", "double_dqn", "categorical_dqn", "dueling_dqn"] + work_list = ["agent_random", "dqn", "double_dqn", "categorical_dqn"] for agent in fail_list: with pytest.raises(ValueError): @@ -71,7 +73,7 @@ def save(self, file_name, dire_name="."): pass def __init__(self, observation_space, action_space, device=None): - super().__init__(device) + super().__init__(Discrete(1), Discrete(1), device) self.get_action_done = 0 self.learn_done = 0 self.episode_finished_done = 0 @@ -252,6 +254,14 @@ def test_trainer_train(): assert fake_env.step_done == number_episode + eval and fake_env.reset_done == number_episode + eval + 1 assert fake_env.render_done == number_episode + eval + # test nb_evaluation + fake_env = FakeEnv() + fake_agent = FakeAgent(observation_space=None, action_space=None) + trainer = Trainer(environment=fake_env, agent=fake_agent) + + for i in range(10): + trainer.train(max_episode=100, nb_evaluation=i) + class FakeOutStream(iostream.OutStream): diff --git a/tests/train.py b/tests/train.py new file mode 100644 index 0000000..db66190 --- /dev/null +++ b/tests/train.py @@ -0,0 +1,22 @@ +from blobrl import Trainer, Record +from blobrl.agents import CategoricalDQN, DQN, DoubleDQN + +import gym + +if __name__ == "__main__": + + for agent in [CategoricalDQN, DQN, DoubleDQN]: + + env = gym.make("CartPole-v1") + a = agent(env.observation_space, env.action_space) + trainer = Trainer(environment=env, agent=agent) + + for i in range(100): + + trainer.train(max_episode=50, render=False, nb_evaluation=0) + m = max([Record.sum_records(e) for e in trainer.logger.episodes]) + print(agent.__name__, i, m) + if m > 200: + break + + print("####### ", agent.__name__, i, m, " #######")