Skip to content

cuis15/learning-to-collaborate

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

learning-to-collaborate

the implementation of learning to collaborate


The description of all files

  1. main.py: the main function for all experiments;
  2. hyper_model/models.py: this script defines all necessary model structures;
  3. solvers: this script implements two different optimization for learning the whole Pareto Front (linear and EPO);
  4. train.py: this script defines the train function for all experiments (including evaluation, saving model, loading model, etc.);
  5. utils/utils_data.py: this script pre-processes all data set which will be used for the following training and evaluating;
  6. utils/utils_func.py: the needed extra functions;
  7. utils/utils_sampling.py: this script is used for generating non i.i.d data for all experiments.

Preparations

Construct Conda Environment

python 3.6, the needed environment libraries are in requirements.txt

conda create -n E8519 python=3.6
conda install --yes --file requirements.txt

Datasets

  1. Synthetic data set the source code will generate the needed data when running the synthetic experiments
  2. Adult data set Adult data set is pre-processed following the work in [1,2] the processed data set is in data/Adult
  3. CIFAR10 We download CIFAR10 when firstly running the CIFAR10 experiments and the downloaded data will be saved in data/CIFAR10
  4. eICU eICU dataset needs approval when researchers need to have access to it.

Get Started

Parameters Description in main.py for Running all experiments

  1. dataset: the needed dataset for running experiments;
  2. trainN: the generated synthetic data samples for training;
  3. std: the $\rho$ when generating synthetic data;
  4. sample_ray: whether we need to sample direction vectors $d$ for training the Pareto Front of all objectives;
  5. target_usr: the target client where the learned personalized model will deploy;
  6. total_hnet_epoch: the num of epoch for training the Pareto Front;
  7. total_ray_epoch: the num of epoch for training the direction vector $d$
  8. lr: learning rate for updating the hypernetwork;
  9. lr_prefer: learning rate for updating the direction vector $d$;
  10. gpus: the GPU device;
  11. n_hidden: the num of hidden layers of the hypernetwork;
  12. num_users: the num of clients in each experiment

Example Synthetic Experiment

python main.py --dataset synthetic --trainN 2000 --std 0.1 --sample_ray --target_dir synthetic --target_usr 4 --total_hnet_epoch 1000 --epochs_per_valid 1 --total_ray_epoch 200 --total_epoch 1 --gpus 0 --n_hidden 1 --lr 0.01 --lr_prefer 0.01 --seed 1 --solver_type linear

Example Adult Experiment

python  main.py --dataset adult --target_dir adult --target_usr 0 --total_hnet_epoch 20 --epochs_per_valid 100 --total_ray_epoch 1 --total_epoch 2000 --gpus 0 --n_hidden 3 --lr 0.05 --lr_prefer 0.01 --seed 0 --solver_type epo

Example CIFAR10 Experiment

python main.py --dataset cifar10 --num_users 10 --target_usr 9 --total_hnet_epoch 10000 --total_ray_epoch 1000 --total_epoch 1 --seed 3 --local_bs 512 --lr 0.01 --lr_prefer 0.01 --solver_type linear --sample_ray --n_hidden 3 --embedding_dim 5 --input_dim 20 --output_dim 2 --hidden_dim 100  --gpus 0

Reference

[1] Tian Li, Maziar Sanjabi, Ahmad Beirami, and Virginia Smith. Fair resource allocation in federated learning.429arXiv preprint arXiv:1905.10497, 2019 [2] Mehryar Mohri, Gary Sivek, and Ananda Theertha Suresh. Agnostic federated learning. InInternational431Conference on Machine Learning, pages 4615–4625. PMLR, 2019.

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages