-
-
Notifications
You must be signed in to change notification settings - Fork 986
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add smoothing parameter to Discrete Cosine transform and reparam #2430
Conversation
@torch.no_grad() | ||
def _weight(self, y): | ||
size = y.size(-1) | ||
if self._weight_cache is None or self._weight_cache.size(-1) != size: | ||
# Weight by frequency**smooth, where the DCT-II frequencies are: | ||
freq = torch.arange(0.5, size - 0.5, size, dtype=y.dtype, device=y.device) | ||
w = freq.pow_(self.smooth) | ||
w /= w.log().mean().exp() # Ensure |jacobian| = 1. | ||
self._weight_cache = w | ||
return self._weight_cache |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fehiepsi could you please check my math here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I checked weight by frequency**smooth
(up to constant multiplication - which does not matter after normalization) and |jacobian|=1
. So this looks correct to me. However, I don't understand the phrase
When 0, this transforms white noise to white noise;
when 1 this transforms continuous brownian-like motion to white noise;
when 2 this transforms doubly-cumsummed white noise to white noise;
Could you elaborate on it (maybe give some reference)? It is fine to me if you confirm that it is correct.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The terminology comes from https://en.wikipedia.org/wiki/Colors_of_noise . I have also added a blurb to the DiscreteCosineReparam
docstring.
The intuition is that we are treating this kind of like a linear normalizing flow where the codomain distribution should be approximately iid standard normal aka white noise. By setting smooth=1
the domain distribution is brownian motion (i.e. cumsum or integral of white noise). By setting smooth=2
the domain distribution is doubly-integrated white noise which is continuously differentiable.
Another piece of intuition is that this is a discrete analog of the Matern kernels for continuous GPs: smooth=0 is like Matern 1/2, smooth=1 is like Matern 3/2, smooth=2 is like Matern 5/2, etc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I made some plots to test your claims. It looks correct!
# Weight by frequency**smooth, where the DCT-II frequencies are: | ||
freq = torch.arange(0.5, size - 0.5, size, dtype=y.dtype, device=y.device) | ||
w = freq.pow_(self.smooth) | ||
w /= w.log().mean().exp() # Ensure |jacobian| = 1. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice trick to prevent blow up!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Addresses #2426
This adds a
smooth
parameter to theDiscreteCosineTransform
andDiscreteCosineReparam
allowing the reparameterizer to automatically encode C-smooth functions. That is,DiscreteCosineReparam(smooth=s)
transforms C-s functions to white noise, e.g.DCR(smooth=1)
transforms brownian motion to white noise, andDCR(smooth=2)
transforms cumsum(cumsum(white noise)) to white noise.I am hoping this will be useful in SVI and HMC where by setting
smooth=1
e.g. we would avoid the need forAutoNormal
orHMC
adaptation to learn the frequency-proportional variance (or rather we would warm start those learned variances at a better place).Tested