Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pairwise distance computation with smaller memory footprint #23

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

toinsson
Copy link
Contributor

@toinsson toinsson commented Jul 1, 2022

This PR aims at mitigating the high memory cost of the pairwise computation by avoiding broadcasting between x and y. This is pointed out in #19.

The computation seems tricky to get right. This post shows that this method can be unstable (understand some rounding errors), but the alternative from torch (doc) does seem to suffer from the same memory problem.

This was tested like so:

In [1]: import torch

In [2]: a = torch.rand((100, 40, 15))

In [3]: b = torch.rand((100, 50, 15))

In [4]: def _euclidean_dist_func(x, y):
   ...:     """
   ...:     Calculates the Euclidean distance between each element in x and y per timestep
   ...:     """
   ...:     n = x.size(1)
   ...:     m = y.size(1)
   ...:     d = x.size(2)
   ...:     x = x.unsqueeze(2).expand(-1, n, m, d)
   ...:     y = y.unsqueeze(1).expand(-1, n, m, d)
   ...:     return torch.pow(x - y, 2).sum(3)
   ...: 

In [5]: def pairwise_l2_squared(x, y):
   ...:     x_norm = (x**2).sum(-1).unsqueeze(-1)
   ...:     y_norm = (y**2).sum(-1).unsqueeze(-2)
   ...:     dist = x_norm + y_norm - 2.0 * torch.bmm(x, y.mT)
   ...:     return torch.clamp(dist, 0.0, torch.inf)
   ...: 

In [6]: res_a = _euclidean_dist_func(a, b)

In [7]: res_b = pairwise_l2_squared(a, b)

In [8]: torch.allclose(res_a, res_b)
Out[8]: True

@Maghoumi
Copy link
Owner

Maghoumi commented Jul 4, 2022

Thanks for creating this PR! The goal is to basically expand (x-y)^2 to x^2 + y^2 - 2xy, right?

I noticed an error when running this change. What is y.mT? At first I thought it was just a transpose operation (.T) but even with that, the dimensions don't make sense to me. Am I missing something?

@toinsson
Copy link
Contributor Author

toinsson commented Jul 6, 2022

The goal is to basically expand (x-y)^2 to x^2 + y^2 - 2xy, right?
Yes, this computation uses the quadratic expansion to save memory.

To be precise, in _euclidean_dist_func:
x = x.unsqueeze(2).expand(-1, n, m, d) and y = y.unsqueeze(1).expand(-1, n, m, d) create matrices of shape (B, M, N, D) which are only reduced after the pow operation by using sum to (B, M, N).
On the other hand, in pairwise_l2_squared:
The sum is used before adding dummy dimensions (with unsqueeze) and using broadcasting. Broadcasting (x_norm + y_norm) has the following effect on shapes: (B, M, 1) + (B, 1, N) => (B, M, N).

So, with _euclidean_dist_func you end up temporarily with matrices of maximum shape (B, M, N, D), whereas with pairwise_l2_squared you directly move up to (B, M, N).

I noticed an error when running this change.
Please paste the error message in the discussion. The output from the ipython prompts show that both functions are equivalent.

What is y.mT? The dimensions don't make sense.
The distance matrix computation is relying on batched input. This is the reason why mT is used instead of T, and also why bmm is used instead of mm.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants