Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to change number of classes for already trained retinanet model #28

Open
Monk5088 opened this issue Sep 24, 2022 · 3 comments
Open

Comments

@Monk5088
Copy link

I have trained my retinanet mode from object detection library, now i want to change the number of classes for the next dataset.
I have found a way in pytorch to change the classification head of retinanet to do the same, can anyone help me on how can i perform the same for retinanet.py from object-detection-fastai.

from torchvision.models.detection import retinanet_resnet50_fpn_v2, RetinaNet_ResNet50_FPN_V2_Weights
        from torchvision.models.detection.retinanet import RetinaNetHead, RetinaNetClassificationHead
        weights = RetinaNet_ResNet50_FPN_V2_Weights.DEFAULT
        model = retinanet_resnet50_fpn_v2(weights=weights, box_score_thresh=0.7)

        # replace classification layer
        out_channels = model.head.classification_head.conv[0].out_channels
        num_anchors = model.head.classification_head.num_anchors
        model.head.classification_head.num_classes = num_classes

        cls_logits = torch.nn.Conv2d(out_channels, num_anchors * num_classes, kernel_size=3, stride=1, padding=1)
        torch.nn.init.normal_(cls_logits.weight, std=0.01)  # as per pytorch code
        torch.nn.init.constant_(cls_logits.bias, -math.log((1 - 0.01) / 0.01))  # as per pytorcch code
        # assign cls head to model
        model.head.classification_head.cls_logits = cls_logits
@ChristianMarzahl
Copy link
Owner

Dear @Monk5088,

If your dataset and with it the number of classes, the code should adapt to the new number of classes.

In the following example, the parameter n_classes defines the number of classes.

RetinaNet(encoder, n_classes=data.train_ds.c, n_anchors=3, sizes=[32], chs=8, final_bias=-4., n_conv=3)

@Monk5088
Copy link
Author

Monk5088 commented Sep 27, 2022

Dear @ChristianMarzahl ,
I have tried changing the number of classes while initialising retinanet, but when i do learner.load(), it throws me weight mismatch error.
CODE:

batch_size = 64

do_flip = True
flip_vert = True 
max_rotate = 90 
max_zoom = 1.1 
max_lighting = 0.2
max_warp = 0.2
p_affine = 0.75 
p_lighting = 0.75 

tfms = get_transforms(do_flip=do_flip,
                      flip_vert=flip_vert,
                      max_rotate=max_rotate,
                      max_zoom=max_zoom,
                      max_lighting=max_lighting,
                      max_warp=max_warp,
                      p_affine=p_affine,
                      p_lighting=p_lighting)
train, valid = ObjectItemListSlide(train_images) ,ObjectItemListSlide(valid_images)
item_list = ItemLists(".", train, valid)
lls = item_list.label_from_func(lambda x: x.y, label_cls=SlideObjectCategoryList)
lls = lls.transform(tfms, tfm_y=True, size=patch_size)
data = lls.databunch(bs=batch_size, collate_fn=bb_pad_collate,num_workers=0).normalize() 

Here train dataset is as follows:
SlideLabelList (100 items)
x: ObjectItemListSlide
Image (3, 256, 256),Image (3, 256, 256),Image (3, 256, 256),Image (3, 256, 256),Image (3, 256, 256)
y: SlideObjectCategoryList
ImageBBox (256, 256),ImageBBox (256, 256),ImageBBox (256, 256),ImageBBox (256, 256),ImageBBox (256, 256)
Path: .
And the train_images is a list of object_detection_fastai.helper.wsi_loader.SlideContainer objects that are created using the following function:

def create_wsi_container(annotations_df: pd.DataFrame):

    container = []

    for image_name in tqdm(annotations_df["file_name"].unique()):

        image_annos = annotations_df[annotations_df["file_name"] == image_name]

        bboxes = [box   for box   in image_annos["box"]]
        labels = [label for label in image_annos["cat"]]

        container.append(SlideContainer(image_folder/image_name, y=[bboxes, labels], level=res_level,width=patch_size, height=patch_size, sample_func=sample_function))

    return container

CODE FOR LEARNER:

backbone = "ResNet34" #["ResNet18", "ResNet34", "ResNet50", "ResNet101", "ResNet150"]

backbone_model = models.resnet18
if backbone == "ResNet34":
    backbone_model = models.resnet34
if backbone == "ResNet50":
    backbone_model = models.resnet50
if backbone == "ResNet101":
    backbone_model = models.resnet101
if backbone == "ResNet150":
    backbone_model = models.resnet150

pre_trained_on_imagenet = False
encoder = create_body(backbone_model, pre_trained_on_imagenet, -2)


loss_function = "FocalLoss" 

if loss_function == "FocalLoss":
    crit = RetinaNetFocalLoss(anchors)


channels = 128 


final_bias = -4 


n_conv = 3 
model = RetinaNet(encoder, n_classes=3, 
                  n_anchors=len(scales) * len(ratios), 
                  sizes=[size[0] for size in sizes], 
                  chs=channels, # number of hidden layers for the classification head
                  final_bias=final_bias,
                  n_conv=n_conv # Number of hidden layers
                  )
voc = PascalVOCMetric(anchors, patch_size, [str(i) for i in data.train_ds.y.classes[1:]])
learn = Learner(data, model, loss_func=crit, 
                callback_fns=[BBMetrics,ShowGraph,CSVLogger,partial(GradientClipping, clip=2.0)],metrics=[voc])
learn.load("PATH/to/.pth",strict=False) 

ERROR:

/usr/local/lib/python3.7/dist-packages/fastai/basic_train.py in load(self, file, device, strict, with_opt, purge, remove_module)
    271             model_state = state['model']
    272             if remove_module: model_state = remove_module_load(model_state)
--> 273             get_model(self.model).load_state_dict(model_state, strict=strict)
    274             if ifnone(with_opt,True):
    275                 if not hasattr(self, 'opt'): self.create_opt(defaults.lr, self.wd)

/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
   1496         if len(error_msgs) > 0:
   1497             raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
-> 1498                                self.class.name, "\n\t".join(error_msgs)))
   1499         return _IncompatibleKeys(missing_keys, unexpected_keys)
   1500 

RuntimeError: Error(s) in loading state_dict for RetinaNet:
    size mismatch for classifier.3.weight: copying a param with shape torch.Size([2, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([3, 128, 3, 3]).
    size mismatch for classifier.3.bias: copying a param with shape torch.Size([2]) from checkpoint, the shape in current model is torch.Size([3]).

@Monk5088
Copy link
Author

I have trained it on 2 class dataset, but i want it to just predict on the 3 class dataset, and both dataset share the same first and last class in databunch ,i.e., the first databnch i trained my retinanet on has following classes:
['background', 'mitosis']
While the new dataset on which i need prediction contains the following classes:
['background', 'hard negative', 'mitosis']
So is it possible for my model to only predict the mitosis for new dataset.
Thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants