-
Notifications
You must be signed in to change notification settings - Fork 48
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
25 normalize data on gpu #39
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a lot cleaner and also good to move computations from CPU to GPU. There is a question mark about forcing standardization that should be double-checked. In the future we might want to handle standardization of forcing the same for all forcing variables, or have a more robust way to control this (i.e. which forcing should be standardized or not), but let's leave that to then.
init_states, target_states, forcing_features = batch | ||
init_states = (init_states - self.data_mean) / self.data_std | ||
target_states = (target_states - self.data_mean) / self.data_std | ||
forcing_features = (forcing_features - self.flux_mean) / self.flux_std |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now all forcing seem to be normalized with the flux statistics, but there is more forcing than flux. Note how this was only applied to the flux in WeatherDataset
before.
Have you tested that this gives exactly the same tensors as before? (e.g. save the first batch to disk on main, check this out, save first batch and compare).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
True, the forcings are now handled differently. I suggest to implement a new logic to handle forcings in #54. The user can define combined_vars that share statistics and also define vars that should not be normalized.
I merged with the latest updates from main @leifdenby and commented on your suggestions @joeloskarsson. In the following I want to show that the output tensor was not affected by this change. As you suggested I stored the last batch for both gpu/cpu normalization with Then I compared the tensors like this: So except for the forcings tensor which are handled differently the two approaches create identical output. |
How do you think we should progress with this @sadamov ? If the forcing is handled differently in #54, would it make more sense to try to merge this after that? (I guess baking this change into #54 would just make it even bigger). Or should we merge this first so that #54 can build on it and be adapted to use it? I would not be happy to merge this without fixing so also the forcing tensors match the previous implementation. But we could just do a quick fix for now so the standardization is only applied to the flux dimensions of the forcing tensor? |
I propose to merge #54 first and leave this open until then. We now know that |
As we look at this after #66 we should think about more options for rescaling of the different variables. In #66 all variables (including state/forcing/static) are standardized. There are benefits to allowing for all three of:
for different variables. We do however need a way to specify what should be used for each variable, as well as computing the needed data statistics. Some of this might have to be done in mllam_data_prep, but at least the final computation should be in scope for this PR, as it is what the code has to do on the GPU. |
Summary
This PR introduces
on_after_batch_transfer
logic to normalize data on GPU instead of on CPU before transfer.Rationale
Normalization is faster on GPU than CPU. In the current code the data was normalized in the pytorch dataset class in the get_item method potentially slowing down training; especially on systems with fewer CPU cores.
Changes
on_after_batch_transfer
in thear_model.py
scriptcreate_parameter_weights.py
script to work with the new changes (not reloading standardized dataset)Testing
Both training and evaluation was successful. The training loss of 3.230 on the meps_example is identical to before the changes. The create_parameter_weights script was executed to successfully generate the stats.
Not-In-Scope
The normalization stats and other static features will all become zarr archives in the future. Their path defined in the data_config.yaml file.