-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #33 from boostcampaitech5/develop
Develop
- Loading branch information
Showing
1,582 changed files
with
108,900 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
/checkpoints | ||
/predictions | ||
checkpoints/ | ||
predictions/ | ||
.git | ||
/wandb | ||
wandb/ | ||
/__* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
<p align="center"> | ||
<picture> | ||
<img src="imgs/boostcampAITechlogo.png"> | ||
</picture> | ||
<div align="center"> | ||
<img src="https://img.shields.io/badge/Python-FFD43B?style=for-the-badge&logo=python&logoColor=blue"> | ||
<img src="https://img.shields.io/badge/PyTorch-EE4C2C?style=for-the-badge&logo=pytorch&logoColor=white"> | ||
</div> | ||
</p> | ||
|
||
# ✨ 팀 소개 | ||
|
||
Thanks goes to these wonderful people ([emoji key](https://allcontributors.org/docs/en/emoji-key)): | ||
|
||
<div align="center"> | ||
<table> | ||
<tr> | ||
<td align="center"><a href="https://github.com/seungki1011"><img src="https://avatars.githubusercontent.com/u/120040458?v=4?s=100" width="100px;" alt=""/><br /><sub><b>김승기</b></sub><br /> | ||
<a href="https://github.com/boostcampaitech5/level2_objectdetection-cv-03/commits?author=seungki1011" title="Code">💻</a> | ||
<a href="https://github.com/boostcampaitech5/level2_objectdetection-cv-03/tree/main/upsampling" title="Data">🔣</a> | ||
<a href="https://github.com/boostcampaitech5/level2_objectdetection-cv-03/tree/main/mmdetection/configs/_teamconfig_" title="Infrastructure">🚇</a> | ||
<a href="https://github.com/boostcampaitech5/level2_objectdetection-cv-03/commits/main" title="Maintenance">🚧</a> | ||
</td> | ||
<td align="center"><a href="https://github.com/jjjuuuun"><img src="https://avatars.githubusercontent.com/u/86290308?v=4?s=100" width="100px;" alt=""/><br /><sub><b>김준영</b></sub></a><br /> | ||
<a href="https://github.com/boostcampaitech5/level2_objectdetection-cv-03/commits?author=jjjuuuun" title="Code">💻</a> | ||
<a href="https://github.com/boostcampaitech5/level2_objectdetection-cv-03/tree/main/mmdetection/configs/_teamconfig_" title="Infrastructure">🚇</a> | ||
<a href="https://github.com/boostcampaitech5/level2_objectdetection-cv-03/commits/main" title="Maintenance">🚧</a> | ||
<a href="https://github.com/boostcampaitech5/level2_objectdetection-cv-03" title="projectManagement">📆</a> | ||
</td> | ||
<td align="center"><a href="https://github.com/helpmeIamnewbie"><img src="https://avatars.githubusercontent.com/u/102274521?v=4?s=100" width="100px;" alt=""/><br /><sub><b>전형우</b></sub></a><br /> | ||
<a href="https://github.com/boostcampaitech5/level2_objectdetection-cv-03/commits?author=helpmeIamnewbie" title="Code">💻</a> | ||
<a href="https://github.com/boostcampaitech5/level2_objectdetection-cv-03" title="Ideas & Planning">🤔</a> | ||
<a href="https://github.com/boostcampaitech5/level2_objectdetection-cv-03/pulls?q=" title="Reviewed Pull Requests">👀</a> | ||
<a href="https://github.com/boostcampaitech5/level2_objectdetection-cv-03/commits?author=helpmeIamnewbie" title="Tests">⚠️</a> | ||
</td> | ||
<td align="center"><a href="https://github.com/CheonJiEun"><img src="https://avatars.githubusercontent.com/u/53997172?v=4?s=100" width="100px;" alt=""/><br /><sub><b>천지은</b></sub></a><br /> | ||
<a href="https://github.com/boostcampaitech5/level2_objectdetection-cv-03/commits?author=CheonJiEun" title="Code">💻</a> | ||
<a href="https://github.com/boostcampaitech5/level2_objectdetection-cv-03/tree/main/upsampling" title="Data">🔣</a> | ||
<a href="https://github.com/boostcampaitech5/level2_objectdetection-cv-03/tree/main/mmdetection/configs/_teamconfig_" title="Examples">💡</a> | ||
<a href="https://github.com/boostcampaitech5/level2_objectdetection-cv-03/pulls?q=" title="Research">🔬</a> | ||
</td> | ||
<td align="center"><a href="https://github.com/Eyecaramba"><img src="https://avatars.githubusercontent.com/u/86091292?v=4?s=100" width="100px;" alt=""/><br /><sub><b>신우진</b></sub></a><br /> | ||
<a href="https://github.com/boostcampaitech5/level2_objectdetection-cv-03/commits?author=Eyecaramba" title="Code">💻</a> | ||
<a href="https://github.com/boostcampaitech5/level2_objectdetection-cv-03" title="Ideas & Planning">🤔</a> | ||
<a href="https://github.com/boostcampaitech5/level2_objectdetection-cv-03/tree/main/mmdetection/configs/_teamconfig_" title="Infrastructure">🚇</a> | ||
<a href="https://github.com/boostcampaitech5/level2_objectdetection-cv-03/pulls?q=" title="Reviewed Pull Requests">👀</a> | ||
</td> | ||
</tr> | ||
</table> | ||
</div> | ||
|
||
This project follows the [all-contributors](https://github.com/all-contributors/all-contributors) specification. Contributions of any kind welcome! | ||
|
||
# 💀 프로젝트 소개 | ||
|
||
<p align="center"> | ||
<picture> | ||
<img src="imgs/handbone_segmentation.png"> | ||
</picture> | ||
</p> | ||
|
||
뼈는 우리 몸의 구조와 기능에 중요한 영향을 미치기 때문에, 정확한 뼈 분할은 의료 진단 및 치료 계획을 개발하는 데 필수적입니다. Bone segmentation은 인공지능 분야에서 중요한 응용 분야 중 하나로, 특히, 딥러닝 기술을 이용한 뼈 segmentation은 많은 연구가 이루어지고 있으며, 다양한 목적으로 도움을 줄 수 있습니다. | ||
1. 질병 진단의 목적으로 뼈의 형태나 위치가 변형되거나 부러지거나 골절 등이 있을 경우, 그 부위에서 발생하는 문제를 정확하게 파악하여 적절한 치료를 시행할 수 있습니다. | ||
2. 수술 계획을 세우는데 도움이 됩니다. 의사들은 뼈 구조를 분석하여 어떤 종류의 수술이 필요한지, 어떤 종류의 재료가 사용될 수 있는지 등을 결정할 수 있습니다. | ||
3. 의료장비 제작에 필요한 정보를 제공합니다. 예를 들어, 인공 관절이나 치아 임플란트를 제작할 때 뼈 구조를 분석하여 적절한 크기와 모양을 결정할 수 있습니다. | ||
4. 의료 교육에서도 활용될 수 있습니다. 의사들은 병태 및 부상에 대한 이해를 높이고 수술 계획을 개발하는 데 필요한 기술을 연습할 수 있습니다. | ||
|
||
이번 프로젝트는 `Boostcamp AI Tech` CV 트랙내에서 진행된 대회이며 mean dice coefficient으로 최종평가를 진행하게 됩니다. | ||
|
||
# 📆 프로젝트 일정 | ||
|
||
프로젝트 전체 일정 | ||
|
||
- 2023.06.05 ~ 2023.06.22 | ||
|
||
프로젝트 세부 일정 | ||
|
||
- 2023.06.05 ~ 2023.06.09 : Semantic Semgmentation에 대해 알아보기 | ||
- 2023.06.08 ~ 2023.06.08 : Baseline Model 실험 | ||
- 2023.06.08 ~ 2023.06.09 : EDA | ||
- 2023.06.09 ~ 2023.06.11 : Augmentation 실험 | ||
- 2023.06.12 ~ 2023.06.14 : Loss, Optimizer, Image size 실험 | ||
- 2023.06.14 ~ 2023.06.14 : MMSegmentation 구현 | ||
- 2023.06.14 ~ 2023.06.22 : Model, Image size, Offline Augmentation 실험 | ||
- 2023.06.21 ~ 2023.06.22 : Ensemble | ||
|
||
# 👨💻 프로젝트 수행 | ||
|
||
1. [EDA](https://calico-dance-4bf.notion.site/EDA-db11b32576644efa9dc836a9135b55f0)✔️ | ||
2. [Augmentation](https://calico-dance-4bf.notion.site/Augmentation-5767f538c8ee4cf88462fe1bf2526a96)✔️ | ||
3. [Model](https://calico-dance-4bf.notion.site/Model-c8ddb0c1ddbf41abb5c0a2937da16b61)✔️ | ||
4. [중간정리](https://calico-dance-4bf.notion.site/09699a2814c04e83bb391627ab965c01)⭐ | ||
5. [발표자료](https://calico-dance-4bf.notion.site/f0407bed529a4bbbae93d5d6c520ec4f)⭐ | ||
|
||
# 🗒️ 프로젝트 결과 | ||
|
||
#### Public | ||
<img align="center" src="imgs/public.png" width="600" height="80"> | ||
|
||
#### Private | ||
<img align="center" src="imgs/private.png" width="600" height="80"> | ||
|
||
# 🔄️ Directory | ||
|
||
```bash | ||
├── .gitignore | ||
├── TTA.py | ||
├── dataset.py | ||
├── inference.py | ||
├── main.py | ||
├── metric.py | ||
├── test.py | ||
├── train.py | ||
├── pre-commit-config.yaml | ||
├── gitcommit_template.txt | ||
├── README.md | ||
├── imgs | ||
├── utils | ||
└── mmsegmentation | ||
├── _teamconfigs_ | ||
│ └── [test]ExpName | ||
│ ├── config.py | ||
│ ├── dataset.py | ||
│ ├── default_runtime.py | ||
│ ├── schedule.py | ||
│ └── segformer_mit-b0.py | ||
│ └── [test]MMSeg_AMP_GA | ||
│ | ||
├── train.py | ||
└── test.py | ||
``` | ||
|
||
# ⚙️ 설치 | ||
|
||
#### Baseline Code | ||
```pip install -r requirements.txt ``` | ||
|
||
#### MMSegmentation | ||
Link ➡️ | ||
1. [MMSegmentation for our Project](https://calico-dance-4bf.notion.site/MMSegmentation-71f191822d5042129ccbcf7b9384f211)✔️ | ||
2. [Official GitHub](https://github.com/open-mmlab/mmsegmentation)😀 | ||
|
||
# ⚡️ 빠른 시작 | ||
|
||
#### Train | ||
``` python train.py --exp-name {실험명} ``` | ||
#### Evaluation | ||
``` python test.py --exp-name {실험명} ``` | ||
|
||
# 🤔 Wrap-Up Report | ||
|
||
[Wrap-Up Report](https://file.notion.so/f/s/08bebda1-a3bb-4e83-93a9-706133868688/Semantic_Segmentation(%EA%B3%B5%EA%B0%9C%EC%9A%A9).pdf?id=59e780e8-9756-4381-8fc8-69e6383a4c16&table=block&spaceId=34e15efc-e2be-46ba-ae66-9fe65d825d78&expirationTimestamp=1687613606086&signature=pj_ir0N1PZm9B9ZTOH79znr-gxer3yajNNW3qO4-mFU&downloadName=Semantic+Segmentation%28%EA%B3%B5%EA%B0%9C%EC%9A%A9%29.pdf)⭐ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,219 @@ | ||
import os | ||
|
||
from tqdm import tqdm | ||
import numpy as np | ||
import pandas as pd | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
|
||
import random | ||
from argparse import ArgumentParser | ||
|
||
from dataset import XRayDataset, XRayInferenceDataset | ||
from torch.utils.data import Dataset | ||
|
||
import albumentations as A | ||
import wandb | ||
import ttach as tta | ||
from torch.utils.data import DataLoader | ||
import torch.nn as nn | ||
import torch.optim as optim | ||
from torchvision import models | ||
|
||
import cv2 | ||
import ttach as tta | ||
|
||
CLASSES = [ | ||
"finger-1", | ||
"finger-2", | ||
"finger-3", | ||
"finger-4", | ||
"finger-5", | ||
"finger-6", | ||
"finger-7", | ||
"finger-8", | ||
"finger-9", | ||
"finger-10", | ||
"finger-11", | ||
"finger-12", | ||
"finger-13", | ||
"finger-14", | ||
"finger-15", | ||
"finger-16", | ||
"finger-17", | ||
"finger-18", | ||
"finger-19", | ||
"Trapezium", | ||
"Trapezoid", | ||
"Capitate", | ||
"Hamate", | ||
"Scaphoid", | ||
"Lunate", | ||
"Triquetrum", | ||
"Pisiform", | ||
"Radius", | ||
"Ulna", | ||
] | ||
# https://github.com/qubvel/ttach/blob/master/ttach/transforms.py -- 참고 | ||
test_transform = tta.Compose([tta.Resize(sizes=(1024, 1024), original_size=(2048,2048), interpolation='bilinear'), tta.HorizontalFlip()]) | ||
|
||
class CustomModel(torch.nn.Module): | ||
def __init__(self, model): | ||
super(CustomModel, self).__init__() | ||
self.model = model | ||
|
||
def forward(self, x): | ||
output = self.model(x) | ||
|
||
if isinstance(output, dict): | ||
output = output["out"] | ||
return output | ||
|
||
class XRayInferenceDataset_TTA(Dataset): | ||
def __init__(self, data_root, transforms: A = None): | ||
""" | ||
Args: | ||
data_root : csv 파일 위치 | ||
""" | ||
self.df = pd.read_csv(os.path.join(data_root, f"test.csv")) | ||
self.data_root = data_root | ||
self.transforms = transforms | ||
|
||
def __len__(self): | ||
return len(self.df) | ||
|
||
def __getitem__(self, idx): | ||
row = self.df.iloc[idx] | ||
image_path = row["filenames"] | ||
image_name = os.path.join(image_path.split("/")[-2], image_path.split("/")[-1]) | ||
|
||
image = cv2.imread(image_path) | ||
|
||
image = image / 255.0 | ||
image = image.transpose(2, 0, 1) | ||
|
||
image = torch.from_numpy(image).float() | ||
|
||
return image, image_name | ||
|
||
def encode_mask_to_rle(mask): | ||
""" | ||
mask: numpy array binary mask | ||
1 - mask | ||
0 - background | ||
Returns encoded run length | ||
""" | ||
pixels = mask.flatten() | ||
pixels = np.concatenate([[0], pixels, [0]]) | ||
runs = np.where(pixels[1:] != pixels[:-1])[0] + 1 | ||
runs[1::2] -= runs[::2] | ||
return " ".join(str(x) for x in runs) | ||
|
||
|
||
def decode_rle_to_mask(rle, height, width): | ||
s = rle.split() | ||
starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])] | ||
starts -= 1 | ||
ends = starts + lengths | ||
img = np.zeros(height * width, dtype=np.uint8) | ||
|
||
for lo, hi in zip(starts, ends): | ||
img[lo:hi] = 1 | ||
|
||
return img.reshape(height, width) | ||
|
||
|
||
def test(data_loader, classes, best_model_dir, save_dir, is_csv=True, thr=0.5): | ||
print("Start inference ...") | ||
idx2class = {i: v for i, v in enumerate(classes)} | ||
|
||
model = torch.load(os.path.join(best_model_dir, "best_model.pt"))["model"] | ||
model = CustomModel(model) | ||
model = tta.SegmentationTTAWrapper(model, test_transform, merge_mode='mean') | ||
model.cuda() | ||
model.eval() | ||
|
||
rles = [] | ||
filename_and_class = [] | ||
with torch.no_grad(): | ||
for step, (images, image_names) in tqdm( | ||
enumerate(data_loader), total=len(data_loader) | ||
): | ||
images = images.cuda() | ||
# outputs = model(images)["out"] | ||
outputs = model(images) | ||
# restore original size | ||
outputs = F.interpolate(outputs, size=(2048, 2048), mode="bilinear") | ||
outputs = torch.sigmoid(outputs) | ||
outputs = (outputs > thr).detach().cpu().numpy() | ||
for output, image_name in zip(outputs, image_names): | ||
for c, segm in enumerate(output): | ||
rle = encode_mask_to_rle(segm) | ||
rles.append(rle) | ||
filename_and_class.append(f"{idx2class[c]}_{image_name}") | ||
|
||
if is_csv: | ||
classes, filename = zip(*[x.split("_") for x in filename_and_class]) | ||
image_name = [os.path.basename(f) for f in filename] | ||
df = pd.DataFrame( | ||
{ | ||
"image_name": image_name, | ||
"class": classes, | ||
"rle": rles, | ||
} | ||
) | ||
|
||
df.to_csv(os.path.join(save_dir, "submission.csv"), index=False) | ||
print("CSV file creation successful") | ||
else: | ||
return rles, filename_and_class | ||
|
||
|
||
def main(args): | ||
|
||
save_csv = os.path.join(args.save_csv, args.exp_name) | ||
save_checkpoint = os.path.join(args.save_checkpoint, args.exp_name) | ||
|
||
test_dataset = XRayInferenceDataset_TTA(args.data_root, transforms=test_transform) | ||
|
||
test_loader = DataLoader( | ||
dataset=test_dataset, | ||
batch_size=2, | ||
shuffle=False, | ||
num_workers=2, | ||
drop_last=False, | ||
) | ||
|
||
# Inference | ||
test(test_loader, CLASSES, save_checkpoint, save_csv, args.make_csv) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = ArgumentParser() | ||
|
||
# Path | ||
parser.add_argument( | ||
"--data-root", | ||
type=str, | ||
default="../data", | ||
) | ||
parser.add_argument( | ||
"--save-checkpoint", | ||
type=str, | ||
default="./checkpoints", | ||
) | ||
parser.add_argument( | ||
"--save-csv", | ||
type=str, | ||
default="./predictions", | ||
) | ||
|
||
parser.add_argument("--exp-name", type=str, default="[test]ExpName") | ||
|
||
# Inference | ||
parser.add_argument("--make-csv", type=bool, default=True) | ||
|
||
args = parser.parse_args() | ||
|
||
main(args) |
Oops, something went wrong.