-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
executable file
·100 lines (78 loc) · 3.42 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import numpy as np
import matplotlib.pyplot as plt
from skimage.util import img_as_float
from skimage import transform
from skimage import io
from segmentation import *
import os
def visualize_mean_color_image(img, segments):
img = img_as_float(img)
k = np.max(segments) + 1
mean_color_img = np.zeros(img.shape)
for i in range(k):
mean_color = np.mean(img[segments == i], axis=0)
mean_color_img[segments == i] = mean_color
plt.imshow(mean_color_img)
plt.axis('off')
plt.show()
def compute_segmentation(img, k,
clustering_fn=kmeans_fast,
feature_fn=color_position_features,
scale=0):
""" Compute a segmentation for an image.
First a feature vector is extracted from each pixel of an image. Next a
clustering algorithm is applied to the set of all feature vectors. Two
pixels are assigned to the same segment if and only if their feature
vectors are assigned to the same cluster.
Args:
img - An array of shape (H, W, C) to segment.
k - The number of segments into which the image should be split.
clustering_fn - The method to use for clustering. The function should
take an array of N points and an integer value k as input and
output an array of N assignments.
feature_fn - A function used to extract features from the image.
scale - (OPTIONAL) parameter giving the scale to which the image
should be in the range 0 < scale <= 1. Setting this argument to a
smaller value will increase the speed of the clustering algorithm
but will cause computed segments to be blockier. This setting is
usually not necessary for kmeans clustering, but when using HAC
clustering this parameter will probably need to be set to a value
less than 1.
"""
assert scale <= 1 and scale >= 0, \
'Scale should be in the range between 0 and 1'
H, W, C = img.shape
if scale > 0:
# Scale down the image for faster computation.
img = transform.rescale(img, scale)
features = feature_fn(img)
assignments = clustering_fn(features, k)
segments = assignments.reshape((img.shape[:2]))
if scale > 0:
# Resize segmentation back to the image's original size
segments = transform.resize(segments, (H, W), preserve_range=True)
# Resizing results in non-interger values of pixels.
# Round pixel values to the closest interger
segments = np.rint(segments).astype(int)
return segments
def load_dataset(data_dir):
"""
This function assumes 'gt' directory contains ground truth segmentation
masks for images in 'imgs' dir. The segmentation mask for image
'imgs/aaa.jpg' is 'gt/aaa.png'
"""
imgs = []
gt_masks = []
# Load all the images under 'data_dir/imgs' and corresponding
# segmentation masks under 'data_dir/gt'.
for fname in sorted(os.listdir(os.path.join(data_dir, 'imgs'))):
if fname.endswith('.jpg'):
# Load image
img = io.imread(os.path.join(data_dir, 'imgs', fname))
imgs.append(img)
# Load corresponding gt segmentation mask
mask_fname = fname[:-4] + '.png'
gt_mask = io.imread(os.path.join(data_dir, 'gt', mask_fname))
gt_mask = (gt_mask != 0).astype(int) # Convert to binary mask (0s and 1s)
gt_masks.append(gt_mask)
return imgs, gt_masks