Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make train on lunar lander #14

Merged
merged 17 commits into from
Mar 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions emote/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,8 @@ def log_scalars(self, step, suffix=None):
k = "/".join(k_split)
self._writer.add_scalar(k, v, step)

def end_cycle(self, bp_step, bp_samples):
def end_cycle(self, bp_step):
self.log_scalars(bp_step, suffix="bp_step")
if self._log_samples:
self.log_scalars(bp_samples, suffix="bp_samples")


class TerminalLogger(Callback):
Expand Down Expand Up @@ -196,7 +194,7 @@ def __init__(
self._opts: List[optim.Optimizer] = optimizers if optimizers else []
self._nets: List[nn.Module] = networks if networks else []

def end_cycle(self, inf_step, bp_step, bp_samples):
def end_cycle(self, inf_step, bp_step):
state_dict = {}
state_dict["callback_state_dicts"] = [cb.state_dict() for cb in self._cbs]
state_dict["network_state_dicts"] = [net.state_dict() for net in self._nets]
Expand All @@ -205,7 +203,6 @@ def end_cycle(self, inf_step, bp_step, bp_samples):
"checkpoint_index": self._checkpoint_index,
"inf_step": inf_step,
"bp_step": bp_step,
"bp_samples": bp_samples,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've always found samples to be by far the most useful stat to log against, it removes the batch size dependency. Can we keep this, or perhaps default to it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's fair enough. What if we don't keep samples in state_dict as Martin suggested, but since as you said it's useful for logging, we can also log samples when we log but we won't get it from state_dict, instead we just multiply steps by batch size to get it?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough. I propose we let Riley merge this and then try to come up with a really nice solution when we address #18

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That sounds reasonable, let's do it later then.

}
torch.save(state_dict, f"{self._path}.{self._checkpoint_index}.tar")
self._checkpoint_index += 1
Expand Down
10 changes: 6 additions & 4 deletions emote/nn/action_value_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from torch import nn
from torch import Tensor

from emote.nn.initialization import ortho_init_


class ActionValue(nn.Module):
def __init__(self, observation_dim, action_dim, hidden_dims):
Expand All @@ -14,9 +16,10 @@ def __init__(self, observation_dim, action_dim, hidden_dims):
for n_in, n_out in zip(
[observation_dim + action_dim] + hidden_dims, hidden_dims
)
]
],
nn.Linear(hidden_dims[-1], 1)
)
self.head = nn.Linear(hidden_dims[-1], 1)
self.seq.apply(ortho_init_)

def forward(self, action: Tensor, obs: Tensor) -> Tensor:
bsz, obs_d = obs.shape
Expand All @@ -25,7 +28,6 @@ def forward(self, action: Tensor, obs: Tensor) -> Tensor:
assert obs_d == self.obs_d
assert act_d == self.act_d
x = torch.cat([obs, action], dim=1)
x = self.seq(x)
out = self.head(x)
out = self.seq(x)
assert (bsz, 1) == out.shape
return out
11 changes: 6 additions & 5 deletions emote/nn/gaussian_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import torch.nn.functional as F
from torch import Tensor

from emote.nn.initialization import ortho_init_


class SquashStretchTransform(transforms.Transform):
r"""
Expand Down Expand Up @@ -111,11 +113,10 @@ def __init__(self, observation_dim, action_dim, hidden_dims):
*[
nn.Sequential(nn.Linear(n_in, n_out), nn.ReLU())
for n_in, n_out in zip([observation_dim] + hidden_dims, hidden_dims)
]
],
GaussianPolicyHead(hidden_dims[-1], action_dim)
)
self.head = GaussianPolicyHead(hidden_dims[-1], action_dim)
self.seq.apply(ortho_init_)

def forward(self, obs):
x = self.seq(obs)
pre_actions, neg_log_probs = self.head(x)
return pre_actions, neg_log_probs
return self.seq(obs)
3 changes: 0 additions & 3 deletions pip-requirements.txt

This file was deleted.

18 changes: 18 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from distutils.core import setup

setup(
name='emote',
version='0.1',
description='A modular reinforcement learning library',
author ='Martin Singh-Blom, Tom Solberg, Jack Harmer, Jorge Del Val, Riley Miladi',
author_email='[email protected], [email protected], [email protected], [email protected], [email protected]',
packages=[],
install_requires=[
'gym',
'gym[atari]',
'gym[box2d]',
'gym[classic_control]',
'sphinx-rtd-theme',
'black'
]
)
3 changes: 2 additions & 1 deletion tests/gym/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from .hit_the_middle import HitTheMiddle
from .collector import SimpleGymCollector
from .collector import SimpleGymCollector, ThreadedGymCollector
from .dict_gym_wrapper import DictGymWrapper

__all__ = [
"HitTheMiddle",
"SimpleGymCollector",
"DictGymWrapper",
"ThreadedGymCollector",
]
31 changes: 20 additions & 11 deletions tests/gym/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def __init__(
agent: AgentProxy,
memory: MemoryProxy,
render: bool = True,
warmup_steps: int = 0,
):
super().__init__()
self._agent = agent
Expand All @@ -27,9 +28,12 @@ def __init__(
self._render = render
self._last_environment_rewards = deque(maxlen=1000)
self.num_envs = env.num_envs
self._warmup_steps = warmup_steps

def collect_data(self):
"""Collect a single rollout"""
if self._render:
self._env.render()
actions = self._agent(self._obs)
next_obs = self._env.dict_step(actions)
self._memory.add(self._obs, actions)
Expand All @@ -44,22 +48,29 @@ def collect_multiple(self, count: int):
self.collect_data()

def begin_training(self):
"Runs through the init, step cycle once on main thread to make sure all envs work."
"Make sure all envs work and collect warmup steps."
# Runs through the init, step cycle once on main thread to make sure all envs work.
self._obs = self._env.dict_reset()
actions = self._agent(self._obs)
_ = self._env.step(actions)
_ = self._env.dict_step(actions)
self._obs = self._env.dict_reset()

# Collect trajectories for warmup steps before starting training
iterations_required = self._warmup_steps
self.collect_multiple(iterations_required)


class ThreadedGymCollector(GymCollector):
def __init__(
self,
env: DictGymWrapper,
agent: AgentProxy,
memory: MemoryProxy,
render=True,
render: bool = True,
warmup_steps: int = 0,
):
super().__init__(env, agent, memory, render)
super().__init__(env, agent, memory, render, warmup_steps)
self._warmup_steps = warmup_steps
self._stop = False
self._thread = None

Expand All @@ -81,6 +92,7 @@ def collect_forever(self):
self.collect_data()

def begin_training(self):
super().begin_training()
self._thread = threading.Thread(target=self.collect_forever)
self._thread.start()

Expand All @@ -102,18 +114,15 @@ def __init__(
env: DictGymWrapper,
agent: AgentProxy,
memory: MemoryProxy,
render=True,
bp_steps_per_inf=10,
warmup_steps=0,
render: bool = True,
warmup_steps: int = 0,
bp_steps_per_inf: int = 10,
):
super().__init__(env, agent, memory, render)
self._warmup_steps = warmup_steps
super().__init__(env, agent, memory, render, warmup_steps)
self._bp_steps_per_inf = bp_steps_per_inf

def begin_training(self):
super().begin_training()
iterations_required = self._warmup_steps
self.collect_multiple(iterations_required)
return {"inf_step": self._warmup_steps}

def begin_batch(self, inf_step, bp_step):
Expand Down
4 changes: 2 additions & 2 deletions tests/test_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def test_networks_checkpoint():

t1 = Trainer(c1, onestep_dataloader())
t1.state["inf_step"] = 0
t1.state["bp_samples"] = 0
t1.state["bp_step"] = 0
t1.train()
n2 = nn.Linear(1, 1)
test_data = torch.rand(5, 1)
Expand Down Expand Up @@ -90,7 +90,7 @@ def test_qloss_checkpoints():

t1 = Trainer(c1, random_onestep_dataloader())
t1.state["inf_step"] = 0
t1.state["bp_samples"] = 0
t1.state["bp_step"] = 0
t1.train()
q2 = QNet(2, 1)
test_obs = torch.rand(5, 2)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_htm.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def test_htm():
]

callbacks = logged_cbs + [
SimpleGymCollector(env, agent_proxy, memory_proxy, warmup_steps=500),
SimpleGymCollector(env, agent_proxy, memory_proxy, warmup_steps=500, render=False),
TerminalLogger(logged_cbs, 400),
FinalLossTestCheck([logged_cbs[2]], [10.0], 2000),
]
Expand Down
112 changes: 112 additions & 0 deletions tests/test_lunar_lander.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import torch
from torch.utils.tensorboard import SummaryWriter
from torch import nn
from torch.optim import Adam
from gym.vector import SyncVectorEnv
import gym

from emote import Trainer
from emote.callbacks import (
FinalLossTestCheck,
TensorboardLogger
)
from emote.nn import GaussianPolicyHead
from emote.nn.initialization import ortho_init_
from emote.memory.builder import DictObsTable
from emote.sac import (
QLoss,
QTarget,
PolicyLoss,
AlphaLoss,
FeatureAgentProxy,
)
from emote.memory import TableMemoryProxy, MemoryLoader

from .gym import SimpleGymCollector, DictGymWrapper


class QNet(nn.Module):
def __init__(self, num_obs, num_actions, num_hidden):
super().__init__()
self.q = nn.Sequential(
nn.Linear(num_obs + num_actions, num_hidden),
nn.ReLU(),
nn.Linear(num_hidden, num_hidden),
nn.ReLU(),
nn.Linear(num_hidden, 1),
)
self.q.apply(ortho_init_)

def forward(self, action, obs):
x = torch.cat([obs, action], dim=1)
return self.q(x)


class Policy(nn.Module):
def __init__(self, num_obs, num_actions, num_hidden):
super().__init__()
self.pi = nn.Sequential(
nn.Linear(num_obs, num_hidden),
nn.ReLU(),
nn.Linear(num_hidden, num_hidden),
nn.ReLU(),
GaussianPolicyHead(num_hidden, num_actions),
)
self.pi.apply(ortho_init_)

def forward(self, obs):
return self.pi(obs)


def test_lunar_lander():

experiment_name = "Lunar-lander_test2"

hidden_layer = 256

batch_size = 500
rollout_len = 2

n_env = 60

learning_rate = 1e-3

env = DictGymWrapper(SyncVectorEnv([_make_env(i) for i in range(n_env)]))
table = DictObsTable(spaces=env.dict_space, maxlen=4_000_000)
memory_proxy = TableMemoryProxy(table)
dataloader = MemoryLoader(table, batch_size, rollout_len, "batch_size")

num_actions = env.dict_space.actions.shape[0]
num_obs = list(env.dict_space.state.spaces.values())[0].shape[0]

q1 = QNet(num_obs, num_actions, hidden_layer)
q2 = QNet(num_obs, num_actions, hidden_layer)
policy = Policy(num_obs, num_actions, hidden_layer)

ln_alpha = torch.tensor(1.0, requires_grad=True)
agent_proxy = FeatureAgentProxy(policy)

logged_cbs = [
QLoss(name="q1", q=q1, opt=Adam(q1.parameters(), lr=learning_rate)),
QLoss(name="q2", q=q2, opt=Adam(q2.parameters(), lr=learning_rate)),
PolicyLoss(pi=policy, ln_alpha=ln_alpha, q=q1, opt=Adam(policy.parameters(), lr=learning_rate)),
AlphaLoss(pi=policy, ln_alpha=ln_alpha, opt=Adam([ln_alpha]), n_actions=num_actions),
QTarget(pi=policy, ln_alpha=ln_alpha, q1=q1, q2=q2),
]

callbacks = logged_cbs + [
SimpleGymCollector(env, agent_proxy, memory_proxy, warmup_steps=batch_size*rollout_len),
TensorboardLogger(logged_cbs, SummaryWriter("runs/"+experiment_name), 2000),
FinalLossTestCheck([logged_cbs[2]], [10.0], 300_000_000),
]

trainer = Trainer(callbacks, dataloader)
trainer.train()

def _make_env(rank):
def _thunk():
env = gym.make("LunarLander-v2", continuous=True)
env.seed(rank)
return env
return _thunk