-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding NTXENT loss #204
Adding NTXENT loss #204
Conversation
edyoshikun
commented
Nov 13, 2024
- Adding the ntxent loss besides the tripletmarginal loss
The correctness of the loss function can be checked against community implementations (e.g. this one from PML). Alternatively we can add it as a dependency for lower maintenance overhead. |
I double checked our previous implementation with the one from |
@edyoshikun good call to use |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The changes look good to me! Please merge after comparing training logs and embeddings learned by triplet loss and NTXent loss.
@alishbaimran this is the PR with the nt-xent loss |
@@ -602,7 +602,8 @@ def _train_transform(self) -> list[Callable]: | |||
else: | |||
self.augmentations = [] | |||
if z_scale_range is not None: | |||
if isinstance(z_scale_range, float): | |||
if isinstance(z_scale_range, (float, int)): | |||
z_scale_range = float(z_scale_range) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this change be merged with this PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you see any issue? I was running into issues if the value was not a float.
@edyoshikun , I am receiving the following error with classical contrastive learning:
The tracks zarr I am using has a unique track id for each instance. Do you know what could be causing the issue? |
@edyoshikun , you can find the whole log here: |
@Soorya19Pradeep can you post the whole error? I assume this is coming from the sample detaching method, but not sure. This might be a dataloader issue. Can you check your dataloader with the classical contrastive is returning the right samples? |
@edyoshikun , here is the error:
How can I check the samples returned from the dataloader? |
This could happen if your dataset has length of 0, i.e. no valid patch is available. |
@Soorya19Pradeep have you recently used this branch with the latest commits? |
a190be5
to
2b4a359
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @edyoshikun I am realizing that the nearest neighbor metric that @ziw-liu contributed and the ALFI_displacement.py
that @alishbaimran contributed are already in this branch. Multiple files on this PR were already merged with main
(e.g., gradio example, nearest neighbor metric of distance).
But, the diff relative to the main shows that these are new files. Once you resolve this, I think we should merge this branch into main
and work from there. We can surface more specific bugs and add features related to organelle phenotyping after we merge this branch.
examples/gradio/demo_gradio.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This file was merged with the previous PR to main
. It may be better to merge main
into this branch.
def cross_dissimilarity(features: ArrayLike, metric: str = "cosine") -> NDArray: | ||
"""Dissimilarity/distance between each pair of samples in the features. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps distance_matrix
or similarity_matrix
is a more intuitive name for the method. The distance is measured by a metric. The metric may be cosine
, eucledian
, ...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The contributions in this file are already on main
: https://github.com/mehta-lab/VisCy/blob/main/viscy/representation/evaluation/clustering.py. Why is the diff showing this as a new file?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
probably need to rebase this with main.
…this does not work.
and modifying previous to match same input args
* translation: fix validation loss aggregation (#202) * exposing prefetch and persistent worker (#203) * metrics for dynamic, smoothness and docstrings * updated metrics and plots for distance * fixed CI test cases * nexnt loss prototype * fix bug with z_scale_range in hcs datamodule. If the value is an int this does not work. * exclude the negative pair from dataloader and forward pass * adding option using pytorch-metric-learning implementation and modifying previous to match same input args * removing our implementation of NTXentLoss and using pytorch metric * ruff * prototype for phate and umap plot * - proofreading the calculations - removing unecessary calls to ALFI script - simplifying code to re-use functions * methods to rank nearest neighbors in embeddings * example script to plot state change of a single track * test using scaled features * phate embeddings * removing dataframe from the compute_phate adding docstring * adding phate to the prediction writer and moving it as dependency. * changing the phate defaults in the prediction writer. * ruff * fixing bug in phate in predict writer * adding code for measuring the smoothness * cleanup to run on triplet and ntxent * fix plots for smoothnes * nexnt loss prototype * exclude the negative pair from dataloader and forward pass * adding option using pytorch-metric-learning implementation and modifying previous to match same input args * removing our implementation of NTXentLoss and using pytorch metric * ruff * remove blank line diff * remove blank line diff * simplying the engine * explicit target shape argument in the HCS data module * Revert "explicit target shape argument in the HCS data module" This reverts commit 464d4c9. * Explicit target shape argument in the HCS data module (#212) * explicit target shape argument in the HCS data module * update docstring * update test cases * Gradio example (#158) * initial demo * using the predict_step * modifying paths to chkpt and example pngs * updating gradio as the one on Huggingface * adding configurable phate arguments via config * script to recompute phate and overwrite the previous phate data * ruff * solving redundancies * modularizing the smoothness * removing redundant _fit_phate() * ruff --------- Co-authored-by: Ziwen Liu <[email protected]> Co-authored-by: Alishba Imran <[email protected]> Co-authored-by: Ziwen Liu <[email protected]>
3475223
to
fcc1bd2
Compare