Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sinkhorn2 and its functorch.vmap compatibility #482

Open
hmdolatabadi opened this issue Jun 1, 2023 · 3 comments
Open

sinkhorn2 and its functorch.vmap compatibility #482

hmdolatabadi opened this issue Jun 1, 2023 · 3 comments

Comments

@hmdolatabadi
Copy link

🚀 Feature

Making the ot.sinkhorn2 function compatible with functorch.vmap.

Motivation

I'm using the Python Optimal Transport library. I want to define a loss function that iterates over every sample in my batch and calculates the sinkhorn distance for that sample and its ground-truth value. What I was using before was a for-loop:

for i in range(len(P_batch)):
      if i == 0:
         loss = ot.sinkhorn2(P_batch[i].view(-1, 1), Q_batch[i].view(-1, 1), C, epsilon)
      loss += ot.sinkhorn2(P_batch[i].view(-1, 1), Q_batch[i].view(-1, 1), C, epsilon)

but this is way too slow for my application. I was reading through functorch, and apparently I should have been able to use the vmap functionality.

losses = vmap(ot.sinkhorn2)(P, Q, C, epsilon)

But after wrapping my function in vmap, I get this weird error:

File /anaconda3/envs/my_env/lib/python3.8/site-packages/ot/bregman.py:505, in sinkhorn_knopp(a, b, M, reg, numItermax, stopThr, verbose, log, warn, warmstart, **kwargs)
    502 v = b / KtransposeU
    503 u = 1. / nx.dot(Kp, v)
--> 505 if (nx.any(KtransposeU == 0)
    506         or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v))
    507         or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))):
    508     # we have reached the machine precision
    509     # come back to previous solution and quit loop
    510     warnings.warn('Warning: numerical errors at iteration %d' % ii)
    511     u = uprev

RuntimeError: vmap: It looks like you're attempting to use a Tensor in some data-dependent control flow. We don't support that yet, please shout over at https://github.com/pytorch/functorch/issues/257 .

Pitch

Apparently, the data-dependent if-statement needs to be replaced with other alternatives. Any help is appreciated.

@rflamary
Copy link
Collaborator

rflamary commented Jun 2, 2023

That is a good point but POT is implemeted in pure python with backend and geting rid tof conditional flows is going to be a pain.

Note that for what you want to compute (P sinkhorn in paralell with the same cost C) one does not need to do a loop/vmap and the sinkhorns can be impelmmented with already paralell matrix products with very little change in the sinkhorn_knopp function. We do not provide it in POT (maybe we will one day but we need to find the proper API) but feel free to reach me in the POT slack if you want some pointers.

@hmdolatabadi
Copy link
Author

Thanks @rflamary! I wanted to join the POT Slack, but unfortunately it seems that the workspace invite link hasn't been shared. Could you send me the POT Slack invite? Thanks.

@rflamary
Copy link
Collaborator

rflamary commented Jun 2, 2023

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants