Skip to content

Commit

Permalink
[#136] muzero first commit
Browse files Browse the repository at this point in the history
  • Loading branch information
root authored and root committed Mar 10, 2022
1 parent a39d7ee commit b87d94b
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 0 deletions.
37 changes: 37 additions & 0 deletions jorldy/config/muzero/cartpole.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
### MuZero CartPole Config ###

env = {
"name": "cartpole",
"action_type": "discrete",
"render": False,
}

agent = {
"name": "muzero",
"network": "muzero",
"gamma": 0.99,
"epsilon_init": 1.0,
"epsilon_min": 0.01,
"explore_ratio": 0.2,
"buffer_size": 50000,
"batch_size": 32,
"start_train_step": 2000,
"target_update_period": 500,
}

optim = {
"name": "adam",
"lr": 0.0001,
}

train = {
"training": True,
"load_path": None,
"run_step": 100000,
"print_period": 1000,
"save_period": 10000,
"eval_iteration": 10,
# distributed setting
"update_period": 32,
"num_workers": 8,
}
85 changes: 85 additions & 0 deletions jorldy/core/agent/muzero.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from collections import deque
from itertools import islice
import torch
import torch.nn.functional as F

torch.backends.cudnn.benchmark = True
import numpy as np

from .base import BaseAgent

class MuZero(BaseAgent):
"""MuZero agent.
Args:
-
"""

def __init__(
self,
# MuZero
-
**kwargs
):
super(MuZero, self).__init__(network=network, **kwargs)


@torch.no_grad()
def act(self, state, training=True):
self.network.train(training)
pass

def learn(self):
pass
self.num_learn += 1

result = {
"loss": loss.item(),
}

return result

def process(self, transitions, step):
pass

return result


def interact_callback(self, transition):
pass

return _transition


def save(self, path):
print(f"...Save model to {path}...")
torch.save(
{
"network": self.network.state_dict(),
"optimizer": self.optimizer.state_dict(),
},
os.path.join(path, "ckpt"),
)

def load(self, path):
print(f"...Load model from {path}...")
checkpoint = torch.load(os.path.join(path, "ckpt"), map_location=self.device)
self.network.load_state_dict(checkpoint["network"])
self.target_network.load_state_dict(checkpoint["network"])
self.optimizer.load_state_dict(checkpoint["optimizer"])


class MCTS():
def __init__(self):
pass

def selection(self):
pass

def expansion(self):
pass

def backup(self):
pass


18 changes: 18 additions & 0 deletions jorldy/core/network/muzero.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import torch
import torch.nn.functional as F

from .base import BaseNetwork
from .utils import orthogonal_init


class MuZero(BaseNetwork):
def __init__(self, D_in, D_out, D_hidden=512, head="resnet"):
D_head_out = super(MuZero, self).__init__(D_in, D_hidden, head)
self.l = torch.nn.Linear(D_head_out, D_hidden)
self.pi = torch.nn.Linear(D_hidden, D_out)

orthogonal_init(self.l)
orthogonal_init(self.pi, "tanh")

def forward(self, x):
pass

0 comments on commit b87d94b

Please sign in to comment.