Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RetrievalRecall and RetrievalMRR seem to function differently for world_size > 1 and top_k=1 #2852

Open
AdrianM0 opened this issue Nov 29, 2024 · 7 comments
Labels
bug / fix Something isn't working help wanted Extra attention is needed v1.4.x

Comments

@AdrianM0
Copy link

AdrianM0 commented Nov 29, 2024

🐛 Bug

I am using RetrievalRecall and RetrievalMRR in the same way in the snippet provided below. There seems to be a behaviour issue, since RetrievalRecall will return RetrievalMRR divided by world_size for a multi-gpu setup, while RetrievalMRR will behave as expected. Can someone suggest what the issue might be? For top_k=1, these two functions should return the same. For top_k > 1 both behave as expected.

metrics = [
            RetrievalMRR,
            RetrievalRecall,
        ]
metric_names = [metric.__name__ for metric in metrics]
if self.world_size > 1:
    # both all gather calls return tensors of shape (World_Size, Batch_Size, Embedding_Size)
    all_embeddings_central_mod = self.all_gather(embeddings_central_mod, sync_grads=True)
    all_embeddings_other_mod = self.all_gather(embeddings_other_mod, sync_grads=True)
    all_embeddings_central_mod = all_embeddings_central_mod.flatten(0, 1)
    all_embeddings_other_mod = all_embeddings_other_mod.flatten(0, 1)
else:
    all_embeddings_central_mod = embeddings_central_mod.detach().clone()
    all_embeddings_other_mod = embeddings_other_mod.detach().clone()
device = select_device()

# reference: https://medium.com/@dhruvbird/all-pairs-cosine-similarity-in-pytorch-867e722c8572
# adding a third dim allows to compute pairwise cosine sim.
cos_sim = cosine_similarity(
    all_embeddings_central_mod.unsqueeze(1),
    all_embeddings_other_mod.unsqueeze(0),
    dim=2,
).to(device)
# preds, target, indexes
flatten_cos_sim = cos_sim.flatten().to(device)  # (Batch Size*Batch Size)

# the metric calculations are grouped by indexes and then averaged
# repeat interleave creates tensors of the form [0, 0, 1, 1, 2, 2]
indexes = (
    torch.arange(all_embeddings_central_mod.shape[0]).repeat_interleave(all_embeddings_other_mod.shape[0]).to(device)
)
# Diagonal elements are the true querries, the rest are false querries
target = torch.eye(all_embeddings_central_mod.shape[0], dtype=torch.long).flatten().to(device)
assert target.sum() == all_embeddings_central_mod.shape[0]
for k_val in k_list:
    for metric, metric_name in zip(metrics, metric_names):
        metric_to_log = metric(top_k=k_val)
        metric_to_log.update(flatten_cos_sim, target, indexes=indexes)
        self.log(
            f"{prefix}_{central_modality}_{other_modality}_{metric_name}_top_{k_val}",
            metric_to_log.compute(),
            batch_size=self.per_device_batch_size * self.world_size,
            sync_dist=self.world_size > 1,
        )
@AdrianM0 AdrianM0 added bug / fix Something isn't working help wanted Extra attention is needed labels Nov 29, 2024
Copy link

Hi! thanks for your contribution!, great first issue!

@AdrianM0 AdrianM0 changed the title RetrievalRecall and RetrievalMRR seem to function differently for a world size greater than 1 RetrievalRecall and RetrievalMRR seem to function differently for world_size > 1 and top_k=1 Nov 29, 2024
@Borda
Copy link
Member

Borda commented Nov 29, 2024

is it with the latest version?

@AdrianM0
Copy link
Author

AdrianM0 commented Nov 29, 2024

is it with the latest version?

I am using this version of the package:
torchmetrics=1.4.1

@Borda Borda added the v1.4.x label Nov 29, 2024
@SkafteNicki
Copy link
Member

Hi @AdrianM0, thanks for reporting this issue.
I tried to debug what is happening, so I created a small script:

import contextlib
import os

import torch
from torch import tensor
from torch.multiprocessing import Pool, set_sharing_strategy, set_start_method

from torchmetrics.retrieval import RetrievalMRR, RetrievalRecall

with contextlib.suppress(RuntimeError):
    set_start_method("spawn")
    set_sharing_strategy("file_system")

NUM_PROCESSES = 2
TOP_K = 1

indexes = tensor([0, 0, 0, 1, 1, 1, 1])
preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2])
target = tensor([False, False, True, False, True, False, True])

def setup_ddp(rank: int, world_size: int) -> None:
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = str(8088)

    if torch.distributed.group.WORLD is not None:
        torch.distributed.destroy_process_group()

    if torch.distributed.is_available():
        torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size)

def function_to_run(rank: int, world_size: int) -> None:
    setup_ddp(rank, world_size)
    print(f"Rank {rank} is running")
    metric1 = RetrievalRecall(top_k=TOP_K)
    metric2 = RetrievalMRR(top_k=TOP_K)
    metric1.update(preds[rank::NUM_PROCESSES], target[rank::NUM_PROCESSES], indexes[rank::NUM_PROCESSES])
    metric2.update(preds[rank::NUM_PROCESSES], target[rank::NUM_PROCESSES], indexes[rank::NUM_PROCESSES])
    res1 = metric1.compute()
    res2 = metric2.compute()
    print(f"rank {rank} says result is {res1}, {res2}")


if __name__ == "__main__":
    pool = Pool(NUM_PROCESSES)
    pool.starmap(function_to_run, [(i, NUM_PROCESSES) for i in range(NUM_PROCESSES)])

    metric1 = RetrievalRecall(top_k=TOP_K)
    metric2 = RetrievalMRR(top_k=TOP_K)
    metric1.update(preds, target, indexes)
    metric2.update(preds, target, indexes)
    res1 = metric1.compute()
    res2 = metric2.compute()
    print(f"result is {res1}, {res2}")

however, I am unable to reproduce the issue. Regardless of what the top_k argument is set to I get the same result if I am running in a distributed setting or non-distributed.

So I am fairly confident that the issue is not in torchmetrics, but how stuff is logged in lightning. I try to debug further.

@AdrianM0
Copy link
Author

AdrianM0 commented Dec 4, 2024

Hi @AdrianM0, thanks for reporting this issue. I tried to debug what is happening, so I created a small script:

import contextlib
import os

import torch
from torch import tensor
from torch.multiprocessing import Pool, set_sharing_strategy, set_start_method

from torchmetrics.retrieval import RetrievalMRR, RetrievalRecall

with contextlib.suppress(RuntimeError):
    set_start_method("spawn")
    set_sharing_strategy("file_system")

NUM_PROCESSES = 2
TOP_K = 1

indexes = tensor([0, 0, 0, 1, 1, 1, 1])
preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2])
target = tensor([False, False, True, False, True, False, True])

def setup_ddp(rank: int, world_size: int) -> None:
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = str(8088)

    if torch.distributed.group.WORLD is not None:
        torch.distributed.destroy_process_group()

    if torch.distributed.is_available():
        torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size)

def function_to_run(rank: int, world_size: int) -> None:
    setup_ddp(rank, world_size)
    print(f"Rank {rank} is running")
    metric1 = RetrievalRecall(top_k=TOP_K)
    metric2 = RetrievalMRR(top_k=TOP_K)
    metric1.update(preds[rank::NUM_PROCESSES], target[rank::NUM_PROCESSES], indexes[rank::NUM_PROCESSES])
    metric2.update(preds[rank::NUM_PROCESSES], target[rank::NUM_PROCESSES], indexes[rank::NUM_PROCESSES])
    res1 = metric1.compute()
    res2 = metric2.compute()
    print(f"rank {rank} says result is {res1}, {res2}")


if __name__ == "__main__":
    pool = Pool(NUM_PROCESSES)
    pool.starmap(function_to_run, [(i, NUM_PROCESSES) for i in range(NUM_PROCESSES)])

    metric1 = RetrievalRecall(top_k=TOP_K)
    metric2 = RetrievalMRR(top_k=TOP_K)
    metric1.update(preds, target, indexes)
    metric2.update(preds, target, indexes)
    res1 = metric1.compute()
    res2 = metric2.compute()
    print(f"result is {res1}, {res2}")

however, I am unable to reproduce the issue. Regardless of what the top_k argument is set to I get the same result if I am running in a distributed setting or non-distributed.

So I am fairly confident that the issue is not in torchmetrics, but how stuff is logged in lightning. I try to debug further.

Thanks for debugging this @niberger

@SkafteNicki
Copy link
Member

@AdrianM0 is there a reason you are explicitly setting the batch size here:

for k_val in k_list:
    for metric, metric_name in zip(metrics, metric_names):
        metric_to_log = metric(top_k=k_val)
        metric_to_log.update(flatten_cos_sim, target, indexes=indexes)
        self.log(
            f"{prefix}_{central_modality}_{other_modality}_{metric_name}_top_{k_val}",
            metric_to_log.compute(),
            batch_size=self.per_device_batch_size * self.world_size,
            sync_dist=self.world_size > 1,
        )

I am asking because lightning internally will sync the batch size between processes:
https://github.com/Lightning-AI/pytorch-lightning/blob/be608fa355b835b9b0727df2f5476f0a1d90bc59/src/lightning/pytorch/trainer/connectors/logger_connector/result.py#L251
which means that if you have a real batch size of 64 per device and each process gets one batch, the accumulated batch size that is used when calculating the average value is 128. But if you then explicit say that the batch size is 64*2 you will end up with the accumulated batch size being 256.
So try setting:

batch_size=self.per_device_batch_size

(or remove it if you do not need it, lightning will automatically try to extract it for you if possible)

@AdrianM0
Copy link
Author

AdrianM0 commented Dec 4, 2024

@SkafteNicki the reason is mostly because I am aggregating embeddings across GPUs so the tensors going into the metrics do not have a self.per_device_batch_size shape, but a self.per_device_batch_size*self.world_size

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug / fix Something isn't working help wanted Extra attention is needed v1.4.x
Projects
None yet
Development

No branches or pull requests

3 participants