-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
jizong
committed
Jul 7, 2020
1 parent
1b898c2
commit 8a5b6ab
Showing
54 changed files
with
3,117 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
Arch: | ||
name: enet | ||
num_classes: 4 | ||
|
||
Optim: | ||
name: Adam | ||
lr: 0.0001 | ||
weight_decay: 0.00005 | ||
|
||
Scheduler: | ||
name: StepLR | ||
step_size: 30 | ||
gamma: 0.1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
from enum import Enum | ||
from pathlib import Path | ||
|
||
PROJECT_PATH = str(Path(__file__).parents[1]) | ||
DATA_PATH = str(Path(PROJECT_PATH) / ".data") | ||
Path(DATA_PATH).mkdir(exist_ok=True, parents=True) | ||
CONFIG_PATH= str(Path(PROJECT_PATH, "config")) | ||
|
||
|
||
class ModelState(Enum): | ||
TRAIN = "TRAIN" | ||
TEST = "TEST" | ||
EVAL = "EVAL" | ||
|
||
@staticmethod | ||
def from_str(mode_str): | ||
""" Init from string | ||
:param mode_str: ['train', 'eval', 'predict'] | ||
""" | ||
if mode_str == "train": | ||
return ModelState.TRAIN | ||
elif mode_str == "test": | ||
return ModelState.TEST | ||
elif mode_str == "eval": | ||
return ModelState.EVAL | ||
else: | ||
raise ValueError("Invalid argument mode_str {}".format(mode_str)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
def get_arch(*args, **kwargs): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
class SequentialWrapperTwice: | ||
def __init__(self, transform=None) -> None: | ||
self._transform = transform | ||
|
||
def __call__( | ||
self, *imgs, random_seed=None | ||
): | ||
return [ | ||
self._transform.__call__(*imgs, random_seed=random_seed), | ||
self._transform.__call__(*imgs, random_seed=random_seed), | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
import random | ||
from abc import abstractmethod | ||
from copy import deepcopy as dcp | ||
from typing import Union, List, Set | ||
|
||
from torch.utils.data.sampler import Sampler | ||
|
||
|
||
class ContrastDataset: | ||
""" | ||
each patient has 2 code, the first code is the group_name, which is the patient id | ||
the second code is the partition code, indicating the position of the image slice. | ||
All patients should have the same partition numbers so that they can be aligned. | ||
For ACDC dataset, the ED and ES ventricular volume should be considered further | ||
""" | ||
|
||
@abstractmethod | ||
def _get_partition(self, *args) -> Union[str, int]: | ||
"""get the partition of a 2D slice given its index or filename""" | ||
pass | ||
|
||
@abstractmethod | ||
def _get_group(self, *args) -> Union[str, int]: | ||
"""get the group name of a 2D slice given its index or filename""" | ||
pass | ||
|
||
@abstractmethod | ||
def show_paritions(self) -> List[Union[str, int]]: | ||
"""show all groups of 2D slices in the dataset""" | ||
pass | ||
|
||
def show_parition_set(self) -> Set[Union[str, int]]: | ||
"""show all groups of 2D slices in the dataset""" | ||
return set(self.show_paritions()) | ||
|
||
@abstractmethod | ||
def show_groups(self) -> List[Union[str, int]]: | ||
"""show all groups of 2D slices in the dataset""" | ||
pass | ||
|
||
def show_group_set(self) -> Set[Union[str, int]]: | ||
"""show all groups of 2D slices in the dataset""" | ||
return set(self.show_groups()) | ||
|
||
|
||
class ContrastBatchSampler(Sampler): | ||
""" | ||
This class is going to realize the sampling for different patients and from the same patients | ||
`we form batches by first randomly sampling m < M volumes. Then, for each sampled volume, we sample one image per | ||
partition resulting in S images per volume. Next, we apply a pair of random transformations on each sampled image and | ||
add them to the batch | ||
""" | ||
|
||
class _SamplerIterator: | ||
|
||
def __init__(self, group2index, partion2index, group_sample_num=4, partition_sample_num=1) -> None: | ||
self._group2index, self._partition2index = dcp(group2index), dcp(partion2index) | ||
|
||
assert group_sample_num >= 1 and group_sample_num <= len(self._group2index.keys()), group_sample_num | ||
self._group_sample_num = group_sample_num | ||
self._partition_sample_num = partition_sample_num | ||
|
||
def __iter__(self): | ||
return self | ||
|
||
def __next__(self): | ||
batch_index = [] | ||
cur_gsamples = random.sample(self._group2index.keys(), self._group_sample_num) | ||
assert isinstance(cur_gsamples, list), cur_gsamples | ||
# for each gsample, sample at most partition_sample_num slices per partion | ||
for cur_gsample in cur_gsamples: | ||
gavailableslices = self._group2index[cur_gsample] | ||
for savailbleslices in self._partition2index.values(): | ||
sampled_slices = random.sample(sorted(set(gavailableslices) & set(savailbleslices)), | ||
self._partition_sample_num) | ||
batch_index.extend(sampled_slices) | ||
return batch_index | ||
|
||
def __init__(self, dataset: ContrastDataset, group_sample_num=4, partition_sample_num=1) -> None: | ||
self._dataset = dataset | ||
filenames = dcp(list(dataset._filenames.values())[0]) | ||
group2index = {} | ||
partiton2index = {} | ||
for i, filename in enumerate(filenames): | ||
group = dataset._get_group(filename) | ||
if group not in group2index: | ||
group2index[group] = [] | ||
group2index[group].append(i) | ||
partition = dataset._get_partition(filename) | ||
if partition not in partiton2index: | ||
partiton2index[partition] = [] | ||
partiton2index[partition].append(i) | ||
self._group2index = group2index | ||
self._partition2index = partiton2index | ||
self._group_sample_num = group_sample_num | ||
self._partition_sample_num = partition_sample_num | ||
|
||
def __iter__(self): | ||
return self._SamplerIterator(self._group2index, self._partition2index, self._group_sample_num, | ||
self._partition_sample_num) | ||
|
||
def __len__(self) -> int: | ||
return len(self._dataset) # type: ignore |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
from typing import List, Tuple, Union | ||
|
||
from deepclustering.augment import SequentialWrapper | ||
from deepclustering.dataset import ACDCDataset as _ACDCDataset | ||
from torch import Tensor | ||
|
||
from contrastyou.dataloader._seg_datset import ContrastDataset | ||
|
||
|
||
class ACDCDataset(ContrastDataset, _ACDCDataset): | ||
def __init__(self, root_dir: str, mode: str, transforms: SequentialWrapper = None, | ||
verbose=True) -> None: | ||
super().__init__(root_dir, mode, ["img", "gt"], transforms, verbose) | ||
|
||
def __getitem__(self, index) -> Tuple[List[Tensor], str, str, str]: | ||
data, filename = super().__getitem__(index) | ||
partition = self._get_partition(filename) | ||
group = self._get_group(filename) | ||
return data, filename, partition, group | ||
|
||
def _get_group(self, filename) -> Union[str, int]: | ||
return self._get_group_name(filename) | ||
|
||
def _get_partition(self, filename) -> Union[str, int]: | ||
return 0 | ||
|
||
def show_paritions(self) -> List[Union[str, int]]: | ||
return [self._get_partition(f) for f in list(self._filenames.values())[0]] | ||
|
||
def show_groups(self) -> List[Union[str, int]]: | ||
return [self._get_group(f) for f in list(self._filenames.values())[0]] | ||
|
||
|
||
|
||
|
Empty file.
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
|
||
from .meter_interface import MeterInterface | ||
from .individual_meters import * | ||
|
||
# todo: improve the stability of each meter |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
from ._metric import _Metric | ||
# individual package for meters based on | ||
""" | ||
>>>class _Metric(metaclass=ABCMeta): | ||
>>> @abstractmethod | ||
>>> def reset(self): | ||
>>> pass | ||
>>> | ||
>>> @abstractmethod | ||
>>> def add(self, *args, **kwargs): | ||
>>> pass | ||
>>> | ||
>>> @abstractmethod | ||
>>> def log(self): | ||
>>> pass | ||
>>> | ||
>>> @abstractmethod | ||
>>> def summary(self) -> dict: | ||
>>> pass | ||
>>> | ||
>>> @abstractmethod | ||
>>> def detailed_summary(self) -> dict: | ||
>>> pass | ||
""" | ||
|
||
from .averagemeter import AverageValueMeter | ||
from .confusionmatrix import ConfusionMatrix | ||
from .hausdorff import HaussdorffDistance | ||
from .instance import InstanceValue | ||
from .iou import IoU |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
from abc import abstractmethod, ABCMeta | ||
|
||
|
||
class _Metric(metaclass=ABCMeta): | ||
"""Base class for all metrics. | ||
record the values within a single epoch | ||
From: https://github.com/pytorch/tnt/blob/master/torchnet/meter/meter.py | ||
""" | ||
|
||
@abstractmethod | ||
def reset(self): | ||
pass | ||
|
||
@abstractmethod | ||
def add(self, *args, **kwargs): | ||
pass | ||
|
||
# @abstractmethod | ||
# def log(self): | ||
# pass | ||
|
||
@abstractmethod | ||
def summary(self) -> dict: | ||
pass | ||
|
||
@abstractmethod | ||
def detailed_summary(self) -> dict: | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
from typing import List | ||
|
||
import numpy as np | ||
|
||
from ._metric import _Metric | ||
|
||
|
||
class AverageValueMeter(_Metric): | ||
def __init__(self): | ||
super(AverageValueMeter, self).__init__() | ||
self.reset() | ||
self.val = 0 | ||
|
||
def add(self, value, n=1): | ||
self.val = value | ||
self.sum += value | ||
self.var += value * value | ||
self.n += n | ||
|
||
if self.n == 0: | ||
self.mean, self.std = np.nan, np.nan | ||
elif self.n == 1: | ||
self.mean = 0.0 + self.sum # This is to force a copy in torch/numpy | ||
self.std = np.inf | ||
self.mean_old = self.mean | ||
self.m_s = 0.0 | ||
else: | ||
self.mean = self.mean_old + (value - n * self.mean_old) / float(self.n) | ||
self.m_s += (value - self.mean_old) * (value - self.mean) | ||
self.mean_old = self.mean | ||
self.std = np.sqrt(self.m_s / (self.n - 1.0)) | ||
|
||
def value(self): | ||
return self.mean, self.std | ||
|
||
def reset(self): | ||
self.n = 0 | ||
self.sum = 0.0 | ||
self.var = 0.0 | ||
self.val = 0.0 | ||
self.mean = np.nan | ||
self.mean_old = 0.0 | ||
self.m_s = 0.0 | ||
self.std = 0.0 | ||
|
||
def summary(self) -> dict: | ||
# this function returns a dict and tends to aggregate the historical results. | ||
return {"mean": self.value()[0]} | ||
|
||
def detailed_summary(self) -> dict: | ||
# this function returns a dict and tends to aggregate the historical results. | ||
return {"mean": self.value()[0], "val": self.value()[1]} | ||
|
||
def __repr__(self): | ||
def _dict2str(value_dict: dict): | ||
return "\t".join([f"{k}:{v}" for k, v in value_dict.items()]) | ||
|
||
return f"{self.__class__.__name__}: n={self.n} \n \t {_dict2str(self.detailed_summary())}" | ||
|
||
def get_plot_names(self) -> List[str]: | ||
return ["mean"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import torch | ||
from numbers import Number | ||
from ._metric import _Metric | ||
import numpy as np | ||
|
||
|
||
class Cache(_Metric): | ||
""" | ||
Cache is a meter to just store the elements in self.log. For statistic propose of use. | ||
""" | ||
|
||
def __init__(self) -> None: | ||
super().__init__() | ||
self.log = [] | ||
|
||
def reset(self): | ||
self.log = [] | ||
|
||
def add(self, input): | ||
self.log.append(input) | ||
|
||
def value(self, **kwargs): | ||
return len(self.log) | ||
|
||
def summary(self) -> dict: | ||
return {"total elements": self.log.__len__()} | ||
|
||
def detailed_summary(self) -> dict: | ||
return self.summary() | ||
|
||
|
||
class AveragewithStd(Cache): | ||
""" | ||
this Meter is going to return the mean and std_lower, std_high for a list of scalar values | ||
""" | ||
|
||
def add(self, input): | ||
assert ( | ||
isinstance(input, Number) | ||
or (isinstance(input, torch.Tensor) and input.shape.__len__() <= 1) | ||
or (isinstance(input, np.ndarray) and input.shape.__len__() <= 1) | ||
) | ||
if torch.is_tensor(input): | ||
input = input.cpu().item() | ||
|
||
super().add(input) | ||
|
||
def value(self, **kwargs): | ||
return torch.Tensor(self.log).mean().item() | ||
|
||
def summary(self) -> dict: | ||
torch_log = torch.Tensor(self.log) | ||
mean = torch_log.mean() | ||
std = torch_log.std() | ||
return { | ||
"mean": mean.item(), | ||
"lstd": mean.item() - std.item(), | ||
"hstd": mean.item() + std.item(), | ||
} |
Oops, something went wrong.