You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Making the ot.sinkhorn2 function compatible with functorch.vmap.
Motivation
I'm using the Python Optimal Transport library. I want to define a loss function that iterates over every sample in my batch and calculates the sinkhorn distance for that sample and its ground-truth value. What I was using before was a for-loop:
for i in range(len(P_batch)):
if i == 0:
loss = ot.sinkhorn2(P_batch[i].view(-1, 1), Q_batch[i].view(-1, 1), C, epsilon)
loss += ot.sinkhorn2(P_batch[i].view(-1, 1), Q_batch[i].view(-1, 1), C, epsilon)
but this is way too slow for my application. I was reading through functorch, and apparently I should have been able to use the vmap functionality.
losses = vmap(ot.sinkhorn2)(P, Q, C, epsilon)
But after wrapping my function in vmap, I get this weird error:
File /anaconda3/envs/my_env/lib/python3.8/site-packages/ot/bregman.py:505, in sinkhorn_knopp(a, b, M, reg, numItermax, stopThr, verbose, log, warn, warmstart, **kwargs)
502 v = b / KtransposeU
503 u = 1. / nx.dot(Kp, v)
--> 505 if (nx.any(KtransposeU == 0)
506 or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v))
507 or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))):
508 # we have reached the machine precision
509 # come back to previous solution and quit loop
510 warnings.warn('Warning: numerical errors at iteration %d' % ii)
511 u = uprev
RuntimeError: vmap: It looks like you're attempting to use a Tensor in some data-dependent control flow. We don't support that yet, please shout over at https://github.com/pytorch/functorch/issues/257 .
Pitch
Apparently, the data-dependent if-statement needs to be replaced with other alternatives. Any help is appreciated.
The text was updated successfully, but these errors were encountered:
That is a good point but POT is implemeted in pure python with backend and geting rid tof conditional flows is going to be a pain.
Note that for what you want to compute (P sinkhorn in paralell with the same cost C) one does not need to do a loop/vmap and the sinkhorns can be impelmmented with already paralell matrix products with very little change in the sinkhorn_knopp function. We do not provide it in POT (maybe we will one day but we need to find the proper API) but feel free to reach me in the POT slack if you want some pointers.
Thanks @rflamary! I wanted to join the POT Slack, but unfortunately it seems that the workspace invite link hasn't been shared. Could you send me the POT Slack invite? Thanks.
🚀 Feature
Making the
ot.sinkhorn2
function compatible withfunctorch.vmap
.Motivation
I'm using the
Python Optimal Transport
library. I want to define a loss function that iterates over every sample in my batch and calculates thesinkhorn
distance for that sample and its ground-truth value. What I was using before was a for-loop:but this is way too slow for my application. I was reading through
functorch
, and apparently I should have been able to use thevmap
functionality.But after wrapping my function in
vmap
, I get this weird error:Pitch
Apparently, the data-dependent
if-statement
needs to be replaced with other alternatives. Any help is appreciated.The text was updated successfully, but these errors were encountered: