diff --git a/flgo/__init__.py b/flgo/__init__.py index f473b861..23e70397 100644 --- a/flgo/__init__.py +++ b/flgo/__init__.py @@ -1,4 +1,4 @@ -from .utils.fflow import init, gen_task, gen_task_by_para, tune, run_in_parallel, multi_init_and_run +from .utils.fflow import init, gen_task, gen_task_by_para, gen_benchmark_from_file,tune, run_in_parallel, multi_init_and_run communicator = None diff --git a/flgo/benchmark/mnist_classification/dataset.py b/flgo/benchmark/mnist_classification/config.py similarity index 100% rename from flgo/benchmark/mnist_classification/dataset.py rename to flgo/benchmark/mnist_classification/config.py diff --git a/flgo/benchmark/mnist_classification/core.py b/flgo/benchmark/mnist_classification/core.py index 4a3f50c3..b8029a8f 100644 --- a/flgo/benchmark/mnist_classification/core.py +++ b/flgo/benchmark/mnist_classification/core.py @@ -1,6 +1,6 @@ import os from flgo.benchmark.toolkits.cv.classification import GeneralCalculator, FromDatasetPipe, FromDatasetGenerator -from .dataset import train_data, test_data +from .config import train_data, test_data class TaskGenerator(FromDatasetGenerator): def __init__(self): diff --git a/flgo/benchmark/toolkits/cv/classification/temp/__init__.py b/flgo/benchmark/toolkits/cv/classification/temp/__init__.py new file mode 100644 index 00000000..b8bdee51 --- /dev/null +++ b/flgo/benchmark/toolkits/cv/classification/temp/__init__.py @@ -0,0 +1,9 @@ +from .model import default_model +import flgo.benchmark.toolkits.visualization +import flgo.benchmark.toolkits.partition + +default_model = default_model +default_partitioner = flgo.benchmark.toolkits.partition.IIDPartitioner +default_partition_para = {'num_clients':100} +visualize = flgo.benchmark.toolkits.visualization.visualize_by_class + diff --git a/flgo/benchmark/toolkits/cv/classification/temp/config.py b/flgo/benchmark/toolkits/cv/classification/temp/config.py new file mode 100644 index 00000000..d967d262 --- /dev/null +++ b/flgo/benchmark/toolkits/cv/classification/temp/config.py @@ -0,0 +1,13 @@ +""" +train_data (torch.utils.data.Dataset), +test_data (torch.utils.data.Dataset), +and the model (torch.nn.Module) should be implemented here. + +""" +import torch.nn + +train_data = None +test_data = None + +def get_model(*args, **kwargs) -> torch.nn.Module: + raise NotImplementedError \ No newline at end of file diff --git a/flgo/benchmark/toolkits/cv/classification/temp/core.py b/flgo/benchmark/toolkits/cv/classification/temp/core.py new file mode 100644 index 00000000..b8029a8f --- /dev/null +++ b/flgo/benchmark/toolkits/cv/classification/temp/core.py @@ -0,0 +1,14 @@ +import os +from flgo.benchmark.toolkits.cv.classification import GeneralCalculator, FromDatasetPipe, FromDatasetGenerator +from .config import train_data, test_data + +class TaskGenerator(FromDatasetGenerator): + def __init__(self): + super(TaskGenerator, self).__init__(benchmark=os.path.split(os.path.dirname(__file__))[-1], + train_data=train_data, test_data=test_data) + +class TaskPipe(FromDatasetPipe): + def __init__(self, task_path): + super(TaskPipe, self).__init__(task_path, train_data, test_data) + +TaskCalculator = GeneralCalculator \ No newline at end of file diff --git a/flgo/benchmark/toolkits/cv/classification/temp/model/default_model.py b/flgo/benchmark/toolkits/cv/classification/temp/model/default_model.py new file mode 100644 index 00000000..4e019879 --- /dev/null +++ b/flgo/benchmark/toolkits/cv/classification/temp/model/default_model.py @@ -0,0 +1,17 @@ +from ..config import get_model +from flgo.utils.fmodule import FModule + +class Model(FModule): + def __init__(self): + super().__init__() + self.model = get_model() + + def forward(self, *args, **kwargs): + return self.model(*args, **kwargs) + +def init_local_module(object): + pass + +def init_global_module(object): + if 'Server' in object.__class__.__name__: + object.model = Model().to(object.device) \ No newline at end of file diff --git a/flgo/utils/fflow.py b/flgo/utils/fflow.py index 9d1cb314..e72a04c1 100644 --- a/flgo/utils/fflow.py +++ b/flgo/utils/fflow.py @@ -1,4 +1,5 @@ import collections +import shutil import sys import copy import multiprocessing @@ -174,6 +175,19 @@ def load_configuration(config={}): else: raise TypeError('The input config should be either a dict or a filename.') +def gen_benchmark_from_file(benchmark:str, config_file:str, target_path='.',data_type:str='cv', task_type:str='classification'): + if not os.path.exists(config_file): raise FileNotFoundError('File {} not found.'.format(config_file)) + target_path = os.path.abspath(target_path) + bmk_path = os.path.join(target_path, benchmark) + if os.path.exists(bmk_path): raise FileExistsError('Task {} already exists'.format(bmk_path)) + if data_type.lower() =='cv': + if task_type == 'classification': + temp_path = os.path.join(flgo.benchmark.path, 'toolkits', 'cv', 'classification', 'temp') + shutil.copytree(temp_path, bmk_path) + shutil.copyfile(config_file, os.path.join(bmk_path, 'config.py')) + bmk_module = '.'.join(os.path.relpath(bmk_path, os.getcwd()).split(os.path.sep)) + return bmk_module + def gen_task_by_para(benchmark, bmk_para:dict={}, Partitioner=None, par_para:dict={}, task_path: str='', rawdata_path:str='', seed:int=0): r""" Generate a federated task according to the parameters of this function. The formats and meanings of the inputs are listed as below: