diff --git a/EDA/EDA.ipynb b/EDA/EDA.ipynb new file mode 100644 index 0000000..3040b74 --- /dev/null +++ b/EDA/EDA.ipynb @@ -0,0 +1,161 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "aa608dff", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import numpy as np\n", + "\n", + "# 시각화를 위한 라이브러리\n", + "import matplotlib.pyplot as plt\n", + "%matplotlib inline\n", + "from matplotlib.patches import Patch\n", + "import webcolors\n", + "\n", + "import sys\n", + "sys.path.append('/opt/ml/segmentation/semantic-segmentation-level2-cv-06/')\n", + "from utils import label_accuracy_score, add_hist, class_colormap, label_to_color_image\n", + "from dataset import CustomDataLoader, collate_fn,\\\n", + " train_transform, val_transform, test_transform,\\\n", + " category_names" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6f74eb9b", + "metadata": {}, + "outputs": [], + "source": [ + "plt.rcParams['axes.grid'] = False" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fc293266", + "metadata": {}, + "outputs": [], + "source": [ + "root_path = '../input/data'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "79fe6a23", + "metadata": {}, + "outputs": [], + "source": [ + "# train.json / validation.json / test.json 디렉토리 설정\n", + "train_path = root_path + '/train.json'\n", + "val_path = root_path + '/val.json'\n", + "test_path = root_path + '/test.json'\n", + "batch_size = 32\n", + "\n", + "train_dataset = CustomDataLoader(data_dir=train_path, mode='train', transform=train_transform)\n", + "#val_dataset = CustomDataLoader(data_dir=val_path, mode='val', transform=val_transform)\n", + "#test_dataset = CustomDataLoader(data_dir=test_path, mode='test', transform=test_transform)\n", + "\n", + "# DataLoader\n", + "train_loader = torch.utils.data.DataLoader(dataset=train_dataset, \n", + " batch_size=batch_size,\n", + " shuffle=False,\n", + " num_workers=4,\n", + " collate_fn=collate_fn)\n", + "\n", + "# val_loader = torch.utils.data.DataLoader(dataset=val_dataset, \n", + "# batch_size=batch_size,\n", + "# shuffle=False,\n", + "# num_workers=4,\n", + "# collate_fn=collate_fn)\n", + "\n", + "# test_loader = torch.utils.data.DataLoader(dataset=test_dataset,\n", + "# batch_size=batch_size,\n", + "# num_workers=4,\n", + "# collate_fn=collate_fn)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e075742b", + "metadata": {}, + "outputs": [], + "source": [ + "for i, (imgs, masks, image_infos) in enumerate(train_loader):\n", + " #print(len(imgs))\n", + " if i == 0:\n", + " break" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3a8a5316", + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "category_and_rgb = [[category, (r,g,b)] for idx, (category, r, g, b) in enumerate(class_colormap.values)]\n", + "legend_elements = [Patch(facecolor=webcolors.rgb_to_hex(rgb), \n", + " edgecolor=webcolors.rgb_to_hex(rgb), \n", + " label=category) for category, rgb in category_and_rgb]\n", + "\n", + "fig, ax = plt.subplots(nrows=16, ncols=4, figsize=(20, 100))\n", + "# train_loader의 output 결과(image 및 mask) 확인\n", + "for i in range(len(imgs)):\n", + " temp_image_infos = image_infos[i]\n", + " temp_images = imgs[i]\n", + " temp_masks = masks[i]\n", + " \n", + " ax[int(i//2),0 + (i%2)*2].imshow(temp_images.permute([1,2,0]))\n", + " ax[int(i//2),0 + (i%2)*2].grid(False)\n", + " ax[int(i//2),0 + (i%2)*2].set_title(\"{}\".format([category_names[int(i)] for i in list(np.unique(temp_masks))]), fontsize = 10)\n", + " ax[int(i//2),0 + (i%2)*2].set_xlabel(temp_image_infos['file_name'])\n", + "\n", + " ax[int(i//2),1 + (i%2)*2].imshow(label_to_color_image(temp_masks.detach().cpu().numpy()))\n", + " ax[int(i//2),1 + (i%2)*2].grid(False)\n", + " #ax[int(i//2),1 + (i%2)*2].set_title(\"{}\".format([{int(i),category_names[int(i)]} for i in list(np.unique(temp_masks))], fontsize = 5))\n", + " if (i%2)*2 == 2:\n", + " ax[int(i//2),1 + (i%2)*2].legend(handles=legend_elements, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0)\n", + " \n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8d535a07", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "segmentation", + "language": "python", + "name": "segmentation" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/EDA/class_dict.csv b/EDA/class_dict.csv new file mode 100755 index 0000000..34a30ee --- /dev/null +++ b/EDA/class_dict.csv @@ -0,0 +1,13 @@ +name,r,g,b +Backgroud,0,0,0 +General trash,250,0,50 +Paper,0,255,0 +Paper pack,0,180,80 +Metal,185,185,185 +Glass,100,100,100 +Plastic,200,50,150 +Styrofoam,50,150,200 +Plastic bag,50,200,150 +Battery,200,200,200 +Clothing,255,255,255 + diff --git a/class_dict.csv b/class_dict.csv new file mode 100755 index 0000000..34a30ee --- /dev/null +++ b/class_dict.csv @@ -0,0 +1,13 @@ +name,r,g,b +Backgroud,0,0,0 +General trash,250,0,50 +Paper,0,255,0 +Paper pack,0,180,80 +Metal,185,185,185 +Glass,100,100,100 +Plastic,200,50,150 +Styrofoam,50,150,200 +Plastic bag,50,200,150 +Battery,200,200,200 +Clothing,255,255,255 + diff --git a/dataset.py b/dataset.py index f6ce1bb..46fc09c 100644 --- a/dataset.py +++ b/dataset.py @@ -7,10 +7,9 @@ import albumentations as A from albumentations.pytorch import ToTensorV2 - -dataset_path = './input/data' -# category_names = ['Background', 'General trash', 'Paper', 'Paper pack', 'Metal', 'Glass', -# 'Plastic', 'Styrofoam', 'Plastic bag', 'Battery', 'Clothing'] +dataset_path = '/opt/ml/segmentation/semantic-segmentation-level2-cv-06/input/data/' +category_names = ['Background', 'General trash', 'Paper', 'Paper pack', 'Metal', 'Glass', + 'Plastic', 'Styrofoam', 'Plastic bag', 'Battery', 'Clothing'] def get_classname(classID, cats): diff --git a/utils.py b/utils.py index cdfe482..bead005 100644 --- a/utils.py +++ b/utils.py @@ -3,7 +3,9 @@ import random import numpy as np import matplotlib.pyplot as plt -import os +import pandas as pd + +class_colormap = pd.read_csv("/opt/ml/segmentation/semantic-segmentation-level2-cv-06/class_dict.csv") def _fast_hist(label_true, label_pred, n_class): mask = (label_true >= 0) & (label_true < n_class) @@ -64,11 +66,11 @@ def grid_image(images, masks, preds, n=4, shuffle=False): ax[idx*3].grid(False) ax[idx*3+1] = figure.add_subplot(gs[idx, 1]) - ax[idx*3+1].imshow(mask) + ax[idx*3+1].imshow(label_to_color_image(mask)) ax[idx*3+1].grid(False) ax[idx*3+2] = figure.add_subplot(gs[idx, 2]) - ax[idx*3+2].imshow(pred) + ax[idx*3+2].imshow(label_to_color_image(pred)) ax[idx*3+2].grid(False) # 나중에 확률 값으로 얼마나 틀렸는지 시각화 해주는 열을 추가하면 더 좋을듯? @@ -76,7 +78,40 @@ def grid_image(images, masks, preds, n=4, shuffle=False): return figure +def create_trash_label_colormap(): + """Creates a label colormap used in Trash segmentation. + Returns: + A colormap for visualizing segmentation results. + """ + global class_colormap + + colormap = np.zeros((11, 3), dtype=np.uint8) + for inex, (_, r, g, b) in enumerate(class_colormap.values): + colormap[inex] = [r, g, b] + + return colormap + +def label_to_color_image(label): + """Adds color defined by the dataset colormap to the label. + Args: + label: A 2D array with integer type, storing the segmentation label. + Returns: + result: A 2D array with floating type. The element of the array + is the color indexed by the corresponding element in the input label + to the trash color map. + Raises: + ValueError: If label is not of rank 2 or its value is larger than color + map maximum entry. + """ + if label.ndim != 2: + raise ValueError('Expect 2-D input label') + colormap = create_trash_label_colormap() + + if np.max(label) >= len(colormap): + raise ValueError('label value too large.') + + return colormap[label] def remove_dot_underbar(root): for path in os.listdir(root): @@ -91,7 +126,6 @@ def remove_dot_underbar(root): # root = './' # remove_dot_underbar(root) - # def label_accuracy_score(label_trues, label_preds, n_class): # """Returns accuracy score evaluation result. # - overall accuracy