-
Notifications
You must be signed in to change notification settings - Fork 411
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RetrievalRecall
and RetrievalMRR
seem to function differently for world_size > 1
and top_k=1
#2852
Comments
Hi! thanks for your contribution!, great first issue! |
RetrievalRecall
and RetrievalMRR
seem to function differently for a world size greater than 1RetrievalRecall
and RetrievalMRR
seem to function differently for world_size > 1
and top_k=1
is it with the latest version? |
I am using this version of the package: |
Hi @AdrianM0, thanks for reporting this issue. import contextlib
import os
import torch
from torch import tensor
from torch.multiprocessing import Pool, set_sharing_strategy, set_start_method
from torchmetrics.retrieval import RetrievalMRR, RetrievalRecall
with contextlib.suppress(RuntimeError):
set_start_method("spawn")
set_sharing_strategy("file_system")
NUM_PROCESSES = 2
TOP_K = 1
indexes = tensor([0, 0, 0, 1, 1, 1, 1])
preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2])
target = tensor([False, False, True, False, True, False, True])
def setup_ddp(rank: int, world_size: int) -> None:
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(8088)
if torch.distributed.group.WORLD is not None:
torch.distributed.destroy_process_group()
if torch.distributed.is_available():
torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size)
def function_to_run(rank: int, world_size: int) -> None:
setup_ddp(rank, world_size)
print(f"Rank {rank} is running")
metric1 = RetrievalRecall(top_k=TOP_K)
metric2 = RetrievalMRR(top_k=TOP_K)
metric1.update(preds[rank::NUM_PROCESSES], target[rank::NUM_PROCESSES], indexes[rank::NUM_PROCESSES])
metric2.update(preds[rank::NUM_PROCESSES], target[rank::NUM_PROCESSES], indexes[rank::NUM_PROCESSES])
res1 = metric1.compute()
res2 = metric2.compute()
print(f"rank {rank} says result is {res1}, {res2}")
if __name__ == "__main__":
pool = Pool(NUM_PROCESSES)
pool.starmap(function_to_run, [(i, NUM_PROCESSES) for i in range(NUM_PROCESSES)])
metric1 = RetrievalRecall(top_k=TOP_K)
metric2 = RetrievalMRR(top_k=TOP_K)
metric1.update(preds, target, indexes)
metric2.update(preds, target, indexes)
res1 = metric1.compute()
res2 = metric2.compute()
print(f"result is {res1}, {res2}") however, I am unable to reproduce the issue. Regardless of what the So I am fairly confident that the issue is not in torchmetrics, but how stuff is logged in lightning. I try to debug further. |
Thanks for debugging this @niberger |
@AdrianM0 is there a reason you are explicitly setting the batch size here: for k_val in k_list:
for metric, metric_name in zip(metrics, metric_names):
metric_to_log = metric(top_k=k_val)
metric_to_log.update(flatten_cos_sim, target, indexes=indexes)
self.log(
f"{prefix}_{central_modality}_{other_modality}_{metric_name}_top_{k_val}",
metric_to_log.compute(),
batch_size=self.per_device_batch_size * self.world_size,
sync_dist=self.world_size > 1,
) I am asking because lightning internally will sync the batch size between processes: batch_size=self.per_device_batch_size (or remove it if you do not need it, lightning will automatically try to extract it for you if possible) |
@SkafteNicki the reason is mostly because I am aggregating embeddings across GPUs so the tensors going into the metrics do not have a |
🐛 Bug
I am using
RetrievalRecall
andRetrievalMRR
in the same way in the snippet provided below. There seems to be a behaviour issue, sinceRetrievalRecall
will returnRetrievalMRR
divided byworld_size
for a multi-gpu setup, whileRetrievalMRR
will behave as expected. Can someone suggest what the issue might be? Fortop_k=1
, these two functions should return the same. Fortop_k > 1
both behave as expected.The text was updated successfully, but these errors were encountered: