Skip to content

Commit

Permalink
Merge pull request #19 from boostcampaitech2/EDA
Browse files Browse the repository at this point in the history
EDA branch pr 드립니다.
확인했습니다.
  • Loading branch information
ppskj178 authored Oct 21, 2021
2 parents 3d8c878 + 82777ef commit aff5b4f
Show file tree
Hide file tree
Showing 5 changed files with 228 additions and 8 deletions.
161 changes: 161 additions & 0 deletions EDA/EDA.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "aa608dff",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import numpy as np\n",
"\n",
"# 시각화를 위한 라이브러리\n",
"import matplotlib.pyplot as plt\n",
"%matplotlib inline\n",
"from matplotlib.patches import Patch\n",
"import webcolors\n",
"\n",
"import sys\n",
"sys.path.append('/opt/ml/segmentation/semantic-segmentation-level2-cv-06/')\n",
"from utils import label_accuracy_score, add_hist, class_colormap, label_to_color_image\n",
"from dataset import CustomDataLoader, collate_fn,\\\n",
" train_transform, val_transform, test_transform,\\\n",
" category_names"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6f74eb9b",
"metadata": {},
"outputs": [],
"source": [
"plt.rcParams['axes.grid'] = False"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fc293266",
"metadata": {},
"outputs": [],
"source": [
"root_path = '../input/data'"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "79fe6a23",
"metadata": {},
"outputs": [],
"source": [
"# train.json / validation.json / test.json 디렉토리 설정\n",
"train_path = root_path + '/train.json'\n",
"val_path = root_path + '/val.json'\n",
"test_path = root_path + '/test.json'\n",
"batch_size = 32\n",
"\n",
"train_dataset = CustomDataLoader(data_dir=train_path, mode='train', transform=train_transform)\n",
"#val_dataset = CustomDataLoader(data_dir=val_path, mode='val', transform=val_transform)\n",
"#test_dataset = CustomDataLoader(data_dir=test_path, mode='test', transform=test_transform)\n",
"\n",
"# DataLoader\n",
"train_loader = torch.utils.data.DataLoader(dataset=train_dataset, \n",
" batch_size=batch_size,\n",
" shuffle=False,\n",
" num_workers=4,\n",
" collate_fn=collate_fn)\n",
"\n",
"# val_loader = torch.utils.data.DataLoader(dataset=val_dataset, \n",
"# batch_size=batch_size,\n",
"# shuffle=False,\n",
"# num_workers=4,\n",
"# collate_fn=collate_fn)\n",
"\n",
"# test_loader = torch.utils.data.DataLoader(dataset=test_dataset,\n",
"# batch_size=batch_size,\n",
"# num_workers=4,\n",
"# collate_fn=collate_fn)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e075742b",
"metadata": {},
"outputs": [],
"source": [
"for i, (imgs, masks, image_infos) in enumerate(train_loader):\n",
" #print(len(imgs))\n",
" if i == 0:\n",
" break"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3a8a5316",
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"category_and_rgb = [[category, (r,g,b)] for idx, (category, r, g, b) in enumerate(class_colormap.values)]\n",
"legend_elements = [Patch(facecolor=webcolors.rgb_to_hex(rgb), \n",
" edgecolor=webcolors.rgb_to_hex(rgb), \n",
" label=category) for category, rgb in category_and_rgb]\n",
"\n",
"fig, ax = plt.subplots(nrows=16, ncols=4, figsize=(20, 100))\n",
"# train_loader의 output 결과(image 및 mask) 확인\n",
"for i in range(len(imgs)):\n",
" temp_image_infos = image_infos[i]\n",
" temp_images = imgs[i]\n",
" temp_masks = masks[i]\n",
" \n",
" ax[int(i//2),0 + (i%2)*2].imshow(temp_images.permute([1,2,0]))\n",
" ax[int(i//2),0 + (i%2)*2].grid(False)\n",
" ax[int(i//2),0 + (i%2)*2].set_title(\"{}\".format([category_names[int(i)] for i in list(np.unique(temp_masks))]), fontsize = 10)\n",
" ax[int(i//2),0 + (i%2)*2].set_xlabel(temp_image_infos['file_name'])\n",
"\n",
" ax[int(i//2),1 + (i%2)*2].imshow(label_to_color_image(temp_masks.detach().cpu().numpy()))\n",
" ax[int(i//2),1 + (i%2)*2].grid(False)\n",
" #ax[int(i//2),1 + (i%2)*2].set_title(\"{}\".format([{int(i),category_names[int(i)]} for i in list(np.unique(temp_masks))], fontsize = 5))\n",
" if (i%2)*2 == 2:\n",
" ax[int(i//2),1 + (i%2)*2].legend(handles=legend_elements, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0)\n",
" \n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8d535a07",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "segmentation",
"language": "python",
"name": "segmentation"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.11"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
13 changes: 13 additions & 0 deletions EDA/class_dict.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
name,r,g,b
Backgroud,0,0,0
General trash,250,0,50
Paper,0,255,0
Paper pack,0,180,80
Metal,185,185,185
Glass,100,100,100
Plastic,200,50,150
Styrofoam,50,150,200
Plastic bag,50,200,150
Battery,200,200,200
Clothing,255,255,255

13 changes: 13 additions & 0 deletions class_dict.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
name,r,g,b
Backgroud,0,0,0
General trash,250,0,50
Paper,0,255,0
Paper pack,0,180,80
Metal,185,185,185
Glass,100,100,100
Plastic,200,50,150
Styrofoam,50,150,200
Plastic bag,50,200,150
Battery,200,200,200
Clothing,255,255,255

7 changes: 3 additions & 4 deletions dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@
import albumentations as A
from albumentations.pytorch import ToTensorV2


dataset_path = './input/data'
# category_names = ['Background', 'General trash', 'Paper', 'Paper pack', 'Metal', 'Glass',
# 'Plastic', 'Styrofoam', 'Plastic bag', 'Battery', 'Clothing']
dataset_path = '/opt/ml/segmentation/semantic-segmentation-level2-cv-06/input/data/'
category_names = ['Background', 'General trash', 'Paper', 'Paper pack', 'Metal', 'Glass',
'Plastic', 'Styrofoam', 'Plastic bag', 'Battery', 'Clothing']


def get_classname(classID, cats):
Expand Down
42 changes: 38 additions & 4 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import random
import numpy as np
import matplotlib.pyplot as plt
import os
import pandas as pd

class_colormap = pd.read_csv("/opt/ml/segmentation/semantic-segmentation-level2-cv-06/class_dict.csv")

def _fast_hist(label_true, label_pred, n_class):
mask = (label_true >= 0) & (label_true < n_class)
Expand Down Expand Up @@ -64,19 +66,52 @@ def grid_image(images, masks, preds, n=4, shuffle=False):
ax[idx*3].grid(False)

ax[idx*3+1] = figure.add_subplot(gs[idx, 1])
ax[idx*3+1].imshow(mask)
ax[idx*3+1].imshow(label_to_color_image(mask))
ax[idx*3+1].grid(False)

ax[idx*3+2] = figure.add_subplot(gs[idx, 2])
ax[idx*3+2].imshow(pred)
ax[idx*3+2].imshow(label_to_color_image(pred))
ax[idx*3+2].grid(False)
# 나중에 확률 값으로 얼마나 틀렸는지 시각화 해주는 열을 추가하면 더 좋을듯?

figure.suptitle('image / GT / pred', fontsize=16)

return figure

def create_trash_label_colormap():
"""Creates a label colormap used in Trash segmentation.
Returns:
A colormap for visualizing segmentation results.
"""
global class_colormap

colormap = np.zeros((11, 3), dtype=np.uint8)
for inex, (_, r, g, b) in enumerate(class_colormap.values):
colormap[inex] = [r, g, b]

return colormap

def label_to_color_image(label):
"""Adds color defined by the dataset colormap to the label.
Args:
label: A 2D array with integer type, storing the segmentation label.
Returns:
result: A 2D array with floating type. The element of the array
is the color indexed by the corresponding element in the input label
to the trash color map.
Raises:
ValueError: If label is not of rank 2 or its value is larger than color
map maximum entry.
"""
if label.ndim != 2:
raise ValueError('Expect 2-D input label')

colormap = create_trash_label_colormap()

if np.max(label) >= len(colormap):
raise ValueError('label value too large.')

return colormap[label]

def remove_dot_underbar(root):
for path in os.listdir(root):
Expand All @@ -91,7 +126,6 @@ def remove_dot_underbar(root):
# root = './'
# remove_dot_underbar(root)


# def label_accuracy_score(label_trues, label_preds, n_class):
# """Returns accuracy score evaluation result.
# - overall accuracy
Expand Down

0 comments on commit aff5b4f

Please sign in to comment.