Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Confidence intervals for an embeddings layer #256

Open
PinchOfData opened this issue Nov 23, 2024 · 1 comment
Open

Confidence intervals for an embeddings layer #256

PinchOfData opened this issue Nov 23, 2024 · 1 comment

Comments

@PinchOfData
Copy link

Hi Alex,

Thanks for the great package.

I am currently adapting Variational Autoencoders to estimate common latent variable models used in the social sciences.

I am mainly interested in building confidence intervals for the resulting encodings (i.e., the encoder's output) and less interested in the decoding (i.e., the decoder's output).

Is this feasible with your package? It is equivalent to building confidence intervals for intermediary outputs (e.g., an embedding layer) in a neural network.

Thanks,
Germain

@wiseodd
Copy link
Collaborator

wiseodd commented Nov 27, 2024

Something like this?

model = ...
train_loader = ...
model = train_model(model, train_loader)  # SGD

# Laplace only on encoder
for p in model.encoder.parameters():
    p.requires_grad = True

# Don't quantify uncertainty on decoder
for p in model.decoder.parameters():
    p.requires_grad = False

la = Laplace(model, ...)
la.fit(train_loader)
la.optimize_prior_precision()

# Getting confidence estimate on encoder outputs

# See https://aleximmer.github.io/Laplace/api_reference/parametriclaplace/#laplace.baselaplace.ParametricLaplace.sample
N_SAMPLES = 10
encoder_params_samples = la.sample(n_samples=N_SAMPLES)

encoder_output_samples = []

for sample in encoder_params_samples:
    # From torch.nn.utils
    vector_to_parameters(sample, model.encoder.parameters())
    encoder_output.append(model.encoder.forward(x_test))

encoder_output_samples = torch.vstack(encoder_output_samples)

enc_output = encoder_output_samples.mean(dim=0)
enc_output_var = encoder_output_samples.var(dim=0)

Untested; no guarantee that laplace-torch supports this OotB. But PR is welcome!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants