From 8aae6bdc73498f5b278b5fb5cd073bc3820dd541 Mon Sep 17 00:00:00 2001 From: Sri Harsha Date: Fri, 12 Jul 2024 16:43:52 -0700 Subject: [PATCH 1/3] seperated data loaders classes --- src/data_loaders/cifar.py | 65 ++++++ src/data_loaders/domainnet.py | 101 +++++++++ src/data_loaders/medmnist.py | 106 +++++++++ src/data_loaders/mnist.py | 25 +++ src/data_loaders/wilds.py | 71 ++++++ src/utils/data_utils.py | 403 +++------------------------------- 6 files changed, 400 insertions(+), 371 deletions(-) create mode 100644 src/data_loaders/cifar.py create mode 100644 src/data_loaders/domainnet.py create mode 100644 src/data_loaders/medmnist.py create mode 100644 src/data_loaders/mnist.py create mode 100644 src/data_loaders/wilds.py diff --git a/src/data_loaders/cifar.py b/src/data_loaders/cifar.py new file mode 100644 index 0000000..ccd8f61 --- /dev/null +++ b/src/data_loaders/cifar.py @@ -0,0 +1,65 @@ +import numpy as np +import torch +import torchvision.transforms as T +from torchvision.datasets import CIFAR10 + + +class CIFAR10Dataset: + """ + CIFAR-10 Dataset Class. + """ + def __init__(self, dpath: str, rot_angle: int = 0) -> None: + self.image_size = 32 + self.NUM_CLS = 10 + self.mean = np.array((0.4914, 0.4822, 0.4465)) + self.std = np.array((0.2023, 0.1994, 0.2010)) + self.num_channels = 3 + + self.train_transform = T.Compose([ + T.RandomCrop(32, padding=4), + T.RandomHorizontalFlip(), + T.ToTensor(), + T.Normalize(self.mean, self.std), + ]) + self.test_transform = T.Compose([ + T.ToTensor(), + T.Normalize(self.mean, self.std), + ]) + + if rot_angle != 0: + self.train_transform.transforms.insert(1, T.RandomVerticalFlip()) + self.train_transform.transforms.append( + T.Lambda(lambda img: T.functional.rotate(img, rot_angle)) + ) + self.test_transform.transforms.append( + T.Lambda(lambda img: T.functional.rotate(img, rot_angle)) + ) + + self.train_dset = CIFAR10(root=dpath, train=True, download=True, transform=self.train_transform) + self.test_dset = CIFAR10(root=dpath, train=False, download=True, transform=self.test_transform) + self.image_bound_l = torch.tensor((-self.mean / self.std).reshape(1, -1, 1, 1)).float() + self.image_bound_u = torch.tensor(((1 - self.mean) / self.std).reshape(1, -1, 1, 1)).float() + + +class CIFAR10R90Dataset(CIFAR10Dataset): + """ + CIFAR-10 Dataset Class with 90 degrees rotation. + """ + def __init__(self, dpath: str) -> None: + super().__init__(dpath, rot_angle=90) + + +class CIFAR10R180Dataset(CIFAR10Dataset): + """ + CIFAR-10 Dataset Class with 180 degrees rotation. + """ + def __init__(self, dpath: str) -> None: + super().__init__(dpath, rot_angle=180) + + +class CIFAR10R270Dataset(CIFAR10Dataset): + """ + CIFAR-10 Dataset Class with 270 degrees rotation. + """ + def __init__(self, dpath: str) -> None: + super().__init__(dpath, rot_angle=270) diff --git a/src/data_loaders/domainnet.py b/src/data_loaders/domainnet.py new file mode 100644 index 0000000..e9d0a87 --- /dev/null +++ b/src/data_loaders/domainnet.py @@ -0,0 +1,101 @@ +import os +import numpy as np +from PIL import Image +import torchvision.transforms as T + + +def read_domainnet_data(dataset_path: str, domain_name: str, split: str = "train", labels_to_keep=None): + """ + Reads DomainNet data. + """ + data_paths = [] + data_labels = [] + split_file = os.path.join(dataset_path, "splits", f"{domain_name}_{split}.txt") + + with open(split_file, "r", encoding="utf-8") as f: + lines = f.readlines() + for line in lines: + line = line.strip() + data_path, label = line.split(" ") + label_name = data_path.split("/")[1] + if labels_to_keep is None or label_name in labels_to_keep: + data_path = os.path.join(dataset_path, data_path) + if labels_to_keep is not None: + label = labels_to_keep.index(label_name) + else: + label = int(label) + data_paths.append(data_path) + data_labels.append(label) + + return data_paths, data_labels + + +class DomainNet: + """ + DomainNet Dataset Class. + """ + def __init__(self, data_paths, data_labels, transforms, domain_name, cache=False): + self.data_paths = data_paths + self.data_labels = data_labels + self.transforms = transforms + self.domain_name = domain_name + self.cached_data = [] + + if cache: + for idx, _ in enumerate(data_paths): + self.cached_data.append(self.__read_data__(idx)) + + def __read_data__(self, index): + img = Image.open(self.data_paths[index]) + if img.mode != "RGB": + img = img.convert("RGB") + label = self.data_labels[index] + img = T.ToTensor()(img) + return img, label + + def __getitem__(self, index): + if self.cached_data: + img, label = self.cached_data[index] + else: + img, label = self.__read_data__(index) + img = self.transforms(img) + return img, label + + def __len__(self): + return len(self.data_paths) + + +class DomainNetDataset: + """ + DomainNet Dataset Class. + """ + def __init__(self, dpath: str, domain_name: str) -> None: + self.image_size = 32 + self.crop_scale = 0.75 + self.image_resize = int(np.ceil(self.image_size / self.crop_scale)) + + labels_to_keep = [ + "suitcase", "teapot", "pillow", "streetlight", "table", + "bathtub", "wine_glass", "vase", "umbrella", "bench" + ] + self.num_cls = len(labels_to_keep) + self.num_channels = 3 + + train_transform = T.Compose([ + T.Resize((self.image_resize, self.image_resize), antialias=True), + ]) + test_transform = T.Compose([ + T.Resize((self.image_size, self.image_size), antialias=True), + ]) + train_data_paths, train_data_labels = read_domainnet_data( + dpath, domain_name, split="train", labels_to_keep=labels_to_keep + ) + test_data_paths, test_data_labels = read_domainnet_data( + dpath, domain_name, split="test", labels_to_keep=labels_to_keep + ) + self.train_dset = DomainNet( + train_data_paths, train_data_labels, train_transform, domain_name + ) + self.test_dset = DomainNet( + test_data_paths, test_data_labels, test_transform, domain_name + ) diff --git a/src/data_loaders/medmnist.py b/src/data_loaders/medmnist.py new file mode 100644 index 0000000..0f0d18f --- /dev/null +++ b/src/data_loaders/medmnist.py @@ -0,0 +1,106 @@ +import os +import medmnist +import numpy as np +import torchvision.transforms as T + + +class MEDMNISTDataset: + """ + MEDMNIST Dataset Class. + """ + def __init__(self, dpath: str, data_flag: str) -> None: + self.mean = np.array([0.5]) + self.std = np.array([0.5]) + info = medmnist.INFO[data_flag] + self.num_channels = info["n_channels"] + self.data_class = getattr(medmnist, info["python_class"]) + + self.transform = T.Compose([ + T.ToTensor(), + T.Normalize(self.mean, self.std) + ]) + + if not os.path.exists(dpath): + os.makedirs(dpath) + + def target_transform(x): + return x[0] + + self.train_dset = self.data_class( + root=dpath, split="train", transform=self.transform, + target_transform=target_transform, download=True + ) + self.test_dset = self.data_class( + root=dpath, split="test", transform=self.transform, + target_transform=target_transform, download=True + ) + + +class PathMNISTDataset(MEDMNISTDataset): + """ + PathMNIST Dataset Class. + """ + def __init__(self, dpath: str) -> None: + super().__init__(dpath, "pathmnist") + self.image_size = 28 + self.num_cls = 9 + + +class DermaMNISTDataset(MEDMNISTDataset): + """ + DermaMNIST Dataset Class. + """ + def __init__(self, dpath: str) -> None: + super().__init__(dpath, "dermamnist") + self.image_size = 28 + self.num_cls = 7 + + +class BloodMNISTDataset(MEDMNISTDataset): + """ + BloodMNIST Dataset Class. + """ + def __init__(self, dpath: str) -> None: + super().__init__(dpath, "bloodmnist") + self.image_size = 28 + self.num_cls = 8 + + +class TissueMNISTDataset(MEDMNISTDataset): + """ + TissueMNIST Dataset Class. + """ + def __init__(self, dpath: str) -> None: + super().__init__(dpath, "tissuemnist") + self.image_size = 28 + self.num_cls = 8 + + +class OrganAMNISTDataset(MEDMNISTDataset): + """ + OrganAMNIST Dataset Class. + """ + def __init__(self, dpath: str) -> None: + super().__init__(dpath, "organamnist") + self.image_size = 28 + self.num_cls = 11 + + +class OrganCMNISTDataset(MEDMNISTDataset): + """ + OrganCMNIST Dataset Class. + """ + def __init__(self, dpath: str) -> None: + super().__init__(dpath, "organcmnist") + self.image_size = 28 + self.num_cls = 11 + + +class OrganSMNISTDataset(MEDMNISTDataset): + """ + OrganSMNIST Dataset Class. + """ + def __init__(self, dpath: str) -> None: + super().__init__(dpath, "organsmnist") + self.image_size = 28 + self.num_cls = 11 diff --git a/src/data_loaders/mnist.py b/src/data_loaders/mnist.py new file mode 100644 index 0000000..082f11d --- /dev/null +++ b/src/data_loaders/mnist.py @@ -0,0 +1,25 @@ +import torchvision.transforms as T +from torchvision.datasets import MNIST + + +class MNISTDataset: + """ + MNIST Dataset Class. + """ + def __init__(self, dpath: str) -> None: + self.image_size = 28 + self.num_cls = 10 + self.mean = 0.1307 + self.std = 0.3081 + self.num_channels = 1 + + self.train_transform = T.Compose([ + T.ToTensor(), + T.Normalize(self.mean, self.std), + ]) + self.test_transform = T.Compose([ + T.ToTensor(), + T.Normalize(self.mean, self.std), + ]) + self.train_dset = MNIST(root=dpath, train=True, download=True, transform=self.train_transform) + self.test_dset = MNIST(root=dpath, train=False, download=True, transform=self.test_transform) diff --git a/src/data_loaders/wilds.py b/src/data_loaders/wilds.py new file mode 100644 index 0000000..ab470fc --- /dev/null +++ b/src/data_loaders/wilds.py @@ -0,0 +1,71 @@ +import numpy as np +import wilds +import torchvision.transforms as T +from wilds.datasets.wilds_dataset import WILDSSubset +from utils.data_utils import CacheDataset + + +WILDS_DOMAINS_DICT = { + "iwildcam": "location", + "camelyon17": "hospital", + "rxrx1": "experiment", + "fmow": "region", +} + + +class WildsDset: + """ + WILDS Dataset Class. + """ + def __init__(self, dset, transform=None): + self.dset = dset + self.transform = transform + self.targets = [t.item() for t in list(dset.y_array)] + + def __getitem__(self, index): + img, label, _ = self.dset[index] + if self.transform is not None: + img = self.transform(img) + return img, label.item() + + def __len__(self): + return len(self.dset) + + +class WildsDataset: + """ + WILDS Dataset Class. + """ + def __init__(self, dset_name: str, dpath: str, domain: int) -> None: + dset = wilds.get_dataset(dset_name, download=False, root_dir=dpath) + self.num_cls = len(list(np.unique(dset.y_array))) + + domain_key = WILDS_DOMAINS_DICT[dset_name] + (idx,) = np.where( + (dset.metadata_array[:, dset.metadata_fields.index(domain_key)].numpy() == domain) & + (dset.split_array == 0) + ) + + self.mean = np.array((0.4914, 0.4822, 0.4465)) + self.std = np.array((0.2023, 0.1994, 0.2010)) + self.num_channels = 3 + + train_transform = T.Compose([ + T.RandomResizedCrop(32), + T.RandomHorizontalFlip(), + T.ToTensor(), + T.Normalize(self.mean, self.std), + ]) + test_transform = T.Compose([ + T.Resize(32), + T.ToTensor(), + T.Normalize(self.mean, self.std), + ]) + + num_samples_domain = len(idx) + train_samples = int(num_samples_domain * 0.8) + idx = np.random.permutation(idx) + train_dset = WILDSSubset(dset, idx[:train_samples], transform=None) + test_dset = WILDSSubset(dset, idx[train_samples:], transform=None) + self.train_dset = WildsDset(train_dset, transform=train_transform) + self.test_dset = CacheDataset(WildsDset(test_dset, transform=test_transform)) diff --git a/src/utils/data_utils.py b/src/utils/data_utils.py index b167ae8..1ef7aab 100644 --- a/src/utils/data_utils.py +++ b/src/utils/data_utils.py @@ -1,200 +1,10 @@ -import os +import importlib import numpy as np import torch import torchvision.transforms as T from torch.utils.data import Subset -from torchvision.datasets import CIFAR10, MNIST -from PIL import Image -import medmnist -import wilds -from wilds.datasets.wilds_dataset import WILDSSubset -class CIFAR10Dataset: - """ - CIFAR-10 Dataset Class. - """ - def __init__(self, dpath: str, rot_angle: int = 0) -> None: - self.image_size = 32 - self.NUM_CLS = 10 - self.mean = np.array((0.4914, 0.4822, 0.4465)) - self.std = np.array((0.2023, 0.1994, 0.2010)) - self.num_channels = 3 - - self.train_transform = T.Compose([ - T.RandomCrop(32, padding=4), - T.RandomHorizontalFlip(), - T.ToTensor(), - T.Normalize(self.mean, self.std), - ]) - self.test_transform = T.Compose([ - T.ToTensor(), - T.Normalize(self.mean, self.std), - ]) - - if rot_angle != 0: - self.train_transform.transforms.insert(1, T.RandomVerticalFlip()) - self.train_transform.transforms.append( - T.Lambda(lambda img: T.functional.rotate(img, rot_angle)) - ) - self.test_transform.transforms.append( - T.Lambda(lambda img: T.functional.rotate(img, rot_angle)) - ) - - self.train_dset = CIFAR10(root=dpath, train=True, download=True, transform=self.train_transform) - self.test_dset = CIFAR10(root=dpath, train=False, download=True, transform=self.test_transform) - self.image_bound_l = torch.tensor((-self.mean / self.std).reshape(1, -1, 1, 1)).float() - self.image_bound_u = torch.tensor(((1 - self.mean) / self.std).reshape(1, -1, 1, 1)).float() - - -class CIFAR10R90Dataset(CIFAR10Dataset): - """ - CIFAR-10 Dataset Class with 90 degrees rotation. - """ - def __init__(self, dpath: str) -> None: - super().__init__(dpath, rot_angle=90) - - -class CIFAR10R180Dataset(CIFAR10Dataset): - """ - CIFAR-10 Dataset Class with 180 degrees rotation. - """ - def __init__(self, dpath: str) -> None: - super().__init__(dpath, rot_angle=180) - - -class CIFAR10R270Dataset(CIFAR10Dataset): - """ - CIFAR-10 Dataset Class with 270 degrees rotation. - """ - def __init__(self, dpath: str) -> None: - super().__init__(dpath, rot_angle=270) - - -class MNISTDataset: - """ - MNIST Dataset Class. - """ - def __init__(self, dpath: str) -> None: - self.image_size = 28 - self.num_cls = 10 - self.mean = 0.1307 - self.std = 0.3081 - self.num_channels = 1 - - self.train_transform = T.Compose([ - T.ToTensor(), - T.Normalize(self.mean, self.std), - ]) - self.test_transform = T.Compose([ - T.ToTensor(), - T.Normalize(self.mean, self.std), - ]) - self.train_dset = MNIST(root=dpath, train=True, download=True, transform=self.train_transform) - self.test_dset = MNIST(root=dpath, train=False, download=True, transform=self.test_transform) - - -class MEDMNISTDataset: - """ - MEDMNIST Dataset Class. - """ - def __init__(self, dpath: str, data_flag: str) -> None: - self.mean = np.array([0.5]) - self.std = np.array([0.5]) - info = medmnist.INFO[data_flag] - self.num_channels = info["n_channels"] - self.data_class = getattr(medmnist, info["python_class"]) - - self.transform = T.Compose([ - T.ToTensor(), - T.Normalize(self.mean, self.std) - ]) - - if not os.path.exists(dpath): - os.makedirs(dpath) - - def target_transform(x): - return x[0] - - self.train_dset = self.data_class( - root=dpath, split="train", transform=self.transform, - target_transform=target_transform, download=True - ) - self.test_dset = self.data_class( - root=dpath, split="test", transform=self.transform, - target_transform=target_transform, download=True - ) - - -class PathMNISTDataset(MEDMNISTDataset): - """ - PathMNIST Dataset Class. - """ - def __init__(self, dpath: str) -> None: - super().__init__(dpath, "pathmnist") - self.image_size = 28 - self.num_cls = 9 - - -class DermaMNISTDataset(MEDMNISTDataset): - """ - DermaMNIST Dataset Class. - """ - def __init__(self, dpath: str) -> None: - super().__init__(dpath, "dermamnist") - self.image_size = 28 - self.num_cls = 7 - - -class BloodMNISTDataset(MEDMNISTDataset): - """ - BloodMNIST Dataset Class. - """ - def __init__(self, dpath: str) -> None: - super().__init__(dpath, "bloodmnist") - self.image_size = 28 - self.num_cls = 8 - - -class TissueMNISTDataset(MEDMNISTDataset): - """ - TissueMNIST Dataset Class. - """ - def __init__(self, dpath: str) -> None: - super().__init__(dpath, "tissuemnist") - self.image_size = 28 - self.num_cls = 8 - - -class OrganAMNISTDataset(MEDMNISTDataset): - """ - OrganAMNIST Dataset Class. - """ - def __init__(self, dpath: str) -> None: - super().__init__(dpath, "organamnist") - self.image_size = 28 - self.num_cls = 11 - - -class OrganCMNISTDataset(MEDMNISTDataset): - """ - OrganCMNIST Dataset Class. - """ - def __init__(self, dpath: str) -> None: - super().__init__(dpath, "organcmnist") - self.image_size = 28 - self.num_cls = 11 - - -class OrganSMNISTDataset(MEDMNISTDataset): - """ - OrganSMNIST Dataset Class. - """ - def __init__(self, dpath: str) -> None: - super().__init__(dpath, "organsmnist") - self.image_size = 28 - self.num_cls = 11 - class CacheDataset: """ @@ -230,197 +40,48 @@ def __len__(self): return len(self.dset) -def read_domainnet_data(dataset_path: str, domain_name: str, split: str = "train", labels_to_keep=None): - """ - Reads DomainNet data. - """ - data_paths = [] - data_labels = [] - split_file = os.path.join(dataset_path, "splits", f"{domain_name}_{split}.txt") - - with open(split_file, "r", encoding="utf-8") as f: - lines = f.readlines() - for line in lines: - line = line.strip() - data_path, label = line.split(" ") - label_name = data_path.split("/")[1] - if labels_to_keep is None or label_name in labels_to_keep: - data_path = os.path.join(dataset_path, data_path) - if labels_to_keep is not None: - label = labels_to_keep.index(label_name) - else: - label = int(label) - data_paths.append(data_path) - data_labels.append(label) - - return data_paths, data_labels - - -class DomainNet: - """ - DomainNet Dataset Class. - """ - def __init__(self, data_paths, data_labels, transforms, domain_name, cache=False): - self.data_paths = data_paths - self.data_labels = data_labels - self.transforms = transforms - self.domain_name = domain_name - self.cached_data = [] - - if cache: - for idx, _ in enumerate(data_paths): - self.cached_data.append(self.__read_data__(idx)) - - def __read_data__(self, index): - img = Image.open(self.data_paths[index]) - if img.mode != "RGB": - img = img.convert("RGB") - label = self.data_labels[index] - img = T.ToTensor()(img) - return img, label - - def __getitem__(self, index): - if self.cached_data: - img, label = self.cached_data[index] - else: - img, label = self.__read_data__(index) - img = self.transforms(img) - return img, label - - def __len__(self): - return len(self.data_paths) - - -class DomainNetDataset: - """ - DomainNet Dataset Class. - """ - def __init__(self, dpath: str, domain_name: str) -> None: - self.image_size = 32 - self.crop_scale = 0.75 - self.image_resize = int(np.ceil(self.image_size / self.crop_scale)) - - labels_to_keep = [ - "suitcase", "teapot", "pillow", "streetlight", "table", - "bathtub", "wine_glass", "vase", "umbrella", "bench" - ] - self.num_cls = len(labels_to_keep) - self.num_channels = 3 - - train_transform = T.Compose([ - T.Resize((self.image_resize, self.image_resize), antialias=True), - ]) - test_transform = T.Compose([ - T.Resize((self.image_size, self.image_size), antialias=True), - ]) - train_data_paths, train_data_labels = read_domainnet_data( - dpath, domain_name, split="train", labels_to_keep=labels_to_keep - ) - test_data_paths, test_data_labels = read_domainnet_data( - dpath, domain_name, split="test", labels_to_keep=labels_to_keep - ) - self.train_dset = DomainNet( - train_data_paths, train_data_labels, train_transform, domain_name - ) - self.test_dset = DomainNet( - test_data_paths, test_data_labels, test_transform, domain_name - ) - - -WILDS_DOMAINS_DICT = { - "iwildcam": "location", - "camelyon17": "hospital", - "rxrx1": "experiment", - "fmow": "region", -} - - -class WildsDset: - """ - WILDS Dataset Class. - """ - def __init__(self, dset, transform=None): - self.dset = dset - self.transform = transform - self.targets = [t.item() for t in list(dset.y_array)] - - def __getitem__(self, index): - img, label, _ = self.dset[index] - if self.transform is not None: - img = self.transform(img) - return img, label.item() - - def __len__(self): - return len(self.dset) - - -class WildsDataset: - """ - WILDS Dataset Class. - """ - def __init__(self, dset_name: str, dpath: str, domain: int) -> None: - dset = wilds.get_dataset(dset_name, download=False, root_dir=dpath) - self.num_cls = len(list(np.unique(dset.y_array))) - - domain_key = WILDS_DOMAINS_DICT[dset_name] - (idx,) = np.where( - (dset.metadata_array[:, dset.metadata_fields.index(domain_key)].numpy() == domain) & - (dset.split_array == 0) - ) - - self.mean = np.array((0.4914, 0.4822, 0.4465)) - self.std = np.array((0.2023, 0.1994, 0.2010)) - self.num_channels = 3 - - train_transform = T.Compose([ - T.RandomResizedCrop(32), - T.RandomHorizontalFlip(), - T.ToTensor(), - T.Normalize(self.mean, self.std), - ]) - test_transform = T.Compose([ - T.Resize(32), - T.ToTensor(), - T.Normalize(self.mean, self.std), - ]) - - num_samples_domain = len(idx) - train_samples = int(num_samples_domain * 0.8) - idx = np.random.permutation(idx) - train_dset = WILDSSubset(dset, idx[:train_samples], transform=None) - test_dset = WILDSSubset(dset, idx[train_samples:], transform=None) - self.train_dset = WildsDset(train_dset, transform=train_transform) - self.test_dset = CacheDataset(WildsDset(test_dset, transform=test_transform)) - - def get_dataset(dname: str, dpath: str): """ Returns the appropriate dataset class based on the dataset name. """ dset_mapping = { - "cifar10": CIFAR10Dataset, - "cifar10_r0": CIFAR10Dataset, - "cifar10_r90": CIFAR10R90Dataset, - "cifar10_r180": CIFAR10R180Dataset, - "cifar10_r270": CIFAR10R270Dataset, - "mnist": MNISTDataset, - "pathmnist": PathMNISTDataset, - "dermamnist": DermaMNISTDataset, - "bloodmnist": BloodMNISTDataset, - "tissuemnist": TissueMNISTDataset, - "organamnist": OrganAMNISTDataset, - "organcmnist": OrganCMNISTDataset, - "organsmnist": OrganSMNISTDataset, + "cifar10": ("data_loaders.cifar", "CIFAR10Dataset"), + "cifar10_r0": ("data_loaders.cifar", "CIFAR10Dataset"), + "cifar10_r90": ("data_loaders.cifar", "CIFAR10R90Dataset"), + "cifar10_r180": ("data_loaders.cifar", "CIFAR10R180Dataset"), + "cifar10_r270": ("data_loaders.cifar", "CIFAR10R270Dataset"), + "mnist": ("data_loaders.mnist", "MNISTDataset"), + "pathmnist": ("data_loaders.medmnist", "PathMNISTDataset"), + "dermamnist": ("data_loaders.medmnist", "DermaMNISTDataset"), + "bloodmnist": ("data_loaders.medmnist", "BloodMNISTDataset"), + "tissuemnist": ("data_loaders.medmnist", "TissueMNISTDataset"), + "organamnist": ("data_loaders.medmnist", "OrganAMNISTDataset"), + "organcmnist": ("data_loaders.medmnist", "OrganCMNISTDataset"), + "organsmnist": ("data_loaders.medmnist", "OrganSMNISTDataset"), + "domainnet": ("data_loaders.domainnet", "DomainNetDataset"), + "wilds": ("data_loaders.wilds", "WildsDataset"), } + if dname not in dset_mapping: + raise ValueError(f"Unknown dataset name: {dname}") + if dname.startswith("wilds"): dname_parts = dname.split("_") - return WildsDataset(dname_parts[1], dpath, int(dname_parts[2])) + module_path, class_name = dset_mapping["wilds"] + module = importlib.import_module(module_path) + dataset_class = getattr(module, class_name) + return dataset_class(dname_parts[1], dpath, int(dname_parts[2])) elif dname.startswith("domainnet"): dname_parts = dname.split("_") - return DomainNetDataset(dpath, dname_parts[1]) + module_path, class_name = dset_mapping["domainnet"] + module = importlib.import_module(module_path) + dataset_class = getattr(module, class_name) + return dataset_class(dpath, dname_parts[1]) else: - return dset_mapping[dname](dpath) + module_path, class_name = dset_mapping[dname] + module = importlib.import_module(module_path) + dataset_class = getattr(module, class_name) + return dataset_class(dpath) def filter_by_class(dataset, classes): @@ -448,7 +109,7 @@ def extr_noniid(train_dataset, samples_per_client, classes): return Subset(all_data, perm[:samples_per_client]) -def cifar_extr_noniid( +def cifar_extr_noniid( train_dataset, test_dataset, num_users, n_class, num_samples, rate_unbalance ): """ From 7276723e8a1605da18cde45002980e0b87a392bc Mon Sep 17 00:00:00 2001 From: Sri Harsha Date: Wed, 17 Jul 2024 12:38:41 -0700 Subject: [PATCH 2/3] Added dockerfile; updated requirements; updated readme --- Dockerfile | 63 ++++++++++++++++++++++++++++++++++++++++++++++++ docker_run.sh | 1 + requirements.txt | 25 +++++++++---------- src/README.md | 42 ++++++++++++++++++++++++++++++++ 4 files changed, 118 insertions(+), 13 deletions(-) create mode 100644 Dockerfile create mode 100644 docker_run.sh diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..0762d31 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,63 @@ +# Use the latest LTS (Long-Term Support) version of Ubuntu +FROM ubuntu:latest + +# Set working directory to /sonar +WORKDIR /sonar + +RUN apt-get update && \ + apt-get install -y software-properties-common && \ + add-apt-repository ppa:deadsnakes/ppa + +# Install base utilities +#RUN apt-get install -y \ +# build-essential \ +# wget \ +# python3.11 \ +# python3-pip \ +# && apt-get clean \ +# && rm -rf /var/lib/apt/lists/* + +# Copy requirements.txt to the working directory +#COPY requirements.txt . + +# Install Python development tools (needed for creating virtual environments) +#RUN apt-get update && apt-get install -y python3-venv + +#RUN apt-get update && apt-get install -y openmpi-bin + +#RUN apt-get update && apt-get install -y libopenmpi-dev + +# Create a virtual environment +#RUN python3 -m venv env + +# Activate the virtual environment (source this command in subsequent RUN steps) +#RUN . env/bin/activate + +# Install dependencies from requirements.txt within the virtual environment +#RUN pip3 install -r requirements.txt +#Install base utilities and Python with required dependencies +RUN apt-get update && \ + apt-get install -y software-properties-common && \ + add-apt-repository ppa:deadsnakes/ppa && \ + apt-get update && \ + apt-get install -y \ + build-essential \ + wget \ + python3.11 \ + python3-pip \ + python3-venv \ + openmpi-bin \ + libopenmpi-dev \ + && apt-get clean \ + && rm -rf /var/lib/apt/lists/* + +# Copy requirements.txt to the working directory +COPY requirements.txt . + +# Create a virtual environment and install dependencies +RUN python3 -m venv env && \ + . env/bin/activate && \ + env/bin/pip install -r requirements.txt + +# To activate the virtual environment by default, use ENTRYPOINT +ENTRYPOINT ["/bin/bash", "-c", "source /sonar/env/bin/activate && exec \"$@\"", "--"] \ No newline at end of file diff --git a/docker_run.sh b/docker_run.sh new file mode 100644 index 0000000..0a04409 --- /dev/null +++ b/docker_run.sh @@ -0,0 +1 @@ +docker run -it --rm -v $(pwd)/src:/sonar/src sonar_image /bin/bash -c "cd /sonar/src && mpirun --allow-run-as-root -np 4 -host localhost:11 ../env/bin/python main.py" \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 7ed2977..b4866e0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,18 +23,18 @@ mpi4py==3.1.6 mpmath==1.3.0 networkx==3.3 numpy==1.26.4 -nvidia-cublas-cu12==12.1.3.1 -nvidia-cuda-cupti-cu12==12.1.105 -nvidia-cuda-nvrtc-cu12==12.1.105 -nvidia-cuda-runtime-cu12==12.1.105 -nvidia-cudnn-cu12==8.9.2.26 -nvidia-cufft-cu12==11.0.2.54 -nvidia-curand-cu12==10.3.2.106 -nvidia-cusolver-cu12==11.4.5.107 -nvidia-cusparse-cu12==12.1.0.106 -nvidia-nccl-cu12==2.20.5 -nvidia-nvjitlink-cu12==12.5.40 -nvidia-nvtx-cu12==12.1.105 +nvidia-cublas-cu12 +nvidia-cuda-cupti-cu12>=12.1.105 +nvidia-cuda-nvrtc-cu12>=12.1.105 +nvidia-cuda-runtime-cu12>=12.1.105 +nvidia-cudnn-cu12>=8.9.2.26 +nvidia-cufft-cu12>=11.0.2.54 +nvidia-curand-cu12>=10.3.2.106 +nvidia-cusolver-cu12>=11.4.5.107 +nvidia-cusparse-cu12>=12.1.0.106 +nvidia-nccl-cu12>=2.20.5 +nvidia-nvjitlink-cu12>=12.5.40 +nvidia-nvtx-cu12>=12.1.105 ogb==1.3.6 outdated==0.2.2 packaging==24.0 @@ -57,7 +57,6 @@ tifffile==2024.5.22 torch==2.3.0 torchvision==0.18.0 tqdm==4.66.4 -triton==2.3.0 typing_extensions==4.12.0 tzdata==2024.1 urllib3==2.2.1 diff --git a/src/README.md b/src/README.md index 29f6ad0..5e1055e 100644 --- a/src/README.md +++ b/src/README.md @@ -1,3 +1,45 @@ +## Setting Up the Project with Docker + +### Prerequisites + +Ensure you have Docker installed on your machine. You can download it from [Docker's official website](https://www.docker.com/get-started). + +### Building the Docker Image + +1. Clone the repository: + ```sh + git clone https://github.com/aidecentralized/sonar + cd sonar + ``` + +2. Build the Docker image: + ```sh + docker build -t sonar_image:latest . + ``` + +### Running the Container + +We provide a `docker_run.sh` script to simplify running the Docker container. + +1. Ensure the script has execution permissions: + ```sh + chmod +x docker_run.sh + ``` + +2. Run the script: +- Using './': + ```sh + ./docker_run.sh + ``` +- Using 'bash': + ```sh + bash docker_run.sh + ``` + +The `docker_run.sh` script will handle the necessary Docker commands to start the container with the appropriate settings. + +Note: If you have used a different Image name while building Docker Image, Please update the name accordingly in `docker_run.sh` script + ### Running the code Let's say you want to run the model training of 3 nodes on a machine. That means there will be 4 nodes in total because there is 1 more node in addition to the clients --- server. The whole point of this project is to eventually transition to a distributed system where each node can be a separate machine and a server is simply another node. But for now, this is how things are done. From 9d5d4f5381d7d4988f0ecd10703a67d62ce2002c Mon Sep 17 00:00:00 2001 From: Sri Harsha Date: Thu, 18 Jul 2024 10:17:35 -0700 Subject: [PATCH 3/3] cleaned Dockerfile --- Dockerfile | 31 ------------------------------- 1 file changed, 31 deletions(-) diff --git a/Dockerfile b/Dockerfile index 0762d31..025ba57 100644 --- a/Dockerfile +++ b/Dockerfile @@ -4,37 +4,6 @@ FROM ubuntu:latest # Set working directory to /sonar WORKDIR /sonar -RUN apt-get update && \ - apt-get install -y software-properties-common && \ - add-apt-repository ppa:deadsnakes/ppa - -# Install base utilities -#RUN apt-get install -y \ -# build-essential \ -# wget \ -# python3.11 \ -# python3-pip \ -# && apt-get clean \ -# && rm -rf /var/lib/apt/lists/* - -# Copy requirements.txt to the working directory -#COPY requirements.txt . - -# Install Python development tools (needed for creating virtual environments) -#RUN apt-get update && apt-get install -y python3-venv - -#RUN apt-get update && apt-get install -y openmpi-bin - -#RUN apt-get update && apt-get install -y libopenmpi-dev - -# Create a virtual environment -#RUN python3 -m venv env - -# Activate the virtual environment (source this command in subsequent RUN steps) -#RUN . env/bin/activate - -# Install dependencies from requirements.txt within the virtual environment -#RUN pip3 install -r requirements.txt #Install base utilities and Python with required dependencies RUN apt-get update && \ apt-get install -y software-properties-common && \