Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

first commit #72

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions .devcontainer/devcontainer.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
// For format details, see https://aka.ms/devcontainer.json. For config options, see the README at:
// https://github.com/microsoft/vscode-dev-containers/tree/v0.245.2/containers/python-3
{
"name": "Python 3",
"build": {
"dockerfile": "Dockerfile",
"context": "..",
"args": {
// Update 'VARIANT' to pick a Python version: 3, 3.10, 3.9, 3.8, 3.7, 3.6
// Append -bullseye or -buster to pin to an OS version.
// Use -bullseye variants on local on arm64/Apple Silicon.
"VARIANT": "3.10-bullseye",
// Options
"NODE_VERSION": "lts/*"
}
},

// Configure tool-specific properties.
"customizations": {
// Configure properties specific to VS Code.
"vscode": {
// Set *default* container specific settings.json values on container create.
"settings": {
"python.defaultInterpreterPath": "/usr/local/bin/python",
"python.linting.enabled": true,
"python.linting.pylintEnabled": true,
"python.formatting.autopep8Path": "/usr/local/py-utils/bin/autopep8",
"python.formatting.blackPath": "/usr/local/py-utils/bin/black",
"python.formatting.yapfPath": "/usr/local/py-utils/bin/yapf",
"python.linting.banditPath": "/usr/local/py-utils/bin/bandit",
"python.linting.flake8Path": "/usr/local/py-utils/bin/flake8",
"python.linting.mypyPath": "/usr/local/py-utils/bin/mypy",
"python.linting.pycodestylePath": "/usr/local/py-utils/bin/pycodestyle",
"python.linting.pydocstylePath": "/usr/local/py-utils/bin/pydocstyle",
"python.linting.pylintPath": "/usr/local/py-utils/bin/pylint"
},

// Add the IDs of extensions you want installed when the container is created.
"extensions": [
"ms-python.python",
"ms-python.vscode-pylance"
]
}
},

// Use 'forwardPorts' to make a list of ports inside the container available locally.
// "forwardPorts": [],

// Use 'postCreateCommand' to run commands after the container is created.
// "postCreateCommand": "pip3 install --user -r requirements.txt",

// Comment out to connect as root instead. More info: https://aka.ms/vscode-remote/containers/non-root.
"remoteUser": "vscode",
"features": {
"azure-cli": "latest",
"sshd": "latest",
"powershell": "7.1",
"jupyterlab": "3.6.2"
}
}
1 change: 1 addition & 0 deletions Minigridcustom
Submodule Minigridcustom added at 8fdebe
97 changes: 97 additions & 0 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,111 @@ def __init__(self, obs_space, action_space, use_memory=False, use_text=False):
self.use_text = use_text
self.use_memory = use_memory


# Define image embedding
self.image_conv = nn.Sequential(
nn.Conv2d(3, 16, (2, 2)),
nn.ReLU(),
nn.MaxPool2d((2, 2)),
nn.Conv2d(16, 32, (2, 2)),
nn.ReLU(),
nn.Conv2d(32, 64, (2, 2)),
nn.ReLU()
)
n = obs_space["image"][0]
m = obs_space["image"][1]
self.image_embedding_size = ((n-1)//2-2)*((m-1)//2-2)*64

# Define memory
if self.use_memory:
self.memory_rnn = nn.LSTMCell(self.image_embedding_size, self.semi_memory_size)

# Define text embedding
if self.use_text:
self.word_embedding_size = 32
self.word_embedding = nn.Embedding(obs_space["text"], self.word_embedding_size)
self.text_embedding_size = 128
self.text_rnn = nn.GRU(self.word_embedding_size, self.text_embedding_size, batch_first=True)

# Resize image embedding
self.embedding_size = self.semi_memory_size
if self.use_text:
self.embedding_size += self.text_embedding_size

# Define actor's model
self.actor = nn.Sequential(
nn.Linear(self.embedding_size, 64),
nn.Tanh(),
nn.Linear(64, action_space.n)
)

# Define critic's model
self.critic = nn.Sequential(
nn.Linear(self.embedding_size, 64),
nn.Tanh(),
nn.Linear(64, 1)
)

# Initialize parameters correctly
self.apply(init_params)

@property
def memory_size(self):
return 2*self.semi_memory_size

@property
def semi_memory_size(self):
return self.image_embedding_size

def forward(self, obs, memory):
x = obs.image.transpose(1, 3).transpose(2, 3)
x = self.image_conv(x)
x = x.reshape(x.shape[0], -1)

if self.use_memory:
hidden = (memory[:, :self.semi_memory_size], memory[:, self.semi_memory_size:])
hidden = self.memory_rnn(x, hidden)
embedding = hidden[0]
memory = torch.cat(hidden, dim=1)
else:
embedding = x

if self.use_text:
embed_text = self._get_embed_text(obs.text)
embedding = torch.cat((embedding, embed_text), dim=1)

x = self.actor(embedding)
dist = Categorical(logits=F.log_softmax(x, dim=1))

x = self.critic(embedding)
value = x.squeeze(1)

return dist, value, memory

def _get_embed_text(self, text):
_, hidden = self.text_rnn(self.word_embedding(text))
return hidden[-1]

class ACModelDrop(nn.Module, torch_ac.RecurrentACModel):
def __init__(self, obs_space, action_space, use_memory=False, use_text=False):
super().__init__()

# Decide which components are enabled
self.use_text = use_text
self.use_memory = use_memory


# Define image embedding
self.image_conv = nn.Sequential(
nn.Conv2d(3, 16, (2, 2)),
nn.Dropout(p=0.5),
nn.ReLU(),
nn.MaxPool2d((2, 2)),
nn.Conv2d(16, 32, (2, 2)),
nn.Dropout(p=0.5),
nn.ReLU(),
nn.Conv2d(32, 64, (2, 2)),
nn.Dropout(p=0.5),
nn.ReLU()
)
n = obs_space["image"][0]
Expand Down
11 changes: 11 additions & 0 deletions scripts/train.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@


import argparse
import time
import datetime
Expand All @@ -8,6 +10,7 @@
import utils
from utils import device
from model import ACModel
from model import ACModelDrop


# Parse arguments
Expand All @@ -31,6 +34,8 @@
help="number of processes (default: 16)")
parser.add_argument("--frames", type=int, default=10**7,
help="number of frames of training (default: 1e7)")
parser.add_argument("--drop", type=int, default=0,
help="number of frames of training (default: 1e7)")

# Parameters for main algorithm
parser.add_argument("--epochs", type=int, default=4,
Expand Down Expand Up @@ -118,6 +123,9 @@

# Load model

if args.drop == 1:
acmodel = ACModelDrop(obs_space, envs[0].action_space, args.mem, args.text)

acmodel = ACModel(obs_space, envs[0].action_space, args.mem, args.text)
if "model_state" in status:
acmodel.load_state_dict(status["model_state"])
Expand All @@ -131,6 +139,7 @@
algo = torch_ac.A2CAlgo(envs, acmodel, device, args.frames_per_proc, args.discount, args.lr, args.gae_lambda,
args.entropy_coef, args.value_loss_coef, args.max_grad_norm, args.recurrence,
args.optim_alpha, args.optim_eps, preprocess_obss)

elif args.algo == "ppo":
algo = torch_ac.PPOAlgo(envs, acmodel, device, args.frames_per_proc, args.discount, args.lr, args.gae_lambda,
args.entropy_coef, args.value_loss_coef, args.max_grad_norm, args.recurrence,
Expand Down Expand Up @@ -201,3 +210,5 @@
status["vocab"] = preprocess_obss.vocab.vocab
utils.save_status(status, model_dir)
txt_logger.info("Status saved")


1 change: 1 addition & 0 deletions torch-ac
Submodule torch-ac added at b6602c
3 changes: 3 additions & 0 deletions utils/env.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import gymnasium as gym
# from Minigridcustom.minigrid.minigrid_env import MiniGridEnv
# from Minigridcustom.minigrid.envs.custom import CustomEnv



def make_env(env_key, seed=None, render_mode=None):
Expand Down