Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Length scale gradient NaN when x not fully continuous #121

Closed
evanmunro opened this issue Jul 29, 2020 · 10 comments · Fixed by #129
Closed

Length scale gradient NaN when x not fully continuous #121

evanmunro opened this issue Jul 29, 2020 · 10 comments · Fixed by #129

Comments

@evanmunro
Copy link

If I get the gradient of the logpdf using Zygote, then the gradient of the length scale parameter is NaN if x has any values that are the same. This is related to #120, and might be fixed by the suggested implementation there. But wanted to note this in another issue to show that the issues around #120 are not purely computational. See code that generates this issue below. The first println(gradient) works fine, as per the tutorial, but once there is some x values that are repeated, then the gradient of length scale are NaN.

using Zygote: gradient
using Stheno
using Optim: optimize, BFGS, minimizer

function unpack(θ)
    σ² = exp(θ[1]) + 1e-6
    l = exp(θ[2]) + 1e-6
    σ²_n = exp(θ[3]) + 1e-6
    return σ², l, σ²_n
end

# x is continuous 
init = randn(3)
x = randn(100); y= randn(100)

function nlml(θ)
    σ², l, σ²_n = unpack(θ)
    k = σ² * stretch(Matern52(), l)
    f = GP(k, GPC())
    return -logpdf(f(x, σ²_n), y)
end
println(gradient(nlml, init)[1])

# some repeated values of x  
x = vcat(x, x); y = randn(200)
function nlml(θ)
    σ², l, σ²_n = unpack(θ)
    k = σ² * stretch(Matern52(), l)
    f = GP(k, GPC())
    return -logpdf(f(x, σ²_n), y)
end
println(gradient(nlml, init)[1])

Output of the first print statement:

julia> println(gradient(nlml, init)[1])
[-1.0713468383523934, -12.447387099066987, -131.494269354701]

Output of the second print statement:

println(gradient(nlml, init)[1])
[0.9055697178496014, NaN, -294.0797397890903]

I wonder if this is related to numerical instabilities in the cholesky decomposition when K(x,x) is not full rank? But overall, this seems like a bug. It can be addressed by adding some small jitter to the x values so that there is never a repeated x value in the dataset. But it seems like in many real world applications it is reasonable to expect some discrete x's?

@willtebbutt
Copy link
Member

willtebbutt commented Aug 2, 2020

Hmmm do you know what values σ², l, σ²_n came out to in this example? In particular, as you say this shouldn't be a problem provided that σ²_n is large enough.

From the perspective of inference at the minute, this is a pretty adversarial case, so I'm reluctant to think of it as a bug.

Were you to consider it a bug, the fix would probably involve attempting to detect whether or not you've got repeated inputs and then doing something to exploit the structure (possibly the thing in your other PR). This is definitely a tradeoff, because if your inputs are quite high-dimensional, this check may take a while.

@evanmunro
Copy link
Author

evanmunro commented Aug 3, 2020

I didn't set a seed for the random generation of the hyperparameters, but the problem is reproducible every time I have run it, without adding jitter to the x values.

I set up a reproducible example that may look a bit adversarial, but for our actual project, this issue has prevented us from using Stheno (for now) for 2 out of 3 of our datasets.

For our use case x is always low dimensional. Our x variable in some cases is a test score, which is on a point scale from, for example 1 to 60. Then the y outcome is continuous (some kind of z-score). With 6k observations, there are many repeated x values, and as written, I get the same kind of error as in the simple example above. I understand in many other use cases then inputs might be high dimensional, and this issue very quickly disappears and becomes a low probability event.

Agree that this is best addressed by moving forward with #120. One way to avoid computational issues is to have users flag in some way that the low rank structure exploitation from discrete x's should be checked (and not do this by default).

@willtebbutt
Copy link
Member

Yes, I appreciate your problem, but it would still be helpful to know at which values of the hyperparameters things are breaking -- you can either jitter x or you can ensure that σ²_n is always sufficiently large to ensure that the Cholesky factorisation succeeds. Typically I tend to go for ensuring that σ²_n is large enough, so I would really like to know if this isn't working in this example for some reason.

If, for example, it's collapsing to the 1e-6 lower-bound that you've imposed, then possibly that needs to change. If, on the other hand, it's still quite large and things are failing, then I would like to understand why.

@evanmunro
Copy link
Author

evanmunro commented Aug 3, 2020

Understood, sorry for not providing a specific example earlier. If rather than random, then init=[log(1), log(1), log(10)] or init=[log(1), log(1), log(100)], then the error remains. So for quite large values of σ²_n it is having issues.

@willtebbutt
Copy link
Member

Thanks for providing the example input 🙂

So I've done some digging around and it actually doesn't appear to be related to the numerical instabilities that you might get with the cholesky factorisation when everything is low-rank as we had suspected, rather, it appears to be Stheno.jls implementation of the Euclidean distance. So this is most definitely a bug!

using Distances, Stheno, Zygote

init = randn(3)
x = randn(100)
x = vcat(x, x)
function foo(l)
    x′ = x .* l
    return Distances.pairwise(Euclidean(), x′)
end

_, back = Zygote.pullback(foo, init)

julia> back(randn(200, 200))
(NaN,)

So if, for example, you use the EQ kernel rather than the Matern52, your above example will work fine.

This issue has cropped up before and I thought we had fixed it. Clearly not. I'll dig around further and try to fix it.

@willtebbutt
Copy link
Member

willtebbutt commented Aug 7, 2020

I've done a bit of thinking. Could you please open an issue in Zygote.jl relating to this issue? In short, there's a bug in the rule for doing reverse-mode AD through pairwise(Euclidean(), x), so it's technically Zygote's concern at the minute.

I'm happy to help push this along in Zygote -- we should be able to resolve this fairly quickly (over the next few days).

@evanmunro
Copy link
Author

Yes I'm happy to. Do you have a working example of the problem that you found with Zygote that doesn't rely on Stheno's Euclidean function ? Would be helpful to post in the issue.

@willtebbutt
Copy link
Member

If you just add a dims argument to pairwise in the above example you'll hit Zygote:

https://github.com/FluxML/Zygote.jl/blob/86d1dd559f02ed7fb6e4435383382af7badcd2ad/src/lib/distances.jl#L77

I should really just remove the current Euclidean @adjoint from Stheno

@willtebbutt
Copy link
Member

Scratch that, I forgot that the implementation in Stheno is there precisely to handle a case that Zygote can't.

I'll try and fix that issue.

Do you have a working example of the problem that you found with Zygote that doesn't rely on Stheno's Euclidean function ?

If you can provide an example where x is a Matrix you'll hit Zygote's implementation.

@evanmunro
Copy link
Author

Sorry, I'm still not clear on what kind of issue I should open in Zygote to get this fixed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants