Skip to content

Commit

Permalink
Add unet results on llgmri
Browse files Browse the repository at this point in the history
  • Loading branch information
ptoupas committed Nov 19, 2023
1 parent ac27fb7 commit 44330a6
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 5 deletions.
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ bash scripts/run_quantization.sh
|-------|----------------------------------------------------------------|---------|---------|--------|--------------|----------------|
| unet | [mmsegmentation](https://github.com/open-mmlab/mmsegmentation) | 69.10 | 69.10 | 1.98 | 61.74 | 68.43 |

### llgmri (val, Dice coefficient)
| Model | Source | Float32 | Fixed16 | Fixed8 | BFP8 (Layer) | BFP8 (Channel) |
|---------------|-------------------------------------------------|---------|---------|--------|--------------|----------------|
| unet | [brain-segmentation-pytorch](https://github.com/mateuszbuda/brain-segmentation-pytorch) | 90.89 | 90.88 | 80.98 | 90.95 | 90.85 |
| unet-bilinear | [brain-segmentation-pytorch](https://github.com/mateuszbuda/brain-segmentation-pytorch) | 91.05 | 91.05 | 77.51 | 91.04 | 91.03 |

### ucf101 (val-split1, top-1 acc)
| Model | Source | Float32 | Fixed16 | Fixed8 | BFP8 (Layer) | BFP8 (Channel) |
|-------|----------------------------------------------------------------|---------|---------|--------|--------------|----------------|
Expand Down
15 changes: 10 additions & 5 deletions models/segmentation/lggmri.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,17 @@
from skimage.transform import resize
from skimage.io import imread
from torch.utils.data import DataLoader, Dataset
from models.segmentation.utils import apply_conv_transp_approx
from tqdm import tqdm

class BrainModelWrapper(TorchModelWrapper):
def load_model(self, eval=True):
def load_model(self, eval=True, approx_transpose_conv=True):
self.model = torch.hub.load('mateuszbuda/brain-segmentation-pytorch', 'unet',
in_channels=3, out_channels=1, init_features=32, pretrained=True)

if approx_transpose_conv:
apply_conv_transp_approx(self.model)

if torch.cuda.is_available():
self.model = self.model.cuda()

Expand All @@ -24,7 +29,7 @@ def load_data(self, batch_size, workers):
# https://www.kaggle.com/datasets/mateuszbuda/lgg-mri-segmentation

LGGMRI_PATH = os.environ.get("LGGMRI_PATH", os.path.expanduser("~/dataset/lgg-mri-segmentation/kaggle_3m"))

val_dataset = BrainSegmentationDataset(
images_dir=LGGMRI_PATH,
subset="validation",
Expand All @@ -36,7 +41,7 @@ def load_data(self, batch_size, workers):
val_dataset, batch_size=batch_size, drop_last=False, num_workers=workers
)

self.data_loaders['validate'] = val_loader
self.data_loaders['validate'] = val_loader
self.data_loaders['calibrate'] = val_loader # todo: support calibrate

def inference(self, mode='validate'):
Expand Down Expand Up @@ -72,7 +77,7 @@ def inference(self, mode='validate'):
loader.dataset.patients,
)
dsc_dist = dsc_distribution(volumes)
dsc_dist_plot = plot_dsc(dsc_dist)
imsave(args.figure, dsc_dist_plot)
Expand All @@ -89,7 +94,7 @@ def inference(self, mode='validate'):
filepath = os.path.join(args.predictions, filename)
imsave(filepath, image)
'''

mean_dsc = np.mean( dsc_per_volume(pred_list, true_list, loader.dataset.patient_slice_index) )
print("Mean DSC:", mean_dsc)
return mean_dsc
Expand Down

0 comments on commit 44330a6

Please sign in to comment.