Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding NTXENT loss #204

Merged
merged 11 commits into from
Dec 23, 2024
Merged

Adding NTXENT loss #204

merged 11 commits into from
Dec 23, 2024

Conversation

edyoshikun
Copy link
Contributor

  • Adding the ntxent loss besides the tripletmarginal loss

@ziw-liu ziw-liu linked an issue Nov 13, 2024 that may be closed by this pull request
@ziw-liu ziw-liu added enhancement New feature or request representation Representation learning (SSL) labels Nov 13, 2024
@ziw-liu ziw-liu added this to the v0.4.0 milestone Nov 13, 2024
@ziw-liu
Copy link
Collaborator

ziw-liu commented Nov 13, 2024

The correctness of the loss function can be checked against community implementations (e.g. this one from PML). Alternatively we can add it as a dependency for lower maintenance overhead.

@edyoshikun
Copy link
Contributor Author

I double checked our previous implementation with the one from pytorch-metric-learning spitting same results. Adding this package as dependency.

@edyoshikun edyoshikun marked this pull request as ready for review November 14, 2024 22:36
@mattersoflight
Copy link
Member

@edyoshikun good call to use pytorch-metric-learning

Copy link
Member

@mattersoflight mattersoflight left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes look good to me! Please merge after comparing training logs and embeddings learned by triplet loss and NTXent loss.

@edyoshikun
Copy link
Contributor Author

@alishbaimran this is the PR with the nt-xent loss

@@ -602,7 +602,8 @@ def _train_transform(self) -> list[Callable]:
else:
self.augmentations = []
if z_scale_range is not None:
if isinstance(z_scale_range, float):
if isinstance(z_scale_range, (float, int)):
z_scale_range = float(z_scale_range)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this change be merged with this PR?

Copy link
Contributor Author

@edyoshikun edyoshikun Dec 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you see any issue? I was running into issues if the value was not a float.

viscy/data/triplet.py Outdated Show resolved Hide resolved
@Soorya19Pradeep
Copy link
Contributor

Soorya19Pradeep commented Dec 10, 2024

@edyoshikun , I am receiving the following error with classical contrastive learning:

raise ValueError(f"num_samples should be a positive integer value, but got num_samples={self.num_samples}")
ValueError: num_samples should be a positive integer value, but got num_samples=0

The tracks zarr I am using has a unique track id for each instance. Do you know what could be causing the issue?

@Soorya19Pradeep
Copy link
Contributor

@edyoshikun , you can find the whole log here: /hpc/projects/intracellular_dashboard/organelle_dynamics/organelle_phenotyping/models/classical/SEC61/rev1_2chan_NTXent/slurm-17231174.out

@edyoshikun
Copy link
Contributor Author

@Soorya19Pradeep can you post the whole error? I assume this is coming from the sample detaching method, but not sure. This might be a dataloader issue. Can you check your dataloader with the classical contrastive is returning the right samples?

@Soorya19Pradeep
Copy link
Contributor

@edyoshikun , here is the error:

[rank3]: Traceback (most recent call last):
[rank3]:   File "/hpc/mydata/soorya.pradeep/envs/dynaclr/bin/viscy", line 8, in <module>
[rank3]:     sys.exit(main())
[rank3]:   File "/hpc/mydata/soorya.pradeep/scratch/viscy_infection_phenotyping/VisCy/viscy/cli.py", line 55, in main
[rank3]:     _ = VisCyCLI(
[rank3]:   File "/hpc/mydata/soorya.pradeep/envs/dynaclr/lib/python3.10/site-packages/lightning/pytorch/cli.py", line 394, in __init__
[rank3]:     self._run_subcommand(self.subcommand)
[rank3]:   File "/hpc/mydata/soorya.pradeep/envs/dynaclr/lib/python3.10/site-packages/lightning/pytorch/cli.py", line 701, in _run_subcommand
[rank3]:     fn(**fn_kwargs)
[rank3]:   File "/hpc/mydata/soorya.pradeep/envs/dynaclr/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 538, in fit
[rank3]:     call._call_and_handle_interrupt(
[rank3]:   File "/hpc/mydata/soorya.pradeep/envs/dynaclr/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 46, in _call_and_handle_interrupt
[rank3]:     return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
[rank3]:   File "/hpc/mydata/soorya.pradeep/envs/dynaclr/lib/python3.10/site-packages/lightning/pytorch/strategies/launchers/subprocess_script.py", line 105, in launch
[rank3]:     return function(*args, **kwargs)
[rank3]:   File "/hpc/mydata/soorya.pradeep/envs/dynaclr/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 574, in _fit_impl
[rank3]:     self._run(model, ckpt_path=ckpt_path)
[rank3]:   File "/hpc/mydata/soorya.pradeep/envs/dynaclr/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 981, in _run
[rank3]:     results = self._run_stage()
[rank3]:   File "/hpc/mydata/soorya.pradeep/envs/dynaclr/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 1025, in _run_stage
[rank3]:     self.fit_loop.run()
[rank3]:   File "/hpc/mydata/soorya.pradeep/envs/dynaclr/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py", line 197, in run
[rank3]:     self.setup_data()
[rank3]:   File "/hpc/mydata/soorya.pradeep/envs/dynaclr/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py", line 225, in setup_data
[rank3]:     train_dataloader = _request_dataloader(source)
[rank3]:   File "/hpc/mydata/soorya.pradeep/envs/dynaclr/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py", line 325, in _request_dataloader
[rank3]:     return data_source.dataloader()
[rank3]:   File "/hpc/mydata/soorya.pradeep/envs/dynaclr/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py", line 292, in dataloader
[rank3]:     return call._call_lightning_datamodule_hook(self.instance.trainer, self.name)
[rank3]:   File "/hpc/mydata/soorya.pradeep/envs/dynaclr/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 189, in _call_lightning_datamodule_hook
[rank3]:     return fn(*args, **kwargs)
[rank3]:   File "/hpc/mydata/soorya.pradeep/scratch/viscy_infection_phenotyping/VisCy/viscy/data/hcs.py", line 519, in train_dataloader
[rank3]:     return DataLoader(
[rank3]:   File "/hpc/mydata/soorya.pradeep/envs/dynaclr/lib/python3.10/site-packages/lightning/fabric/utilities/data.py", line 324, in wrapper
[rank3]:     init(obj, *args, **kwargs)
[rank3]:   File "/hpc/mydata/soorya.pradeep/envs/dynaclr/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 351, in __init__
[rank3]:     sampler = RandomSampler(dataset, generator=generator)  # type: ignore[arg-type]
[rank3]:   File "/hpc/mydata/soorya.pradeep/envs/dynaclr/lib/python3.10/site-packages/torch/utils/data/sampler.py", line 144, in __init__
[rank3]:     raise ValueError(f"num_samples should be a positive integer value, but got num_samples={self.num_samples}")
[rank3]: ValueError: num_samples should be a positive integer value, but got num_samples=0

How can I check the samples returned from the dataloader?

@ziw-liu
Copy link
Collaborator

ziw-liu commented Dec 11, 2024

This could happen if your dataset has length of 0, i.e. no valid patch is available.

@edyoshikun
Copy link
Contributor Author

@Soorya19Pradeep have you recently used this branch with the latest commits?

Copy link
Member

@mattersoflight mattersoflight left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @edyoshikun I am realizing that the nearest neighbor metric that @ziw-liu contributed and the ALFI_displacement.py that @alishbaimran contributed are already in this branch. Multiple files on this PR were already merged with main (e.g., gradio example, nearest neighbor metric of distance).

But, the diff relative to the main shows that these are new files. Once you resolve this, I think we should merge this branch into main and work from there. We can surface more specific bugs and add features related to organelle phenotyping after we merge this branch.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file was merged with the previous PR to main. It may be better to merge main into this branch.

Comment on lines 36 to 37
def cross_dissimilarity(features: ArrayLike, metric: str = "cosine") -> NDArray:
"""Dissimilarity/distance between each pair of samples in the features.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps distance_matrix or similarity_matrix is a more intuitive name for the method. The distance is measured by a metric. The metric may be cosine, eucledian, ...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The contributions in this file are already on main: https://github.com/mehta-lab/VisCy/blob/main/viscy/representation/evaluation/clustering.py. Why is the diff showing this as a new file?

Copy link
Contributor Author

@edyoshikun edyoshikun Dec 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably need to rebase this with main.

ziw-liu and others added 4 commits December 22, 2024 17:51
* translation: fix validation loss aggregation (#202)

* exposing prefetch and persistent worker (#203)

* metrics for dynamic, smoothness and docstrings

* updated metrics and plots for distance

* fixed CI test cases

* nexnt loss prototype

* fix bug with z_scale_range in hcs datamodule. If the value is an int this does not work.

* exclude the negative pair from dataloader and forward pass

* adding option using pytorch-metric-learning implementation
and modifying previous to match same input args

* removing our implementation of NTXentLoss and using pytorch metric

* ruff

* prototype for phate and umap plot

* - proofreading the calculations
- removing unecessary calls to ALFI script
- simplifying code to re-use functions

* methods to rank nearest neighbors in embeddings

* example script to plot state change of a single track

* test using scaled features

* phate embeddings

* removing dataframe from the compute_phate
adding docstring

* adding phate to the prediction writer and moving it as dependency.

* changing the phate defaults in the prediction writer.

* ruff

* fixing bug in phate in predict writer

* adding code for measuring the smoothness

* cleanup to run on triplet and ntxent

* fix plots for smoothnes

* nexnt loss prototype

* exclude the negative pair from dataloader and forward pass

* adding option using pytorch-metric-learning implementation
and modifying previous to match same input args

* removing our implementation of NTXentLoss and using pytorch metric

* ruff

* remove blank line diff

* remove blank line diff

* simplying the engine

* explicit target shape argument in the HCS data module

* Revert "explicit target shape argument in the HCS data module"

This reverts commit 464d4c9.

* Explicit target shape argument in the HCS data module (#212)

* explicit target shape argument in the HCS data module

* update docstring

* update test cases

* Gradio example (#158)

* initial demo

* using the predict_step

* modifying paths to chkpt and example pngs

* updating gradio as the one on Huggingface

* adding configurable phate arguments via config

* script to recompute phate and overwrite the previous phate data

* ruff

* solving redundancies

* modularizing the smoothness

* removing redundant _fit_phate()

* ruff

---------

Co-authored-by: Ziwen Liu <[email protected]>
Co-authored-by: Alishba Imran <[email protected]>
Co-authored-by: Ziwen Liu <[email protected]>
@mattersoflight mattersoflight merged commit 316deee into main Dec 23, 2024
4 checks passed
@ziw-liu ziw-liu deleted the ntxent_loss branch December 23, 2024 22:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request representation Representation learning (SSL)
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Implement NT-Xent loss (implicit negative samples)
4 participants