Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue: Challenges in Using semantic_sam_trainer for Fine-Tuning Semantic Segmentation #847

Open
adityajbir opened this issue Feb 4, 2025 · 4 comments

Comments

@adityajbir
Copy link

I successfully implemented the instance segmentation function provided in the Micro-SAM repository. However, while using the semantic_sam_trainer function to fine-tune the model for semantic segmentation on custom images, I encountered several issues. Below, I detail the problems, fixes made to the source code, and the remaining unresolved issue.

Issue 1: RuntimeError: "host_softmax" not implemented for 'Bool'

Problem: When computing the Dice Loss using a custom loss function, the following error occurred:

RuntimeError: "host_softmax" not implemented for 'Bool'

This happened because the prediction tensor was treated as a boolean type, but the torch.softmax function requires a floating-point input.

Fix: To address this, I explicitly converted the pred tensor to a floating-point type before applying torch.softmax.

if self.softmax:
    pred = torch.softmax(pred.float(), dim=1)

Issue 2: ValueError: Expected input and target of same shape

Problem: When comparing the input tensor against class indices in _one_hot_encoder, the output shape of the tensor was [B, H*num_classes, W] instead of the expected [B, num_classes, H, W]. This caused a mismatch in shapes during the loss computation.

Fix: I modified the _one_hot_encoder function to unsqueeze the tensor along the channel dimension (axis 1) to align the output with the expected shape.
Modified Code:
Before:

temp_prob = input_tensor == i
tensor_list.append(temp_prob)

After:

temp_prob = (input_tensor == i).unsqueeze(1)  # Shape: [B, 1, H, W]
tensor_list.append(temp_prob)

Final concatenation:

output_tensor = torch.cat(tensor_list, dim=1)  # Shape: [B, num_classes, H, W]
return output_tensor.float()

Issue 3: RuntimeError: "host_softmax" not implemented for 'Bool' (Recurrence)

Problem: The softmax error reappeared due to the masks tensor being of boolean type. This issue arose inside the _compute_loss function.

Fix: I converted the masks tensor to a floating-point type within _compute_loss.
Modified Code:

masks = masks.float()

Issue 4: AssertionError: Class number out of range

Problem: While running the code with the assumption of 3 classes, the following error occurred:

Assertion `t >= 0 && t < n_classes` failed.

This indicates that one or more pixels in the target tensor had values outside the valid range of [0, num_classes-1]. This particular error occurred when it was run with 3 classes

Debugging Steps:

Verified the groud truth masks had 4 classes in the dataset.
Updated the code to handle 4 classes. However, this led to a shape mismatch error in the Dice Loss computation.

Issue 5: ValueError: Expected input and target of same shape

Problem: After resolving the previous issues, a ValueError was raised:

ValueError: Expected input and target of same shape, got: torch.Size([2, 3, 488, 685]), torch.Size([2, 4, 488, 685]).

This occurred because the input tensor had 3 channels, while the target tensor had 4 channels.

@anwai98
Copy link
Contributor

anwai98 commented Feb 5, 2025

Hi @adityajbir,

Thanks for your interest in micro-sam.

Before I look into the details and try to reproduce the issues, could you elaborate on the problem statement?
i.e. what are your input images, what are the corresponding labels, and what is the expected outcome?

This would be a good starting point for us to discuss further details!

@adityajbir
Copy link
Author

adityajbir commented Feb 6, 2025

Hi @anwai98 ,

For the input images I am using images of cells as pngs. For the mask I am using Labelme(annotation software) to first add labels to those input images. Those files are treated as json files and then I convert the json files back into the mask as a png. For the labels, we should have 4 classes, Background and 3 cell types. For the expected outcome, currently I am trying to get the training phase of the code working where I pass in all the necessary parameters to the SemanticSamTrainer. Once the training of the model is finished, I would evaluate the performance of the model against a test set and visualize it.

@anwai98
Copy link
Contributor

anwai98 commented Feb 17, 2025

Hi @adityajbir,

I am sorry for the late response. I got engaged with a few things here and there.

For us to get started with a proper discussion, I created example scripts for semantic segmentation (from the context, I am assuming the nature of task is segmenting cells in 2d images?): https://github.com/computational-cell-analytics/medico-sam/blob/master/examples/semantic_segmentation/train_semantic_segmentation_2d.py

(I think you are probably using the same structure, if not let me know. It would be good to identify the differences and discuss them)

Re: task: Okay. I understand the context now. Thanks for sharing!

Coming to a few pointers from the problems you encountered:

  • Can you ensure that your targets (i.e. ground-truth annotations are float tensors)? (I think keeping an eye of passing the inputs properly is important. In future, I plan to create sanity checks in future for users to be more aware of the batched inputs provided by the dataloader for training)
  • The size mismatch for targets and predictions is strange for me as well. If you check under the above provided link under the get_dataloaders function, the docstring briefly explains how the inputs (both the image and corresponding labels are expected). Can you verify if that already takes care of stuff for you?

My hunch says that if these two things are taken care of, most of the issues you reported should be fixed. Let me know if I missed anything.

(And I sincerely apologize once again for the super late response)

@adityajbir
Copy link
Author

Hi @anwai98 ,

Thank you for the response and I will go ahead and try some of the suggested changes and get back to you with my findings in a few days.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants