Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(0.93.2) Update Adapt.jl compat and fix Float32 CATKE on GPU #3876

Merged
merged 16 commits into from
Oct 30, 2024

Conversation

ali-ramadhan
Copy link
Member

@ali-ramadhan ali-ramadhan commented Oct 28, 2024

Just opening this PR with what is currently a band-aid fix for Float32 CATKE on GPUs. Hoping I can figure out what's actually wrong. But for now it'll be useful to have a working branch.

Resolves #3870 (eventually, hopefully)

@ali-ramadhan ali-ramadhan changed the title Band aid for Float32 CATKE Fixing Float32 CATKE on GPU Oct 28, 2024
@glwagner
Copy link
Member

Just some minor comments:

  1. Use eltype(grid) instead of a type parameter, following YASGuide
  2. Does the annotation ::FT work?

@navidcy navidcy added bug 🐞 Even a perfect program still has bugs turbulence closures 🎐 GPU 👾 Where Oceananigans gets its powers from labels Oct 28, 2024
@ali-ramadhan ali-ramadhan marked this pull request as ready for review October 29, 2024 18:07
@ali-ramadhan
Copy link
Member Author

Thanks for the review @glwagner. I made the changes and also added a test that fails (well CUDA crashes) without this PR, and passes with this PR.

Does the annotation ::FT work?

Unfortunately not. I ended up getting GPU exceptions instead: #3870 (comment)

@@ -228,7 +227,7 @@ end
Jᵇᵋ = closure.minimum_convective_buoyancy_flux
Jᵇᵢⱼ = @inbounds Jᵇ[i, j, 1]
Jᵇ⁺ = max(Jᵇᵋ, Jᵇᵢⱼ, Jᵇ★) # selects fastest (dominant) time-scale
t★ = (ℓᴰ^2 / Jᵇ⁺)^(1/3)
t★ = cbrt(ℓᴰ^2 / Jᵇ⁺)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you check if this fixes the issue without the need for extra convert?

It will be good to avoid "over converting", because this could cause us to fail to catch spurious promotion which will hurt performance (eg removing some of the benefit of using Float32)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked but it was not enough unfortunately.

Agree that it would be good to not over convert.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The point is just that you may lose much of the advantage of using Float32 in the first place if you throw convert around

Comment on lines 265 to 268
κu★ = min(κu, κu_max)
FT = eltype(grid)
return convert(FT, κu★)
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this PR I suggest changing this to a a type annotation (κu★::FT) since this is what we will want in the long run

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately that does not work, but maybe I can pinpoint where the conversion is happening then we won't need any conversion here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I am not suggesting this as a solution, but rather as a way to catch a bug in the future.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I don't think I actually know how type annotations work. So if we say

return κu★::FT

and κu★ is not of type FT then an error/exception will be thrown?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so, isn't that what you found?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bug checking mechanism which may be more broadly useful as we try to prevent spurious promotion

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No idea why I thought it was a different mechanism to convert types lol. But yes that is what I found.

@glwagner
Copy link
Member

I think to preserve the work in this PR, we should add a Float32 test which will fail if a spurious promotion undermines performance

@ali-ramadhan
Copy link
Member Author

I think to preserve the work in this PR, we should add a Float32 test which will fail if a spurious promotion undermines performance

Agreed. I'll revisit this PR later to see if I can find where the conversion happens. The test I added only checks to see if we can take a time step. But I should be able to also add a test to ensure no spurious promotion occurred.

@glwagner
Copy link
Member

glwagner commented Oct 29, 2024

I think to preserve the work in this PR, we should add a Float32 test which will fail if a spurious promotion undermines performance

Agreed. I'll revisit this PR later to see if I can find where the conversion happens. The test I added only checks to see if we can take a time step. But I should be able to also add a test to ensure no spurious promotion occurred.

Ah, that will work as a test if we remove the convert.

The convert is a good sanity check to find where the problem is, but its not a solution since it merely allows the code to run without error --- it doesn't actually allow us to realize the benefits of using Float32. Arguably with this it is actually worse to use Float32, since the numerics are degraded bbut the perfrmance benefit is not fully realized

@ali-ramadhan
Copy link
Member Author

Following #3870 (comment) this PR now just changes how grid coordinate ranges are constructed. Curious to see if any tests fail. But locally it fixed CATKE + Float32.

@glwagner I ended up doing this if it looks okay:

    κu★ = min(κu, κu_max)
    FT = eltype(grid)
    return κu★::FT

I assume there's a small cost associated with the type annotation ::FT?

@ali-ramadhan ali-ramadhan changed the title Fixing Float32 CATKE on GPU Make grid coordinate ranges type-safe and fix Float32 CATKE on GPU Oct 30, 2024
Comment on lines 117 to 118
F = StepRangeLen{FT, FT, FT, Int}(F)
C = StepRangeLen{FT, FT, FT, Int}(C)
Copy link
Member

@glwagner glwagner Oct 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it is correct to describe this change as "making the grid coordinate ranges type-safe".

Perhaps more conservatively, we only require the first type parameter of StepRangeLen to be FT. The second two should be twice precision; eg Float64 if the first is Float32, or TwicePrecision{Float64} if FT=Float64.

Now, if we want to support non-standard ranges, we can perhaps consider that. But I suspect there is something going on that isn't completely understood here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is discussed on #3870, but just to express the problem here: we expect the ranges to have the type StepRangeLen{Float32, Float64, Float64, Int} for a Float32 grid. Therefore, this PR does some damage by producing an unexpected range type.

It is unexpected or a bug that a range of type StepRangeLen{Float32, Float64, Float64, Int} would produce Float64. Therefore the first course of action is to verify that StepRangeLen{Float32, Float64, Float64, Int} is producing Float64 --- on either CPU or GPU.

I have a hunch that StepRangeLen is somehow adapted for GPU incorrectly. Or, if it is intentional, then I think we should fix the output of xnode, ynode, znode...

If there's good motivation for reducing the precision of ranges (either in addition or alternatively to the above suggestion) then I think we can entertain it. I'm not sure we want to hardcode this change though, it might be better to provide it as an option. The imprecision of ranges is already a bit annoying.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this makes the test pass and the MWE in #3870 not error, maybe the hint is that StepRangeLen just needs to be adapted for the GPU?

But yeah I'm not sure why StepRangeLen{Float32, Float64, Float64, Int} seems to produce Float64 numbers in GPU kernels. Perhaps this is worth opening an issue on CUDA.jl with a MWE?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, though it is not for us to adapt (that would be type piracy) because we own neither Adapt nor StepRangeLen. Plus it seems to be adapted, so what is happening...

https://github.com/JuliaGPU/Adapt.jl/blob/5ef7c5329609df7ffb5b19942d6747b3dcc162c2/src/base.jl#L79-L80

@glwagner
Copy link
Member

This may help JuliaGPU/Adapt.jl#88

@ali-ramadhan ali-ramadhan changed the title Make grid coordinate ranges type-safe and fix Float32 CATKE on GPU Update Adapt.jl compat and fix Float32 CATKE on GPU Oct 30, 2024
@ali-ramadhan
Copy link
Member Author

I just changed the Adapt.jl compat entry to make use of the new version with the StepRangeLen fix. The MWE from #3870 does not error with Adapt.jl v4.1.1 locally.

@ali-ramadhan ali-ramadhan changed the title Update Adapt.jl compat and fix Float32 CATKE on GPU (0.93.2) Update Adapt.jl compat and fix Float32 CATKE on GPU Oct 30, 2024
@@ -43,7 +43,7 @@ OceananigansEnzymeExt = "Enzyme"
OceananigansMakieExt = ["MakieCore", "Makie"]

[compat]
Adapt = "3, 4"
Adapt = "^4.1.1"
Copy link
Collaborator

@navidcy navidcy Oct 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Adapt = "^4.1.1"
Adapt = "4.1.1"

(same but cleaner)

Copy link
Collaborator

@navidcy navidcy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is heroic debugging!!

@navidcy
Copy link
Collaborator

navidcy commented Oct 30, 2024

you can forget my minor suggestion to remove ^! let's merge this!

@ali-ramadhan
Copy link
Member Author

Haha it did take a while but with a satisfying ending!

And thanks for the suggestion! I didn't realize that 4.1.1 and ^4.1.1 would be the same here. But since it's okay with you, I'll merge to avoid waiting on another round of tests to pass 🙃

@ali-ramadhan ali-ramadhan merged commit f2a8fb3 into main Oct 30, 2024
46 checks passed
@ali-ramadhan ali-ramadhan deleted the ali/fix-catke-f32 branch October 30, 2024 20:28
@navidcy
Copy link
Collaborator

navidcy commented Oct 30, 2024

to avoid waiting on another round of tests to pass 🙃

Exactly! Takes for ever...!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug 🐞 Even a perfect program still has bugs GPU 👾 Where Oceananigans gets its powers from turbulence closures 🎐
Projects
None yet
3 participants