Skip to content

Commit

Permalink
multiple agents equal to the number of policies learned, but no commo…
Browse files Browse the repository at this point in the history
…n storage yet
  • Loading branch information
joannapng committed Mar 16, 2024
1 parent 03c5ffb commit cfa01c8
Show file tree
Hide file tree
Showing 38 changed files with 40 additions and 1,146 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
data/*
pretrain/data/*
pretrain/checkpoints/*
__pycache__
nohup.out
4 changes: 2 additions & 2 deletions pretrain/pretrain.py → pretrain.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import argparse
import torchvision
from trainer import Trainer
from utils import get_model_config
from pretrain.trainer import Trainer
from pretrain.utils import get_model_config

model_names = sorted(name for name in torchvision.models.__dict__ if name.islower() and not name.startswith("__") and
callable(torchvision.models.__dict__[name]) and not name.startswith("get_"))
Expand Down
Empty file added pretrain/__init__.py
Empty file.
2 changes: 1 addition & 1 deletion pretrain/logger/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@

__all__ = ["Logger"]

from .Logger import *
from .Logger import Logger
Binary file modified pretrain/logger/__pycache__/__init__.cpython-311.pyc
Binary file not shown.
2 changes: 1 addition & 1 deletion pretrain/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@

__all__ = ["LeNet5"]

from .LeNet5 import *
from .LeNet5 import *
Binary file modified pretrain/models/__pycache__/__init__.cpython-311.pyc
Binary file not shown.
8 changes: 4 additions & 4 deletions pretrain/trainer/Trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
from torchvision import transforms
from torchvision.datasets import CIFAR10, MNIST
import torchvision.models
from logger import Logger
from models import LeNet5
from utils import get_torchvision_model
from ..logger import Logger
from ..models import LeNet5
from ..utils import *

networks = {'LeNet5' : LeNet5}

Expand Down Expand Up @@ -139,7 +139,7 @@ def init_model(self):
self.model = builder(num_classes = self.num_classes, in_channels = self.in_channels).to(self.device)
else:
# if model-path is specified, it will be loaded below
self.model = get_torchvision_model(self.args.model_name, self.num_classes, self.device, self.args.pretrained and self.args.model_path is None)
self.model = utils.get_torchvision_model(self.args.model_name, self.num_classes, self.device, self.args.pretrained and self.args.model_path is None)

if self.args.resume_from is not None and not self.args.pretrained: # resume training from checkpoint
print('Loading model from checkpoint at: {}'.format(self.args.resume_from))
Expand Down
2 changes: 1 addition & 1 deletion pretrain/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@

__all__ = ["Trainer"]

from .Trainer import *
from .Trainer import Trainer
Binary file modified pretrain/trainer/__pycache__/Trainer.cpython-311.pyc
Binary file not shown.
Binary file modified pretrain/trainer/__pycache__/__init__.cpython-311.pyc
Binary file not shown.
File renamed without changes.
6 changes: 0 additions & 6 deletions pretrain/utils/__init__.py

This file was deleted.

Binary file removed pretrain/utils/__pycache__/__init__.cpython-311.pyc
Binary file not shown.
Binary file removed pretrain/utils/__pycache__/utils.cpython-311.pyc
Binary file not shown.
19 changes: 16 additions & 3 deletions train/train.py → train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import argparse
import torchvision
from agent import Agent
import numpy as np
from train.env import ModelEnv
from pretrain.utils import get_model_config
from stable_baselines3 import DDPG

model_names = sorted(name for name in torchvision.models.__dict__ if name.islower() and not name.startswith("__") and
callable(torchvision.models.__dict__[name]) and not name.startswith("get_"))
Expand Down Expand Up @@ -53,11 +56,21 @@
parser.add_argument('--max-bit', type=int, default=8, help = 'Maximum bit width (default: 8)')

### ----- AGENT ------ ###
parser.add_argument('--num_agents', default = 10, type = int, help = 'Number of agents')
parser.add_argument('--num_agents', default = 5, type = int, help = 'Number of agents')

def main():
args = parser.parse_args()
agent = Agent(args)
envs = []
agents = []
weights = [[1.0, 0.0000], [1.0, 0.0025], [1.0, 0.0050], [1.0, 0.0075], [1.0, 0.0100]]

for i in range(args.num_agents):
envs.append(ModelEnv(args, np.array(weights[i]), get_model_config(args.model_name, args.custom_model_name)))
agents.append(DDPG("MlpPolicy", envs[-1], verbose = 1))

for i, agent in enumerate(agents):
agent.learn(total_timesteps = 20, log_interval = 10)
agent.save("agent_{}_{}".format(weights[i][0], weights[i][1]))

if __name__ == "__main__":
main()
Empty file added train/__init__.py
Empty file.
6 changes: 0 additions & 6 deletions train/agent/Agent.py

This file was deleted.

Loading

0 comments on commit cfa01c8

Please sign in to comment.