-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make train on lunar lander #14
Merged
Merged
Changes from all commits
Commits
Show all changes
17 commits
Select commit
Hold shift + click to select a range
babc9de
Add setup file
riley-mld ccde355
Fix stepping bug
riley-mld 4128388
Add ortho weight init
riley-mld 469c36d
Add bp_samples to state and correct bp_step to bp_samples in trainer
riley-mld fddd812
Add lunarlander test
riley-mld 5509196
Add rendering and warmup steps to collectors
riley-mld 528d498
Disable rendering for hit the middle
riley-mld 8e8719b
Sync to main
riley-mld 21b2531
Revert bp_step to bp_sample change
riley-mld 36f1e2c
Renaming
riley-mld bc52e3b
remove req file
riley-mld 930f241
Address comments and build network in test file
riley-mld fbb740a
Move collecting warmup steps to GymCollector
riley-mld a1e631b
Remove bp_samples
riley-mld 4445bfa
Fix bug
riley-mld 7c2ae64
address comments
riley-mld 46b243b
Remove unused code
riley-mld File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
from distutils.core import setup | ||
|
||
setup( | ||
name='emote', | ||
version='0.1', | ||
description='A modular reinforcement learning library', | ||
author ='Martin Singh-Blom, Tom Solberg, Jack Harmer, Jorge Del Val, Riley Miladi', | ||
author_email='[email protected], [email protected], [email protected], [email protected], [email protected]', | ||
packages=[], | ||
install_requires=[ | ||
'gym', | ||
'gym[atari]', | ||
'gym[box2d]', | ||
'gym[classic_control]', | ||
'sphinx-rtd-theme', | ||
'black' | ||
] | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,10 @@ | ||
from .hit_the_middle import HitTheMiddle | ||
from .collector import SimpleGymCollector | ||
from .collector import SimpleGymCollector, ThreadedGymCollector | ||
from .dict_gym_wrapper import DictGymWrapper | ||
|
||
__all__ = [ | ||
"HitTheMiddle", | ||
"SimpleGymCollector", | ||
"DictGymWrapper", | ||
"ThreadedGymCollector", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
import torch | ||
from torch.utils.tensorboard import SummaryWriter | ||
from torch import nn | ||
from torch.optim import Adam | ||
from gym.vector import SyncVectorEnv | ||
import gym | ||
|
||
from emote import Trainer | ||
from emote.callbacks import ( | ||
FinalLossTestCheck, | ||
TensorboardLogger | ||
) | ||
from emote.nn import GaussianPolicyHead | ||
from emote.nn.initialization import ortho_init_ | ||
from emote.memory.builder import DictObsTable | ||
from emote.sac import ( | ||
QLoss, | ||
QTarget, | ||
PolicyLoss, | ||
AlphaLoss, | ||
FeatureAgentProxy, | ||
) | ||
from emote.memory import TableMemoryProxy, MemoryLoader | ||
|
||
from .gym import SimpleGymCollector, DictGymWrapper | ||
|
||
|
||
class QNet(nn.Module): | ||
def __init__(self, num_obs, num_actions, num_hidden): | ||
super().__init__() | ||
self.q = nn.Sequential( | ||
nn.Linear(num_obs + num_actions, num_hidden), | ||
nn.ReLU(), | ||
nn.Linear(num_hidden, num_hidden), | ||
nn.ReLU(), | ||
nn.Linear(num_hidden, 1), | ||
) | ||
self.q.apply(ortho_init_) | ||
|
||
def forward(self, action, obs): | ||
x = torch.cat([obs, action], dim=1) | ||
return self.q(x) | ||
|
||
|
||
class Policy(nn.Module): | ||
def __init__(self, num_obs, num_actions, num_hidden): | ||
super().__init__() | ||
self.pi = nn.Sequential( | ||
nn.Linear(num_obs, num_hidden), | ||
nn.ReLU(), | ||
nn.Linear(num_hidden, num_hidden), | ||
nn.ReLU(), | ||
GaussianPolicyHead(num_hidden, num_actions), | ||
) | ||
self.pi.apply(ortho_init_) | ||
|
||
def forward(self, obs): | ||
return self.pi(obs) | ||
|
||
|
||
def test_lunar_lander(): | ||
|
||
experiment_name = "Lunar-lander_test2" | ||
|
||
hidden_layer = 256 | ||
|
||
batch_size = 500 | ||
rollout_len = 2 | ||
|
||
n_env = 60 | ||
|
||
learning_rate = 1e-3 | ||
|
||
env = DictGymWrapper(SyncVectorEnv([_make_env(i) for i in range(n_env)])) | ||
table = DictObsTable(spaces=env.dict_space, maxlen=4_000_000) | ||
memory_proxy = TableMemoryProxy(table) | ||
dataloader = MemoryLoader(table, batch_size, rollout_len, "batch_size") | ||
|
||
num_actions = env.dict_space.actions.shape[0] | ||
num_obs = list(env.dict_space.state.spaces.values())[0].shape[0] | ||
|
||
q1 = QNet(num_obs, num_actions, hidden_layer) | ||
q2 = QNet(num_obs, num_actions, hidden_layer) | ||
policy = Policy(num_obs, num_actions, hidden_layer) | ||
|
||
ln_alpha = torch.tensor(1.0, requires_grad=True) | ||
agent_proxy = FeatureAgentProxy(policy) | ||
|
||
logged_cbs = [ | ||
QLoss(name="q1", q=q1, opt=Adam(q1.parameters(), lr=learning_rate)), | ||
QLoss(name="q2", q=q2, opt=Adam(q2.parameters(), lr=learning_rate)), | ||
PolicyLoss(pi=policy, ln_alpha=ln_alpha, q=q1, opt=Adam(policy.parameters(), lr=learning_rate)), | ||
AlphaLoss(pi=policy, ln_alpha=ln_alpha, opt=Adam([ln_alpha]), n_actions=num_actions), | ||
QTarget(pi=policy, ln_alpha=ln_alpha, q1=q1, q2=q2), | ||
] | ||
|
||
callbacks = logged_cbs + [ | ||
SimpleGymCollector(env, agent_proxy, memory_proxy, warmup_steps=batch_size*rollout_len), | ||
TensorboardLogger(logged_cbs, SummaryWriter("runs/"+experiment_name), 2000), | ||
FinalLossTestCheck([logged_cbs[2]], [10.0], 300_000_000), | ||
] | ||
|
||
trainer = Trainer(callbacks, dataloader) | ||
trainer.train() | ||
|
||
def _make_env(rank): | ||
def _thunk(): | ||
env = gym.make("LunarLander-v2", continuous=True) | ||
env.seed(rank) | ||
return env | ||
return _thunk | ||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've always found samples to be by far the most useful stat to log against, it removes the batch size dependency. Can we keep this, or perhaps default to it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's fair enough. What if we don't keep samples in state_dict as Martin suggested, but since as you said it's useful for logging, we can also log samples when we log but we won't get it from state_dict, instead we just multiply steps by batch size to get it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fair enough. I propose we let Riley merge this and then try to come up with a really nice solution when we address #18
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That sounds reasonable, let's do it later then.