Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into multi-sim-loss
Browse files Browse the repository at this point in the history
  • Loading branch information
thorstenwagner committed Jun 18, 2024
2 parents 8712c61 + ada76c7 commit 5e1f92c
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 14 deletions.
4 changes: 3 additions & 1 deletion docs/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ Next you can create the TomoTwin environment:
.. prompt:: bash $

mamba env create -n tomotwin -f https://raw.githubusercontent.com/MPI-Dortmund/tomotwin-cryoet/main/conda_env_tomotwin.yml
conda activate tomotwin
pip install tomotwin-cryoet

2. Install Napari
Expand All @@ -32,6 +33,7 @@ Here we assume that you don't have napari installed. Please do:
.. prompt:: bash $

mamba env create -n napari-tomotwin -f https://raw.githubusercontent.com/MPI-Dortmund/napari-tomotwin/main/conda_env.yml
conda activate napari-tomotwin
pip install napari-tomotwin

3. Link Napari
Expand Down Expand Up @@ -59,7 +61,7 @@ To update an existing TomoTwin installation just do:
mamba env update -n tomotwin -f https://raw.githubusercontent.com/MPI-Dortmund/tomotwin-cryoet/main/conda_env_tomotwin.yml --prune
conda activate tomotwin
pip install tomotwin-cryoet
mamba env update -n napari-tomotwin -f https://raw.githubusercontent.com/MPI-Dortmund/tomotwin-cryoet/main/conda_env_napari.yml --prune
mamba env update -n napari-tomotwin -f https://raw.githubusercontent.com/MPI-Dortmund/napari-tomotwin/main/conda_env.yml --prune

Download latest model
^^^^^^^^^^^^^^^^^^^^^
Expand Down
2 changes: 1 addition & 1 deletion tomotwin/modules/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def check_for_updates():
print("Latest version:\t\t", latest_version)
print(
"More information here:\n",
"https://tomotwin-cryoet.readthedocs.io/en/stable/changes.html",
"https://github.com/MPI-Dortmund/tomotwin-cryoet/releases",
)
print("###############################################")
else:
Expand Down
24 changes: 12 additions & 12 deletions tomotwin/modules/training/torchtrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,12 +220,14 @@ def classification_f1_score(self, test_loader: DataLoader) -> float:
anchor_vol = batch["anchor"].to(self.device, non_blocking=True)
positive_vol = batch["positive"].to(self.device, non_blocking=True)
negative_vol = batch["negative"].to(self.device, non_blocking=True)
full_input = torch.cat((anchor_vol,positive_vol,negative_vol), dim=0)
filenames = batch["filenames"]
with autocast():
# TODO: Probably concat anchor, positive and vol into one batch and run only one forward pass is enough.
anchor_out = self.model.forward(anchor_vol)
positive_out = self.model.forward(positive_vol)
negative_out = self.model.forward(negative_vol)
out = self.model.forward(full_input)
out = torch.split(out, anchor_vol.shape[0], dim=0)
anchor_out = out[0]
positive_out = out[1]
negative_out = out[2]

anchor_out_np = anchor_out.cpu().detach().numpy()
for i, anchor_filename in enumerate(filenames[0]):
Expand Down Expand Up @@ -258,16 +260,14 @@ def run_batch(self, batch: Dict):
anchor_vol = batch["anchor"].to(self.device, non_blocking=True)
positive_vol = batch["positive"].to(self.device, non_blocking=True)
negative_vol = batch["negative"].to(self.device, non_blocking=True)
full_input = torch.cat((anchor_vol,positive_vol,negative_vol), dim=0)
with autocast():
# TODO: Probably concat anchor, positive and vol into one batch and run only on forward pass is enough.
anchor_out = self.model.forward(anchor_vol)
positive_out = self.model.forward(positive_vol)
negative_out = self.model.forward(negative_vol)

out = self.model.forward(full_input)
out = torch.split(out, anchor_vol.shape[0], dim=0)
loss = self.criterion(
anchor_out,
positive_out,
negative_out,
out[0],
out[1],
out[2],
label_anchor=batch["label_anchor"],
label_positive=batch["label_positive"],
label_negative=batch["label_negative"],
Expand Down

0 comments on commit 5e1f92c

Please sign in to comment.