-
Notifications
You must be signed in to change notification settings - Fork 0
/
robust_inference.py
76 lines (58 loc) · 2.12 KB
/
robust_inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
r"""Inference components such as estimators, training losses and MCMC samplers."""
# __all__ = [
# 'RNPE',
# ]
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from itertools import islice
from torch import Tensor, BoolTensor, Size
from typing import *
from zuko.distributions import Distribution, DiagNormal, NormalizingFlow
from zuko.flows import FlowModule, MAF
from zuko.transforms import FreeFormJacobianTransform
from zuko.utils import broadcast
from lampe.nn import MLP
from lampe.inference import NPE, NPELoss
class RNPE(NPE):
def __init__(
self,
theta_dim: int,
x_dim: int,
build: Callable[[int, int], FlowModule] = MAF,
**kwargs,
):
super().__init__(theta_dim, x_dim, build)
def forward(self, theta: Tensor, x: Tensor) -> Tensor:
r"""
Arguments:
theta: The parameters :math:`\theta`, with shape :math:`(*, D)`.
x: The observation :math:`x`, with shape :math:`(*, L)`.
Returns:
The log-density :math:`\log p_\phi(\theta | x)`, with shape :math:`(*,)`.
"""
# print("RNPE", theta.shape, x.shape)
theta, x = broadcast(theta, x, ignore=1)
return self.flow(x).log_prob(theta)
def rsample(self, x: Tensor, shape: Size = ()) -> Tensor:
r"""
Arguments:
x: The observation :math:`x`, with shape :math:`(*, L)`.
shape: The shape :math:`S` of the samples.
Returns:
The reparameterized samples :math:`\theta \sim p_\phi(\theta | x)`,
with shape :math:`S + (*, D)`, while preserving the gradient information.
"""
return self.flow(x).rsample(shape)
def sample(self, x: Tensor, shape: Size = ()) -> Tensor:
r"""
Arguments:
x: The observation :math:`x`, with shape :math:`(*, L)`.
shape: The shape :math:`S` of the samples.
Returns:
The samples :math:`\theta \sim p_\phi(\theta | x)`,
with shape :math:`S + (*, D)`.
"""
with torch.no_grad():
return self.flow(x).sample(shape)