-
Notifications
You must be signed in to change notification settings - Fork 26
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Length scale gradient NaN when x not fully continuous #121
Comments
Hmmm do you know what values From the perspective of inference at the minute, this is a pretty adversarial case, so I'm reluctant to think of it as a bug. Were you to consider it a bug, the fix would probably involve attempting to detect whether or not you've got repeated inputs and then doing something to exploit the structure (possibly the thing in your other PR). This is definitely a tradeoff, because if your inputs are quite high-dimensional, this check may take a while. |
I didn't set a seed for the random generation of the hyperparameters, but the problem is reproducible every time I have run it, without adding jitter to the x values. I set up a reproducible example that may look a bit adversarial, but for our actual project, this issue has prevented us from using Stheno (for now) for 2 out of 3 of our datasets. For our use case Agree that this is best addressed by moving forward with #120. One way to avoid computational issues is to have users flag in some way that the low rank structure exploitation from discrete x's should be checked (and not do this by default). |
Yes, I appreciate your problem, but it would still be helpful to know at which values of the hyperparameters things are breaking -- you can either jitter If, for example, it's collapsing to the |
Understood, sorry for not providing a specific example earlier. If rather than random, then |
Thanks for providing the example input 🙂 So I've done some digging around and it actually doesn't appear to be related to the numerical instabilities that you might get with the cholesky factorisation when everything is low-rank as we had suspected, rather, it appears to be using Distances, Stheno, Zygote
init = randn(3)
x = randn(100)
x = vcat(x, x)
function foo(l)
x′ = x .* l
return Distances.pairwise(Euclidean(), x′)
end
_, back = Zygote.pullback(foo, init)
julia> back(randn(200, 200))
(NaN,) So if, for example, you use the This issue has cropped up before and I thought we had fixed it. Clearly not. I'll dig around further and try to fix it. |
I've done a bit of thinking. Could you please open an issue in Zygote.jl relating to this issue? In short, there's a bug in the rule for doing reverse-mode AD through I'm happy to help push this along in Zygote -- we should be able to resolve this fairly quickly (over the next few days). |
Yes I'm happy to. Do you have a working example of the problem that you found with Zygote that doesn't rely on Stheno's Euclidean function ? Would be helpful to post in the issue. |
If you just add a I should really just remove the current |
Scratch that, I forgot that the implementation in Stheno is there precisely to handle a case that Zygote can't. I'll try and fix that issue.
If you can provide an example where |
Sorry, I'm still not clear on what kind of issue I should open in Zygote to get this fixed |
If I get the gradient of the logpdf using Zygote, then the gradient of the length scale parameter is NaN if x has any values that are the same. This is related to #120, and might be fixed by the suggested implementation there. But wanted to note this in another issue to show that the issues around #120 are not purely computational. See code that generates this issue below. The first println(gradient) works fine, as per the tutorial, but once there is some x values that are repeated, then the gradient of length scale are NaN.
Output of the first print statement:
Output of the second print statement:
I wonder if this is related to numerical instabilities in the cholesky decomposition when K(x,x) is not full rank? But overall, this seems like a bug. It can be addressed by adding some small jitter to the x values so that there is never a repeated x value in the dataset. But it seems like in many real world applications it is reasonable to expect some discrete x's?
The text was updated successfully, but these errors were encountered: