-
Notifications
You must be signed in to change notification settings - Fork 59
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft implementation of point estimation #281
base: dev
Are you sure you want to change the base?
Conversation
Find a notebook to look at training and inference for such a point estimator here: https://github.com/han-ol/bayesflow/blob/point-estimation/examples/draft-point-est.ipynb |
After writing this up, in my opinion the best structure is probably to have:
For class names I would propose: If we want to support convenient creation of heads by just specifying a loss function, this can be done by subclassing a Additionally, I would prefer an implementation where estimate() returns a dictionary with named estimates corresponding to the individual heads. After collecting some of your thoughts I would proceed with implementing whatever we converge on. What do you think? |
Hans, thanks for the great ideas and the very mature first draft! Here are some initial thoughts from my side. More to come later:
|
This looks really cool already! Thank you for this PR! There is a lot of content here in this thread already and I may benefit from a call where you show me the current state. This would help me give reasonable feedback. I will contact you offline about it. |
Ok cool, Paul! Stefan, thanks for your takes already! Some notes to some of them:
I am not sure about this. Tagging @LarsKue for this question, I think you mentioned a preference of parallel implementation rather than inheritance?
Agreed, that some form of naming is necessary. How do the "data classes" you imagine differ from dictionaries of the type: # assuming a batch_size of two, two quantile levels and 3 inference variables:
dict(
mean=[
[ 1, 2, 3],
[2, 1, 3]
], # shape=(2,3)
quantiles=[
[[ -1, 0, 1], [ 3, 4, 5]],
[[1, -1, -2], [3, 2, 5]]
], # shape=(2,2,3)
...
) For one thing, it might be good to make the quantile levels accessible somewhere close to the estimated quantiles. This could be part of a data class.
👍
Ok! Just in their defence, I'd say point estimators are also fully Bayesian ;) Functionals of the proper Bayesian posterior distribution.
Good point! Using both |
Regarding the output, I think we can go with dictionaries for now. I believe some custom data classes will come in handy rather soon. Just a thought: Can the heads simply be determined automatically assuming the scoring rules know their dims? |
Yes, they can mostly be inferred, and I'd suggest linear layers followed by a reshape as an overwritable default. However, some scoring rules benefit from (e.g. quantiles, monotonously increasing) or need (e.g. covariance matrix, positive semidefinite) a specific architecture. |
…ltiple scoring rules
…low-org#291) * Splines draft * update keras requirement * small improvements to error messages * add rq spline function * add spline transform * update searchsorted utils for jax also add padd util * update tests * add assert_allclose util for improved messages * parametrize transform for flow tests * update jacobian, jacobian trace, vjp, jvp, and corresponding usages and tests * fix imports, remove old jacobian and jvp, fix application in free form flow * improve logdet computation in free form flows * Fix comparison for symbolic tensors under tf * Add splines to twomoons notebook * improve pad utility * fix missing left edge in spline * fix inside mask edge case * explicitly set bias initializer * add better expand utility * small clean up, renaming * fix indexing, fix inside check * dump * fix sign of log jacobian for inverse pass in rq spline * fix parameter splitting for spline transform * improve readability * fix scale and shift trailing dimension * fix inverse pass return value * correctly choose bins once for each dimension, even for multi-dimensional inputs * run formatter * reduce searchsorted log spam * log backend used at setup * remove maximum message cache size * Improve warning message for jax searchsorted * Fix spline parameter binning for compiled contexts * update inverse transform same as forward * Update TwoMoons notebook with splines WIP [skip ci] * fix spline inverse call for out of bounds values * Add working splines --------- Co-authored-by: stefanradev93 <[email protected]>
* rename scoring_rules module to scores * mean and median wrappers for normed difference scores * argument handling for quantile score * scoring rule tests
9166a7b
to
8a1986d
Compare
* rename link_functions to links * separated activation function Ordered from OrderedQuantiles, one for generality, the other for automatic smart anchor selection based on quantile levels * introduce link function for learnable positive semi-definite matrices * full test coverage for links module
8a1986d
to
ab821bf
Compare
…onal estimation * nested dictionary based heads * move head configuration into get_head methods of scoring rules * serialization of PointInferenceNetwork body * serialization of ScoringRules including nested subnets and link arguments * serialization of links * serialization of heads taking care of keras bugs related to nested layer attributes * unconditional point estimation * training flag propagation * draft of PointInferenceNetwork.sample() and PointInferenceNetwork.log_prob() * forced conversion of reference_shape to tuple * unit tests for links and ScoringRules * tests for successful building and correct output structure of PointInferenceNetwork * two moons integration tests (excluding multivariate normal) * relax assert_layers_equal: do not require variable names to be identical anymore
This (draft) pull request is meant to hold discussion and development of point estimation within BayesFlow.
The functionality per se was also discussed in the issue #121.
The implementation should make it easy to
Commit 093785d contains a first example including ONLY quantile estimation.
Writing down draft specifications and guiding thoughts
Names for everything
What is "inference", does it include point estimation? Are we calling networks discussed below PointInferenceNetworks, PointRegressors, or something else?
For now I stick with
ContinuousPointApproximator
andPointInferenceNetwork
to make their roles in relation to the existing codebase obvious.Components
A
ContinuousPointApproximator
parallels theContinuousApproximator
and bundles some feed-forward architecture with an optional summary network suitable for learning point estimates optimized to minimize some Bayes risk.Thus it serves the same roles (including Adapter calls, summary network, etc), but instead of a
sample
method it provides anestimate
method.PointInferenceNetwork
parallels theInferenceNetwork
by providing a base class to inherit from forgenerative model classes suiteable for the role of approximating a conditional distributionfeed-forward model classes suitable for point estimation.Convenient default estimator
The API for functionality that covers the need of most users could be something like
, with optional constructor arguments to tweak it a bit:
choices in output of
estimate
For inference the method
estimate(data_dict))
produces the point estimates for a given input/condition/observation.The default PointInferenceNetwork would produce 5 point estimates (mean, std and three quantile levels) of the marginal distribution for each inference variable.
These estimates need more explanation than samples provided by generative networks typically need. We need to communicate to user, diagnostics code or other downstream tasks which estimate lands where. It seems to me that it would be helpful if such
estimate
output is structured as a dictionary with the point estimates names as keys, likedict(mean: tensor, std: tensor, ...)
rather than a tensor from concatenating all individual estimates.
Architecture
The architecture has one shared "body" network as well as separate "head" networks for each scoring rule.
The subnet keyword argument has a default of "mlp", in general the argument is resolved by
find_network
from the bayesflow.utils which can take a string for predefined networks or a user defined custom class.Currently the non-shared networks are just linear layers.
Extendable design
The first draft only includes a PointInferenceNetwork subclass for quantile estimation, currently called called QuantileRegressor and found in bayesflow/networks/regressors/quantile_regressor.py.
This is just the first step, and we want to support different loss functions (which are typically called scoring rules in this context) that result in other point estimates.
A more flexible API including custom scoring rules /losses could use a
PointInferenceNetwork
that accepts a single or a sequence ofScoringRule
s.This can generate the appropriate number of heads in the
build()
method.It also can pass the respective outputs to the corresponding scoring rules that compute their actual loss contributions and sum them up.
A
ScoringRule
/ScoringLoss
has a name (we want to distinguish multiple of them later) and ascore
/compute
method.A
ScoringRule
could also compute aprediction_shape
in the constructor to be accessed by generic code that generates a corresponding head for the scoring rule.choice we have here
prediction_shape
could also be aprediction_size
, an int rather than a tupleThe head would have
output_shape=(*batch_shape, *prediction_shape, num_variables)
(32,)
and(1,)
(32, 1, 4)
shall be headoutput_shape
32*1*4
units and reshape operation which together form a headinteraction with choice of output of estimate method: If we choose to allow multidimensional
prediction_shape
s rather than onlyprediction_size
s, we can not concatenate output of different heads since each can have an individualprediction_shape
.Thus the decision interacts with whether the predict method should return tensors or a dict of tensors (see above).
Below are some notes how
ScoringRule
definitions could look like.If we want to support specific activation functions for different scoring rules, we might add a non-parametric activation function to the ScoringRule's definition.
We could also go all the way and have the
ScoringRule
also contain the head itself, taking the last joint embedding and mapping it to estimates. This then would naturally include an optional nonlinearity in the end and simultaneously give users the option to tweak the architecture of the non-shared weights.Other notes:
bf.Workflow
?