Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A issue when I want to test on my own dataset #13

Open
wangbaoyuanGUET opened this issue Jun 18, 2024 · 0 comments
Open

A issue when I want to test on my own dataset #13

wangbaoyuanGUET opened this issue Jun 18, 2024 · 0 comments

Comments

@wangbaoyuanGUET
Copy link

wangbaoyuanGUET commented Jun 18, 2024

Hi!Dear Developers!
Here is my test code, please ask me if I wrote it correctly?

import torch
import numpy as np
from lib import networks
from lib import models
from lib.data.med_transforms import *
from lib.utils import set_seed, dist_setup, get_conf
from monai.losses import DiceCELoss, DiceLoss
from collections import defaultdict, OrderedDict
from monai.metrics import compute_meandice, compute_hausdorff_distance
from functools import partial
from lib.data.med_datasets import *
from lib.utils import SmoothedValue, concat_all_gather, LayerDecayValueAssigner
from monai.inferers import sliding_window_inference
from monai.data import decollate_batch
import nibabel as nib 

class Test():
    def __init__(self, args):
        #super().__init__(args, test_path)
        self.args = args
        self.model_name = args.proj_name
        self.scaler = torch.cuda.amp.GradScaler()
        self.metric_funcs = OrderedDict([('Dice', compute_meandice), ('HD', partial(compute_hausdorff_distance, percentile=95))])
        
    def build_model(self):
        print(f"=> creating model {self.model_name}")

        self.loss_fn = DiceCELoss(to_onehot_y=True,
                                          softmax=True,
                                          squared_pred=True,
                                          smooth_nr=args.smooth_nr,
                                          smooth_dr=args.smooth_dr)
        self.post_pred, self.post_label = get_post_transforms(args)
        self.model = getattr(models, self.model_name)(encoder=getattr(networks, args.enc_arch),
                                                          decoder=getattr(networks, args.dec_arch),
                                                          args=args)
        print(f"=> loading checkpoint")
        checkpoint = torch.load(args.pretrain, map_location='cpu')
        state_dict = checkpoint['state_dict']
        msg = self.model.load_state_dict(state_dict, strict=False)
        print(f"Loading messages: \n {msg}")
        print(f"=> Finish loading pretrained weights from {args.pretrain}")
        self.model.eval()
        self.model.cuda(args.gpu)

    def build_dataloader(self):
        print("=> creating test dataloader")
        args = self.args
        #test_transform = get_test_transforms(args)
        test_transform = get_testV2_transforms(args)

        self.val_dataloader = get_val_loader(args, args.batch_size, args.workers, test_transform)

    @torch.no_grad()
    def evaluate(self):
        args = self.args
        self.build_dataloader()
        self.build_model()
        model = self.model
        dice_list_case = []
        print("=> Start Evaluating")
        val_loader = self.val_dataloader        
        roi_size = (args.roi_x, args.roi_y, args.roi_z) if args.spatial_dim == 3 else None
        meters = defaultdict(SmoothedValue)
        ts_samples = int(len(val_loader))
        val_samples = len(val_loader) - ts_samples
        ts_meters = defaultdict(SmoothedValue)

        for i, batch_data in enumerate(val_loader):
            image, target = batch_data['image'].to(args.gpu, non_blocking=True), batch_data['label'].to(args.gpu, non_blocking=True)
            original_affine = batch_data["label_meta_dict"]["affine"][0].numpy()
            _, _, h, w, d = target.shape
            target_shape = (h, w, d)
            img_name = batch_data["image_meta_dict"]["filename_or_obj"][0].split("/")[-1]

            with torch.cuda.amp.autocast():
                val_output = sliding_window_inference(image, roi_size=roi_size, sw_batch_size=4, predictor=model, overlap=args.infer_overlap)
                val_output = torch.softmax(val_output, 1).cpu().numpy()
                val_output = np.argmax(val_output, axis=1).astype(np.uint8)[0]
                target = target.cpu().numpy()[0, 0, :, :, :]
                val_output = resample_3d(img=val_output, target_size=target_shape)
                print(f'val_output shape is {val_output.shape} | target shape is {target_shape}')
                mean_dice = dice(val_output == 1, target == 1)
                print(f"=>Evaluating on {img_name}, Mean Dice: {mean_dice}")    
                dice_list_case.append(mean_dice)
                nib.save(
                    nib.Nifti1Image(val_output.astype(np.uint8), original_affine), os.path.join('/home/lzb/wby/3D_Project/SelfMedMAEv2.0/Test_Output', img_name)
                )
        print("Overall Mean Dice: {}".format(np.mean(dice_list_case)))

def resample_3d(img, target_size):
    imx, imy, imz = img.shape
    tx, ty, tz = target_size
    zoom_ratio = (float(tx) / float(imx), float(ty) / float(imy), float(tz) / float(imz))
    import scipy.ndimage as ndimage
    img_resampled = ndimage.zoom(img, zoom_ratio, order=0, prefilter=False)
    return img_resampled


def dice(x, y):
    intersect = np.sum(np.sum(np.sum(x * y)))
    y_sum = np.sum(np.sum(np.sum(y)))
    if y_sum == 0:
         return 0.0
    x_sum = np.sum(np.sum(np.sum(x)))
    return 2 * intersect / (x_sum + y_sum)

def compute_avg_metric(metric, meters, metric_name, batch_size, args):
    assert len(metric.shape) == 2
    if args.dataset == 'btcv':
        # cls_avg_metric = np.nanmean(np.nanmean(metric, axis=0))
        cls_avg_metric = np.mean(np.ma.masked_invalid(np.nanmean(metric, axis=0)))
        # cls8_avg_metric = np.nanmean(np.nanmean(metric[..., btcv_8cls_idx], axis=0))
        #cls8_avg_metric = np.nanmean(np.ma.masked_invalid(np.nanmean(metric[..., btcv_8cls_idx], axis=0)))
        meters[metric_name].update(value=cls_avg_metric, n=batch_size)
        #meters[f'cls8_{metric_name}'].update(value=cls8_avg_metric, n=batch_size)
    else:
        cls_avg_metric = np.nanmean(np.nanmean(metric, axis=0))
        meters[metric_name].update(value=cls_avg_metric, n=batch_size)

if __name__ == '__main__':
    args = get_conf()
    args.test = True
    args.num_classes = 2
    test_example = Test(args)
    test_example.evaluate()
    
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant