From 2e202c90191447f281adf275ed29403fc1c50015 Mon Sep 17 00:00:00 2001 From: salomaestro Date: Tue, 4 Feb 2025 20:09:03 +0100 Subject: [PATCH] Update docstrings for documentation (still lots to do) Added a bit more information to some of the docstrings, but there is still a lot of room for improvement ;) --- utils/dataloaders/datasources.py | 7 +++++ utils/dataloaders/mnist_0_3.py | 3 ++ utils/dataloaders/usps_0_6.py | 54 +++++++++++++++++++++++++++----- utils/load_data.py | 31 +++++++++++++++++- utils/load_metric.py | 42 +++++++++++++++++++++++-- utils/load_model.py | 38 ++++++++++++++++++++-- 6 files changed, 163 insertions(+), 12 deletions(-) diff --git a/utils/dataloaders/datasources.py b/utils/dataloaders/datasources.py index 5559031..f0d2e01 100644 --- a/utils/dataloaders/datasources.py +++ b/utils/dataloaders/datasources.py @@ -1,3 +1,10 @@ +"""This module contains the data sources for the datasets used in the experiments. + +The data sources are defined as dictionaries with the following keys +- train: A list containing the URL, filename, and MD5 hash of the training data. +- test: A list containing the URL, filename, and MD5 hash of the test data. +""" + USPS_SOURCE = { "train": [ "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.bz2", diff --git a/utils/dataloaders/mnist_0_3.py b/utils/dataloaders/mnist_0_3.py index 5e5a935..0a82909 100644 --- a/utils/dataloaders/mnist_0_3.py +++ b/utils/dataloaders/mnist_0_3.py @@ -10,6 +10,7 @@ class MNISTDataset0_3(Dataset): """ A custom dataset class for loading MNIST data, specifically for digits 0 through 3. + Parameters ---------- data_path : Path @@ -20,6 +21,7 @@ class MNISTDataset0_3(Dataset): A function/transform that takes in an image and returns a transformed version. Default is None. download : bool, optional If True, downloads the dataset if it is not already present in the specified data_path. Default is False. + Attributes ---------- data_path : Path @@ -40,6 +42,7 @@ class MNISTDataset0_3(Dataset): Indices of the labels that are less than 4. length : int The number of samples in the dataset. + Methods ------- _parse_labels(train) diff --git a/utils/dataloaders/usps_0_6.py b/utils/dataloaders/usps_0_6.py index 6e63d3d..3673fa9 100644 --- a/utils/dataloaders/usps_0_6.py +++ b/utils/dataloaders/usps_0_6.py @@ -26,7 +26,7 @@ class USPSDataset0_6(Dataset): Args ---- data_path : pathlib.Path - Path to the USPS dataset file. + Path to the data directory. train : bool, optional Mode of the dataset. transform : callable, optional @@ -60,18 +60,29 @@ class USPSDataset0_6(Dataset): Examples -------- + >>> from torchvision import transforms >>> from src.datahandlers import USPSDataset0_6 - >>> dataset = USPSDataset0_6(path="data/usps.h5", mode="train") + >>> transform = transforms.Compose([ + ... transforms.Resize((16, 16)), + ... transforms.ToTensor() + ... ]) + >>> dataset = USPSDataset0_6( + ... data_path="data", + ... transform=transform + ... download=True, + ... train=True, + ... ) >>> len(dataset) 5460 >>> data, target = dataset[0] >>> data.shape - (16, 16) + (1, 16, 16) >>> target - 6 + tensor([1., 0., 0., 0., 0., 0., 0.]) """ filename = "usps.h5" + num_classes = 7 def __init__( self, @@ -85,7 +96,6 @@ def __init__( path = data_path if isinstance(data_path, Path) else Path(data_path) self.filepath = path / self.filename self.transform = transform - self.num_classes = 7 # 0-6 self.mode = "train" if train else "test" # Download the dataset if it does not exist in a temporary directory @@ -116,7 +126,24 @@ def _dataset_ok(self): return True def download(self, url, filepath, checksum, mode): - """Download the USPS dataset.""" + """Download the USPS dataset, and save it as an HDF5 file. + + Args + ---- + url : str + URL to download the dataset from. + filepath : pathlib.Path + Path to save the downloaded dataset. + checksum : str + MD5 checksum of the downloaded file. + mode : str + Mode of the dataset, either train or test. + + Raises + ------ + ValueError + If the checksum of the downloaded file does not match the expected checksum. + """ def reporthook(blocknum, blocksize, totalsize): """Report download progress.""" @@ -164,7 +191,20 @@ def reporthook(blocknum, blocksize, totalsize): @staticmethod def check_integrity(filepath, checksum): - """Check the integrity of the USPS dataset file.""" + """Check the integrity of the USPS dataset file. + + Args + ---- + filepath : pathlib.Path + Path to the USPS dataset file. + checksum : str + MD5 checksum of the dataset file. + + Returns + ------- + bool + True if the checksum of the file matches the expected checksum, False otherwise + """ file_hash = hashlib.md5(filepath.read_bytes()).hexdigest() diff --git a/utils/load_data.py b/utils/load_data.py index d4d5795..9060013 100644 --- a/utils/load_data.py +++ b/utils/load_data.py @@ -4,6 +4,35 @@ def load_data(dataset: str, *args, **kwargs) -> Dataset: + """ + Load the dataset based on the dataset name. + + Args + ---- + dataset : str + Name of the dataset to load. + *args : list + Additional arguments for the dataset class. + **kwargs : dict + Additional keyword arguments for the dataset class. + + Returns + ------- + dataset : torch.utils.data.Dataset + Dataset object. + + Raises + ------ + NotImplementedError + If the dataset is not implemented. + + Examples + -------- + >>> from utils import load_data + >>> dataset = load_data("usps_0-6", data_path="data", train=True, download=True) + >>> len(dataset) + 5460 + """ match dataset.lower(): case "usps_0-6": return USPSDataset0_6(*args, **kwargs) @@ -12,4 +41,4 @@ def load_data(dataset: str, *args, **kwargs) -> Dataset: case "usps_7-9": return USPSH5_Digit_7_9_Dataset(*args, **kwargs) case _: - raise ValueError(f"Dataset: {dataset} not implemented.") + raise NotImplementedError(f"Dataset: {dataset} not implemented.") diff --git a/utils/load_metric.py b/utils/load_metric.py index 9c942d1..a8aacc0 100644 --- a/utils/load_metric.py +++ b/utils/load_metric.py @@ -7,6 +7,44 @@ class MetricWrapper(nn.Module): + """ + Wrapper class for metrics, that runs multiple metrics on the same data. + + Args + ---- + metrics : list[str] + List of metrics to run on the data. + + Attributes + ---------- + metrics : dict + Dictionary containing the metric functions. + tmp_scores : dict + Dictionary containing the temporary scores of the metrics. + + Methods + ------- + __call__(y_true, y_pred) + Call the metric functions on the true and predicted labels. + accumulate() + Get the average scores of the metrics. + reset() + Reset the temporary scores of the metrics. + + Examples + -------- + >>> from utils import MetricWrapper + >>> metrics = MetricWrapper("entropy", "f1", "precision") + >>> y_true = [0, 1, 0, 1] + >>> y_pred = [0, 1, 1, 0] + >>> metrics(y_true, y_pred) + >>> metrics.accumulate() + {'entropy': 0.6931471805599453, 'f1': 0.5, 'precision': 0.5} + >>> metrics.reset() + >>> metrics.accumulate() + {'entropy': [], 'f1': [], 'precision': []} + """ + def __init__(self, *metrics): super().__init__() self.metrics = {} @@ -49,13 +87,13 @@ def __call__(self, y_true, y_pred): for key in self.metrics: self.tmp_scores[key].append(self.metrics[key](y_true, y_pred)) - def __getmetrics__(self): + def accumulate(self): return_metrics = {} for key in self.metrics: return_metrics[key] = np.mean(self.tmp_scores[key]) return return_metrics - def __resetvalues__(self): + def reset(self): for key in self.tmp_scores: self.tmp_scores[key] = [] diff --git a/utils/load_model.py b/utils/load_model.py index b8f96e2..7d09e15 100644 --- a/utils/load_model.py +++ b/utils/load_model.py @@ -4,6 +4,37 @@ def load_model(modelname: str, *args, **kwargs) -> nn.Module: + """ + Load the model based on the model name. + + Args + ---- + modelname : str + Name of the model to load. + *args : list + Additional arguments for the model class. + **kwargs : dict + Additional keyword arguments for the model class. + + Returns + ------- + model : torch.nn.Module + Model object. + + Raises + ------ + NotImplementedError + If the model is not implemented. + + Examples + -------- + >>> from utils import load_model + >>> model = load_model("magnusmodel", num_classes=10) + >>> model + MagnusModel( + (fc1): Linear(in_features=784, out_features=100, bias=True) + (fc2): Linear(in_features=100, out_features=10, bias=True + """ match modelname.lower(): case "magnusmodel": return MagnusModel(*args, **kwargs) @@ -14,6 +45,9 @@ def load_model(modelname: str, *args, **kwargs) -> nn.Module: case "solveigmodel": return SolveigModel(*args, **kwargs) case _: - raise ValueError( - f"Model: {modelname} has not been implemented. \nCheck the documentation for implemented metrics, or check your spelling" + errmsg = ( + f"Model: {modelname} not implemented. " + "Check the documentation for implemented models, " + "or check your spelling." ) + raise NotImplementedError(errmsg)