diff --git a/main.py b/main.py new file mode 100644 index 0000000..f78f70b --- /dev/null +++ b/main.py @@ -0,0 +1,17 @@ +import sys +sys.path.insert(0,"/content/federated-learning/") + +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt +import copy +import numpy as np +from torchvision import datasets, transforms +import torch + +from utils.sampling import mnist_iid, mnist_noniid, cifar_iid +from utils.options import args_parser +from models.Update import LocalUpdate +from models.Nets import MLP, CNNMnist, CNNCifar +from models.Fed import FedAvg +from models.test import test_img diff --git a/main_fed.py b/main_fed.py index d170fc8..9b47861 100644 --- a/main_fed.py +++ b/main_fed.py @@ -20,10 +20,14 @@ if __name__ == '__main__': # parse args + # if 這邊是調用option 的資料 args = args_parser() - args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu') + # How we set up the function if you don't use gpu you don't - # load dataset and split users + args.device = torch.device('cuda:{}'.format(args.tpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu') + + # load dataset and split users + # search the iid data if you want more you need set up by tfdata set if args.dataset == 'mnist': trans_mnist = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) dataset_train = datasets.MNIST('../data/mnist/', train=True, download=True, transform=trans_mnist) diff --git a/models/test.py b/models/test.py index 0e46fd7..e66e04a 100644 --- a/models/test.py +++ b/models/test.py @@ -16,14 +16,14 @@ def test_img(net_g, datatest, args): data_loader = DataLoader(datatest, batch_size=args.bs) l = len(data_loader) for idx, (data, target) in enumerate(data_loader): - if args.gpu != -1: + if args.tpu != -1: data, target = data.cuda(), target.cuda() log_probs = net_g(data) # sum up batch loss test_loss += F.cross_entropy(log_probs, target, reduction='sum').item() # get the index of the max log-probability y_pred = log_probs.data.max(1, keepdim=True)[1] - correct += y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum() + correct += y_pred.eq(target.data.view_as(y_pred)).long().tpu().sum() test_loss /= len(data_loader.dataset) accuracy = 100.00 * correct / len(data_loader.dataset) diff --git a/utils/options.py b/utils/options.py index b29af3b..b29ef6a 100644 --- a/utils/options.py +++ b/utils/options.py @@ -32,10 +32,12 @@ def args_parser(): parser.add_argument('--iid', action='store_true', help='whether i.i.d or not') parser.add_argument('--num_classes', type=int, default=10, help="number of classes") parser.add_argument('--num_channels', type=int, default=3, help="number of channels of imges") - parser.add_argument('--gpu', type=int, default=0, help="GPU ID, -1 for CPU") + parser.add_argument('--tpu', type=int, default=0, help="tpu ID, -1 for CPU") parser.add_argument('--stopping_rounds', type=int, default=10, help='rounds of early stopping') parser.add_argument('--verbose', action='store_true', help='verbose print') parser.add_argument('--seed', type=int, default=1, help='random seed (default: 1)') parser.add_argument('--all_clients', action='store_true', help='aggregation over all clients') - args = parser.parse_args() + # if for colab need to add + #args=[] + args = parser.parse_args(args=[]) return args