diff --git a/.gitignore b/.gitignore index 6448e76c..f70be632 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ # Created by .ignore support plugin (hsz.mobi) ### Example user template - +.data +runs # IntelliJ project files .idea *.iml diff --git a/config/config.yaml b/config/config.yaml new file mode 100644 index 00000000..8266b39c --- /dev/null +++ b/config/config.yaml @@ -0,0 +1,13 @@ +Arch: + name: enet + num_classes: 4 + +Optim: + name: Adam + lr: 0.0001 + weight_decay: 0.00005 + +Scheduler: + name: StepLR + step_size: 30 + gamma: 0.1 \ No newline at end of file diff --git a/contrastyou/__init__.py b/contrastyou/__init__.py index e69de29b..1b700ec0 100644 --- a/contrastyou/__init__.py +++ b/contrastyou/__init__.py @@ -0,0 +1,27 @@ +from enum import Enum +from pathlib import Path + +PROJECT_PATH = str(Path(__file__).parents[1]) +DATA_PATH = str(Path(PROJECT_PATH) / ".data") +Path(DATA_PATH).mkdir(exist_ok=True, parents=True) +CONFIG_PATH= str(Path(PROJECT_PATH, "config")) + + +class ModelState(Enum): + TRAIN = "TRAIN" + TEST = "TEST" + EVAL = "EVAL" + + @staticmethod + def from_str(mode_str): + """ Init from string + :param mode_str: ['train', 'eval', 'predict'] + """ + if mode_str == "train": + return ModelState.TRAIN + elif mode_str == "test": + return ModelState.TEST + elif mode_str == "eval": + return ModelState.EVAL + else: + raise ValueError("Invalid argument mode_str {}".format(mode_str)) diff --git a/contrastyou/arch/__init__.py b/contrastyou/arch/__init__.py new file mode 100644 index 00000000..cfb4691d --- /dev/null +++ b/contrastyou/arch/__init__.py @@ -0,0 +1,2 @@ +def get_arch(*args, **kwargs): + pass diff --git a/contrastyou/augment/__init__.py b/contrastyou/augment/__init__.py new file mode 100644 index 00000000..acc7fe57 --- /dev/null +++ b/contrastyou/augment/__init__.py @@ -0,0 +1,11 @@ +class SequentialWrapperTwice: + def __init__(self, transform=None) -> None: + self._transform = transform + + def __call__( + self, *imgs, random_seed=None + ): + return [ + self._transform.__call__(*imgs, random_seed=random_seed), + self._transform.__call__(*imgs, random_seed=random_seed), + ] diff --git a/contrastyou/dataloader/_seg_datset.py b/contrastyou/dataloader/_seg_datset.py new file mode 100644 index 00000000..5305d6d0 --- /dev/null +++ b/contrastyou/dataloader/_seg_datset.py @@ -0,0 +1,103 @@ +import random +from abc import abstractmethod +from copy import deepcopy as dcp +from typing import Union, List, Set + +from torch.utils.data.sampler import Sampler + + +class ContrastDataset: + """ + each patient has 2 code, the first code is the group_name, which is the patient id + the second code is the partition code, indicating the position of the image slice. + All patients should have the same partition numbers so that they can be aligned. + For ACDC dataset, the ED and ES ventricular volume should be considered further + """ + + @abstractmethod + def _get_partition(self, *args) -> Union[str, int]: + """get the partition of a 2D slice given its index or filename""" + pass + + @abstractmethod + def _get_group(self, *args) -> Union[str, int]: + """get the group name of a 2D slice given its index or filename""" + pass + + @abstractmethod + def show_paritions(self) -> List[Union[str, int]]: + """show all groups of 2D slices in the dataset""" + pass + + def show_parition_set(self) -> Set[Union[str, int]]: + """show all groups of 2D slices in the dataset""" + return set(self.show_paritions()) + + @abstractmethod + def show_groups(self) -> List[Union[str, int]]: + """show all groups of 2D slices in the dataset""" + pass + + def show_group_set(self) -> Set[Union[str, int]]: + """show all groups of 2D slices in the dataset""" + return set(self.show_groups()) + + +class ContrastBatchSampler(Sampler): + """ + This class is going to realize the sampling for different patients and from the same patients + `we form batches by first randomly sampling m < M volumes. Then, for each sampled volume, we sample one image per + partition resulting in S images per volume. Next, we apply a pair of random transformations on each sampled image and + add them to the batch + """ + + class _SamplerIterator: + + def __init__(self, group2index, partion2index, group_sample_num=4, partition_sample_num=1) -> None: + self._group2index, self._partition2index = dcp(group2index), dcp(partion2index) + + assert group_sample_num >= 1 and group_sample_num <= len(self._group2index.keys()), group_sample_num + self._group_sample_num = group_sample_num + self._partition_sample_num = partition_sample_num + + def __iter__(self): + return self + + def __next__(self): + batch_index = [] + cur_gsamples = random.sample(self._group2index.keys(), self._group_sample_num) + assert isinstance(cur_gsamples, list), cur_gsamples + # for each gsample, sample at most partition_sample_num slices per partion + for cur_gsample in cur_gsamples: + gavailableslices = self._group2index[cur_gsample] + for savailbleslices in self._partition2index.values(): + sampled_slices = random.sample(sorted(set(gavailableslices) & set(savailbleslices)), + self._partition_sample_num) + batch_index.extend(sampled_slices) + return batch_index + + def __init__(self, dataset: ContrastDataset, group_sample_num=4, partition_sample_num=1) -> None: + self._dataset = dataset + filenames = dcp(list(dataset._filenames.values())[0]) + group2index = {} + partiton2index = {} + for i, filename in enumerate(filenames): + group = dataset._get_group(filename) + if group not in group2index: + group2index[group] = [] + group2index[group].append(i) + partition = dataset._get_partition(filename) + if partition not in partiton2index: + partiton2index[partition] = [] + partiton2index[partition].append(i) + self._group2index = group2index + self._partition2index = partiton2index + self._group_sample_num = group_sample_num + self._partition_sample_num = partition_sample_num + + def __iter__(self): + return self._SamplerIterator(self._group2index, self._partition2index, self._group_sample_num, + self._partition_sample_num) + + def __len__(self) -> int: + return len(self._dataset) # type: ignore diff --git a/contrastyou/dataloader/acdc_dataset.py b/contrastyou/dataloader/acdc_dataset.py new file mode 100644 index 00000000..7fb9c226 --- /dev/null +++ b/contrastyou/dataloader/acdc_dataset.py @@ -0,0 +1,35 @@ +from typing import List, Tuple, Union + +from deepclustering.augment import SequentialWrapper +from deepclustering.dataset import ACDCDataset as _ACDCDataset +from torch import Tensor + +from contrastyou.dataloader._seg_datset import ContrastDataset + + +class ACDCDataset(ContrastDataset, _ACDCDataset): + def __init__(self, root_dir: str, mode: str, transforms: SequentialWrapper = None, + verbose=True) -> None: + super().__init__(root_dir, mode, ["img", "gt"], transforms, verbose) + + def __getitem__(self, index) -> Tuple[List[Tensor], str, str, str]: + data, filename = super().__getitem__(index) + partition = self._get_partition(filename) + group = self._get_group(filename) + return data, filename, partition, group + + def _get_group(self, filename) -> Union[str, int]: + return self._get_group_name(filename) + + def _get_partition(self, filename) -> Union[str, int]: + return 0 + + def show_paritions(self) -> List[Union[str, int]]: + return [self._get_partition(f) for f in list(self._filenames.values())[0]] + + def show_groups(self) -> List[Union[str, int]]: + return [self._get_group(f) for f in list(self._filenames.values())[0]] + + + + diff --git a/contrastyou/dataloader/mmwhs_dataset.py b/contrastyou/dataloader/mmwhs_dataset.py new file mode 100644 index 00000000..e69de29b diff --git a/contrastyou/helper/__init__.py b/contrastyou/helper/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrastyou/losses/__init__.py b/contrastyou/losses/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrastyou/meters2/__init__.py b/contrastyou/meters2/__init__.py new file mode 100644 index 00000000..5f0a9b0b --- /dev/null +++ b/contrastyou/meters2/__init__.py @@ -0,0 +1,5 @@ + +from .meter_interface import MeterInterface +from .individual_meters import * + +# todo: improve the stability of each meter diff --git a/contrastyou/meters2/individual_meters/__init__.py b/contrastyou/meters2/individual_meters/__init__.py new file mode 100644 index 00000000..f94d9627 --- /dev/null +++ b/contrastyou/meters2/individual_meters/__init__.py @@ -0,0 +1,30 @@ +from ._metric import _Metric +# individual package for meters based on +""" +>>>class _Metric(metaclass=ABCMeta): +>>> @abstractmethod +>>> def reset(self): +>>> pass +>>> +>>> @abstractmethod +>>> def add(self, *args, **kwargs): +>>> pass +>>> +>>> @abstractmethod +>>> def log(self): +>>> pass +>>> +>>> @abstractmethod +>>> def summary(self) -> dict: +>>> pass +>>> +>>> @abstractmethod +>>> def detailed_summary(self) -> dict: +>>> pass +""" + +from .averagemeter import AverageValueMeter +from .confusionmatrix import ConfusionMatrix +from .hausdorff import HaussdorffDistance +from .instance import InstanceValue +from .iou import IoU \ No newline at end of file diff --git a/contrastyou/meters2/individual_meters/_metric.py b/contrastyou/meters2/individual_meters/_metric.py new file mode 100644 index 00000000..b0b8fa3b --- /dev/null +++ b/contrastyou/meters2/individual_meters/_metric.py @@ -0,0 +1,28 @@ +from abc import abstractmethod, ABCMeta + + +class _Metric(metaclass=ABCMeta): + """Base class for all metrics. + record the values within a single epoch + From: https://github.com/pytorch/tnt/blob/master/torchnet/meter/meter.py + """ + + @abstractmethod + def reset(self): + pass + + @abstractmethod + def add(self, *args, **kwargs): + pass + + # @abstractmethod + # def log(self): + # pass + + @abstractmethod + def summary(self) -> dict: + pass + + @abstractmethod + def detailed_summary(self) -> dict: + pass diff --git a/contrastyou/meters2/individual_meters/averagemeter.py b/contrastyou/meters2/individual_meters/averagemeter.py new file mode 100644 index 00000000..843aa199 --- /dev/null +++ b/contrastyou/meters2/individual_meters/averagemeter.py @@ -0,0 +1,61 @@ +from typing import List + +import numpy as np + +from ._metric import _Metric + + +class AverageValueMeter(_Metric): + def __init__(self): + super(AverageValueMeter, self).__init__() + self.reset() + self.val = 0 + + def add(self, value, n=1): + self.val = value + self.sum += value + self.var += value * value + self.n += n + + if self.n == 0: + self.mean, self.std = np.nan, np.nan + elif self.n == 1: + self.mean = 0.0 + self.sum # This is to force a copy in torch/numpy + self.std = np.inf + self.mean_old = self.mean + self.m_s = 0.0 + else: + self.mean = self.mean_old + (value - n * self.mean_old) / float(self.n) + self.m_s += (value - self.mean_old) * (value - self.mean) + self.mean_old = self.mean + self.std = np.sqrt(self.m_s / (self.n - 1.0)) + + def value(self): + return self.mean, self.std + + def reset(self): + self.n = 0 + self.sum = 0.0 + self.var = 0.0 + self.val = 0.0 + self.mean = np.nan + self.mean_old = 0.0 + self.m_s = 0.0 + self.std = 0.0 + + def summary(self) -> dict: + # this function returns a dict and tends to aggregate the historical results. + return {"mean": self.value()[0]} + + def detailed_summary(self) -> dict: + # this function returns a dict and tends to aggregate the historical results. + return {"mean": self.value()[0], "val": self.value()[1]} + + def __repr__(self): + def _dict2str(value_dict: dict): + return "\t".join([f"{k}:{v}" for k, v in value_dict.items()]) + + return f"{self.__class__.__name__}: n={self.n} \n \t {_dict2str(self.detailed_summary())}" + + def get_plot_names(self) -> List[str]: + return ["mean"] diff --git a/contrastyou/meters2/individual_meters/cache.py b/contrastyou/meters2/individual_meters/cache.py new file mode 100644 index 00000000..2bda0f1d --- /dev/null +++ b/contrastyou/meters2/individual_meters/cache.py @@ -0,0 +1,59 @@ +import torch +from numbers import Number +from ._metric import _Metric +import numpy as np + + +class Cache(_Metric): + """ + Cache is a meter to just store the elements in self.log. For statistic propose of use. + """ + + def __init__(self) -> None: + super().__init__() + self.log = [] + + def reset(self): + self.log = [] + + def add(self, input): + self.log.append(input) + + def value(self, **kwargs): + return len(self.log) + + def summary(self) -> dict: + return {"total elements": self.log.__len__()} + + def detailed_summary(self) -> dict: + return self.summary() + + +class AveragewithStd(Cache): + """ + this Meter is going to return the mean and std_lower, std_high for a list of scalar values + """ + + def add(self, input): + assert ( + isinstance(input, Number) + or (isinstance(input, torch.Tensor) and input.shape.__len__() <= 1) + or (isinstance(input, np.ndarray) and input.shape.__len__() <= 1) + ) + if torch.is_tensor(input): + input = input.cpu().item() + + super().add(input) + + def value(self, **kwargs): + return torch.Tensor(self.log).mean().item() + + def summary(self) -> dict: + torch_log = torch.Tensor(self.log) + mean = torch_log.mean() + std = torch_log.std() + return { + "mean": mean.item(), + "lstd": mean.item() - std.item(), + "hstd": mean.item() + std.item(), + } diff --git a/contrastyou/meters2/individual_meters/confusionmatrix.py b/contrastyou/meters2/individual_meters/confusionmatrix.py new file mode 100644 index 00000000..30bcb93d --- /dev/null +++ b/contrastyou/meters2/individual_meters/confusionmatrix.py @@ -0,0 +1,112 @@ +import warnings + +import numpy as np +import torch + +from ._metric import _Metric + + +class ConfusionMatrix(_Metric): + """Constructs a confusion matrix for a multi-class classification problems. + + Does not support multi-label, multi-class problems. + + Keyword arguments: + - num_classes (int): number of classes in the classification problem. + - normalized (boolean, optional): Determines whether or not the confusion + matrix is normalized or not. Default: False. + + Modified from: https://github.com/pytorch/tnt/blob/master/torchnet/meter/confusionmeter.py + """ + + def __init__(self, num_classes, ignore_index=255, normalized=False): + super().__init__() + + self.conf = np.ndarray((num_classes, num_classes), dtype=np.int32) + self.normalized = normalized + self.num_classes = num_classes + self.ignore_index = ignore_index + self.reset() + + def reset(self): + self.conf.fill(0) + + def add(self, predicted, target): + """Computes the confusion matrix + + The shape of the confusion matrix is K x K, where K is the number + of classes. + + Keyword arguments: + - predicted (Tensor or numpy.ndarray): Can be an N x K tensor/array of + predicted scores obtained from the model for N examples and K classes, + or an N-tensor/array of integer values between 0 and K-1. + - target (Tensor or numpy.ndarray): Can be an N x K tensor/array of + ground-truth classes for N examples and K classes, or an N-tensor/array + of integer values between 0 and K-1. + + """ + # If target and/or predicted are tensors, convert them to numpy arrays + if torch.is_tensor(predicted): + predicted = predicted.cpu().numpy() + if torch.is_tensor(target): + target = target.cpu().numpy() + assert predicted.shape == target.shape + # assert predicted.shape[0] == target.shape[0], \ + # 'number of targets and predicted outputs do not match' + # + # if np.ndim(predicted) != 1: + # assert predicted.shape[1] == self.num_classes, \ + # 'number of predictions does not match size of confusion matrix' + # predicted = np.argmax(predicted, 1) + # else: + # assert (predicted.max() < self.num_classes) and (predicted.min() >= 0), \ + # 'predicted values are not between 0 and k-1' + # + # if np.ndim(target) != 1: + # assert target.shape[1] == self.num_classes, \ + # 'Onehot target does not match size of confusion matrix' + # assert (target >= 0).all() and (target <= 1).all(), \ + # 'in one-hot encoding, target values should be 0 or 1' + # assert (target.sum(1) == 1).all(), \ + # 'multi-label setting is not supported' + # target = np.argmax(target, 1) + # else: + # assert (target.max() < self.num_classes) and (target.min() >= 0), \ + # 'target values are not between 0 and k-1' + + mask = (target >= 0) & (target < self.num_classes) + + # hack for bincounting 2 arrays together + x = predicted[mask] + self.num_classes * target[mask] + bincount_2d = np.bincount(x.astype(np.int32), minlength=self.num_classes ** 2) + assert bincount_2d.size == self.num_classes ** 2 + conf = bincount_2d.reshape((self.num_classes, self.num_classes)) + + self.conf += conf + + def value(self): + """ + Returns: + Confustion matrix of K rows and K columns, where rows corresponds + to ground-truth targets and columns corresponds to predicted + targets. + """ + if self.normalized: + conf = self.conf.astype(np.float32) + return conf / conf.sum(1).clip(min=1e-12)[:, None] + else: + return self.conf + + def summary(self) -> dict: + with warnings.catch_warnings(): + warnings.filterwarnings("ignore") + acc = np.diag(self.conf).sum() / np.sum(self.conf) + return {"acc": acc} + + def detailed_summary(self) -> dict: + acc = np.diag(self.conf).sum() / np.sum(self.conf) + return {"acc": acc} + + def log(self): + return self.conf diff --git a/contrastyou/meters2/individual_meters/dicemeter.py b/contrastyou/meters2/individual_meters/dicemeter.py new file mode 100644 index 00000000..dac50a95 --- /dev/null +++ b/contrastyou/meters2/individual_meters/dicemeter.py @@ -0,0 +1,115 @@ +from typing import List + +import numpy as np +import torch +import torch.nn.functional as F +from torch import Tensor + +from ._metric import _Metric +from deepclustering2.loss.dice_loss import dice_coef, dice_batch +from deepclustering2.utils import probs2one_hot, class2one_hot +from deepclustering2.utils.typecheckconvert import to_float + +__all__ = ["SliceDiceMeter", "BatchDiceMeter"] + + +# from deepclustering.decorator.decorator import threaded + + +def toOneHot(pred_logit, mask): + """ + :param pred_logit: logit with b,c, h, w. it is fine to pass simplex prediction or onehot. + :param mask: gt mask with b,h,w + :return: onehot presentation of prediction and mask, pred.shape == mask.shape == b,c, h , w + """ + oh_predmask = probs2one_hot(F.softmax(pred_logit, 1)) + oh_mask = class2one_hot(mask.squeeze(1), C=pred_logit.shape[1]) + assert oh_predmask.shape == oh_mask.shape + return oh_predmask, oh_mask + + +class _DiceMeter(_Metric): + def __init__(self, call_function, C=4, report_axises=None) -> None: + super(_DiceMeter, self).__init__() + assert report_axises is None or isinstance(report_axises, (list, tuple)) + if report_axises is not None: + assert max(report_axises) <= C, ( + "Incompatible parameter of `C`={} and " + "`report_axises`={}".format(C, report_axises) + ) + self._C = C + self._report_axis = list(range(self._C)) + if report_axises is not None: + self._report_axis = report_axises + self._diceCallFunction = call_function + self._diceLog = [] # type: ignore + self._n = 0 + + def reset(self): + self._diceLog = [] # type: ignore + self._n = 0 + + def add(self, pred_logit: Tensor, gt: Tensor): + """ + call class2one_hot to convert onehot to input. + :param pred_logit: predicton, can be simplex or logit with shape b, c, h, w + :param gt: ground truth label with shape b, h, w or b, 1, h, w + :return: + """ + assert pred_logit.shape.__len__() == 4, f"pred_logit shape:{pred_logit.shape}" + if gt.shape.__len__() == 4: + gt = gt.squeeze(2) + assert gt.shape.__len__() == 3 + dice_value = self._diceCallFunction(*toOneHot(pred_logit, gt)) + if dice_value.shape.__len__() == 1: + dice_value = dice_value.unsqueeze(0) + assert dice_value.shape.__len__() == 2 + self._diceLog.append(dice_value) + self._n += 1 + + def value(self): + if self._n > 0: + log = torch.cat(self._diceLog) + means = log.mean(0) + stds = log.std(0) + report_means = log[:, self._report_axis].mean(1) + report_std = report_means.std() + report_mean = report_means.mean() + return (report_mean, report_std), (means, stds) + else: + return (np.nan, np.nan), ([np.nan] * self._C, [np.nan] * self._C) + + def detailed_summary(self) -> dict: + _, (means, _) = self.value() + return {f"DSC{i}": to_float(means[i]) for i in range(len(means))} + + def summary(self) -> dict: + _, (means, _) = self.value() + return {f"DSC{i}": to_float(means[i]) for i in self._report_axis} + + def get_plot_names(self) -> List[str]: + return [f"DSC{i}" for i in self._report_axis] + + def __repr__(self): + string = f"C={self._C}, report_axis={self._report_axis}\n" + return ( + string + "\t" + "\t".join([f"{k}:{v}" for k, v in self.summary().items()]) + ) + + +class SliceDiceMeter(_DiceMeter): + """ + used for 2d dice for sliced input. + """ + + def __init__(self, C=4, report_axises=None) -> None: + super().__init__(call_function=dice_coef, report_axises=report_axises, C=C) + + +class BatchDiceMeter(_DiceMeter): + """ + used for 3d dice for structure input. + """ + + def __init__(self, C=4, report_axises=None) -> None: + super().__init__(call_function=dice_batch, report_axises=report_axises, C=C) diff --git a/contrastyou/meters2/individual_meters/general_dice_meter.py b/contrastyou/meters2/individual_meters/general_dice_meter.py new file mode 100644 index 00000000..7d3e807a --- /dev/null +++ b/contrastyou/meters2/individual_meters/general_dice_meter.py @@ -0,0 +1,184 @@ +from typing import Union, List + +import torch +from torch import Tensor + +from deepclustering.meters._metric import _Metric +from deepclustering.meters import BatchDiceMeter +from deepclustering.loss.dice_loss import MetaDice +from deepclustering.utils import ( + simplex, + one_hot, + class2one_hot, + probs2one_hot, + to_float, +) +from collections.abc import Iterable +import numpy as np + + +class UniversalDice(_Metric): + def __init__(self, C=4, report_axises=None) -> None: + super(UniversalDice, self).__init__() + assert report_axises is None or isinstance( + report_axises, (list, tuple) + ), f"`report_axises` should be either None or an iterator, given {type(report_axises)}" + if report_axises is not None: + assert max(report_axises) <= C, ( + "Incompatible parameter of `C`={} and " + "`report_axises`={}".format(C, report_axises) + ) + self._C = C + self._report_axis = list(range(self._C)) + if report_axises is not None: + self._report_axis = report_axises + self.reset() + + def reset(self): + self._intersections = [] + self._unions = [] + self._group_names = [] + self._n = 0 + + def add( + self, pred: Tensor, target: Tensor, group_name: Union[str, List[str]] = None + ): + """ + add pred and target + :param pred: class- or onehot-coded tensor of the same shape as the target + :param target: class- or onehot-coded tensor of the same shape as the pred + :param group_name: List of names, or a string of a name, or None. + indicating 2D slice dice, batch-based dice + :return: + """ + + assert pred.shape == target.shape, ( + f"incompatible shape of `pred` and `target`, given " + f"{pred.shape} and {target.shape}." + ) + + assert not pred.requires_grad and not target.requires_grad + + if group_name is not None: + if not isinstance(group_name, str): + if isinstance(group_name, Iterable): + assert ( + len(group_name) == pred.shape[0] + ) # number of group name should be the same as the pred batch size + assert isinstance(group_name[0], str) + else: + raise TypeError(f"type of `group_name` wrong {type(group_name)}") + + onehot_pred, onehot_target = self._convert2onehot(pred, target) + B, C, *hw = pred.shape + + # current group name: + current_group_name = [ + str(self._n) + f"_{i:03d}" for i in range(B) + ] # make it like slice based dice + if group_name is not None: + current_group_name = group_name + if isinstance(group_name, str): + # this is too make 3D dice. + current_group_name = [group_name] * B + assert isinstance(current_group_name, list) + interaction, union = ( + self._intersaction(onehot_pred, onehot_target), + self._union(onehot_pred, onehot_target), + ) + self._intersections.append(interaction) + self._unions.append(union) + self._group_names.extend(current_group_name) + self._n += 1 + + @property + def log(self): + if self._n > 0: + group_names = self.group_names + interaction_array = torch.cat(self._intersections, dim=0) + union_array = torch.cat(self._unions, dim=0) + group_name_array = np.asarray(self._group_names) + resulting_dice = [] + for unique_name in group_names: + index = group_name_array == unique_name + group_dice = (2 * interaction_array[index].sum(0) + 1e-6) / ( + union_array[index].sum(0) + 1e-6 + ) + resulting_dice.append(group_dice) + resulting_dice = torch.stack(resulting_dice, dim=0) + return resulting_dice + + def value(self, **kwargs): + if self._n == 0: + return ([np.nan] * self._C, [np.nan] * self._C) + + resulting_dice = self.log + return (resulting_dice.mean(0), resulting_dice.std(0)) + + def summary(self) -> dict: + means, stds = self.value() + return {f"DSC{i}": to_float(means[i]) for i in self._report_axis} + + def detailed_summary(self) -> dict: + means, stds = self.value() + return { + **{f"DSC{i}": to_float(means[i]) for i in self._report_axis}, + **{f"DSC_std{i}": to_float(stds[i]) for i in self._report_axis}, + } + + @property + def group_names(self): + return sorted(set(self._group_names)) + + @staticmethod + def _intersaction(pred: Tensor, target: Tensor): + """ + return the interaction, supposing the two inputs are onehot-coded. + :param pred: onehot pred + :param target: onehot target + :return: tensor of intersaction over classes + """ + assert pred.shape == target.shape + assert one_hot(pred) and one_hot(target) + + B, C, *hw = pred.shape + intersect = (pred * target).sum(list(range(2, 2 + len(hw)))) + assert intersect.shape == (B, C) + return intersect + + @staticmethod + def _union(pred: Tensor, target: Tensor): + """ + return the union, supposing the two inputs are onehot-coded. + :param pred: onehot pred + :param target: onehot target + :return: tensor of intersaction over classes + """ + assert pred.shape == target.shape + assert one_hot(pred) and one_hot(target) + + B, C, *hw = pred.shape + union = (pred + target).sum(list(range(2, 2 + len(hw)))) + assert union.shape == (B, C) + return union + + def _convert2onehot(self, pred: Tensor, target: Tensor): + # only two possibility: both onehot or both class-coded. + assert pred.shape == target.shape + # if they are onehot-coded: + if simplex(pred, 1) and one_hot(target): + return probs2one_hot(pred).long(), target.long() + # here the pred and target are labeled long + return ( + class2one_hot(pred, self._C).long(), + class2one_hot(target, self._C).long(), + ) + + def get_plot_names(self) -> List[str]: + return [f"DSC{i}" for i in self._report_axis] + + def __repr__(self): + string = f"C={self._C}, report_axis={self._report_axis}\n" + return ( + string + "\t" + "\t".join([f"{k}:{v}" for k, v in self.summary().items()]) + ) diff --git a/contrastyou/meters2/individual_meters/hausdorff.py b/contrastyou/meters2/individual_meters/hausdorff.py new file mode 100644 index 00000000..1458bfad --- /dev/null +++ b/contrastyou/meters2/individual_meters/hausdorff.py @@ -0,0 +1,142 @@ +__all__ = ["HaussdorffDistance"] + +import warnings +from typing import * + +import numpy as np +import torch +from medpy.metric.binary import hd +from torch import Tensor + +from ._metric import _Metric +from deepclustering2.utils import one_hot + + +class HaussdorffDistance(_Metric): + default_class_num = 4 + + def __init__(self, C=None, report_axises=None) -> None: + super().__init__() + self._haussdorff_log: List[Tensor] = [] + self._C = C + self._report_axises = report_axises + + def reset(self): + self._haussdorff_log = [] + + def add( + self, + pred: Tensor, + label: Tensor, + voxelspacing: Union[float, List[float]] = None, + ) -> None: + """ + Add function to add torch.Tensor for pred and label, which are all one-hot matrices. + :param pred: one-hot prediction matrix + :param label: one-hot label matrix + :param voxelspacing: voxel space for 2D slices + :return: None + """ + assert one_hot(pred), pred + assert one_hot(label), label + assert ( + len(pred.shape) == 4 + ), f"Input tensor is restricted to 4-D tensor, given {pred.shape}." + assert pred.shape == label.shape, ( + f"The shape of pred and label should be the same, " + f"given {pred.shape, label.shape}" + ) + B, C, _, _ = pred.shape # here we only accept 4 dimensional input. + if self._C is None: + self._C = C + else: + assert ( + self._C == C + ), f"Input dimension C: {C} is not consistent with the registered C:{self._C}" + + res = torch.zeros((B, C), dtype=torch.float32, device=pred.device) + n_pred = pred.cpu().numpy() + n_target = label.cpu().numpy() + for b in range(B): + if C == 2: + res[b, :] = numpy_haussdorf( + n_pred[b, 0], n_target[b, 0], voxelspacing=voxelspacing + ) + continue + + for c in range(C): + res[b, c] = numpy_haussdorf( + n_pred[b, c], n_target[b, c], voxelspacing=voxelspacing + ) + + self._haussdorff_log.append(res) + + def value(self, **kwargs): + log: Tensor = self.log + means = log.mean(0) + stds = log.std(0) + report_means = ( + log.mean(1) + if self._report_axises == "all" + else log[:, self._report_axises].mean(1) + ) + report_std = report_means.std() + report_mean = report_means.mean() + return (report_mean, report_std), (means, stds) + + def summary(self) -> dict: + if self._report_axises is None: + self._report_axises = [ + i + for i in range( + self._C if self._C is not None else self.default_class_num + ) + ] + + _, (means, _) = self.value() + return {f"HD{i}": means[i].item() for i in self._report_axises} + + def detailed_summary(self) -> dict: + if self._report_axises is None: + self._report_axises = [ + i + for i in range( + self._C if self._C is not None else self.default_class_num + ) + ] + _, (means, _) = self.value() + return {f"HD{i}": means[i].item() for i in range(len(means))} + + @property + def log(self): + try: + log = torch.cat(self._haussdorff_log) + except RuntimeError: + warnings.warn(f"No log has been found", RuntimeWarning) + log = torch.Tensor( + tuple( + [ + 0 + for _ in range( + self._C if self._C is not None else self.default_class_num + ) + ] + ) + ) + log = log.unsqueeze(0) + assert len(log.shape) == 2 + return log + + +def numpy_haussdorf( + pred: np.ndarray, target: np.ndarray, voxelspacing: Union[float, List[float]] = None +) -> float: + assert len(pred.shape) == 2 + assert pred.shape == target.shape + + # h = max(directed_hausdorff(pred, target)[0], directed_hausdorff(target, pred)[0]) + try: + h = hd(pred, target, voxelspacing) + except RuntimeError: + h = 0 + return h diff --git a/contrastyou/meters2/individual_meters/instance.py b/contrastyou/meters2/individual_meters/instance.py new file mode 100644 index 00000000..067958cb --- /dev/null +++ b/contrastyou/meters2/individual_meters/instance.py @@ -0,0 +1,25 @@ +from ._metric import _Metric + + +# this meter is to show the instance value, instead of print. + + +class InstanceValue(_Metric): + def __init__(self) -> None: + super().__init__() + self.instance_value = None + + def reset(self): + self.instance_value = None + + def add(self, value): + self.instance_value = value + + def value(self, **kwargs): + return self.instance_value + + def summary(self) -> dict: + return {"value": self.instance_value} + + def detailed_summary(self) -> dict: + return {"value": self.instance_value} diff --git a/contrastyou/meters2/individual_meters/iou.py b/contrastyou/meters2/individual_meters/iou.py new file mode 100644 index 00000000..dca7a756 --- /dev/null +++ b/contrastyou/meters2/individual_meters/iou.py @@ -0,0 +1,127 @@ +import numpy as np +import torch + +from .confusionmatrix import ConfusionMatrix +from ._metric import _Metric + + +class IoU(_Metric): + """Computes the intersection over union (IoU) per class and corresponding + mean (mIoU). + + Intersection over union (IoU) is a common evaluation metric for semantic + segmentation. The predictions are first accumulated in a confusion matrix + and the IoU is computed from it as follows: + + IoU = true_positive / (true_positive + false_positive + false_negative). + + Keyword arguments: + - num_classes (int): number of classes in the classification problem + - normalized (boolean, optional): Determines whether or not the confusion + matrix is normalized or not. Default: False. + - ignore_index (int or iterable, optional): Index of the classes to ignore + when computing the IoU. Can be an int, or any iterable of ints. + """ + + def __init__(self, num_classes, normalized=False, ignore_index=255): + super().__init__() + self.num_classes = num_classes + self.conf_metric = ConfusionMatrix( + num_classes, ignore_index=ignore_index, normalized=normalized + ) + + if ignore_index is None: + self.ignore_index = None + elif isinstance(ignore_index, int): + self.ignore_index = (ignore_index,) + else: + try: + self.ignore_index = tuple(ignore_index) + except TypeError: + raise ValueError("'ignore_index' must be an int or iterable") + + def reset(self): + self.conf_metric.reset() + + def add(self, predicted, target): + """Adds the predicted and target pair to the IoU metric. + + Keyword arguments: + - predicted (Tensor): Can be a (N, K, H, W) tensor of + predicted scores obtained from the model for N examples and K classes, + or (N, H, W) tensor of integer values between 0 and K-1. + - target (Tensor): Can be a (N, K, H, W) tensor of + target scores for N examples and K classes, or (N, H, W) tensor of + integer values between 0 and K-1. + + """ + # Dimensions check + assert predicted.size(0) == target.size( + 0 + ), "number of targets and predicted outputs do not match" + assert ( + predicted.dim() == 3 or predicted.dim() == 4 + ), "predictions must be of dimension (N, H, W) or (N, K, H, W)" + assert ( + target.dim() == 3 or target.dim() == 4 + ), "targets must be of dimension (N, H, W) or (N, K, H, W)" + + # If the tensor is in categorical format convert it to integer format + if predicted.dim() == 4: + _, predicted = predicted.max(1) + # if target.dim() == 4: + # _, target = target.max(1) + + self.conf_metric.add(predicted.view(-1), target.long().view(-1)) + + def value(self): + """Computes the IoU and mean IoU. + + The mean computation ignores NaN elements of the IoU array. + + Returns: + Tuple: (IoU, mIoU). The first output is the per class IoU, + for K classes it's numpy.ndarray with K elements. The second output, + is the mean IoU. + """ + hist = self.conf_metric.value() + # if self.ignore_index is not None: + # for index in self.ignore_index: + # conf_matrix[:, self.ignore_index] = 0 + # conf_matrix[self.ignore_index, :] = 0 + # true_positive = np.diag(conf_matrix) + # false_positive = np.sum(conf_matrix, 0) - true_positive + # false_negative = np.sum(conf_matrix, 1) - true_positive + # + # # Just in case we get a division by 0, ignore/hide the error + # with np.errstate(divide='ignore', invalid='ignore'): + # iou = true_positive / (true_positive + false_positive + false_negative) + # + # ## this mean_iou doesn't consider whether the class has been in the gt. + + acc = np.diag(hist).sum() / hist.sum() + acc_cls = np.diag(hist) / hist.sum(axis=1) + acc_cls = np.nanmean(acc_cls) + iu = (np.diag(hist) + 1e-16) / ( + hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist) + 1e-16 + ) + valid = hist.sum(axis=1) > 0 # added 横着加 + mean_iu = np.nanmean(iu[valid]) # gt 出现过的mean_iu + freq = hist.sum(axis=1) / hist.sum() + fwavacc = (freq[freq > 0] * iu[freq > 0]).sum() + # cls_iu = dict(zip(range(n_class), iu)) + cls_iu = iu + return { + "Overall_Acc": acc, + "Mean_Acc": acc_cls, + "FreqW_Acc": fwavacc, + "Validated_Mean_IoU": mean_iu, + "Mean_IoU": np.nanmean(iu), + "Class_IoU": torch.from_numpy(cls_iu).float(), + } + + def summary(self) -> dict: + return { + f"{k}": v + for k, v in zip(range(self.num_classes), self.value()["Class_IoU"]) + } diff --git a/contrastyou/meters2/individual_meters/kappa.py b/contrastyou/meters2/individual_meters/kappa.py new file mode 100644 index 00000000..0ee6aba6 --- /dev/null +++ b/contrastyou/meters2/individual_meters/kappa.py @@ -0,0 +1,69 @@ +from typing import List + +import torch +from sklearn.metrics import cohen_kappa_score +from torch import Tensor + +from ._metric import _Metric + + +class KappaMetrics(_Metric): + """ SKLearnMetrics computes various classification metrics at the end of a batch. + Unforunately, doesn't work when used with generators....""" + + def __init__(self) -> None: + super().__init__() + self.kappa = [] + + def add( + self, predicts: List[Tensor], target: Tensor, considered_classes: List[int] + ): + for predict in predicts: + assert predict.shape == target.shape + predicts = [predict.detach().data.cpu().numpy().ravel() for predict in predicts] + target = target.detach().data.cpu().numpy().ravel() + mask = [t in considered_classes for t in target] + predicts = [predict[mask] for predict in predicts] + target = target[mask] + kappa_score = [cohen_kappa_score(predict, target) for predict in predicts] + self.kappa.append(kappa_score) + + def reset(self): + self.kappa = [] + + def value(self): + return torch.Tensor(self.kappa).mean(0) + + def summary(self): + return {f"kappa{i}": self.value()[i].item() for i in range(len(self.value()))} + + def detailed_summary(self): + return {f"kappa{i}": self.value()[i].item() for i in range(len(self.value()))} + + +class Kappa2Annotator(KappaMetrics): + def __init__(self) -> None: + super().__init__() + + def add( + self, + predict1: Tensor, + predict2: Tensor, + gt=Tensor, + considered_classes=[1, 2, 3], + ): + assert predict1.shape == predict2.shape + gt = gt.data.cpu().numpy().ravel() + predict1 = predict1.detach().data.cpu().numpy().ravel() + predict2 = predict2.detach().data.cpu().numpy().ravel() + + if considered_classes is not None: + mask = [t in considered_classes for t in gt] + predict1 = predict1[mask] + predict2 = predict2[mask] + + kappa = cohen_kappa_score(y1=predict1, y2=predict2) + self.kappa.append(kappa) + + def value(self, **kwargs): + return torch.Tensor(self.kappa).mean() diff --git a/contrastyou/meters2/individual_meters/surface_distance.py b/contrastyou/meters2/individual_meters/surface_distance.py new file mode 100644 index 00000000..934984d4 --- /dev/null +++ b/contrastyou/meters2/individual_meters/surface_distance.py @@ -0,0 +1,29 @@ +import numpy as np +from deepclustering.utils.typecheckconvert import to_numpy +from medpy.metric import assd +from medpy.metric.binary import __surface_distances + +__all__ = ["hausdorff_distance", "mod_hausdorff_distance", "average_surface_distance"] + + +def hausdorff_distance(data1, data2, voxelspacing=None): + data1, data2 = to_numpy(data1), to_numpy(data2) + hd1 = __surface_distances(data1, data2, voxelspacing, connectivity=1) + hd2 = __surface_distances(data2, data1, voxelspacing, connectivity=1) + hd = max(hd1.max(), hd2.max()) + return hd + + +def mod_hausdorff_distance(data1, data2, voxelspacing=None, percentile=95): + data1, data2 = to_numpy(data1), to_numpy(data2) + hd1 = __surface_distances(data1, data2, voxelspacing, connectivity=1) + hd2 = __surface_distances(data2, data1, voxelspacing, connectivity=1) + hd95_1 = np.percentile(hd1, percentile) + hd95_2 = np.percentile(hd2, percentile) + mhd = max(hd95_1, hd95_2) + return mhd + + +def average_surface_distance(data1, data2, voxelspacing=None): + data1, data2 = to_numpy(data1), to_numpy(data2) + return assd(data1, data2, voxelspacing) diff --git a/contrastyou/meters2/individual_meters/surface_meter.py b/contrastyou/meters2/individual_meters/surface_meter.py new file mode 100644 index 00000000..c5d7283f --- /dev/null +++ b/contrastyou/meters2/individual_meters/surface_meter.py @@ -0,0 +1,145 @@ +from typing import List, Union + +import numpy as np +from deepclustering.meters._metric import _Metric +from deepclustering.utils import ( + simplex, + one_hot, + class2one_hot, + probs2one_hot, + to_float, +) +from torch import Tensor + +from .surface_distance import ( + mod_hausdorff_distance, + hausdorff_distance, + average_surface_distance, +) + + +class SurfaceMeter(_Metric): + meter_choices = { + "mod_hausdorff": mod_hausdorff_distance, + "hausdorff": hausdorff_distance, + "average_surface": average_surface_distance, + } + abbr = {"mod_hausdorff": "MHD", "hausdorff": "HD", "average_surface": "ASD"} + + def __init__(self, C=4, report_axises=None, metername: str = "hausdorff") -> None: + super(SurfaceMeter, self).__init__() + assert report_axises is None or isinstance( + report_axises, (list, tuple) + ), f"`report_axises` should be either None or an iterator, given {type(report_axises)}" + if report_axises is not None: + assert max(report_axises) <= C, ( + "Incompatible parameter of `C`={} and " + "`report_axises`={}".format(C, report_axises) + ) + self._C = C + self._report_axis = list(range(self._C)) + if report_axises is not None: + self._report_axis = report_axises + assert metername in self.meter_choices.keys() + self._surface_name = metername + self._abbr = self.abbr[metername] + self._surface_function = self.meter_choices[metername] + self.reset() + + def reset(self): + self._mhd = [] + self._n = 0 + + def add( + self, + pred: Tensor, + target: Tensor, + voxelspacing: Union[List[float], float] = None, + ): + """ + add pred and target + :param pred: class- or onehot-coded tensor of the same shape as the target + :param target: class- or onehot-coded tensor of the same shape as the pred + : res: resolution for different dimension + :return: + """ + assert pred.shape == target.shape, ( + f"incompatible shape of `pred` and `target`, given " + f"{pred.shape} and {target.shape}." + ) + assert not pred.requires_grad and not target.requires_grad + + onehot_pred, onehot_target = self._convert2onehot(pred, target) + B, C, *hw = pred.shape + mhd = self._evalue(onehot_pred, onehot_target, voxelspacing) + assert mhd.shape == (B, len(self._report_axis)) + self._mhd.append(mhd) + self._n += 1 + + def value(self, **kwargs): + if self._n == 0: + return ([np.nan] * self._C, [np.nan] * self._C) + mhd = np.concatenate(self._mhd, axis=0) + return (mhd.mean(0), mhd.std(0)) + + def summary(self) -> dict: + means, stds = self.value() + return { + f"{self._abbr}{i}": to_float(means[num]) + for num, i in enumerate(self._report_axis) + } + + def detailed_summary(self) -> dict: + means, stds = self.value() + return { + **{ + f"{self._abbr}{i}": to_float(means[num]) + for num, i in enumerate(self._report_axis) + }, + **{ + f"{self._abbr}{i}": to_float(stds[num].item()) + for num, i in enumerate(self._report_axis) + }, + } + + def _evalue(self, pred: Tensor, target: Tensor, voxelspacing): + """ + return the B\times C list + :param pred: onehot pred + :param target: onehot target + :return: tensor of size B x C of type np.array + """ + assert pred.shape == target.shape + assert one_hot(pred, axis=1) and one_hot(target, axis=1) + B, C, *hw = pred.shape + result = np.zeros([B, len(self._report_axis)]) + for b, (one_batch_img, one_batch_gt) in enumerate(zip(pred, target)): + for c, (one_slice_img, one_slice_gt) in enumerate( + zip(one_batch_img[self._report_axis], one_batch_gt[self._report_axis]) + ): + mhd = self._surface_function( + one_slice_img, one_slice_gt, voxelspacing=voxelspacing + ) + result[b, c] = mhd + return result + + def _convert2onehot(self, pred: Tensor, target: Tensor): + # only two possibility: both onehot or both class-coded. + assert pred.shape == target.shape + # if they are onehot-coded: + if simplex(pred, 1) and one_hot(target): + return probs2one_hot(pred).long(), target.long() + # here the pred and target are labeled long + return ( + class2one_hot(pred, self._C).long(), + class2one_hot(target, self._C).long(), + ) + + def get_plot_names(self) -> List[str]: + return [f"{self._abbr}{i}" for num, i in enumerate(self._report_axis)] + + def __repr__(self): + string = f"C={self._C}, report_axis={self._report_axis}\n" + return ( + string + "\t" + "\t".join([f"{k}:{v}" for k, v in self.summary().items()]) + ) diff --git a/contrastyou/meters2/individual_meters/torchnet/__init__.py b/contrastyou/meters2/individual_meters/torchnet/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrastyou/meters2/individual_meters/torchnet/meter/__init__.py b/contrastyou/meters2/individual_meters/torchnet/meter/__init__.py new file mode 100644 index 00000000..ee0b9988 --- /dev/null +++ b/contrastyou/meters2/individual_meters/torchnet/meter/__init__.py @@ -0,0 +1,9 @@ +from .averagevaluemeter import AverageValueMeter +from .classerrormeter import ClassErrorMeter +from .confusionmeter import ConfusionMeter +from .timemeter import TimeMeter +from .msemeter import MSEMeter +from .movingaveragevaluemeter import MovingAverageValueMeter +from .aucmeter import AUCMeter +from .apmeter import APMeter +from .mapmeter import mAPMeter diff --git a/contrastyou/meters2/individual_meters/torchnet/meter/apmeter.py b/contrastyou/meters2/individual_meters/torchnet/meter/apmeter.py new file mode 100644 index 00000000..7b98a459 --- /dev/null +++ b/contrastyou/meters2/individual_meters/torchnet/meter/apmeter.py @@ -0,0 +1,145 @@ +import math +from . import meter +import torch + + +class APMeter(meter.Meter): + """ + The APMeter measures the average precision per class. + + The APMeter is designed to operate on `NxK` Tensors `output` and + `target`, and optionally a `Nx1` Tensor weight where (1) the `output` + contains model output scores for `N` examples and `K` classes that ought to + be higher when the model is more convinced that the example should be + positively labeled, and smaller when the model believes the example should + be negatively labeled (for instance, the output of a sigmoid function); (2) + the `target` contains only values 0 (for negative examples) and 1 + (for positive examples); and (3) the `weight` ( > 0) represents weight for + each sample. + + """ + + def __init__(self): + super(APMeter, self).__init__() + self.reset() + + def reset(self): + """Resets the meter with empty member variables""" + self.scores = torch.FloatTensor(torch.FloatStorage()) + self.targets = torch.LongTensor(torch.LongStorage()) + self.weights = torch.FloatTensor(torch.FloatStorage()) + + def add(self, output, target, weight=None): + """Add a new observation + + Args: + output (Tensor): NxK tensor that for each of the N examples + indicates the probability of the example belonging to each of + the K classes, according to the model. The probabilities should + sum to one over all classes + target (Tensor): binary NxK tensort that encodes which of the K + classes are associated with the N-th input + (eg: a row [0, 1, 0, 1] indicates that the example is + associated with classes 2 and 4) + weight (optional, Tensor): Nx1 tensor representing the weight for + each example (each weight > 0) + + """ + if not torch.is_tensor(output): + output = torch.from_numpy(output) + if not torch.is_tensor(target): + target = torch.from_numpy(target) + + if weight is not None: + if not torch.is_tensor(weight): + weight = torch.from_numpy(weight) + weight = weight.squeeze() + if output.dim() == 1: + output = output.view(-1, 1) + else: + assert ( + output.dim() == 2 + ), "wrong output size (should be 1D or 2D with one column \ + per class)" + if target.dim() == 1: + target = target.view(-1, 1) + else: + assert ( + target.dim() == 2 + ), "wrong target size (should be 1D or 2D with one column \ + per class)" + if weight is not None: + assert weight.dim() == 1, "Weight dimension should be 1" + assert weight.numel() == target.size( + 0 + ), "Weight dimension 1 should be the same as that of target" + assert torch.min(weight) >= 0, "Weight should be non-negative only" + assert torch.equal(target ** 2, target), "targets should be binary (0 or 1)" + if self.scores.numel() > 0: + assert target.size(1) == self.targets.size( + 1 + ), "dimensions for output should match previously added examples." + + # make sure storage is of sufficient size + if self.scores.storage().size() < self.scores.numel() + output.numel(): + new_size = math.ceil(self.scores.storage().size() * 1.5) + new_weight_size = math.ceil(self.weights.storage().size() * 1.5) + self.scores.storage().resize_(int(new_size + output.numel())) + self.targets.storage().resize_(int(new_size + output.numel())) + if weight is not None: + self.weights.storage().resize_(int(new_weight_size + output.size(0))) + + # store scores and targets + offset = self.scores.size(0) if self.scores.dim() > 0 else 0 + self.scores.resize_(offset + output.size(0), output.size(1)) + self.targets.resize_(offset + target.size(0), target.size(1)) + self.scores.narrow(0, offset, output.size(0)).copy_(output) + self.targets.narrow(0, offset, target.size(0)).copy_(target) + + if weight is not None: + self.weights.resize_(offset + weight.size(0)) + self.weights.narrow(0, offset, weight.size(0)).copy_(weight) + + def value(self): + """Returns the model's average precision for each class + + Return: + ap (FloatTensor): 1xK tensor, with avg precision for each class k + + """ + + if self.scores.numel() == 0: + return 0 + ap = torch.zeros(self.scores.size(1)) + if hasattr(torch, "arange"): + rg = torch.arange(1, self.scores.size(0) + 1).float() + else: + rg = torch.range(1, self.scores.size(0)).float() + if self.weights.numel() > 0: + weight = self.weights.new(self.weights.size()) + weighted_truth = self.weights.new(self.weights.size()) + + # compute average precision for each class + for k in range(self.scores.size(1)): + # sort scores + scores = self.scores[:, k] + targets = self.targets[:, k] + _, sortind = torch.sort(scores, 0, True) + truth = targets[sortind] + if self.weights.numel() > 0: + weight = self.weights[sortind] + weighted_truth = truth.float() * weight + rg = weight.cumsum(0) + + # compute true positive sums + if self.weights.numel() > 0: + tp = weighted_truth.cumsum(0) + else: + tp = truth.float().cumsum(0) + + # compute precision curve + precision = tp.div(rg) + + # compute average precision + ap[k] = precision[truth.bool()].sum() / max(float(truth.sum()), 1) + return ap diff --git a/contrastyou/meters2/individual_meters/torchnet/meter/aucmeter.py b/contrastyou/meters2/individual_meters/torchnet/meter/aucmeter.py new file mode 100644 index 00000000..5e0b0e24 --- /dev/null +++ b/contrastyou/meters2/individual_meters/torchnet/meter/aucmeter.py @@ -0,0 +1,85 @@ +import numbers +from . import meter +import numpy as np +import torch + + +class AUCMeter(meter.Meter): + """ + The AUCMeter measures the area under the receiver-operating characteristic + (ROC) curve for binary classification problems. The area under the curve (AUC) + can be interpreted as the probability that, given a randomly selected positive + example and a randomly selected negative example, the positive example is + assigned a higher score by the classification model than the negative example. + + The AUCMeter is designed to operate on one-dimensional Tensors `output` + and `target`, where (1) the `output` contains model output scores that ought to + be higher when the model is more convinced that the example should be positively + labeled, and smaller when the model believes the example should be negatively + labeled (for instance, the output of a signoid function); and (2) the `target` + contains only values 0 (for negative examples) and 1 (for positive examples). + """ + + def __init__(self): + super(AUCMeter, self).__init__() + self.reset() + + def reset(self): + self.scores = torch.DoubleTensor(torch.DoubleStorage()).numpy() + self.targets = torch.LongTensor(torch.LongStorage()).numpy() + + def add(self, output, target): + if torch.is_tensor(output): + output = output.cpu().squeeze().numpy() + if torch.is_tensor(target): + target = target.cpu().squeeze().numpy() + elif isinstance(target, numbers.Number): + target = np.asarray([target]) + assert np.ndim(output) == 1, "wrong output size (1D expected)" + assert np.ndim(target) == 1, "wrong target size (1D expected)" + assert ( + output.shape[0] == target.shape[0] + ), "number of outputs and targets does not match" + assert np.all( + np.add(np.equal(target, 1), np.equal(target, 0)) + ), "targets should be binary (0, 1)" + + self.scores = np.append(self.scores, output) + self.targets = np.append(self.targets, target) + + def value(self): + # case when number of elements added are 0 + if self.scores.shape[0] == 0: + return (0.5, 0.0, 0.0) + + # sorting the arrays + scores, sortind = torch.sort( + torch.from_numpy(self.scores), dim=0, descending=True + ) + scores = scores.numpy() + sortind = sortind.numpy() + + # creating the roc curve + tpr = np.zeros(shape=(scores.size + 1), dtype=np.float64) + fpr = np.zeros(shape=(scores.size + 1), dtype=np.float64) + + for i in range(1, scores.size + 1): + if self.targets[sortind[i - 1]] == 1: + tpr[i] = tpr[i - 1] + 1 + fpr[i] = fpr[i - 1] + else: + tpr[i] = tpr[i - 1] + fpr[i] = fpr[i - 1] + 1 + + tpr /= self.targets.sum() * 1.0 + fpr /= (self.targets - 1.0).sum() * -1.0 + + # calculating area under curve using trapezoidal rule + n = tpr.shape[0] + h = fpr[1:n] - fpr[0 : n - 1] + sum_h = np.zeros(fpr.shape) + sum_h[0 : n - 1] = h + sum_h[1:n] += h + area = (sum_h * tpr).sum() / 2.0 + + return (area, tpr, fpr) diff --git a/contrastyou/meters2/individual_meters/torchnet/meter/averagevaluemeter.py b/contrastyou/meters2/individual_meters/torchnet/meter/averagevaluemeter.py new file mode 100644 index 00000000..19f8c10e --- /dev/null +++ b/contrastyou/meters2/individual_meters/torchnet/meter/averagevaluemeter.py @@ -0,0 +1,41 @@ +from . import meter +import numpy as np + + +class AverageValueMeter(meter.Meter): + def __init__(self): + super(AverageValueMeter, self).__init__() + self.reset() + self.val = 0 + + def add(self, value, n=1): + self.val = value + self.sum += value * n + if n <= 0: + raise ValueError("Cannot use a non-positive weight for the running stat.") + elif self.n == 0: + self.mean = 0.0 + value # This is to force a copy in torch/numpy + self.std = np.inf + self.mean_old = self.mean + self.m_s = 0.0 + else: + self.mean = self.mean_old + n * (value - self.mean_old) / float(self.n + n) + self.m_s += n * (value - self.mean_old) * (value - self.mean) + self.mean_old = self.mean + self.std = np.sqrt(self.m_s / (self.n + n - 1.0)) + self.var = self.std ** 2 + + self.n += n + + def value(self): + return self.mean, self.std + + def reset(self): + self.n = 0 + self.sum = 0.0 + self.var = 0.0 + self.val = 0.0 + self.mean = np.nan + self.mean_old = 0.0 + self.m_s = 0.0 + self.std = np.nan diff --git a/contrastyou/meters2/individual_meters/torchnet/meter/classerrormeter.py b/contrastyou/meters2/individual_meters/torchnet/meter/classerrormeter.py new file mode 100644 index 00000000..ba05231e --- /dev/null +++ b/contrastyou/meters2/individual_meters/torchnet/meter/classerrormeter.py @@ -0,0 +1,52 @@ +import numpy as np +import torch +import numbers +from . import meter + + +class ClassErrorMeter(meter.Meter): + def __init__(self, topk=[1], accuracy=False): + super(ClassErrorMeter, self).__init__() + self.topk = np.sort(topk) + self.accuracy = accuracy + self.reset() + + def reset(self): + self.sum = {v: 0 for v in self.topk} + self.n = 0 + + def add(self, output, target): + if torch.is_tensor(output): + output = output.cpu().squeeze().numpy() + if torch.is_tensor(target): + target = np.atleast_1d(target.cpu().squeeze().numpy()) + elif isinstance(target, numbers.Number): + target = np.asarray([target]) + if np.ndim(output) == 1: + output = output[np.newaxis] + else: + assert np.ndim(output) == 2, "wrong output size (1D or 2D expected)" + assert np.ndim(target) == 1, "target and output do not match" + assert target.shape[0] == output.shape[0], "target and output do not match" + topk = self.topk + maxk = int(topk[-1]) # seems like Python3 wants int and not np.int64 + no = output.shape[0] + + pred = torch.from_numpy(output).topk(maxk, 1, True, True)[1].numpy() + correct = pred == target[:, np.newaxis].repeat(pred.shape[1], 1) + + for k in topk: + self.sum[k] += no - correct[:, 0:k].sum() + self.n += no + + def value(self, k=-1): + if k != -1: + assert ( + k in self.sum.keys() + ), "invalid k (this k was not provided at construction time)" + if self.accuracy: + return (1.0 - float(self.sum[k]) / self.n) * 100.0 + else: + return float(self.sum[k]) / self.n * 100.0 + else: + return [self.value(k_) for k_ in self.topk] diff --git a/contrastyou/meters2/individual_meters/torchnet/meter/confusionmeter.py b/contrastyou/meters2/individual_meters/torchnet/meter/confusionmeter.py new file mode 100644 index 00000000..51fbb4e4 --- /dev/null +++ b/contrastyou/meters2/individual_meters/torchnet/meter/confusionmeter.py @@ -0,0 +1,92 @@ +from . import meter +import numpy as np + + +class ConfusionMeter(meter.Meter): + """Maintains a confusion matrix for a given calssification problem. + + The ConfusionMeter constructs a confusion matrix for a multi-class + classification problems. It does not support multi-label, multi-class problems: + for such problems, please use MultiLabelConfusionMeter. + + Args: + k (int): number of classes in the classification problem + normalized (boolean): Determines whether or not the confusion matrix + is normalized or not + + """ + + def __init__(self, k, normalized=False): + super(ConfusionMeter, self).__init__() + self.conf = np.ndarray((k, k), dtype=np.int32) + self.normalized = normalized + self.k = k + self.reset() + + def reset(self): + self.conf.fill(0) + + def add(self, predicted, target): + """Computes the confusion matrix of K x K size where K is no of classes + + Args: + predicted (tensor): Can be an N x K tensor of predicted scores obtained from + the model for N examples and K classes or an N-tensor of + integer values between 0 and K-1. + target (tensor): Can be a N-tensor of integer values assumed to be integer + values between 0 and K-1 or N x K tensor, where targets are + assumed to be provided as one-hot vectors + + """ + predicted = predicted.cpu().numpy() + target = target.cpu().numpy() + + assert ( + predicted.shape[0] == target.shape[0] + ), "number of targets and predicted outputs do not match" + + if np.ndim(predicted) != 1: + assert ( + predicted.shape[1] == self.k + ), "number of predictions does not match size of confusion matrix" + predicted = np.argmax(predicted, 1) + else: + assert (predicted.max() < self.k) and ( + predicted.min() >= 0 + ), "predicted values are not between 1 and k" + + onehot_target = np.ndim(target) != 1 + if onehot_target: + assert ( + target.shape[1] == self.k + ), "Onehot target does not match size of confusion matrix" + assert (target >= 0).all() and ( + target <= 1 + ).all(), "in one-hot encoding, target values should be 0 or 1" + assert (target.sum(1) == 1).all(), "multi-label setting is not supported" + target = np.argmax(target, 1) + else: + assert (predicted.max() < self.k) and ( + predicted.min() >= 0 + ), "predicted values are not between 0 and k-1" + + # hack for bincounting 2 arrays together + x = predicted + self.k * target + bincount_2d = np.bincount(x.astype(np.int32), minlength=self.k ** 2) + assert bincount_2d.size == self.k ** 2 + conf = bincount_2d.reshape((self.k, self.k)) + + self.conf += conf + + def value(self): + """ + Returns: + Confustion matrix of K rows and K columns, where rows corresponds + to ground-truth targets and columns corresponds to predicted + targets. + """ + if self.normalized: + conf = self.conf.astype(np.float32) + return conf / conf.sum(1).clip(min=1e-12)[:, None] + else: + return self.conf diff --git a/contrastyou/meters2/individual_meters/torchnet/meter/mapmeter.py b/contrastyou/meters2/individual_meters/torchnet/meter/mapmeter.py new file mode 100644 index 00000000..45217384 --- /dev/null +++ b/contrastyou/meters2/individual_meters/torchnet/meter/mapmeter.py @@ -0,0 +1,30 @@ +from . import meter, APMeter + + +class mAPMeter(meter.Meter): + """ + The mAPMeter measures the mean average precision over all classes. + + The mAPMeter is designed to operate on `NxK` Tensors `output` and + `target`, and optionally a `Nx1` Tensor weight where (1) the `output` + contains model output scores for `N` examples and `K` classes that ought to + be higher when the model is more convinced that the example should be + positively labeled, and smaller when the model believes the example should + be negatively labeled (for instance, the output of a sigmoid function); (2) + the `target` contains only values 0 (for negative examples) and 1 + (for positive examples); and (3) the `weight` ( > 0) represents weight for + each sample. + """ + + def __init__(self): + super(mAPMeter, self).__init__() + self.apmeter = APMeter() + + def reset(self): + self.apmeter.reset() + + def add(self, output, target, weight=None): + self.apmeter.add(output, target, weight) + + def value(self): + return self.apmeter.value().mean() diff --git a/contrastyou/meters2/individual_meters/torchnet/meter/meter.py b/contrastyou/meters2/individual_meters/torchnet/meter/meter.py new file mode 100644 index 00000000..e0f5eef4 --- /dev/null +++ b/contrastyou/meters2/individual_meters/torchnet/meter/meter.py @@ -0,0 +1,23 @@ +class Meter(object): + """Meters provide a way to keep track of important statistics in an online manner. + + This class is abstract, but provides a standard interface for all meters to follow. + + """ + + def reset(self): + """Resets the meter to default settings.""" + pass + + def add(self, value): + """Log a new value to the meter + + Args: + value: Next restult to include. + + """ + pass + + def value(self): + """Get the value of the meter in the current state.""" + pass diff --git a/contrastyou/meters2/individual_meters/torchnet/meter/movingaveragevaluemeter.py b/contrastyou/meters2/individual_meters/torchnet/meter/movingaveragevaluemeter.py new file mode 100644 index 00000000..c927c13b --- /dev/null +++ b/contrastyou/meters2/individual_meters/torchnet/meter/movingaveragevaluemeter.py @@ -0,0 +1,31 @@ +import math +from . import meter +import torch + + +class MovingAverageValueMeter(meter.Meter): + def __init__(self, windowsize): + super(MovingAverageValueMeter, self).__init__() + self.windowsize = windowsize + self.valuequeue = torch.Tensor(windowsize) + self.reset() + + def reset(self): + self.sum = 0.0 + self.n = 0 + self.var = 0.0 + self.valuequeue.fill_(0) + + def add(self, value): + queueid = self.n % self.windowsize + oldvalue = self.valuequeue[queueid] + self.sum += value - oldvalue + self.var += value * value - oldvalue * oldvalue + self.valuequeue[queueid] = value + self.n += 1 + + def value(self): + n = min(self.n, self.windowsize) + mean = self.sum / max(1, n) + std = math.sqrt(max((self.var - n * mean * mean) / max(1, n - 1), 0)) + return mean, std diff --git a/contrastyou/meters2/individual_meters/torchnet/meter/msemeter.py b/contrastyou/meters2/individual_meters/torchnet/meter/msemeter.py new file mode 100644 index 00000000..ed5fbd4f --- /dev/null +++ b/contrastyou/meters2/individual_meters/torchnet/meter/msemeter.py @@ -0,0 +1,25 @@ +import math +from . import meter +import torch + + +class MSEMeter(meter.Meter): + def __init__(self, root=False): + super(MSEMeter, self).__init__() + self.reset() + self.root = root + + def reset(self): + self.n = 0 + self.sesum = 0.0 + + def add(self, output, target): + if not torch.is_tensor(output) and not torch.is_tensor(target): + output = torch.from_numpy(output) + target = torch.from_numpy(target) + self.n += output.numel() + self.sesum += torch.sum((output - target) ** 2) + + def value(self): + mse = self.sesum / max(1, self.n) + return math.sqrt(mse) if self.root else mse diff --git a/contrastyou/meters2/individual_meters/torchnet/meter/timemeter.py b/contrastyou/meters2/individual_meters/torchnet/meter/timemeter.py new file mode 100644 index 00000000..ef42ae66 --- /dev/null +++ b/contrastyou/meters2/individual_meters/torchnet/meter/timemeter.py @@ -0,0 +1,39 @@ +import time +from . import meter + + +class TimeMeter(meter.Meter): + """ + + #### tnt.TimeMeter(@ARGP) + @ARGT + + The `tnt.TimeMeter` is designed to measure the time between events and can be + used to measure, for instance, the average processing time per batch of data. + It is different from most other meters in terms of the methods it provides: + + The `tnt.TimeMeter` provides the following methods: + + * `reset()` resets the timer, setting the timer and unit counter to zero. + * `value()` returns the time passed since the last `reset()`; divided by the counter value when `unit=true`. + """ + + def __init__(self, unit): + super(TimeMeter, self).__init__() + self.unit = unit + self.reset() + + def add(self, n=1): + self.n += n + + def reset(self): + self.n = 0 + self.time = time.time() + + def value(self): + if self.unit and self.n == 0: + raise ValueError("Trying to divide by zero in TimeMeter") + elif self.unit: + return (time.time() - self.time) / self.n + else: + return time.time() - self.time diff --git a/contrastyou/meters2/meter_interface.py b/contrastyou/meters2/meter_interface.py new file mode 100644 index 00000000..ef793263 --- /dev/null +++ b/contrastyou/meters2/meter_interface.py @@ -0,0 +1,113 @@ +from collections import OrderedDict +from typing import Dict, List, Optional + +from .individual_meters._metric import _Metric + +_Record_Type = Dict[str, float] + + +class MeterInteractMixin: + individual_meters: Dict[str, _Metric] + _ind_meter_dicts: Dict[str, _Metric] + _group_dicts: Dict[str, List[str]] + group: List[str] + meter_names: List[str] + + def tracking_status( + self, group_name=None, detailed_summary=False + ) -> Dict[str, _Record_Type]: + """ + return current training status from "ind_meters" + :param group_name: + :return: + """ + if group_name: + assert group_name in self.group + return { + k: v.detailed_summary() if detailed_summary else v.summary() + for k, v in self.individual_meters.items() + if k in self._group_dicts[group_name] + } + return { + k: v.detailed_summary() if detailed_summary else v.summary() + for k, v in self.individual_meters.items() + } + + def add(self, meter_name, *args, **kwargs): + assert meter_name in self.meter_names + self._ind_meter_dicts[meter_name].add(*args, **kwargs) + + def reset(self) -> None: + """ + reset individual meters + :return: None + """ + for v in self._ind_meter_dicts.values(): + v.reset() + + +class MeterInterface(MeterInteractMixin): + """ + meter interface only concerns about the situation in one epoch, + without considering historical record and save/load state_dict function. + """ + + def __init__(self) -> None: + """ + :param meter_config: a dict of individual meter configurations + """ + self._ind_meter_dicts: Dict[str, _Metric] = OrderedDict() + self._group_dicts: Dict[str, List[str]] = OrderedDict() + + def __getitem__(self, meter_name: str) -> _Metric: + try: + return self._ind_meter_dicts[meter_name] + except KeyError as e: + print(f"meter_interface.meter_names:{self.meter_names}") + raise e + + def register_meter(self, name: str, meter: _Metric, group_name=None) -> None: + assert isinstance(name, str), name + assert isinstance( + meter, _Metric + ), f"{meter.__class__.__name__} should be a subclass of {_Metric.__class__.__name__}, given {meter}." + # add meters + self._ind_meter_dicts[name] = meter + if group_name is not None: + if group_name not in self._group_dicts: + self._group_dicts[group_name] = [] + self._group_dicts[group_name].append(name) + + def delete_meter(self, name: str) -> None: + assert ( + name in self.meter_names + ), f"{name} should be in `meter_names`: {self.meter_names}, given {name}." + del self._ind_meter_dicts[name] + for group, meter_namelist in self._group_dicts.items(): + if name in meter_namelist: + meter_namelist.remove(name) + + def delete_meters(self, name_list: List[str]): + assert isinstance( + name_list, list + ), f" name_list must be a list of str, given {name_list}." + for name in name_list: + self.delete_meter(name) + + @property + def meter_names(self) -> List[str]: + if hasattr(self, "_ind_meter_dicts"): + return list(self._ind_meter_dicts.keys()) + + @property + def meters(self) -> Optional[Dict[str, _Metric]]: + if hasattr(self, "_ind_meter_dicts"): + return self._ind_meter_dicts + + @property + def group(self) -> List[str]: + return sorted(self._group_dicts.keys()) + + @property + def individual_meters(self): + return self._ind_meter_dicts diff --git a/contrastyou/modules/__init__.py b/contrastyou/modules/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrastyou/modules/model.py b/contrastyou/modules/model.py new file mode 100644 index 00000000..f95b33a3 --- /dev/null +++ b/contrastyou/modules/model.py @@ -0,0 +1,295 @@ +import warnings +from abc import ABCMeta +from copy import deepcopy +from typing import * + +import torch +from deepclustering.arch import get_arch +from deepclustering.utils import simplex +from torch import Tensor, optim +from torch import nn +from torch.nn import functional as F +from torch.optim import lr_scheduler + +from contrastyou import ModelState + +__all__ = ["Model"] + +CType = Dict[str, Union[float, int, str, Dict[str, Any]]] # typing for config +NetType = nn.Module +OptimType = optim.Optimizer +ScheType = optim.lr_scheduler._LRScheduler + + +class NormalGradientBackwardStep(object): + """effectuate the + model.zero() at the initialization + and model.step at the exit + """ + + def __init__(self, loss: Tensor, model): + self.model = model + self.loss = loss + self.model.zero_grad() + + def __enter__(self): + return self.loss + + def __exit__(self, exc_type, exc_val, exc_tb): + self.model.step() + + +class Model(metaclass=ABCMeta): + """ + This is the new class for model interface + """ + + def __init__( + self, + arch: Union[NetType, CType], + optimizer: Union[OptimType, CType] = None, + scheduler: Union[ScheType, CType] = None, + ): + """ + create network from either configuration or module directly. + :param arch: network configuration or network module + :param optimizer: + :param scheduler: + :return: + """ + self._set_arch(arch) + self._set_optimizer(optimizer) + self._set_scheduler(scheduler) + + def _set_arch(self, arch: Union[NetType, CType]) -> None: + self._torchnet: nn.Module + self._arch_dict: Optional[CType] + if isinstance(arch, dict): + self._arch_dict = arch + arch_dict = deepcopy(arch) + arch_name: str = arch_dict.pop("name") # type:ignore + self._torchnet = get_arch(arch_name, arch_dict) + else: + self._arch_dict = None + self._torchnet = arch + assert issubclass(type(self._torchnet), nn.Module), type(self._torchnet) + + def _set_optimizer(self, optimizer: Union[OptimType, CType] = None) -> None: + self._optimizer: Optional[OptimType] + self._optim_dict: Optional[CType] + if optimizer is None: + self._optim_dict = None + self._optimizer = None + elif isinstance(optimizer, dict): + self._optim_dict = optimizer + optim_dict = deepcopy(optimizer) + optim_name: str = optim_dict.pop("name") # type:ignore + self._optimizer = getattr(optim, optim_name)( + self.parameters(), **optim_dict + ) + else: + self._optim_dict = None + self._optimizer = optimizer + if optimizer is not None: + assert issubclass(type(self._optimizer), optim.Optimizer) + + def _set_scheduler(self, scheduler: Union[ScheType, CType] = None) -> None: + self._scheduler: Optional[ScheType] + self._scheduler_dict: Optional[CType] + if scheduler is None: + self._scheduler = None + self._scheduler_dict = None + elif isinstance(scheduler, dict): + self._scheduler_dict = scheduler + scheduler_dict = deepcopy(scheduler) + scheduler_name: str = scheduler_dict.pop("name") # type:ignore + self._scheduler = getattr(lr_scheduler, scheduler_name)( + self._optimizer, + **{k: v for k, v in scheduler_dict.items() if k != "warmup"}, + ) + if "warmup" in scheduler_dict: + # encode warmup scheduler + from deepclustering.schedulers import GradualWarmupScheduler + self._scheduler = GradualWarmupScheduler( # type: ignore + optimizer=self._optimizer, + after_scheduler=self._scheduler, + **scheduler_dict["warmup"], + ) + else: + self._scheduler_dict = None + self._scheduler = scheduler + if scheduler is not None: + assert issubclass(type(self._scheduler), ScheType) + + def parameters(self): + return self._torchnet.parameters() + + def __call__(self, *args, **kwargs): + force_simplex = kwargs.pop("force_simplex", False) + assert isinstance(force_simplex, bool), force_simplex + torch_logits = self._torchnet(*args, **kwargs) + if force_simplex: + if not simplex(torch_logits, 1): + return F.softmax(torch_logits, 1) + return torch_logits + + @property + def training(self): + return self._torchnet.training + + def step(self): + if self._optimizer is not None and hasattr(self._optimizer, "step"): + self._optimizer.step() + + def zero_grad(self) -> None: + if self._optimizer is not None and hasattr(self._optimizer, "zero_grad"): + self._optimizer.zero_grad() + + def schedulerStep(self, *args, **kwargs): + if hasattr(self._scheduler, "step"): + self._scheduler.step(*args, **kwargs) + + def set_mode(self, mode): + assert mode in (ModelState.TRAIN, ModelState.EVAL) or mode in ("train", "eval") + if mode in (ModelState.TRAIN, "train"): + self.train() + elif mode in (ModelState.EVAL, "eval"): + self.eval() + + def train(self): + self._torchnet.train() + + def eval(self): + self._torchnet.eval() + + def to(self, device: torch.device): + self._torchnet.to(device) + if self._optimizer is not None: + for state in self._optimizer.state.values(): # type: ignore + for k, v in state.items(): + if isinstance(v, torch.Tensor): + state[k] = v.to(device) + + def apply(self, *args, **kwargs) -> None: + self._torchnet.apply(*args, **kwargs) + + def get_lr(self): + if self._scheduler is not None: + return self._scheduler.get_lr() + return None + + @property + def optimizer(self): + return self._optimizer + + @property + def scheduler(self): + return self._scheduler + + @property + def optimizer_params(self): + return self._optim_dict + + @property + def scheduler_params(self): + return self._scheduler_dict + + def __repr__(self): + model_descript = ( + f"================== Model =================\n" + f"{self._torchnet.__repr__()}\n" + ) + optimizer_descript = ( + f"================== Optimizer =============\n" + f"{self._optimizer.__repr__()}\n" + if self._optimizer is not None + else "" + ) + scheduler_descript = ( + f"================== Scheduler =============\n" + f"{self._scheduler.__repr__()}\n" + if self._scheduler is not None + else "" + ) + + return model_descript + optimizer_descript + scheduler_descript + + def state_dict(self): + return { + "arch_dict": self._arch_dict, + "optim_dict": self._optim_dict, + "scheduler_dict": self._scheduler_dict, + "net_state_dict": self._torchnet.state_dict(), + "optim_state_dict": self._optimizer.state_dict() if self._optimizer is not None else None, + "scheduler_state_dict": self._scheduler.state_dict() if self._scheduler is not None else None, + } + + def load_state_dict(self, state_dict: dict): + self._arch_dict = state_dict["arch_dict"] + self._optim_dict = state_dict["optim_dict"] + self._scheduler_dict = state_dict["scheduler_dict"] + self._torchnet.load_state_dict(state_dict["net_state_dict"]) + if hasattr(self._optimizer, "load_state_dict") and self._optimizer is not None: + self._optimizer.load_state_dict(state_dict["optim_state_dict"]) + if hasattr(self._scheduler, "load_state_dict") and self._scheduler is not None: + self._scheduler.load_state_dict(state_dict["scheduler_state_dict"]) + + @classmethod + def initialize_from_state_dict(cls, state_dict: Dict[str, dict]): + """ + Initialize an instance based on `state_dict` + :param state_dict: + :return: instance model on cpu. + """ + arch_dict = state_dict["arch_dict"] + assert (arch_dict is not None), "arch is only supported when it is initialized with config." + optim_dict = state_dict["optim_dict"] + if optim_dict is None: + warnings.warn( + f"optim is ignored as it is not initialized with config, use `load_state_dict` instead.", + RuntimeWarning, + ) + scheduler_dict = state_dict["scheduler_dict"] + if scheduler_dict is None: + warnings.warn( + f"scheduler is ignored as it is not initialized with config, use `load_state_dict` instead.", + RuntimeWarning, + ) + model = cls(arch=arch_dict, optimizer=optim_dict, scheduler=scheduler_dict) + model.load_state_dict(state_dict=state_dict) + model.to(torch.device("cpu")) + return model + + +class DPModule(Model): + def __init__(self, arch: Union[NetType, CType], optimizer: Union[OptimType, CType] = None, + scheduler: Union[ScheType, CType] = None): + self._USEDP = False + super().__init__(arch, optimizer, scheduler) + if torch.cuda.is_available(): + if torch.cuda.device_count() > 1: + self._torchnet = torch.nn.DataParallel(self._torchnet) + self._USEDP = True + + def state_dict(self): + return { + "arch_dict": self._arch_dict, + "optim_dict": self._optim_dict, + "scheduler_dict": self._scheduler_dict, + "net_state_dict": self._torchnet.module.state_dict() if self._USEDP else self._torchnet.state_dict(), + "optim_state_dict": self._optimizer.state_dict() if self._optimizer is not None else None, + "scheduler_state_dict": self._scheduler.state_dict() if self._scheduler is not None else None, + } + + def load_state_dict(self, state_dict: dict): + self._arch_dict = state_dict["arch_dict"] + self._optim_dict = state_dict["optim_dict"] + self._scheduler_dict = state_dict["scheduler_dict"] + if self._USEDP: + self._torchnet.module.load_state_dict(state_dict["net_state_dict"]) + else: + self._torchnet.load_state_dict(state_dict["net_state_dict"]) + if hasattr(self._optimizer, "load_state_dict") and self._optimizer is not None: + self._optimizer.load_state_dict(state_dict["optim_state_dict"]) + if hasattr(self._scheduler, "load_state_dict") and self._scheduler is not None: + self._scheduler.load_state_dict(state_dict["scheduler_state_dict"]) diff --git a/contrastyou/storage/__init__.py b/contrastyou/storage/__init__.py new file mode 100644 index 00000000..2f838374 --- /dev/null +++ b/contrastyou/storage/__init__.py @@ -0,0 +1 @@ +from .storage import Storage diff --git a/contrastyou/storage/_historical_container.py b/contrastyou/storage/_historical_container.py new file mode 100644 index 00000000..3e2fef67 --- /dev/null +++ b/contrastyou/storage/_historical_container.py @@ -0,0 +1,79 @@ +import numbers +from abc import ABCMeta +from collections import OrderedDict +from typing import Dict, OrderedDict as OrderedDict_Type, Any + +import pandas as pd + +_Record_Type = Dict[str, float] +_Save_Type = OrderedDict_Type[int, _Record_Type] + +__all__ = ["HistoricalContainer"] + + +class HistoricalContainer(metaclass=ABCMeta): + """ + Aggregate historical information in a ordered dict. + >>> contrainer = HistoricalContainer() + >>> contrainer.add({"loss":1}, epoch=1) + # the dictionray shold not include other dictionary + >>> contrainer.add({"dice1":0.5, "dice2":0.6}, epoch=1) + """ + + def __init__(self) -> None: + self._record_dict: _Save_Type = OrderedDict() + self._current_epoch: int = 0 + + def add(self, input_dict: _Record_Type, epoch=None) -> None: + # only str-num dict can be added. + for v in input_dict.values(): + assert isinstance(v, numbers.Number), v + if epoch: + self._current_epoch = epoch + self._record_dict[self._current_epoch] = input_dict + self._current_epoch += 1 + + def reset(self) -> None: + self._record_dict: _Save_Type = OrderedDict() + self._current_epoch = 0 + + @property + def record_dict(self) -> _Save_Type: + return self._record_dict + + @property + def current_epoch(self) -> int: + """ return current epoch + """ + return self._current_epoch + + def summary(self) -> pd.DataFrame: + # todo: deal with the case where you have absent epoch + try: + validated_table = pd.DataFrame(self.record_dict).T + except ValueError: + validated_table = pd.DataFrame(self.record_dict, index=[""]).T + # check if having missing values + if len(self.record_dict) < self.current_epoch: + missing_table = pd.DataFrame( + index=set(range(self.current_epoch)) - set(self.record_dict.keys()) + ) + validated_table = validated_table.append(missing_table).sort_index() + return validated_table + + def state_dict(self) -> Dict[str, Any]: + """Returns the state of the class. + """ + return self.__dict__ + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + """Loads the schedulers state. + + Arguments: + state_dict (dict): weight_scheduler state. Should be an object returned + from a call to :math:`state_dict`. + """ + self.__dict__.update(state_dict) + + def __repr__(self): + return str(pd.DataFrame(self.record_dict).T) diff --git a/contrastyou/storage/_utils.py b/contrastyou/storage/_utils.py new file mode 100644 index 00000000..43e23aa7 --- /dev/null +++ b/contrastyou/storage/_utils.py @@ -0,0 +1,6 @@ +import pandas as pd + + +def rename_df_columns(dataframe: pd.Series, name: str): + dataframe.columns = list(map(lambda x: name + "_" + x, dataframe.columns)) + return dataframe diff --git a/contrastyou/storage/storage.py b/contrastyou/storage/storage.py new file mode 100644 index 00000000..eeb37515 --- /dev/null +++ b/contrastyou/storage/storage.py @@ -0,0 +1,83 @@ +import functools +from abc import ABCMeta +from collections import defaultdict +from typing import DefaultDict, Callable, List, Dict + +import pandas as pd + +from deepclustering2.utils import path2Path +from ._historical_container import HistoricalContainer +from ._utils import rename_df_columns + +__all__ = ["Storage"] + + +class _IOMixin: + _storage: DefaultDict[str, HistoricalContainer] + summary: Callable[[], pd.DataFrame] + + def state_dict(self): + return self._storage + + def load_state_dict(self, state_dict): + self._storage = state_dict + + def to_csv(self, path, name="storage.csv"): + path = path2Path(path) + assert path.is_dir(), path + path.mkdir(exist_ok=True, parents=True) + self.summary().to_csv(path / name) + + +class Storage(_IOMixin, metaclass=ABCMeta): + + def __init__(self) -> None: + super().__init__() + self._storage = defaultdict(HistoricalContainer) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + + def __getitem__(self, item): + if item not in self._storage: + raise KeyError(f"{item} not found in {__class__.__name__}") + return self._storage[item] + + def put(self, name: str, value: Dict[str, float], epoch=None, prefix="", postfix=""): + self._storage[prefix + name + postfix].add(value, epoch) + + def put_all(self, report_dict, epoch=None): + for k, v in report_dict.items(): + self.put(k, v, epoch) + + def get(self, name, epoch=None): + assert name in self._storage, name + if epoch is None: + return self._storage[name] + return self._storage[name][epoch] + + def summary(self) -> pd.DataFrame: + """ + summary on the list of sub summarys, merging them together. + :return: + """ + result_dict = {} + for k, v in self._storage.items(): + result_dict[k]=v.record_dict + # flatten the dict + from deepclustering.utils import flatten_dict + flatten_result = flatten_dict(result_dict) + return pd.DataFrame(flatten_result) + + @property + def meter_names(self, sorted=False) -> List[str]: + if sorted: + return sorted(self._storage.keys()) + return list(self._storage.keys()) + + @property + def storage(self): + return self._storage diff --git a/contrastyou/trainer/_buffer.py b/contrastyou/trainer/_buffer.py new file mode 100644 index 00000000..21ee759c --- /dev/null +++ b/contrastyou/trainer/_buffer.py @@ -0,0 +1,113 @@ +from collections import OrderedDict +from copy import deepcopy +from typing import Union, TypeVar + +import numpy as np +import torch +from torch import Tensor + +N = TypeVar('N', int, float, Tensor, np.ndarray) + + +class _BufferMixin: + """ + The buffer in Trainer is for automatic loading and saving. + """ + + def __init__(self) -> None: + self._buffers = OrderedDict() + + def register_buffer(self, name: str, value: Union[str, N]): + r"""Adds a persistent buffer to the module. + """ + if '_buffers' not in self.__dict__: + raise AttributeError( + "cannot assign buffer before Module.__init__() call") + elif not isinstance(name, str): + raise TypeError("buffer name should be a string. " + "Got {}".format(torch.typename(name))) + elif '.' in name: + raise KeyError("buffer name can't contain \".\"") + elif name == '': + raise KeyError("buffer name can't be empty string \"\"") + elif hasattr(self, name) and name not in self._buffers: + raise KeyError("attribute '{}' already exists".format(name)) + else: + self._buffers[name] = value + + def __getattr__(self, name): + if '_buffers' in self.__dict__: + _buffers = self.__dict__['_buffers'] + if name in _buffers: + return _buffers[name] + raise AttributeError("'{}' object has no attribute '{}'".format( + type(self).__name__, name)) + + def __setattr__(self, name, value): + buffers = self.__dict__.get('_buffers') + if buffers is not None and name in buffers: + buffers[name] = value + else: + object.__setattr__(self, name, value) + + def __delattr__(self, name): + if name in self._buffers: + del self._buffers[name] + else: + object.__delattr__(self, name) + + def buffer_state_dict(self): + destination = OrderedDict() + for name, buf in self._buffers.items(): + value = buf + if isinstance(buf, Tensor): + value = buf.detach() + if isinstance(buf, np.ndarray): + value = deepcopy(buf) + destination[name] = value + return destination + + def _load_buffer_from_state_dict(self, state_dict, prefix, strict, + missing_keys, unexpected_keys, error_msgs): + + local_name_params = self._buffers.items() + local_state = {k: v for k, v in local_name_params if v is not None} + + for name, param in local_state.items(): + key = prefix + name + if key in state_dict: + input_param = state_dict[key] + # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+ + with torch.no_grad(): + try: + if isinstance(input_param, Tensor): + param.copy_(input_param) + else: + self._buffers[name] = input_param + except Exception as ex: + error_msgs.append('While copying the parameter named "{}", ' + 'an exception occured : {}.' + .format(key, ex.args)) + elif strict: + missing_keys.append(key) + + def load_buffer_state_dict(self, state_dict): + r""" + """ + missing_keys = [] + unexpected_keys = [] + error_msgs = [] + + # copy state_dict so _load_from_state_dict can modify it + state_dict = state_dict.copy() + + def load(module, prefix=''): + module._load_buffer_from_state_dict( + state_dict, prefix, True, missing_keys, unexpected_keys, error_msgs) + + load(self) + + if len(error_msgs) > 0: + raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( + self.__class__.__name__, "\n\t".join(error_msgs))) + return missing_keys, unexpected_keys, error_msgs diff --git a/contrastyou/trainer/_epoch.py b/contrastyou/trainer/_epoch.py new file mode 100644 index 00000000..d0e197f1 --- /dev/null +++ b/contrastyou/trainer/_epoch.py @@ -0,0 +1,54 @@ +from abc import abstractmethod +from typing import Union, Dict + +from torch.utils.data import DataLoader +from torch.utils.data.dataloader import _BaseDataLoaderIter +from contrastyou.modules.model import Model +from contrastyou.meters2 import MeterInterface +from contrastyou import ModelState +from contextlib import contextmanager +class _Epoch: + + @abstractmethod + def register_meters(self): + self._meters: MeterInterface = MeterInterface() + + def write2tensorboard(self): + pass + + def _data_preprocessing(self): + pass + + def run_epoch(self): + pass + + +class TrainEpoch(_Epoch): + + def __init__(self, model:Model, train_loader: Union[DataLoader, _BaseDataLoaderIter], num_batches: int = 512) -> None: + super().__init__() + self._model = model + self._loader = train_loader + self._num_batches = num_batches + self._indicator = range(self._num_batches) + + @contextmanager + def register_meters(self): + super(TrainEpoch, self).register_meters() + self._meters.register_meter() + yield self._meters + self._meters.reset() + + + def run_epoch(self, mode=ModelState.TRAIN ) -> Dict[str, float]: + self._model.set_mode(mode) + + with self.register_meters() as meters: + pass + + + + + +class ValEpoch(_Epoch): + pass diff --git a/contrastyou/trainer/_trainer.py b/contrastyou/trainer/_trainer.py new file mode 100644 index 00000000..e6eb62f4 --- /dev/null +++ b/contrastyou/trainer/_trainer.py @@ -0,0 +1,230 @@ +from abc import abstractmethod +from copy import deepcopy +from pathlib import Path +from typing import Union, Dict, Any, TypeVar + +import numpy as np +import torch +from deepclustering.utils import path2Path +from deepclustering.writer import SummaryWriter +from torch import Tensor +from torch.utils.data import DataLoader +from torch.utils.data.dataloader import _BaseDataLoaderIter + +from ._buffer import _BufferMixin +from .. import PROJECT_PATH +# from ..meters import MeterInterface, AverageValueMeter +from ..modules.model import Model +from ..storage import Storage + +N = TypeVar('N', int, float, Tensor, np.ndarray) + + +class _Trainer(_BufferMixin): + """ + Abstract class for a general trainer, which has _train_loop, _eval_loop,load_state, state_dict, and save_checkpoint + functions. All other trainers are the subclasses of this class. + """ + RUN_PATH = str(Path(PROJECT_PATH) / "runs") + ARCHIVE_PATH = str(Path(PROJECT_PATH) / "archives") + checkpoint_identifier = "last.pth" + + def __init__( + self, + model: Model, + train_loader: Union[DataLoader, _BaseDataLoaderIter], + val_loader: DataLoader, + max_epoch: int = 100, + save_dir: str = "base", + checkpoint: str = None, + device="cpu", + config: dict = None, + ) -> None: + super(_Trainer, self).__init__() + self._model = model + self._train_loader = train_loader + self._val_loader = val_loader + + self.register_buffer("_max_epoch", int(max_epoch)) + self.register_buffer("_best_score", -1.0) + self.register_buffer("_start_epoch", 0) # whether 0 or loaded from the checkpoint. + self.register_buffer("_epoch", None) + + self._save_dir: Path = Path(self.RUN_PATH) / str(save_dir) + self._save_dir.mkdir(exist_ok=True, parents=True) + self._checkpoint = checkpoint + self._device = torch.device(device) + + if config: + self._config = deepcopy(config) + self._config.pop("Config", None) + + self._storage = Storage() + + def to(self, device): + self._model.to(device=device) + + def _start_training(self): + for epoch in range(self._start_epoch, self._max_epoch): + if self._model.get_lr() is not None: + self._meter_interface["lr"].add(self._model.get_lr()[0]) + self.train_loop(train_loader=self._train_loader, epoch=epoch) + with torch.no_grad(): + current_score = self.eval_loop(self._val_loader, epoch) + self._model.schedulerStep() + # save meters and checkpoints + self._meter_interface.step() + self.save_checkpoint(self.state_dict(), epoch, current_score) + self._meter_interface.summary().to_csv(self._save_dir / "wholeMeter.csv") + + def start_training(self): + with SummaryWriter(log_dir=self._save_dir) as self.writer: + return self._start_training() + + @abstractmethod + def _train_loop( + self, + train_loader: Union[DataLoader, _BaseDataLoaderIter] = None, + epoch: int = 0, + mode=None, + *args, + **kwargs, + ): + pass + + def train_loop(self, *args, **kwargs): + return self._train_loop(*args, **kwargs) + + @abstractmethod + def _eval_loop( + self, + val_loader: Union[DataLoader, _BaseDataLoaderIter] = None, + epoch: int = 0, + mode=None, + ) -> float: + pass + + def eval_loop(self, *args, **kwargs): + return self._eval_loop(*args, **kwargs) + + def inference(self, identifier="best.pth", *args, **kwargs): + """ + Inference using the checkpoint, to be override by subclasses. + :param args: + :param kwargs: + :return: + """ + if self._checkpoint is None: + self._checkpoint = self._save_dir + assert Path(self._checkpoint).exists(), Path(self._checkpoint) + assert (Path(self._checkpoint).is_dir() and identifier is not None) or ( + Path(self._checkpoint).is_file() and identifier is None + ) + + state_dict = torch.load( + str(Path(self._checkpoint) / identifier) + if identifier is not None + else self._checkpoint, + map_location=torch.device("cpu"), + ) + self.load_checkpoint(state_dict) + self._model.to(self._device) + # to be added + # probably call self._eval() method. + + def state_dict(self) -> Dict[str, Any]: + """ + return trainer's state dict. The dict is built by considering all the submodules having `state_dict` method. + """ + buffer_state_dict = self.buffer_state_dict() + local_modules = {k: v for k, v in self.__dict__.items() if k != "_buffers"} + + local_state_dict = {} + for module_name, module in local_modules.items(): + if hasattr(module, "state_dict") and callable(getattr(module, "state_dict", None)): + local_state_dict[module_name] = module.state_dict() + destination = {**local_state_dict, **{"_buffers": buffer_state_dict}} + return destination + + def load_state_dict(self, state_dict) -> None: + """ + Load state_dict for submodules having "load_state_dict" method. + :param state_dict: + :return: + """ + missing_keys = [] + unexpected_keys = [] + er_msgs = [] + + for module_name, module in self.__dict__.items(): + if module_name == "_buffers": + self.load_buffer_state_dict(state_dict["_buffers"]) + if hasattr(module, "load_state_dict") and callable(getattr(module, "load_state_dict", None)): + try: + module.load_state_dict(state_dict[module_name]) + except KeyError: + missing_keys.append(module_name) + except Exception as ex: + er_msgs.append( + "while copying {} parameters, " + "error {} occurs".format(module_name, ex) + ) + if len(er_msgs) > 0: + raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( + self.__class__.__name__, "\n\t".join(er_msgs))) + + def save_checkpoint( + self, state_dict, current_epoch, cur_score, save_dir=None, save_name=None + ): + """ + save checkpoint with adding 'epoch' and 'best_score' attributes + :param state_dict: + :param current_epoch: + :param cur_score: + :return: + """ + save_best: bool = True if float(cur_score) > float(self._best_score) else False + if save_best: + self._best_score = float(cur_score) + save_dir = self._save_dir if save_dir is None else path2Path(save_dir) + save_dir.mkdir(parents=True, exist_ok=True) + if save_name is None: + # regular saving + torch.save(state_dict, str(save_dir / "last.pth")) + if save_best: + torch.save(state_dict, str(save_dir / "best.pth")) + else: + # periodic saving + torch.save(state_dict, str(save_dir / save_name)) + + def load_checkpoint_from_path(self, checkpoint_path): + checkpoint_path = path2Path(checkpoint_path) + assert checkpoint_path.exists(), checkpoint_path + if checkpoint_path.is_dir(): + state_dict = torch.load( + str(Path(checkpoint_path) / self.checkpoint_identifier), + map_location=torch.device("cpu"), + ) + else: + assert checkpoint_path.suffix == ".pth", checkpoint_path + state_dict = torch.load( + str(checkpoint_path), map_location=torch.device("cpu"), + ) + self.load_checkpoint(state_dict) + + def clean_up(self, wait_time=3): + """ + Do not touch + :return: + """ + import shutil + import time + + time.sleep(wait_time) # to prevent that the call_draw function is not ended. + Path(self.ARCHIVE_PATH).mkdir(exist_ok=True, parents=True) + sub_dir = self._save_dir.relative_to(Path(self.RUN_PATH)) + save_dir = Path(self.ARCHIVE_PATH) / str(sub_dir) + if Path(save_dir).exists(): + shutil.rmtree(save_dir, ignore_errors=True) + shutil.move(str(self._save_dir), str(save_dir)) + shutil.rmtree(str(self._save_dir), ignore_errors=True) diff --git a/contrastyou/trainer/saver.py b/contrastyou/trainer/saver.py new file mode 100644 index 00000000..e69de29b diff --git a/contrastyou/writer/__init__.py b/contrastyou/writer/__init__.py new file mode 100644 index 00000000..51fb49cb --- /dev/null +++ b/contrastyou/writer/__init__.py @@ -0,0 +1 @@ +from .tensorboard import SummaryWriter diff --git a/contrastyou/writer/tensorboard.py b/contrastyou/writer/tensorboard.py new file mode 100644 index 00000000..0dab83be --- /dev/null +++ b/contrastyou/writer/tensorboard.py @@ -0,0 +1,39 @@ +import atexit +from pathlib import Path + +from tensorboardX import SummaryWriter as _SummaryWriter + + +def path2Path(path) -> Path: + assert isinstance(path, (str, Path)), path + return path if isinstance(path, Path) else Path(path) + + +class SummaryWriter(_SummaryWriter): + + def __init__(self, log_dir=None, comment="", **kwargs): + log_dir = path2Path(log_dir) + assert log_dir.exists() and log_dir.is_dir(), log_dir + super().__init__(str(log_dir / "tensorboard"), comment, **kwargs) + + def add_scalar_with_tag( + self, tag, tag_scalar_dict, global_step=None, walltime=None + ): + """ + Add one-level dictionary {A:1,B:2} with tag + :param tag: main tag like `train` or `val` + :param tag_scalar_dict: dictionary like {A:1,B:2} + :param global_step: epoch + :param walltime: None + :return: + """ + assert global_step is not None + for k, v in tag_scalar_dict.items(): + # self.add_scalars(main_tag=tag, tag_scalar_dict={k: v}) + self.add_scalar(tag=f"{tag}/{k}", scalar_value=v, global_step=global_step) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + atexit.register(self.close) diff --git a/main.py b/main.py new file mode 100644 index 00000000..e55446d6 --- /dev/null +++ b/main.py @@ -0,0 +1,123 @@ +import copy +import numbers +from pathlib import Path + +import numpy as np +from deepclustering.augment import SequentialWrapper, pil_augment +from deepclustering.dataset import PatientSampler +from deepclustering.manager import ConfigManger +from distributed.protocol.tests.test_torch import torch +from torch.utils.data import DataLoader + +from contrastyou import CONFIG_PATH, DATA_PATH +from contrastyou.augment import SequentialWrapperTwice +from contrastyou.dataloader._seg_datset import ContrastBatchSampler +from contrastyou.dataloader.acdc_dataset import ACDCDataset +from contrastyou.modules.model import DPModule as Model + +config = ConfigManger(Path(CONFIG_PATH) / "config.yaml", integrality_check=False, verbose=False).config + +config2 = copy.deepcopy(config) +config2["Scheduler"]["gamma"]=0 + +model = Model(arch=config["Arch"], optimizer=config["Optim"], scheduler=config["Scheduler"]) +model2 = Model(arch=config2["Arch"], optimizer=config2["Optim"], scheduler=config2["Scheduler"]) + +train_transform = SequentialWrapper( + pil_augment.Compose([ + pil_augment.RandomRotation(40), + pil_augment.RandomHorizontalFlip(), + pil_augment.RandomCrop(224), + pil_augment.ToTensor() + ]), + pil_augment.Compose([ + pil_augment.RandomRotation(40), + pil_augment.RandomHorizontalFlip(), + pil_augment.RandomCrop(224), + pil_augment.ToLabel() + ]), + (False, True) +) +val_transform = SequentialWrapper( + pil_augment.Compose([ + pil_augment.CenterCrop(224), + pil_augment.ToTensor() + ]), + pil_augment.Compose([ + pil_augment.CenterCrop(224), + pil_augment.ToLabel() + ]), + (False, True) +) +dataset = ACDCDataset(root_dir=DATA_PATH, mode="train", transforms=SequentialWrapperTwice(train_transform)) +batch_sampler = ContrastBatchSampler(dataset, group_sample_num=8, partition_sample_num=1) +train_loader = DataLoader(dataset, batch_sampler=batch_sampler) +val_dataset = ACDCDataset(root_dir=DATA_PATH, mode="val", transforms=val_transform) +val_batch_sampler = PatientSampler(val_dataset, grp_regex=val_dataset.dataset_pattern, shuffle=False, ) +val_loader = DataLoader(val_dataset, batch_sampler=val_batch_sampler) + +# class Epoch: + + + + +class Trainer: + def __init__(self, model, train_loader, val_loader, max_epoch, device, config) -> None: + self._model = model + self._train_loader = train_loader + self._val_loader = val_loader + self._max_epoch = max_epoch + self._device = device + self._config = config + self._begin_epoch = 0 + self._best_score = -1 + + def state_dict(self): + """ + return trainer's state dict. The dict is built by considering all the submodules having `state_dict` method. + """ + state_dictionary = {} + for module_name, module in self.__dict__.items(): + if hasattr(module, "state_dict"): + state_dictionary[module_name] = module.state_dict() + elif isinstance(module, (numbers.Number, str, torch.Tensor, np.ndarray)): + state_dictionary[module_name] = module + return state_dictionary + + def state_dict2(self): + return self.__dict__ + + def load_state_dict(self, state_dict) -> None: + """ + Load state_dict for submodules having "load_state_dict" method. + :param state_dict: + :return: + """ + for module_name, module in self.__dict__.items(): + if hasattr(module, "load_state_dict"): + try: + module.load_state_dict(state_dict[module_name]) + except KeyError as e: + print(f"Loading checkpoint error for {module_name}, {e}.") + except RuntimeError as e: + print(f"Interface changed error for {module_name}, {e}") + elif isinstance(module, (numbers.Number, str, torch.Tensor, np.ndarray)): + self.__dict__[module_name] = state_dict[module_name] + def load_state_dict2(self, state_dict): + self.__dict__.update(state_dict) + +trainer = Trainer(model, train_loader, val_loader, 100, "cuda", config) +trainer._best_score=torch.Tensor([10000000]) +trainer._best_epoch=1232 +trainer._big=123 +trainer2 = Trainer(model2, train_loader, val_loader, 200, "cpu", config2) +state_dict = trainer.state_dict() +from torchvision.models import vgg11_bn +model1 = vgg11_bn() +model2 = vgg11_bn() + +model2.load_state_dict(model1.state_dict()) + +trainer2.load_state_dict(state_dict) +## with this method, the id of the two items are the same.. +print(trainer2.__dict__) \ No newline at end of file diff --git a/test/test_arch.py b/test/test_arch.py new file mode 100644 index 00000000..e3e874cf --- /dev/null +++ b/test/test_arch.py @@ -0,0 +1,20 @@ +import torch +from torch.nn import Module, Linear + + +class A(Module): + def __init__(self, a=1) -> None: + super().__init__() + self.a = Linear(1, 1) + self.register_buffer("b", torch.Tensor([0])) + self.b = torch.Tensor([0]) + self.d = a + + +A1 = A(1) +A2 = A(2) +A1.a +b = A1.a +c = A2.b +d = A1.d +d = A1.d diff --git a/test/test_dataloader.py b/test/test_dataloader.py new file mode 100644 index 00000000..0ec0b68d --- /dev/null +++ b/test/test_dataloader.py @@ -0,0 +1,32 @@ +from torch.utils.data import DataLoader + +from contrastyou import DATA_PATH +from contrastyou.dataloader._seg_datset import ContrastBatchSampler +from contrastyou.dataloader.acdc_dataset import ACDCDataset + +root = DATA_PATH +from deepclustering.augment import SequentialWrapper, pil_augment +from contrastyou.augment import SequentialWrapperTwice + +transform = SequentialWrapper( + pil_augment.Compose([ + # pil_augment.RandomCrop(128), + pil_augment.RandomRotation(40), + pil_augment.ToTensor() + ]), + pil_augment.Compose([ + # pil_augment.RandomCrop(128), + pil_augment.RandomRotation(40), + + pil_augment.ToLabel() + ]), + if_is_target=[False, True] +) +twicetransform = SequentialWrapperTwice(transform) + +dataset = ACDCDataset(root_dir=DATA_PATH, mode="train", transforms=transform) +print(dataset.show_group_set(), dataset.show_parition_set()) +print(dataset[3]) + +batchsampler = ContrastBatchSampler(dataset, 10, 1) +dataloader = DataLoader(dataset, batch_sampler=batchsampler, num_workers=0, ) diff --git a/test/test_storage.py b/test/test_storage.py new file mode 100644 index 00000000..1a621e33 --- /dev/null +++ b/test/test_storage.py @@ -0,0 +1,16 @@ +import random + +from contrastyou.storage import Storage + +storage = Storage() +for epoch in range(0,10,2): + report_dict = {"loss":1, "dice_meter":{"dice": random.random(), "dice2": random.random()}} + if epoch >5: + report_dict.update({"dice_6":random.random()}) + storage.put_all(report_dict, epoch=epoch) +print(storage["loss"].summary()) +print(storage["dice_meter"].summary()) +print(storage["dice_6"].summary()) +print(storage.summary()) + + diff --git a/test/test_trainer.py b/test/test_trainer.py new file mode 100644 index 00000000..94f0a968 --- /dev/null +++ b/test/test_trainer.py @@ -0,0 +1,16 @@ +import torch + +from contrastyou.modules.model import Model +from contrastyou.trainer._trainer import _Trainer + +model = Model({"name": "enet"}) +trainer = _Trainer(model, None, None, 100, save_dir="tmp", checkpoint="23", device="cuda", config={"1": 2}) +state_dict = trainer.state_dict() +print("1") +model2 = Model({"name": "enet"}) +trainer2 = _Trainer(model2, None, None, 200, save_dir="tmp2", checkpoint="234", device="cpu", config={"1": 4}) +trainer2.load_state_dict(state_dict) +assert id(trainer._model._torchnet.parameters().__next__())!= id(trainer2._model._torchnet.parameters().__next__()) +assert torch.allclose(trainer._model._torchnet.parameters().__next__(),trainer2._model._torchnet.parameters().__next__()) + +