Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue on Tutorial Annotation - UMAP is not from Learned Cell Embeddings #259

Open
ManuelSokolov opened this issue Sep 24, 2024 · 2 comments

Comments

@ManuelSokolov
Copy link

ManuelSokolov commented Sep 24, 2024

Dear scGPT creators,

On the cell type annotation tutorial there is a typo:

At the UMAP plotting:

In the cell-type annotation task, the fine-tuned scGPT predicts cell-type labels for query set as inference. The model performance is evaluated on standard classificaton metrics. Here we visualize the predicted labels over the scGPT cell embeddings, and present the confusion matrix for detailed classification performance on the cell-group level.

The UMAP is plotted with the single cell gene expression, UMAP embeddings and not leveraging the cell embeddings.

To obtain the cell embeddings, the evaluate function should be altered in such manner:

def evaluate(model: nn.Module, loader: DataLoader, return_raw: bool = False) -> float:
"""
Evaluate the model on the evaluation data.
"""
model.eval()
total_loss = 0.0
total_error = 0.0
total_dab = 0.0
total_num = 0
predictions = []
output_cell_embeddings = []
with torch.no_grad():
  for batch_data in loader:
      input_gene_ids = batch_data["gene_ids"].to(device)
      input_values = batch_data["values"].to(device)
      target_values = batch_data["target_values"].to(device)
      batch_labels = batch_data["batch_labels"].to(device)
      celltype_labels = batch_data["celltype_labels"].to(device)

      src_key_padding_mask = input_gene_ids.eq(vocab[pad_token])
      with torch.cuda.amp.autocast(enabled=config.amp):
          output_dict = model(
              input_gene_ids,
              input_values,
              src_key_padding_mask=src_key_padding_mask,
              batch_labels=batch_labels if INPUT_BATCH_LABELS or config.DSBN else None,
              CLS=CLS,  # evaluation does not need CLS or CCE
              CCE=False,
              MVC=False,
              ECS=False,
              do_sample=do_sample_in_train,
              #generative_training = False,
          )
         cell_embeddings = output_dict["cell_emb"]
         for row in cell_embeddings:
                    output_cell_embeddings.append(row.cpu().numpy())
         
          loss = criterion_cls(output_values, celltype_labels)

          if DAB:
              loss_dab = criterion_dab(output_dict["dab_output"], batch_labels)

      total_loss += loss.item() * len(input_gene_ids)
      accuracy = (output_values.argmax(1) == celltype_labels).sum().item()
      total_error += (1 - accuracy / len(input_gene_ids)) * len(input_gene_ids)
      total_dab += loss_dab.item() * len(input_gene_ids) if DAB else 0.0
      total_num += len(input_gene_ids)
      preds = output_values.argmax(1).cpu().numpy()
      predictions.append(preds)

    wandb.log(
        {
            "valid/mse": total_loss / total_num,
            "valid/err": total_error / total_num,
            "valid/dab": total_dab / total_num,
            "valid/sum_mse_dab": (total_loss + dab_weight * total_dab) / total_num,
            "epoch": epoch,
        },
    )

    if return_raw:
        return np.concatenate(predictions, axis=0), output_cell_embeddings

    return total_loss / total_num, total_error / total_num

Then of course adapting test to obtain the cell embeddings of each prediction:

predictions, embeddings = test(best_model, adata_test)

This way the cell embeddings can be used to plot the UMAP and not the single cell gene expression values. On the paper we have an ilustration of the embedding clustering, I believe that it does generated in the same manner so that the embeddings can generated a better representation of the cells than normalized gene expression per cell.

Best Regards.

Manuel

@ManuelSokolov ManuelSokolov changed the title Issue on Tutorial Annotation Issue on Tutorial Annotation - UMAP is not from Learned Cell Embeddings Sep 25, 2024
@subercui
Copy link
Member

subercui commented Oct 2, 2024

Hi, thanks and I am trying to understand your comment.

The UMAP is plotted with the single cell gene expression, UMAP embeddings and not leveraging the cell embeddings.

I think that is indeed our expected behavior of the notebook, since plotting UMAP of the cell embeddings is not the topic of the cell annotation tutorial. In the meantime, are you suggesting we can use the cell embeddings to plot UMAP as well? If so, I surely agree and appreciate your code suggestions.

@ManuelSokolov
Copy link
Author

Yes exactly, just opened a pull request with the code for the change :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants