Skip to content

Commit

Permalink
adding main blocks
Browse files Browse the repository at this point in the history
  • Loading branch information
jizong committed Jul 7, 2020
1 parent 1b898c2 commit 8a5b6ab
Show file tree
Hide file tree
Showing 54 changed files with 3,117 additions and 1 deletion.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Created by .ignore support plugin (hsz.mobi)
### Example user template

.data
runs
# IntelliJ project files
.idea
*.iml
Expand Down
13 changes: 13 additions & 0 deletions config/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
Arch:
name: enet
num_classes: 4

Optim:
name: Adam
lr: 0.0001
weight_decay: 0.00005

Scheduler:
name: StepLR
step_size: 30
gamma: 0.1
27 changes: 27 additions & 0 deletions contrastyou/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from enum import Enum
from pathlib import Path

PROJECT_PATH = str(Path(__file__).parents[1])
DATA_PATH = str(Path(PROJECT_PATH) / ".data")
Path(DATA_PATH).mkdir(exist_ok=True, parents=True)
CONFIG_PATH= str(Path(PROJECT_PATH, "config"))


class ModelState(Enum):
TRAIN = "TRAIN"
TEST = "TEST"
EVAL = "EVAL"

@staticmethod
def from_str(mode_str):
""" Init from string
:param mode_str: ['train', 'eval', 'predict']
"""
if mode_str == "train":
return ModelState.TRAIN
elif mode_str == "test":
return ModelState.TEST
elif mode_str == "eval":
return ModelState.EVAL
else:
raise ValueError("Invalid argument mode_str {}".format(mode_str))
2 changes: 2 additions & 0 deletions contrastyou/arch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
def get_arch(*args, **kwargs):
pass
11 changes: 11 additions & 0 deletions contrastyou/augment/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
class SequentialWrapperTwice:
def __init__(self, transform=None) -> None:
self._transform = transform

def __call__(
self, *imgs, random_seed=None
):
return [
self._transform.__call__(*imgs, random_seed=random_seed),
self._transform.__call__(*imgs, random_seed=random_seed),
]
103 changes: 103 additions & 0 deletions contrastyou/dataloader/_seg_datset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import random
from abc import abstractmethod
from copy import deepcopy as dcp
from typing import Union, List, Set

from torch.utils.data.sampler import Sampler


class ContrastDataset:
"""
each patient has 2 code, the first code is the group_name, which is the patient id
the second code is the partition code, indicating the position of the image slice.
All patients should have the same partition numbers so that they can be aligned.
For ACDC dataset, the ED and ES ventricular volume should be considered further
"""

@abstractmethod
def _get_partition(self, *args) -> Union[str, int]:
"""get the partition of a 2D slice given its index or filename"""
pass

@abstractmethod
def _get_group(self, *args) -> Union[str, int]:
"""get the group name of a 2D slice given its index or filename"""
pass

@abstractmethod
def show_paritions(self) -> List[Union[str, int]]:
"""show all groups of 2D slices in the dataset"""
pass

def show_parition_set(self) -> Set[Union[str, int]]:
"""show all groups of 2D slices in the dataset"""
return set(self.show_paritions())

@abstractmethod
def show_groups(self) -> List[Union[str, int]]:
"""show all groups of 2D slices in the dataset"""
pass

def show_group_set(self) -> Set[Union[str, int]]:
"""show all groups of 2D slices in the dataset"""
return set(self.show_groups())


class ContrastBatchSampler(Sampler):
"""
This class is going to realize the sampling for different patients and from the same patients
`we form batches by first randomly sampling m < M volumes. Then, for each sampled volume, we sample one image per
partition resulting in S images per volume. Next, we apply a pair of random transformations on each sampled image and
add them to the batch
"""

class _SamplerIterator:

def __init__(self, group2index, partion2index, group_sample_num=4, partition_sample_num=1) -> None:
self._group2index, self._partition2index = dcp(group2index), dcp(partion2index)

assert group_sample_num >= 1 and group_sample_num <= len(self._group2index.keys()), group_sample_num
self._group_sample_num = group_sample_num
self._partition_sample_num = partition_sample_num

def __iter__(self):
return self

def __next__(self):
batch_index = []
cur_gsamples = random.sample(self._group2index.keys(), self._group_sample_num)
assert isinstance(cur_gsamples, list), cur_gsamples
# for each gsample, sample at most partition_sample_num slices per partion
for cur_gsample in cur_gsamples:
gavailableslices = self._group2index[cur_gsample]
for savailbleslices in self._partition2index.values():
sampled_slices = random.sample(sorted(set(gavailableslices) & set(savailbleslices)),
self._partition_sample_num)
batch_index.extend(sampled_slices)
return batch_index

def __init__(self, dataset: ContrastDataset, group_sample_num=4, partition_sample_num=1) -> None:
self._dataset = dataset
filenames = dcp(list(dataset._filenames.values())[0])
group2index = {}
partiton2index = {}
for i, filename in enumerate(filenames):
group = dataset._get_group(filename)
if group not in group2index:
group2index[group] = []
group2index[group].append(i)
partition = dataset._get_partition(filename)
if partition not in partiton2index:
partiton2index[partition] = []
partiton2index[partition].append(i)
self._group2index = group2index
self._partition2index = partiton2index
self._group_sample_num = group_sample_num
self._partition_sample_num = partition_sample_num

def __iter__(self):
return self._SamplerIterator(self._group2index, self._partition2index, self._group_sample_num,
self._partition_sample_num)

def __len__(self) -> int:
return len(self._dataset) # type: ignore
35 changes: 35 additions & 0 deletions contrastyou/dataloader/acdc_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from typing import List, Tuple, Union

from deepclustering.augment import SequentialWrapper
from deepclustering.dataset import ACDCDataset as _ACDCDataset
from torch import Tensor

from contrastyou.dataloader._seg_datset import ContrastDataset


class ACDCDataset(ContrastDataset, _ACDCDataset):
def __init__(self, root_dir: str, mode: str, transforms: SequentialWrapper = None,
verbose=True) -> None:
super().__init__(root_dir, mode, ["img", "gt"], transforms, verbose)

def __getitem__(self, index) -> Tuple[List[Tensor], str, str, str]:
data, filename = super().__getitem__(index)
partition = self._get_partition(filename)
group = self._get_group(filename)
return data, filename, partition, group

def _get_group(self, filename) -> Union[str, int]:
return self._get_group_name(filename)

def _get_partition(self, filename) -> Union[str, int]:
return 0

def show_paritions(self) -> List[Union[str, int]]:
return [self._get_partition(f) for f in list(self._filenames.values())[0]]

def show_groups(self) -> List[Union[str, int]]:
return [self._get_group(f) for f in list(self._filenames.values())[0]]




Empty file.
Empty file added contrastyou/helper/__init__.py
Empty file.
Empty file added contrastyou/losses/__init__.py
Empty file.
5 changes: 5 additions & 0 deletions contrastyou/meters2/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@

from .meter_interface import MeterInterface
from .individual_meters import *

# todo: improve the stability of each meter
30 changes: 30 additions & 0 deletions contrastyou/meters2/individual_meters/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from ._metric import _Metric
# individual package for meters based on
"""
>>>class _Metric(metaclass=ABCMeta):
>>> @abstractmethod
>>> def reset(self):
>>> pass
>>>
>>> @abstractmethod
>>> def add(self, *args, **kwargs):
>>> pass
>>>
>>> @abstractmethod
>>> def log(self):
>>> pass
>>>
>>> @abstractmethod
>>> def summary(self) -> dict:
>>> pass
>>>
>>> @abstractmethod
>>> def detailed_summary(self) -> dict:
>>> pass
"""

from .averagemeter import AverageValueMeter
from .confusionmatrix import ConfusionMatrix
from .hausdorff import HaussdorffDistance
from .instance import InstanceValue
from .iou import IoU
28 changes: 28 additions & 0 deletions contrastyou/meters2/individual_meters/_metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from abc import abstractmethod, ABCMeta


class _Metric(metaclass=ABCMeta):
"""Base class for all metrics.
record the values within a single epoch
From: https://github.com/pytorch/tnt/blob/master/torchnet/meter/meter.py
"""

@abstractmethod
def reset(self):
pass

@abstractmethod
def add(self, *args, **kwargs):
pass

# @abstractmethod
# def log(self):
# pass

@abstractmethod
def summary(self) -> dict:
pass

@abstractmethod
def detailed_summary(self) -> dict:
pass
61 changes: 61 additions & 0 deletions contrastyou/meters2/individual_meters/averagemeter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from typing import List

import numpy as np

from ._metric import _Metric


class AverageValueMeter(_Metric):
def __init__(self):
super(AverageValueMeter, self).__init__()
self.reset()
self.val = 0

def add(self, value, n=1):
self.val = value
self.sum += value
self.var += value * value
self.n += n

if self.n == 0:
self.mean, self.std = np.nan, np.nan
elif self.n == 1:
self.mean = 0.0 + self.sum # This is to force a copy in torch/numpy
self.std = np.inf
self.mean_old = self.mean
self.m_s = 0.0
else:
self.mean = self.mean_old + (value - n * self.mean_old) / float(self.n)
self.m_s += (value - self.mean_old) * (value - self.mean)
self.mean_old = self.mean
self.std = np.sqrt(self.m_s / (self.n - 1.0))

def value(self):
return self.mean, self.std

def reset(self):
self.n = 0
self.sum = 0.0
self.var = 0.0
self.val = 0.0
self.mean = np.nan
self.mean_old = 0.0
self.m_s = 0.0
self.std = 0.0

def summary(self) -> dict:
# this function returns a dict and tends to aggregate the historical results.
return {"mean": self.value()[0]}

def detailed_summary(self) -> dict:
# this function returns a dict and tends to aggregate the historical results.
return {"mean": self.value()[0], "val": self.value()[1]}

def __repr__(self):
def _dict2str(value_dict: dict):
return "\t".join([f"{k}:{v}" for k, v in value_dict.items()])

return f"{self.__class__.__name__}: n={self.n} \n \t {_dict2str(self.detailed_summary())}"

def get_plot_names(self) -> List[str]:
return ["mean"]
59 changes: 59 additions & 0 deletions contrastyou/meters2/individual_meters/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import torch
from numbers import Number
from ._metric import _Metric
import numpy as np


class Cache(_Metric):
"""
Cache is a meter to just store the elements in self.log. For statistic propose of use.
"""

def __init__(self) -> None:
super().__init__()
self.log = []

def reset(self):
self.log = []

def add(self, input):
self.log.append(input)

def value(self, **kwargs):
return len(self.log)

def summary(self) -> dict:
return {"total elements": self.log.__len__()}

def detailed_summary(self) -> dict:
return self.summary()


class AveragewithStd(Cache):
"""
this Meter is going to return the mean and std_lower, std_high for a list of scalar values
"""

def add(self, input):
assert (
isinstance(input, Number)
or (isinstance(input, torch.Tensor) and input.shape.__len__() <= 1)
or (isinstance(input, np.ndarray) and input.shape.__len__() <= 1)
)
if torch.is_tensor(input):
input = input.cpu().item()

super().add(input)

def value(self, **kwargs):
return torch.Tensor(self.log).mean().item()

def summary(self) -> dict:
torch_log = torch.Tensor(self.log)
mean = torch_log.mean()
std = torch_log.std()
return {
"mean": mean.item(),
"lstd": mean.item() - std.item(),
"hstd": mean.item() + std.item(),
}
Loading

0 comments on commit 8a5b6ab

Please sign in to comment.