Skip to content

Commit

Permalink
Auto Generate benchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
WwZzz committed Apr 29, 2023
1 parent 5416732 commit ea352ff
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
2 changes: 1 addition & 1 deletion flgo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .utils.fflow import init, gen_task, gen_task_by_para, gen_benchmark_from_file,tune, run_in_parallel, multi_init_and_run
from .utils.fflow import init, gen_task, gen_task_from_para, gen_benchmark_from_file,tune, run_in_parallel, multi_init_and_run

communicator = None

Expand Down
8 changes: 5 additions & 3 deletions flgo/utils/fflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,11 +184,13 @@ def gen_benchmark_from_file(benchmark:str, config_file:str, target_path='.',data
if task_type == 'classification':
temp_path = os.path.join(flgo.benchmark.path, 'toolkits', 'cv', 'classification', 'temp')
shutil.copytree(temp_path, bmk_path)
else:
raise NotImplementedError('FLGo currently only support automatically generate cv.classification task. More other types are comming soon...')
shutil.copyfile(config_file, os.path.join(bmk_path, 'config.py'))
bmk_module = '.'.join(os.path.relpath(bmk_path, os.getcwd()).split(os.path.sep))
return bmk_module

def gen_task_by_para(benchmark, bmk_para:dict={}, Partitioner=None, par_para:dict={}, task_path: str='', rawdata_path:str='', seed:int=0):
def gen_task_from_para(benchmark, bmk_para:dict={}, Partitioner=None, par_para:dict={}, task_path: str= '', rawdata_path:str= '', seed:int=0):
r"""
Generate a federated task according to the parameters of this function. The formats and meanings of the inputs are listed as below:
Expand All @@ -207,9 +209,9 @@ def gen_task_by_para(benchmark, bmk_para:dict={}, Partitioner=None, par_para:dic
>>> import flgo.benchmark.mnist_classification as mnist
>>> from flgo.benchmark.toolkits.partition import IIDPartitioner
>>> # GENERATE TASK BY PASSING THE MODULE OF BENCHMARK AND THE CLASS OF THE PARTITIOENR
>>> flgo.gen_task_by_para(benchmark=mnist, Partitioner = IIDPartitioner, par_para={'num_clients':100}, task_path='./mnist_gen_by_para1')
>>> flgo.gen_task_from_para(benchmark=mnist, Partitioner = IIDPartitioner, par_para={'num_clients':100}, task_path='./mnist_gen_by_para1')
>>> # GENERATE THE SAME TASK BY PASSING THE STRING
>>> flgo.gen_task_by_para(benchmark='flgo.benchmark.mnist_classification', Partitioner='IIDPartitioner', par_para={'num_clients':100}, task_path='./mnist_gen_by_para2')
>>> flgo.gen_task_from_para(benchmark='flgo.benchmark.mnist_classification', Partitioner='IIDPartitioner', par_para={'num_clients':100}, task_path='./mnist_gen_by_para2')
```
"""
random.seed(3 + seed)
Expand Down

0 comments on commit ea352ff

Please sign in to comment.