Skip to content

Commit

Permalink
Auto Generate benchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
WwZzz committed Apr 29, 2023
1 parent 289b4c4 commit 5416732
Show file tree
Hide file tree
Showing 8 changed files with 69 additions and 2 deletions.
2 changes: 1 addition & 1 deletion flgo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .utils.fflow import init, gen_task, gen_task_by_para, tune, run_in_parallel, multi_init_and_run
from .utils.fflow import init, gen_task, gen_task_by_para, gen_benchmark_from_file,tune, run_in_parallel, multi_init_and_run

communicator = None

Expand Down
2 changes: 1 addition & 1 deletion flgo/benchmark/mnist_classification/core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from flgo.benchmark.toolkits.cv.classification import GeneralCalculator, FromDatasetPipe, FromDatasetGenerator
from .dataset import train_data, test_data
from .config import train_data, test_data

class TaskGenerator(FromDatasetGenerator):
def __init__(self):
Expand Down
9 changes: 9 additions & 0 deletions flgo/benchmark/toolkits/cv/classification/temp/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from .model import default_model
import flgo.benchmark.toolkits.visualization
import flgo.benchmark.toolkits.partition

default_model = default_model
default_partitioner = flgo.benchmark.toolkits.partition.IIDPartitioner
default_partition_para = {'num_clients':100}
visualize = flgo.benchmark.toolkits.visualization.visualize_by_class

13 changes: 13 additions & 0 deletions flgo/benchmark/toolkits/cv/classification/temp/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""
train_data (torch.utils.data.Dataset),
test_data (torch.utils.data.Dataset),
and the model (torch.nn.Module) should be implemented here.
"""
import torch.nn

train_data = None
test_data = None

def get_model(*args, **kwargs) -> torch.nn.Module:
raise NotImplementedError
14 changes: 14 additions & 0 deletions flgo/benchmark/toolkits/cv/classification/temp/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import os
from flgo.benchmark.toolkits.cv.classification import GeneralCalculator, FromDatasetPipe, FromDatasetGenerator
from .config import train_data, test_data

class TaskGenerator(FromDatasetGenerator):
def __init__(self):
super(TaskGenerator, self).__init__(benchmark=os.path.split(os.path.dirname(__file__))[-1],
train_data=train_data, test_data=test_data)

class TaskPipe(FromDatasetPipe):
def __init__(self, task_path):
super(TaskPipe, self).__init__(task_path, train_data, test_data)

TaskCalculator = GeneralCalculator
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from ..config import get_model
from flgo.utils.fmodule import FModule

class Model(FModule):
def __init__(self):
super().__init__()
self.model = get_model()

def forward(self, *args, **kwargs):
return self.model(*args, **kwargs)

def init_local_module(object):
pass

def init_global_module(object):
if 'Server' in object.__class__.__name__:
object.model = Model().to(object.device)
14 changes: 14 additions & 0 deletions flgo/utils/fflow.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import collections
import shutil
import sys
import copy
import multiprocessing
Expand Down Expand Up @@ -174,6 +175,19 @@ def load_configuration(config={}):
else:
raise TypeError('The input config should be either a dict or a filename.')

def gen_benchmark_from_file(benchmark:str, config_file:str, target_path='.',data_type:str='cv', task_type:str='classification'):
if not os.path.exists(config_file): raise FileNotFoundError('File {} not found.'.format(config_file))
target_path = os.path.abspath(target_path)
bmk_path = os.path.join(target_path, benchmark)
if os.path.exists(bmk_path): raise FileExistsError('Task {} already exists'.format(bmk_path))
if data_type.lower() =='cv':
if task_type == 'classification':
temp_path = os.path.join(flgo.benchmark.path, 'toolkits', 'cv', 'classification', 'temp')
shutil.copytree(temp_path, bmk_path)
shutil.copyfile(config_file, os.path.join(bmk_path, 'config.py'))
bmk_module = '.'.join(os.path.relpath(bmk_path, os.getcwd()).split(os.path.sep))
return bmk_module

def gen_task_by_para(benchmark, bmk_para:dict={}, Partitioner=None, par_para:dict={}, task_path: str='', rawdata_path:str='', seed:int=0):
r"""
Generate a federated task according to the parameters of this function. The formats and meanings of the inputs are listed as below:
Expand Down

0 comments on commit 5416732

Please sign in to comment.