diff --git a/.github/workflows/train.yml b/.github/workflows/train.yml index 3d7e86f..706211e 100644 --- a/.github/workflows/train.yml +++ b/.github/workflows/train.yml @@ -1,11 +1,13 @@ name: Test Training Code with gRPC on: - workflow_dispatch: + # used for debugging purposes + # workflow_dispatch: push: branches: - # - main - - "*" + # run test on push to main only + - main + # - "*" pull_request: branches: - main @@ -16,6 +18,8 @@ env: jobs: train-check: runs-on: ubuntu-latest + env: + DEVICE: cpu steps: # Step 1: Checkout the code @@ -35,7 +39,7 @@ jobs: sudo apt install -y libopenmpi-dev openmpi-bin sudo apt-get install -y libgl1 libglib2.0-0 - pip install -r requirements.txt + pip install -r requirements_cpu.txt # Step 4: Run gRPC server and client - name: Run test diff --git a/.gitignore b/.gitignore index e03029f..5124753 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,7 @@ di_test/ imgs/ pascal/ data/ +!src/inversefed/data/ notes.txt removeme*.png diff --git a/requirements_cpu.txt b/requirements_cpu.txt new file mode 100644 index 0000000..928c7de --- /dev/null +++ b/requirements_cpu.txt @@ -0,0 +1,161 @@ +anyio==4.3.0 +argon2-cffi==23.1.0 +argon2-cffi-bindings==21.2.0 +arrow==1.3.0 +asttokens==2.4.1 +async-lru==2.0.4 +attrs==23.2.0 +Babel==2.15.0 +beautifulsoup4==4.12.3 +bleach==6.1.0 +certifi==2024.2.2 +cffi==1.16.0 +charset-normalizer==3.3.2 +click==8.1.7 +colorama==0.4.6 +comm==0.2.2 +contourpy==1.2.1 +cycler==0.12.1 +debugpy==1.8.1 +decorator==5.1.1 +defusedxml==0.7.1 +exceptiongroup==1.2.1 +executing==2.0.1 +fastjsonschema==2.19.1 +filelock==3.14.0 +fire==0.6.0 +fonttools==4.52.1 +fqdn==1.5.1 +fsspec==2024.5.0 +ghp-import==2.1.0 +grpcio==1.64.0 +grpcio-tools==1.64.0 +h11==0.14.0 +httpcore==1.0.5 +httpx==0.27.0 +idna==3.7 +imageio==2.34.1 +ipykernel==6.29.4 +ipython==8.24.0 +isoduration==20.11.0 +jedi==0.19.1 +Jinja2==3.1.4 +jmespath==1.0.1 +joblib==1.4.2 +json5==0.9.25 +jsonpointer==2.4 +jsonschema==4.22.0 +jsonschema-specifications==2023.12.1 +jupyter-events==0.10.0 +jupyter-lsp==2.2.5 +jupyter_client==8.6.2 +jupyter_core==5.7.2 +jupyter_server==2.14.0 +jupyter_server_terminals==0.5.3 +jupyterlab==4.2.5 +jupyterlab_pygments==0.3.0 +jupyterlab_server==2.27.2 +kiwisolver==1.4.5 +lazy_loader==0.4 +littleutils==0.2.2 +Markdown==3.7 +MarkupSafe==2.1.5 +matplotlib==3.9.0 +matplotlib-inline==0.1.7 +medmnist==3.0.1 +mergedeep==1.3.4 +mistune==3.0.2 +mkdocs==1.6.0 +mkdocs-get-deps==0.2.0 +mkdocs-material==9.5.31 +mkdocs-material-extensions==1.3.1 +mpi4py==3.1.6 +mpmath==1.3.0 +nbclient==0.10.0 +nbconvert==7.16.4 +nbformat==5.10.4 +nest-asyncio==1.6.0 +networkx==3.3 +notebook_shim==0.2.4 +numpy +# nvidia-cublas-cu12==12.1.3.1 +# nvidia-cuda-cupti-cu12==12.1.105 +# nvidia-cuda-nvrtc-cu12==12.1.105 +# nvidia-cuda-runtime-cu12==12.1.105 +# nvidia-cudnn-cu12==8.9.2.26 +# nvidia-cufft-cu12==11.0.2.54 +# nvidia-curand-cu12==10.3.2.106 +# nvidia-cusolver-cu12==11.4.5.107 +# nvidia-cusparse-cu12==12.1.0.106 +# nvidia-nccl-cu12==2.20.5 +# nvidia-nvjitlink-cu12==12.5.40 +# nvidia-nvtx-cu12==12.1.105 +ogb==1.3.6 +opencv-python==4.10.0.84 +outdated==0.2.2 +overrides==7.7.0 +packaging==24.0 +paginate==0.5.6 +pandas==2.2.2 +pandocfilters==1.5.1 +parso==0.8.4 +pathspec==0.12.1 +pexpect==4.9.0 +pillow==10.3.0 +platformdirs==4.2.2 +prometheus_client==0.20.0 +prompt-toolkit==3.0.43 +protobuf==5.26.1 +psutil==5.9.8 +ptyprocess==0.7.0 +pure-eval==0.2.2 +pycparser==2.22 +Pygments==2.18.0 +pymdown-extensions==10.9 +pyparsing==3.1.2 +python-dateutil==2.9.0.post0 +python-json-logger==2.0.7 +pytz==2024.1 +PyYAML==6.0.1 +pyyaml_env_tag==0.1 +pyzmq==26.0.3 +referencing==0.35.1 +regex==2024.7.24 +requests==2.32.2 +rfc3339-validator==0.1.4 +rfc3986-validator==0.1.1 +rpds-py==0.18.1 +scikit-image==0.23.2 +scikit-learn==1.5.0 +scipy==1.13.1 +Send2Trash==1.8.3 +six==1.16.0 +sniffio==1.3.1 +soupsieve==2.5 +stack-data==0.6.3 +sympy==1.12 +tensorboardX==2.6.2.2 +termcolor==2.4.0 +terminado==0.18.1 +threadpoolctl==3.5.0 +tifffile==2024.5.22 +tinycss2==1.3.0 +tomli==2.0.1 +torch @ https://download.pytorch.org/whl/cpu/torch-2.3.0%2Bcpu-cp310-cp310-linux_x86_64.whl +torchvision @ https://download.pytorch.org/whl/cpu/torchvision-0.18.0%2Bcpu-cp310-cp310-linux_x86_64.whl +tornado==6.4 +tqdm==4.66.4 +traitlets==5.14.3 +#triton +types-python-dateutil==2.9.0.20240316 +typing_extensions==4.12.0 +tzdata==2024.1 +uri-template==1.3.0 +urllib3==2.2.1 +Wand==0.6.13 +watchdog==4.0.2 +wcwidth==0.2.13 +webcolors==1.13 +webencodings==0.5.1 +websocket-client==1.8.0 +wilds==2.0.0 diff --git a/src/algos/fl.py b/src/algos/fl.py index 3670cd6..7d5e58a 100644 --- a/src/algos/fl.py +++ b/src/algos/fl.py @@ -242,7 +242,7 @@ def run_protocol(self): self.round_init() self.local_round_done() - self.single_round() + self.single_round(round) self.test() self.round_finalize() diff --git a/src/configs/algo_config.py b/src/configs/algo_config.py index f0f4976..a2242c6 100644 --- a/src/configs/algo_config.py +++ b/src/configs/algo_config.py @@ -31,7 +31,8 @@ def get_malicious_types(malicious_config_list: List[ConfigType]) -> Dict[str, st traditional_fl: ConfigType = { # Collaboration setup "algo": "fedavg", - "rounds": 5, + "rounds": 2, + # Model parameters "model": "resnet10", "model_lr": 3e-4, diff --git a/src/configs/sys_config_test.py b/src/configs/sys_config_test.py index f357541..4737d16 100644 --- a/src/configs/sys_config_test.py +++ b/src/configs/sys_config_test.py @@ -121,6 +121,14 @@ def get_algo_configs( "exp_keys": [], "dropout_dicts": dropout_dicts, "test_samples_per_user": 200, + "log_memory": True, + # "streaming_aggregation": True, # Make it true for fedstatic + "assign_based_on_host": True, + "hostname_to_device_ids": { + "matlaber1": [2, 3, 4, 5, 6, 7], + "matlaber12": [0, 1, 2, 3], + "matlaber3": [0, 1, 2, 3], + "matlaber4": [0, 2, 3, 4, 5, 6, 7], + } } - current_config = grpc_system_config \ No newline at end of file diff --git a/src/inversefed/data/README.md b/src/inversefed/data/README.md new file mode 100644 index 0000000..0cf811e --- /dev/null +++ b/src/inversefed/data/README.md @@ -0,0 +1,3 @@ +# Data Processing + +This module implements ```construct_dataloaders```. \ No newline at end of file diff --git a/src/inversefed/data/__init__.py b/src/inversefed/data/__init__.py index e69de29..be87c57 100644 --- a/src/inversefed/data/__init__.py +++ b/src/inversefed/data/__init__.py @@ -0,0 +1,6 @@ +"""Data stuff that I usually don't want to see.""" + +from .data_processing import construct_dataloaders + + +__all__ = ['construct_dataloaders'] diff --git a/src/inversefed/data/data.py b/src/inversefed/data/data.py index e69de29..d2d7cd6 100644 --- a/src/inversefed/data/data.py +++ b/src/inversefed/data/data.py @@ -0,0 +1,96 @@ +"""This is data.py from pytorch-examples. + +Refer to +https://github.com/pytorch/examples/blob/master/super_resolution/data.py. +""" + +from os.path import exists, join, basename +from os import makedirs, remove +from six.moves import urllib +import tarfile +from torchvision.transforms import Compose, CenterCrop, ToTensor, Resize, RandomCrop + + +from .datasets import DatasetFromFolder + +def _build_bsds_sr(data_path, augmentations=True, normalize=True, upscale_factor=3, RGB=True): + root_dir = _download_bsd300(dest=data_path) + train_dir = join(root_dir, "train") + crop_size = _calculate_valid_crop_size(256, upscale_factor) + print(f'Crop size is {crop_size}. Upscaling factor is {upscale_factor} in mode {RGB}.') + + trainset = DatasetFromFolder(train_dir, replicate=200, + input_transform=_input_transform(crop_size, upscale_factor), + target_transform=_target_transform(crop_size), RGB=RGB) + + test_dir = join(root_dir, "test") + validset = DatasetFromFolder(test_dir, replicate=200, + input_transform=_input_transform(crop_size, upscale_factor), + target_transform=_target_transform(crop_size), RGB=RGB) + return trainset, validset + +def _build_bsds_dn(data_path, augmentations=True, normalize=True, upscale_factor=1, noise_level=25 / 255, RGB=True): + root_dir = _download_bsd300(dest=data_path) + train_dir = join(root_dir, "train") + + crop_size = _calculate_valid_crop_size(256, upscale_factor) + patch_size = 64 + print(f'Crop size is {crop_size} for patches of size {patch_size}. ' + f'Upscaling factor is {upscale_factor} in mode RGB={RGB}.') + + trainset = DatasetFromFolder(train_dir, replicate=200, + input_transform=_input_transform(crop_size, upscale_factor, patch_size=patch_size), + target_transform=_target_transform(crop_size, patch_size=patch_size), + noise_level=noise_level, RGB=RGB) + + test_dir = join(root_dir, "test") + validset = DatasetFromFolder(test_dir, replicate=200, + input_transform=_input_transform(crop_size, upscale_factor), + target_transform=_target_transform(crop_size), + noise_level=noise_level, RGB=RGB) + return trainset, validset + + +def _download_bsd300(dest="dataset"): + output_image_dir = join(dest, "BSDS300/images") + + if not exists(output_image_dir): + makedirs(dest, exist_ok=True) + url = "http://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/BSDS300-images.tgz" + print("downloading url ", url) + + data = urllib.request.urlopen(url) + + file_path = join(dest, basename(url)) + with open(file_path, 'wb') as f: + f.write(data.read()) + + print("Extracting data") + with tarfile.open(file_path) as tar: + for item in tar: + tar.extract(item, dest) + + remove(file_path) + + return output_image_dir + + +def _calculate_valid_crop_size(crop_size, upscale_factor): + return crop_size - (crop_size % upscale_factor) + + +def _input_transform(crop_size, upscale_factor, patch_size=None): + return Compose([ + CenterCrop(crop_size), + Resize(crop_size // upscale_factor), + RandomCrop(patch_size if patch_size is not None else crop_size // upscale_factor), + ToTensor(), + ]) + + +def _target_transform(crop_size, patch_size=None): + return Compose([ + CenterCrop(crop_size), + RandomCrop(patch_size if patch_size is not None else crop_size), + ToTensor(), + ]) diff --git a/src/inversefed/data/datasets.py b/src/inversefed/data/datasets.py index e69de29..dd28cbf 100644 --- a/src/inversefed/data/datasets.py +++ b/src/inversefed/data/datasets.py @@ -0,0 +1,62 @@ +"""This is dataset.py from pytorch-examples. + +Refer to + +https://github.com/pytorch/examples/blob/master/super_resolution/dataset.py. +""" +import torch +import torch.utils.data as data + +from os import listdir +from os.path import join +from PIL import Image + + +def _is_image_file(filename): + return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg"]) + + +def _load_img(filepath, RGB=True): + img = Image.open(filepath) + if RGB: + pass + else: + img = img.convert('YCbCr') + img, _, _ = img.split() + return img + + +class DatasetFromFolder(data.Dataset): + """Generate an image-to-image dataset from images from the given folder.""" + + def __init__(self, image_dir, replicate=1, input_transform=None, target_transform=None, RGB=True, noise_level=0.0): + """Init with directory, transforms and RGB switch.""" + super(DatasetFromFolder, self).__init__() + self.image_filenames = [join(image_dir, x) for x in listdir(image_dir) if _is_image_file(x)] + + self.input_transform = input_transform + self.target_transform = target_transform + + self.replicate = replicate + self.classes = [None] + self.RGB = RGB + self.noise_level = noise_level + + def __getitem__(self, index): + """Index into dataset.""" + input = _load_img(self.image_filenames[index % len(self.image_filenames)], RGB=self.RGB) + target = input.copy() + if self.input_transform: + input = self.input_transform(input) + if self.target_transform: + target = self.target_transform(target) + + if self.noise_level > 0: + # Add noise + input += self.noise_level * torch.randn_like(input) + + return input, target + + def __len__(self): + """Length is amount of files found.""" + return len(self.image_filenames) * self.replicate diff --git a/src/scheduler.py b/src/scheduler.py index 55da449..36aa1ae 100644 --- a/src/scheduler.py +++ b/src/scheduler.py @@ -129,6 +129,7 @@ def initialize(self, copy_souce_code: bool = True) -> None: rank=self.communication.get_rank(), comm_utils=self.communication, ) + self.communication.send_quorum() def run_job(self) -> None: diff --git a/src/utils/communication/mpi.py b/src/utils/communication/mpi.py index e9b2004..0a9216c 100644 --- a/src/utils/communication/mpi.py +++ b/src/utils/communication/mpi.py @@ -153,7 +153,6 @@ def all_gather(self): """ items: List[Any] = [] for i in range(1, self.size): - print(f"receiving this data: {self.receive(i)}") items.append(self.receive(i)) return items