-
Notifications
You must be signed in to change notification settings - Fork 2
/
pdl.py
31 lines (27 loc) · 1.35 KB
/
pdl.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import numpy
import numpy as np
import torch
import torch.nn.functional as F
import torch.cuda.amp as amp
EPS=1e-8
def projectedDistributionLoss(x, y, num_projections=1000):
'''Projected Distribution Loss (https://arxiv.org/abs/2012.09289)
x.shape = B,M,N,...
'''
def rand_projections(dim, device=torch.device('cpu'), num_projections=1000):
projections = torch.randn((dim,num_projections), device=device)
projections = projections / torch.sqrt(torch.sum(projections ** 2, dim=0, keepdim=True)) # columns are unit length normalized
return projections
x = x.reshape(x.shape[0], x.shape[1], -1) # B,N,M
y = y.reshape(y.shape[0], y.shape[1], -1)
W = rand_projections(x.shape[-1], device=x.device, num_projections=num_projections)#x.shape[-1])
# W = torch.repeat_interleave(W.unsqueeze(0), repeats=x.shape[0], axis=0) # B,M,M' whereM'==M
# e_x = torch.bmm(x, W) # B,N,M'
# e_y = torch.bmm(y, W)
e_x = torch.matmul(x,W) # multiplication via broad-casting
e_y = torch.matmul(y,W)
loss = 0
for ii in range(e_x.shape[2]):
# g = torch.sort(e_x[:,:,ii],dim=1)[0] - torch.sort(e_y[:,:,ii],dim=1)[0]; print(g.mean(), g.min(), g.max())
loss = loss + F.l1_loss(torch.sort(e_x[:,:,ii],dim=1)[0] , torch.sort(e_y[:,:,ii],dim=1)[0]) # if this gives issues; try Huber loss later
return loss