Skip to content

Commit

Permalink
[Feature]add CLAHE transform (#229)
Browse files Browse the repository at this point in the history
* add CLAHE transform

* fix syntax error

* fix syntax error

* restore

* add a test

* modify cv2 to mmcv

* add docstring

* modify

* restore

* fix mmcv.clahe error

* change mmcv version to 1.3.0

* fix bugs

* add all data transformers to __init__

* fix __init__

* fix test_transform
  • Loading branch information
yamengxi authored Dec 2, 2020
1 parent 4dc809a commit 0066ce8
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 6 deletions.
8 changes: 4 additions & 4 deletions mmseg/datasets/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
Transpose, to_tensor)
from .loading import LoadAnnotations, LoadImageFromFile
from .test_time_aug import MultiScaleFlipAug
from .transforms import (Normalize, Pad, PhotoMetricDistortion, RandomCrop,
RandomFlip, RandomRotate, Rerange, Resize, RGB2Gray,
SegRescale)
from .transforms import (CLAHE, AdjustGamma, Normalize, Pad,
PhotoMetricDistortion, RandomCrop, RandomFlip,
RandomRotate, Rerange, Resize, RGB2Gray, SegRescale)

__all__ = [
'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToDataContainer',
'Transpose', 'Collect', 'LoadAnnotations', 'LoadImageFromFile',
'MultiScaleFlipAug', 'Resize', 'RandomFlip', 'Pad', 'RandomCrop',
'Normalize', 'SegRescale', 'PhotoMetricDistortion', 'RandomRotate',
'Rerange', 'RGB2Gray'
'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray'
]
48 changes: 46 additions & 2 deletions mmseg/datasets/pipelines/transforms.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import mmcv
import numpy as np
from mmcv.utils import deprecated_api_warning
from mmcv.utils import deprecated_api_warning, is_tuple_of
from numpy import random

from ..builder import PIPELINES
Expand Down Expand Up @@ -415,7 +415,6 @@ def __call__(self, results):
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Reranged results.
"""
Expand All @@ -439,6 +438,51 @@ def __repr__(self):
return repr_str


@PIPELINES.register_module()
class CLAHE(object):
"""Use CLAHE method to process the image.
See `ZUIDERVELD,K. Contrast Limited Adaptive Histogram Equalization[J].
Graphics Gems, 1994:474-485.` for more information.
Args:
clip_limit (float): Threshold for contrast limiting. Default: 40.0.
tile_grid_size (tuple[int]): Size of grid for histogram equalization.
Input image will be divided into equally sized rectangular tiles.
It defines the number of tiles in row and column. Default: (8, 8).
"""

def __init__(self, clip_limit=40.0, tile_grid_size=(8, 8)):
assert isinstance(clip_limit, (float, int))
self.clip_limit = clip_limit
assert is_tuple_of(tile_grid_size, int)
assert len(tile_grid_size) == 2
self.tile_grid_size = tile_grid_size

def __call__(self, results):
"""Call function to Use CLAHE method process images.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Processed results.
"""

for i in range(results['img'].shape[2]):
results['img'][:, :, i] = mmcv.clahe(
np.array(results['img'][:, :, i], dtype=np.uint8),
self.clip_limit, self.tile_grid_size)

return results

def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(clip_limit={self.clip_limit}, '\
f'tile_grid_size={self.tile_grid_size})'
return repr_str


@PIPELINES.register_module()
class RandomCrop(object):
"""Random crop the image & seg.
Expand Down
40 changes: 40 additions & 0 deletions tests/test_data/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,46 @@ def test_rerange():
assert str(transform) == f'Rerange(min_value={0}, max_value={255})'


def test_CLAHE():
# test assertion if clip_limit is None
with pytest.raises(AssertionError):
transform = dict(type='CLAHE', clip_limit=None)
build_from_cfg(transform, PIPELINES)

# test assertion if tile_grid_size is illegal
with pytest.raises(AssertionError):
transform = dict(type='CLAHE', tile_grid_size=(8.0, 8.0))
build_from_cfg(transform, PIPELINES)

# test assertion if tile_grid_size is illegal
with pytest.raises(AssertionError):
transform = dict(type='CLAHE', tile_grid_size=(9, 9, 9))
build_from_cfg(transform, PIPELINES)

transform = dict(type='CLAHE', clip_limit=2)
transform = build_from_cfg(transform, PIPELINES)
results = dict()
img = mmcv.imread(
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
original_img = copy.deepcopy(img)
results['img'] = img
results['img_shape'] = img.shape
results['ori_shape'] = img.shape
# Set initial values for default meta_keys
results['pad_shape'] = img.shape
results['scale_factor'] = 1.0

results = transform(results)

converted_img = np.empty(original_img.shape)
for i in range(original_img.shape[2]):
converted_img[:, :, i] = mmcv.clahe(
np.array(original_img[:, :, i], dtype=np.uint8), 2, (8, 8))

assert np.allclose(results['img'], converted_img)
assert str(transform) == f'CLAHE(clip_limit={2}, tile_grid_size={(8, 8)})'


def test_seg_rescale():
results = dict()
seg = np.array(
Expand Down

0 comments on commit 0066ce8

Please sign in to comment.