Skip to content

Commit

Permalink
Merge pull request #20 from boostcampaitech2/EDA
Browse files Browse the repository at this point in the history
[FEAT] add img_diff
  • Loading branch information
JiyouSeo authored Oct 21, 2021
2 parents aff5b4f + 79fd775 commit 8a6c306
Show file tree
Hide file tree
Showing 3 changed files with 258 additions and 49 deletions.
72 changes: 36 additions & 36 deletions EDA/EDA.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -80,19 +80,6 @@
"# collate_fn=collate_fn)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e075742b",
"metadata": {},
"outputs": [],
"source": [
"for i, (imgs, masks, image_infos) in enumerate(train_loader):\n",
" #print(len(imgs))\n",
" if i == 0:\n",
" break"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -102,28 +89,39 @@
},
"outputs": [],
"source": [
"category_and_rgb = [[category, (r,g,b)] for idx, (category, r, g, b) in enumerate(class_colormap.values)]\n",
"legend_elements = [Patch(facecolor=webcolors.rgb_to_hex(rgb), \n",
" edgecolor=webcolors.rgb_to_hex(rgb), \n",
" label=category) for category, rgb in category_and_rgb]\n",
"select_batch = 3\n",
"\n",
"fig, ax = plt.subplots(nrows=16, ncols=4, figsize=(20, 100))\n",
"# train_loader의 output 결과(image 및 mask) 확인\n",
"for i in range(len(imgs)):\n",
" temp_image_infos = image_infos[i]\n",
" temp_images = imgs[i]\n",
" temp_masks = masks[i]\n",
" \n",
" ax[int(i//2),0 + (i%2)*2].imshow(temp_images.permute([1,2,0]))\n",
" ax[int(i//2),0 + (i%2)*2].grid(False)\n",
" ax[int(i//2),0 + (i%2)*2].set_title(\"{}\".format([category_names[int(i)] for i in list(np.unique(temp_masks))]), fontsize = 10)\n",
" ax[int(i//2),0 + (i%2)*2].set_xlabel(temp_image_infos['file_name'])\n",
"if not (0 <= select_batch < len(train_loader)):\n",
" raise Exception(\"select_batch index error\")\n",
"\n",
"for i, (imgs, masks, image_infos) in enumerate(train_loader):\n",
" #print(len(imgs))\n",
" if i < select_batch:\n",
" continue\n",
"\n",
" category_and_rgb = [[category, (r,g,b)] for idx, (category, r, g, b) in enumerate(class_colormap.values)]\n",
" legend_elements = [Patch(facecolor=webcolors.rgb_to_hex(rgb), \n",
" edgecolor=webcolors.rgb_to_hex(rgb), \n",
" label=category) for category, rgb in category_and_rgb]\n",
"\n",
" ax[int(i//2),1 + (i%2)*2].imshow(label_to_color_image(temp_masks.detach().cpu().numpy()))\n",
" ax[int(i//2),1 + (i%2)*2].grid(False)\n",
" #ax[int(i//2),1 + (i%2)*2].set_title(\"{}\".format([{int(i),category_names[int(i)]} for i in list(np.unique(temp_masks))], fontsize = 5))\n",
" if (i%2)*2 == 2:\n",
" ax[int(i//2),1 + (i%2)*2].legend(handles=legend_elements, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0)\n",
" fig, ax = plt.subplots(nrows=16, ncols=4, figsize=(20, 100))\n",
" # train_loader의 output 결과(image 및 mask) 확인\n",
" for i in range(len(imgs)):\n",
" temp_image_infos = image_infos[i]\n",
" temp_images = imgs[i]\n",
" temp_masks = masks[i]\n",
" \n",
" ax[int(i//2),0 + (i%2)*2].imshow(temp_images.permute([1,2,0]))\n",
" ax[int(i//2),0 + (i%2)*2].grid(False)\n",
" ax[int(i//2),0 + (i%2)*2].set_title(\"{}\".format([category_names[int(i)] for i in list(np.unique(temp_masks))]), fontsize = 10)\n",
" ax[int(i//2),0 + (i%2)*2].set_xlabel(temp_image_infos['file_name'])\n",
"\n",
" ax[int(i//2),1 + (i%2)*2].imshow(label_to_color_image(temp_masks.detach().cpu().numpy()))\n",
" ax[int(i//2),1 + (i%2)*2].grid(False)\n",
" #ax[int(i//2),1 + (i%2)*2].set_title(\"{}\".format([{int(i),category_names[int(i)]} for i in list(np.unique(temp_masks))], fontsize = 5))\n",
" if (i%2)*2 == 2:\n",
" ax[int(i//2),1 + (i%2)*2].legend(handles=legend_elements, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0)\n",
" break\n",
" \n",
"plt.show()"
]
Expand All @@ -138,10 +136,12 @@
}
],
"metadata": {
"interpreter": {
"hash": "d36e052b391be8c28b05838ade06426769a29575d5fe21a7bc69c7dec0c04c06"
},
"kernelspec": {
"display_name": "segmentation",
"language": "python",
"name": "segmentation"
"display_name": "Python 3.7.11 64-bit ('segmentation': conda)",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
Expand Down
13 changes: 0 additions & 13 deletions EDA/class_dict.csv

This file was deleted.

222 changes: 222 additions & 0 deletions one_off/img_diff.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
import os
import random
import time
import json
import warnings
warnings.filterwarnings('ignore')

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import cv2

import numpy as np
import pandas as pd
from tqdm import tqdm

# 전처리를 위한 라이브러리
from pycocotools.coco import COCO
import torch
import torchvision
import torchvision.transforms as transforms

import albumentations as A
from albumentations.pytorch import ToTensorV2
# 시각화를 위한 라이브러리
import matplotlib.pyplot as plt
import seaborn as sns; sns.set()
from matplotlib.patches import Patch
import webcolors

import albumentations as A
from albumentations.pytorch import ToTensorV2

import sys
sys.path.insert(1, '/opt/ml/semantic-segmentation-level2-cv-06/')
# from utils import label_accuracy_score, add_hist

plt.rcParams['axes.grid'] = False

dataset_path = '../input/data'
anns_file_path = dataset_path + '/' + 'train_all.json'

# Read annotations
with open(anns_file_path, 'r') as f:
dataset = json.loads(f.read())


categories = dataset['categories']
category_names = []
for cat in categories:
category_names.append(cat['name'])
category_names.insert(0,'Background')
# category_names

class_colormap = pd.read_csv("./class_dict.csv")
# class_colormap

def get_classname(classID, cats):
for i in range(len(cats)):
if cats[i]['id']==classID:
return cats[i]['name']
return "None"


def create_trash_label_colormap():
"""Creates a label colormap used in Trash segmentation.
Returns:
A colormap for visualizing segmentation results.
"""
colormap = np.zeros((11, 3), dtype=np.uint8)
for inex, (_, r, g, b) in enumerate(class_colormap.values):
colormap[inex] = [r, g, b]

return colormap

def label_to_color_image(label):
"""Adds color defined by the dataset colormap to the label.
Args:
label: A 2D array with integer type, storing the segmentation label.
Returns:
result: A 2D array with floating type. The element of the array
is the color indexed by the corresponding element in the input label
to the trash color map.
Raises:
ValueError: If label is not of rank 2 or its value is larger than color
map maximum entry.
"""
if label.ndim != 2:
raise ValueError('Expect 2-D input label')

colormap = create_trash_label_colormap()

if np.max(label) >= len(colormap):
raise ValueError('label value too large.')

return colormap[label]

# train.json / validation.json / test.json 디렉토리 설정
train_path = dataset_path + '/train.json'
val_path = dataset_path + '/val.json'
test_path = dataset_path + '/test.json'
batch_size = 32
# collate_fn needs for batch
def collate_fn(batch):
return tuple(zip(*batch))

train_transform = A.Compose([ToTensorV2()])

class CustomDataLoader2(Dataset):
"""COCO format"""
def __init__(self, data_dir, mode = 'train', transform = None):
super().__init__()
self.mode = mode
self.transform = transform
self.coco = COCO(data_dir)

def __getitem__(self, index: int):
# dataset이 index되어 list처럼 동작
image_id = self.coco.getImgIds(imgIds=index)
image_infos = self.coco.loadImgs(image_id)[0]

# cv2 를 활용하여 image 불러오기
images = cv2.imread(os.path.join(dataset_path, image_infos['file_name']))
images = cv2.cvtColor(images, cv2.COLOR_BGR2RGB).astype(np.float32)
images /= 255.0

if (self.mode in ('train', 'val')):
ann_ids = self.coco.getAnnIds(imgIds=image_infos['id'])
anns = self.coco.loadAnns(ann_ids)

# Load the categories in a variable
cat_ids = self.coco.getCatIds()
cats = self.coco.loadCats(cat_ids)

# masks : size가 (height x width)인 2D
# 각각의 pixel 값에는 "category id" 할당
# Background = 0
masks = np.zeros((image_infos["height"], image_infos["width"]))
# General trash = 1, ... , Cigarette = 10
# anns = sorted(anns, key=lambda idx : idx['area'], reverse=True)
# anns = sorted(anns, key=lambda idx: len(idx['segmentation'][0]), reverse=True)
for i in range(len(anns)):
className = get_classname(anns[i]['category_id'], cats)
pixel_value = category_names.index(className)
masks[self.coco.annToMask(anns[i]) == 1] = pixel_value
masks = masks.astype(np.int8)

masks2 = np.zeros((image_infos["height"], image_infos["width"]))
# General trash = 1, ... , Cigarette = 10
anns = sorted(anns, key=lambda idx: len(idx['segmentation'][0]), reverse=False)
for i in range(len(anns)):
className = get_classname(anns[i]['category_id'], cats)
pixel_value = category_names.index(className)
masks2[self.coco.annToMask(anns[i]) == 1] = pixel_value
masks2 = masks2.astype(np.int8)

# transform -> albumentations 라이브러리 활용
origin_image = images
if self.transform is not None:
transformed = self.transform(image=origin_image, mask=masks)
images = transformed["image"]
masks = transformed["mask"]
transformed = self.transform(image=origin_image, mask=masks2)
masks2 = transformed["mask"]

return images, masks, image_infos, masks2

if self.mode == 'test':
# transform -> albumentations 라이브러리 활용
if self.transform is not None:
transformed = self.transform(image=images)
images = transformed["image"]
return images, image_infos

def __len__(self) -> int:
# 전체 dataset의 size를 return
return len(self.coco.getImgIds())

train_dataset2 = CustomDataLoader2(data_dir=train_path, mode='train', transform=train_transform)
train_loader2 = torch.utils.data.DataLoader(dataset=train_dataset2,
batch_size=1,
shuffle=False,
num_workers=4,
collate_fn=collate_fn)


# n = 3
# i=0
figsize=8

for idx, (imgs, masks, image_infos, masks2) in enumerate(tqdm(train_loader2)):
if torch.any(torch.ne(masks[0], masks2[0])).item() == True:
fig, ax = plt.subplots(nrows=1, ncols=4, figsize=(figsize*4, figsize))

draw_mask = torch.ne(masks[0], masks2[0]).type(torch.int8)

ax[0].imshow(imgs[0].permute([1,2,0]))
ax[0].grid(False)
# ax[i,0].set_xlabel(image_infos[0]['file_name'])

ax[1].imshow(label_to_color_image(draw_mask.detach().cpu().numpy()))
ax[1].grid(False)

ax[2].imshow(label_to_color_image(masks[0].detach().cpu().numpy()))
ax[2].grid(False)

ax[3].imshow(label_to_color_image(masks2[0].detach().cpu().numpy()))
ax[3].grid(False)
# i += 1

# print(f'{image_infos[0]["file_name"][:-4]}.png')
plt.savefig(f'{image_infos[0]["file_name"][:-4]}')
# break
# plt.savefig(f'{image_infos[0]["file_name"]}')

# if i >= n:
# break

# plt.show()

0 comments on commit 8a6c306

Please sign in to comment.