Skip to content

Commit

Permalink
Update docstrings for documentation (still lots to do)
Browse files Browse the repository at this point in the history
Added a bit more information to some of the docstrings, but there is
still a lot of room for improvement ;)
  • Loading branch information
salomaestro committed Feb 4, 2025
1 parent d6128d7 commit 2e202c9
Show file tree
Hide file tree
Showing 6 changed files with 163 additions and 12 deletions.
7 changes: 7 additions & 0 deletions utils/dataloaders/datasources.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
"""This module contains the data sources for the datasets used in the experiments.
The data sources are defined as dictionaries with the following keys
- train: A list containing the URL, filename, and MD5 hash of the training data.
- test: A list containing the URL, filename, and MD5 hash of the test data.
"""

USPS_SOURCE = {
"train": [
"https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.bz2",
Expand Down
3 changes: 3 additions & 0 deletions utils/dataloaders/mnist_0_3.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
class MNISTDataset0_3(Dataset):
"""
A custom dataset class for loading MNIST data, specifically for digits 0 through 3.
Parameters
----------
data_path : Path
Expand All @@ -20,6 +21,7 @@ class MNISTDataset0_3(Dataset):
A function/transform that takes in an image and returns a transformed version. Default is None.
download : bool, optional
If True, downloads the dataset if it is not already present in the specified data_path. Default is False.
Attributes
----------
data_path : Path
Expand All @@ -40,6 +42,7 @@ class MNISTDataset0_3(Dataset):
Indices of the labels that are less than 4.
length : int
The number of samples in the dataset.
Methods
-------
_parse_labels(train)
Expand Down
54 changes: 47 additions & 7 deletions utils/dataloaders/usps_0_6.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class USPSDataset0_6(Dataset):
Args
----
data_path : pathlib.Path
Path to the USPS dataset file.
Path to the data directory.
train : bool, optional
Mode of the dataset.
transform : callable, optional
Expand Down Expand Up @@ -60,18 +60,29 @@ class USPSDataset0_6(Dataset):
Examples
--------
>>> from torchvision import transforms
>>> from src.datahandlers import USPSDataset0_6
>>> dataset = USPSDataset0_6(path="data/usps.h5", mode="train")
>>> transform = transforms.Compose([
... transforms.Resize((16, 16)),
... transforms.ToTensor()
... ])
>>> dataset = USPSDataset0_6(
... data_path="data",
... transform=transform
... download=True,
... train=True,
... )
>>> len(dataset)
5460
>>> data, target = dataset[0]
>>> data.shape
(16, 16)
(1, 16, 16)
>>> target
6
tensor([1., 0., 0., 0., 0., 0., 0.])
"""

filename = "usps.h5"
num_classes = 7

def __init__(
self,
Expand All @@ -85,7 +96,6 @@ def __init__(
path = data_path if isinstance(data_path, Path) else Path(data_path)
self.filepath = path / self.filename
self.transform = transform
self.num_classes = 7 # 0-6
self.mode = "train" if train else "test"

# Download the dataset if it does not exist in a temporary directory
Expand Down Expand Up @@ -116,7 +126,24 @@ def _dataset_ok(self):
return True

def download(self, url, filepath, checksum, mode):
"""Download the USPS dataset."""
"""Download the USPS dataset, and save it as an HDF5 file.
Args
----
url : str
URL to download the dataset from.
filepath : pathlib.Path
Path to save the downloaded dataset.
checksum : str
MD5 checksum of the downloaded file.
mode : str
Mode of the dataset, either train or test.
Raises
------
ValueError
If the checksum of the downloaded file does not match the expected checksum.
"""

def reporthook(blocknum, blocksize, totalsize):
"""Report download progress."""
Expand Down Expand Up @@ -164,7 +191,20 @@ def reporthook(blocknum, blocksize, totalsize):

@staticmethod
def check_integrity(filepath, checksum):
"""Check the integrity of the USPS dataset file."""
"""Check the integrity of the USPS dataset file.
Args
----
filepath : pathlib.Path
Path to the USPS dataset file.
checksum : str
MD5 checksum of the dataset file.
Returns
-------
bool
True if the checksum of the file matches the expected checksum, False otherwise
"""

file_hash = hashlib.md5(filepath.read_bytes()).hexdigest()

Expand Down
31 changes: 30 additions & 1 deletion utils/load_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,35 @@


def load_data(dataset: str, *args, **kwargs) -> Dataset:
"""
Load the dataset based on the dataset name.
Args
----
dataset : str
Name of the dataset to load.
*args : list
Additional arguments for the dataset class.
**kwargs : dict
Additional keyword arguments for the dataset class.
Returns
-------
dataset : torch.utils.data.Dataset
Dataset object.
Raises
------
NotImplementedError
If the dataset is not implemented.
Examples
--------
>>> from utils import load_data
>>> dataset = load_data("usps_0-6", data_path="data", train=True, download=True)
>>> len(dataset)
5460
"""
match dataset.lower():
case "usps_0-6":
return USPSDataset0_6(*args, **kwargs)
Expand All @@ -12,4 +41,4 @@ def load_data(dataset: str, *args, **kwargs) -> Dataset:
case "usps_7-9":
return USPSH5_Digit_7_9_Dataset(*args, **kwargs)
case _:
raise ValueError(f"Dataset: {dataset} not implemented.")
raise NotImplementedError(f"Dataset: {dataset} not implemented.")
42 changes: 40 additions & 2 deletions utils/load_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,44 @@


class MetricWrapper(nn.Module):
"""
Wrapper class for metrics, that runs multiple metrics on the same data.
Args
----
metrics : list[str]
List of metrics to run on the data.
Attributes
----------
metrics : dict
Dictionary containing the metric functions.
tmp_scores : dict
Dictionary containing the temporary scores of the metrics.
Methods
-------
__call__(y_true, y_pred)
Call the metric functions on the true and predicted labels.
accumulate()
Get the average scores of the metrics.
reset()
Reset the temporary scores of the metrics.
Examples
--------
>>> from utils import MetricWrapper
>>> metrics = MetricWrapper("entropy", "f1", "precision")
>>> y_true = [0, 1, 0, 1]
>>> y_pred = [0, 1, 1, 0]
>>> metrics(y_true, y_pred)
>>> metrics.accumulate()
{'entropy': 0.6931471805599453, 'f1': 0.5, 'precision': 0.5}
>>> metrics.reset()
>>> metrics.accumulate()
{'entropy': [], 'f1': [], 'precision': []}
"""

def __init__(self, *metrics):
super().__init__()
self.metrics = {}
Expand Down Expand Up @@ -49,13 +87,13 @@ def __call__(self, y_true, y_pred):
for key in self.metrics:
self.tmp_scores[key].append(self.metrics[key](y_true, y_pred))

def __getmetrics__(self):
def accumulate(self):
return_metrics = {}
for key in self.metrics:
return_metrics[key] = np.mean(self.tmp_scores[key])

return return_metrics

def __resetvalues__(self):
def reset(self):
for key in self.tmp_scores:
self.tmp_scores[key] = []
38 changes: 36 additions & 2 deletions utils/load_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,37 @@


def load_model(modelname: str, *args, **kwargs) -> nn.Module:
"""
Load the model based on the model name.
Args
----
modelname : str
Name of the model to load.
*args : list
Additional arguments for the model class.
**kwargs : dict
Additional keyword arguments for the model class.
Returns
-------
model : torch.nn.Module
Model object.
Raises
------
NotImplementedError
If the model is not implemented.
Examples
--------
>>> from utils import load_model
>>> model = load_model("magnusmodel", num_classes=10)
>>> model
MagnusModel(
(fc1): Linear(in_features=784, out_features=100, bias=True)
(fc2): Linear(in_features=100, out_features=10, bias=True
"""
match modelname.lower():
case "magnusmodel":
return MagnusModel(*args, **kwargs)
Expand All @@ -14,6 +45,9 @@ def load_model(modelname: str, *args, **kwargs) -> nn.Module:
case "solveigmodel":
return SolveigModel(*args, **kwargs)
case _:
raise ValueError(
f"Model: {modelname} has not been implemented. \nCheck the documentation for implemented metrics, or check your spelling"
errmsg = (
f"Model: {modelname} not implemented. "
"Check the documentation for implemented models, "
"or check your spelling."
)
raise NotImplementedError(errmsg)

0 comments on commit 2e202c9

Please sign in to comment.