-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathMaskedReduce.py
35 lines (30 loc) · 951 Bytes
/
MaskedReduce.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch
from torch import Tensor, BoolTensor
'''
x (B, N, d)
mask (B, N, 1)
'''
def maskedSum(x: Tensor, mask: BoolTensor, dim: int):
'''
mask true elements
'''
return torch.sum(torch.where(mask, 0, x), dim=dim)
def maskedMean(x: Tensor, mask: BoolTensor, dim: int, gsize: Tensor = None):
'''
mask true elements
'''
if gsize is None:
gsize = x.shape[dim] - torch.sum(mask, dim=dim)
return torch.sum(torch.where(mask, 0, x), dim=dim)/gsize
def maskedMax(x: Tensor, mask: BoolTensor, dim: int):
return torch.max(torch.where(mask, -torch.inf, x), dim=dim)[0]
def maskedMin(x: Tensor, mask: BoolTensor, dim: int):
return torch.min(torch.where(mask, torch.inf, x), dim=dim)[0]
def maskednone(x: Tensor, mask: BoolTensor, dim: int):
return x
reduce_dict = {
"sum": maskedSum,
"mean": maskedMean,
"max": maskedMax,
"none": maskednone
}