Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Marginal Laplace approximation #344

Open
theorashid opened this issue May 15, 2024 · 2 comments
Open

Marginal Laplace approximation #344

theorashid opened this issue May 15, 2024 · 2 comments

Comments

@theorashid
Copy link
Contributor

This is part of #340.

We have a Laplace approximation, but we only want to use it on a subset of variables (the latent field). We want to use some other inference method on the other variables (hyperparameters).

This can perhaps use step sampler, or maybe even something like the marginal model.

@theorashid
Copy link
Contributor Author

Build on top of QMC implementation in #353

@theorashid
Copy link
Contributor Author

theorashid commented Jul 5, 2024

WIP:

Requires pymc-devs/pytensor#887.

Need to limit to three tier models so we can access y (the observations/measurements). So observations are only one level above the latent field.

@_logprob.register(LaplaceMarginalNormalRV)

# for MvNormal
# y are obs
# x is latent field
# params are hyperparams
mean, cov = marginalized_rv_op.dist_params(marginalized_rv_node)
y = # marginalized_rv_op.owner...?
f = # log(p(y | x, params)), takes a latent value and measurement and returns a float
df = grad(f)
d2f = grad(grad(f))

Q = matrix_inverse(cov)  # precision matrix, assume rewrite so it takes user-supplied tau rather than inverting

def newton_step(x, y, gaussian_mean, gaussian_prec):
	# solve A x = b
	# where A is (precision - d2f)
    A = gaussian_prec - d2f(x, y)  # add to diagonal, maybe pt.fill_diagonal(Q, Q.diagonal() + d2f(x)) more efficient than multiply by eye
	linear_part = Q @ gaussian_mean + df(x, y) - x * df(x, y)
    res = pt.linalg.solve(A, linear_part)
	return res

x_mode = # newton step solve to find best x, gaussian_mean=mean, gaussian_prec=Q

quadratic_part = - 0.5 * x_mode.T @ (Q - d2f(x_mode, y)) @ x_mode
linear_part = Q @ mean + df(x_mode, y) - x * d2f(x_mode, y)  # evaluate d2f(x_mode, y) once
log_p_x_y_params = quadratic_part + linear_part + 0.5 * pt.linalg.slogdet(Q) # need const?
# use the value of the constant from the log probability function of the gaussian with the same quadratic and linear terms
# dim * pt.log(2 * pt.pi) / 2

# this is full likelihood but we only need to add marginal part
# which of these do we need to replace
# or should we replace the entire model log-likelihood
likelihood = (               # P(y | params) =
    gaussian.logpdf(x)       # P(x | params)   # PrecisionMvNormal?
    + self.f(x, y).sum()     # * P(y | x, params)  # probably have this from p(y | x)
    - log_laplace_approx     # / P(x | y, params)
)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant