Skip to content

Commit

Permalink
Merge pull request #33 from boostcampaitech5/develop
Browse files Browse the repository at this point in the history
Develop
  • Loading branch information
seungki1011 authored Jun 23, 2023
2 parents 9307903 + 6ca2704 commit 543497c
Show file tree
Hide file tree
Showing 1,582 changed files with 108,900 additions and 24 deletions.
6 changes: 3 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/checkpoints
/predictions
checkpoints/
predictions/
.git
/wandb
wandb/
/__*
153 changes: 153 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
<p align="center">
<picture>
<img src="imgs/boostcampAITechlogo.png">
</picture>
<div align="center">
<img src="https://img.shields.io/badge/Python-FFD43B?style=for-the-badge&logo=python&logoColor=blue">
<img src="https://img.shields.io/badge/PyTorch-EE4C2C?style=for-the-badge&logo=pytorch&logoColor=white">
</div>
</p>

# ✨ 팀 소개

Thanks goes to these wonderful people ([emoji key](https://allcontributors.org/docs/en/emoji-key)):

<div align="center">
<table>
<tr>
<td align="center"><a href="https://github.com/seungki1011"><img src="https://avatars.githubusercontent.com/u/120040458?v=4?s=100" width="100px;" alt=""/><br /><sub><b>김승기</b></sub><br />
<a href="https://github.com/boostcampaitech5/level2_objectdetection-cv-03/commits?author=seungki1011" title="Code">💻</a>
<a href="https://github.com/boostcampaitech5/level2_objectdetection-cv-03/tree/main/upsampling" title="Data">🔣</a>
<a href="https://github.com/boostcampaitech5/level2_objectdetection-cv-03/tree/main/mmdetection/configs/_teamconfig_" title="Infrastructure">🚇</a>
<a href="https://github.com/boostcampaitech5/level2_objectdetection-cv-03/commits/main" title="Maintenance">🚧</a>
</td>
<td align="center"><a href="https://github.com/jjjuuuun"><img src="https://avatars.githubusercontent.com/u/86290308?v=4?s=100" width="100px;" alt=""/><br /><sub><b>김준영</b></sub></a><br />
<a href="https://github.com/boostcampaitech5/level2_objectdetection-cv-03/commits?author=jjjuuuun" title="Code">💻</a>
<a href="https://github.com/boostcampaitech5/level2_objectdetection-cv-03/tree/main/mmdetection/configs/_teamconfig_" title="Infrastructure">🚇</a>
<a href="https://github.com/boostcampaitech5/level2_objectdetection-cv-03/commits/main" title="Maintenance">🚧</a>
<a href="https://github.com/boostcampaitech5/level2_objectdetection-cv-03" title="projectManagement">📆</a>
</td>
<td align="center"><a href="https://github.com/helpmeIamnewbie"><img src="https://avatars.githubusercontent.com/u/102274521?v=4?s=100" width="100px;" alt=""/><br /><sub><b>전형우</b></sub></a><br />
<a href="https://github.com/boostcampaitech5/level2_objectdetection-cv-03/commits?author=helpmeIamnewbie" title="Code">💻</a>
<a href="https://github.com/boostcampaitech5/level2_objectdetection-cv-03" title="Ideas & Planning">🤔</a>
<a href="https://github.com/boostcampaitech5/level2_objectdetection-cv-03/pulls?q=" title="Reviewed Pull Requests">👀</a>
<a href="https://github.com/boostcampaitech5/level2_objectdetection-cv-03/commits?author=helpmeIamnewbie" title="Tests">⚠️</a>
</td>
<td align="center"><a href="https://github.com/CheonJiEun"><img src="https://avatars.githubusercontent.com/u/53997172?v=4?s=100" width="100px;" alt=""/><br /><sub><b>천지은</b></sub></a><br />
<a href="https://github.com/boostcampaitech5/level2_objectdetection-cv-03/commits?author=CheonJiEun" title="Code">💻</a>
<a href="https://github.com/boostcampaitech5/level2_objectdetection-cv-03/tree/main/upsampling" title="Data">🔣</a>
<a href="https://github.com/boostcampaitech5/level2_objectdetection-cv-03/tree/main/mmdetection/configs/_teamconfig_" title="Examples">💡</a>
<a href="https://github.com/boostcampaitech5/level2_objectdetection-cv-03/pulls?q=" title="Research">🔬</a>
</td>
<td align="center"><a href="https://github.com/Eyecaramba"><img src="https://avatars.githubusercontent.com/u/86091292?v=4?s=100" width="100px;" alt=""/><br /><sub><b>신우진</b></sub></a><br />
<a href="https://github.com/boostcampaitech5/level2_objectdetection-cv-03/commits?author=Eyecaramba" title="Code">💻</a>
<a href="https://github.com/boostcampaitech5/level2_objectdetection-cv-03" title="Ideas & Planning">🤔</a>
<a href="https://github.com/boostcampaitech5/level2_objectdetection-cv-03/tree/main/mmdetection/configs/_teamconfig_" title="Infrastructure">🚇</a>
<a href="https://github.com/boostcampaitech5/level2_objectdetection-cv-03/pulls?q=" title="Reviewed Pull Requests">👀</a>
</td>
</tr>
</table>
</div>

This project follows the [all-contributors](https://github.com/all-contributors/all-contributors) specification. Contributions of any kind welcome!

# 💀 프로젝트 소개

<p align="center">
<picture>
<img src="imgs/handbone_segmentation.png">
</picture>
</p>

뼈는 우리 몸의 구조와 기능에 중요한 영향을 미치기 때문에, 정확한 뼈 분할은 의료 진단 및 치료 계획을 개발하는 데 필수적입니다. Bone segmentation은 인공지능 분야에서 중요한 응용 분야 중 하나로, 특히, 딥러닝 기술을 이용한 뼈 segmentation은 많은 연구가 이루어지고 있으며, 다양한 목적으로 도움을 줄 수 있습니다.
1. 질병 진단의 목적으로 뼈의 형태나 위치가 변형되거나 부러지거나 골절 등이 있을 경우, 그 부위에서 발생하는 문제를 정확하게 파악하여 적절한 치료를 시행할 수 있습니다.
2. 수술 계획을 세우는데 도움이 됩니다. 의사들은 뼈 구조를 분석하여 어떤 종류의 수술이 필요한지, 어떤 종류의 재료가 사용될 수 있는지 등을 결정할 수 있습니다.
3. 의료장비 제작에 필요한 정보를 제공합니다. 예를 들어, 인공 관절이나 치아 임플란트를 제작할 때 뼈 구조를 분석하여 적절한 크기와 모양을 결정할 수 있습니다.
4. 의료 교육에서도 활용될 수 있습니다. 의사들은 병태 및 부상에 대한 이해를 높이고 수술 계획을 개발하는 데 필요한 기술을 연습할 수 있습니다.

이번 프로젝트는 `Boostcamp AI Tech` CV 트랙내에서 진행된 대회이며 mean dice coefficient으로 최종평가를 진행하게 됩니다.

# 📆 프로젝트 일정

프로젝트 전체 일정

- 2023.06.05 ~ 2023.06.22

프로젝트 세부 일정

- 2023.06.05 ~ 2023.06.09 : Semantic Semgmentation에 대해 알아보기
- 2023.06.08 ~ 2023.06.08 : Baseline Model 실험
- 2023.06.08 ~ 2023.06.09 : EDA
- 2023.06.09 ~ 2023.06.11 : Augmentation 실험
- 2023.06.12 ~ 2023.06.14 : Loss, Optimizer, Image size 실험
- 2023.06.14 ~ 2023.06.14 : MMSegmentation 구현
- 2023.06.14 ~ 2023.06.22 : Model, Image size, Offline Augmentation 실험
- 2023.06.21 ~ 2023.06.22 : Ensemble

# 👨‍💻 프로젝트 수행

1. [EDA](https://calico-dance-4bf.notion.site/EDA-db11b32576644efa9dc836a9135b55f0)✔️
2. [Augmentation](https://calico-dance-4bf.notion.site/Augmentation-5767f538c8ee4cf88462fe1bf2526a96)✔️
3. [Model](https://calico-dance-4bf.notion.site/Model-c8ddb0c1ddbf41abb5c0a2937da16b61)✔️
4. [중간정리](https://calico-dance-4bf.notion.site/09699a2814c04e83bb391627ab965c01)
5. [발표자료](https://calico-dance-4bf.notion.site/f0407bed529a4bbbae93d5d6c520ec4f)

# 🗒️ 프로젝트 결과

#### Public
<img align="center" src="imgs/public.png" width="600" height="80">

#### Private
<img align="center" src="imgs/private.png" width="600" height="80">

# 🔄️ Directory

```bash
├── .gitignore
├── TTA.py
├── dataset.py
├── inference.py
├── main.py
├── metric.py
├── test.py
├── train.py
├── pre-commit-config.yaml
├── gitcommit_template.txt
├── README.md
├── imgs
├── utils
└── mmsegmentation
├── _teamconfigs_
│ └── [test]ExpName
│ ├── config.py
│ ├── dataset.py
│ ├── default_runtime.py
│ ├── schedule.py
│ └── segformer_mit-b0.py
│ └── [test]MMSeg_AMP_GA
├── train.py
└── test.py
```

# ⚙️ 설치

#### Baseline Code
```pip install -r requirements.txt ```

#### MMSegmentation
Link ➡️
1. [MMSegmentation for our Project](https://calico-dance-4bf.notion.site/MMSegmentation-71f191822d5042129ccbcf7b9384f211)✔️
2. [Official GitHub](https://github.com/open-mmlab/mmsegmentation)😀

# ⚡️ 빠른 시작

#### Train
``` python train.py --exp-name {실험명} ```
#### Evaluation
``` python test.py --exp-name {실험명} ```

# 🤔 Wrap-Up Report

[Wrap-Up Report](https://file.notion.so/f/s/08bebda1-a3bb-4e83-93a9-706133868688/Semantic_Segmentation(%EA%B3%B5%EA%B0%9C%EC%9A%A9).pdf?id=59e780e8-9756-4381-8fc8-69e6383a4c16&table=block&spaceId=34e15efc-e2be-46ba-ae66-9fe65d825d78&expirationTimestamp=1687613606086&signature=pj_ir0N1PZm9B9ZTOH79znr-gxer3yajNNW3qO4-mFU&downloadName=Semantic+Segmentation%28%EA%B3%B5%EA%B0%9C%EC%9A%A9%29.pdf)

219 changes: 219 additions & 0 deletions TTA.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
import os

from tqdm import tqdm
import numpy as np
import pandas as pd

import torch
import torch.nn.functional as F

import random
from argparse import ArgumentParser

from dataset import XRayDataset, XRayInferenceDataset
from torch.utils.data import Dataset

import albumentations as A
import wandb
import ttach as tta
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from torchvision import models

import cv2
import ttach as tta

CLASSES = [
"finger-1",
"finger-2",
"finger-3",
"finger-4",
"finger-5",
"finger-6",
"finger-7",
"finger-8",
"finger-9",
"finger-10",
"finger-11",
"finger-12",
"finger-13",
"finger-14",
"finger-15",
"finger-16",
"finger-17",
"finger-18",
"finger-19",
"Trapezium",
"Trapezoid",
"Capitate",
"Hamate",
"Scaphoid",
"Lunate",
"Triquetrum",
"Pisiform",
"Radius",
"Ulna",
]
# https://github.com/qubvel/ttach/blob/master/ttach/transforms.py -- 참고
test_transform = tta.Compose([tta.Resize(sizes=(1024, 1024), original_size=(2048,2048), interpolation='bilinear'), tta.HorizontalFlip()])

class CustomModel(torch.nn.Module):
def __init__(self, model):
super(CustomModel, self).__init__()
self.model = model

def forward(self, x):
output = self.model(x)

if isinstance(output, dict):
output = output["out"]
return output

class XRayInferenceDataset_TTA(Dataset):
def __init__(self, data_root, transforms: A = None):
"""
Args:
data_root : csv 파일 위치
"""
self.df = pd.read_csv(os.path.join(data_root, f"test.csv"))
self.data_root = data_root
self.transforms = transforms

def __len__(self):
return len(self.df)

def __getitem__(self, idx):
row = self.df.iloc[idx]
image_path = row["filenames"]
image_name = os.path.join(image_path.split("/")[-2], image_path.split("/")[-1])

image = cv2.imread(image_path)

image = image / 255.0
image = image.transpose(2, 0, 1)

image = torch.from_numpy(image).float()

return image, image_name

def encode_mask_to_rle(mask):
"""
mask: numpy array binary mask
1 - mask
0 - background
Returns encoded run length
"""
pixels = mask.flatten()
pixels = np.concatenate([[0], pixels, [0]])
runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
runs[1::2] -= runs[::2]
return " ".join(str(x) for x in runs)


def decode_rle_to_mask(rle, height, width):
s = rle.split()
starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
starts -= 1
ends = starts + lengths
img = np.zeros(height * width, dtype=np.uint8)

for lo, hi in zip(starts, ends):
img[lo:hi] = 1

return img.reshape(height, width)


def test(data_loader, classes, best_model_dir, save_dir, is_csv=True, thr=0.5):
print("Start inference ...")
idx2class = {i: v for i, v in enumerate(classes)}

model = torch.load(os.path.join(best_model_dir, "best_model.pt"))["model"]
model = CustomModel(model)
model = tta.SegmentationTTAWrapper(model, test_transform, merge_mode='mean')
model.cuda()
model.eval()

rles = []
filename_and_class = []
with torch.no_grad():
for step, (images, image_names) in tqdm(
enumerate(data_loader), total=len(data_loader)
):
images = images.cuda()
# outputs = model(images)["out"]
outputs = model(images)
# restore original size
outputs = F.interpolate(outputs, size=(2048, 2048), mode="bilinear")
outputs = torch.sigmoid(outputs)
outputs = (outputs > thr).detach().cpu().numpy()
for output, image_name in zip(outputs, image_names):
for c, segm in enumerate(output):
rle = encode_mask_to_rle(segm)
rles.append(rle)
filename_and_class.append(f"{idx2class[c]}_{image_name}")

if is_csv:
classes, filename = zip(*[x.split("_") for x in filename_and_class])
image_name = [os.path.basename(f) for f in filename]
df = pd.DataFrame(
{
"image_name": image_name,
"class": classes,
"rle": rles,
}
)

df.to_csv(os.path.join(save_dir, "submission.csv"), index=False)
print("CSV file creation successful")
else:
return rles, filename_and_class


def main(args):

save_csv = os.path.join(args.save_csv, args.exp_name)
save_checkpoint = os.path.join(args.save_checkpoint, args.exp_name)

test_dataset = XRayInferenceDataset_TTA(args.data_root, transforms=test_transform)

test_loader = DataLoader(
dataset=test_dataset,
batch_size=2,
shuffle=False,
num_workers=2,
drop_last=False,
)

# Inference
test(test_loader, CLASSES, save_checkpoint, save_csv, args.make_csv)


if __name__ == "__main__":
parser = ArgumentParser()

# Path
parser.add_argument(
"--data-root",
type=str,
default="../data",
)
parser.add_argument(
"--save-checkpoint",
type=str,
default="./checkpoints",
)
parser.add_argument(
"--save-csv",
type=str,
default="./predictions",
)

parser.add_argument("--exp-name", type=str, default="[test]ExpName")

# Inference
parser.add_argument("--make-csv", type=bool, default=True)

args = parser.parse_args()

main(args)
Loading

0 comments on commit 543497c

Please sign in to comment.