Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

some issues with data loading and training #34

Open
RKS-labs opened this issue Jan 7, 2025 · 1 comment
Open

some issues with data loading and training #34

RKS-labs opened this issue Jan 7, 2025 · 1 comment

Comments

@RKS-labs
Copy link

RKS-labs commented Jan 7, 2025

Here are some problems I encounter when exploring this project:

  • It turns out that argparse can't parse manually passed boolean values for --use_wandb. See e. g. https://stackoverflow.com/questions/60999816/argparse-not-parsing-boolean-arguments. And I hope that the logging can be improved when wandb is disabled, since there can be servers with no internet access. The logging looks good in v1.0-lrs-gems branch, but the use_wandb parameter is missing in config.yaml.

  • Also, when specifying a train_file and a valid_file with different max atom counts, an error is raised which relates to the concatenation of these two datasets. They are padded to their respective max atom counts which do not match. --n_valid 1 also seems to raise an error, likely related to incorrect tensor shapes. And I'm wondering if I can turn off validation/test entirely and do it by hand with ASE or something to make sure what is being calculated. Just save the checkpoints every n-th epoch and don't let the valid_ds prematurely discard the checkpoints.

  • When batch size is very large, the error message is a bit weird with no reference to CUDA OOM or something similar. (It costs me some time digging into the code to realize it is a batch size problem, since the default batch size scales with the training set size.)

Traceback (most recent call last):
  File "/miniconda3/envs/so3krates/bin/train_so3krates", line 33, in <module>
    sys.exit(load_entry_point('mlff', 'console_scripts', 'train_so3krates')())
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mlff/mlff/cAPI/mlff_train_so3krates.py", line 482, in train_so3krates
    coach.run(train_state=train_state,
  File "/mlff/mlff/training/coach.py", line 60, in run
    run_training(state=train_state,
  File "/mlff/mlff/training/run.py", line 194, in run_training
    best_valid_metrics, _ = valid_epoch(state, valid_ds, metric_fn, bs=valid_bs)
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mlff/mlff/training/run.py", line 86, in valid_epoch
    epoch_metrics_np = {k: np.mean([metrics[k] for metrics in batch_metrics_np]) for k in batch_metrics_np[0]}
                                                                                          ~~~~~~~~~~~~~~~~^^^
IndexError: list index out of range
  • Training using structures with varying atom counts doesn't seem to work. I'm using a home-made dataset of monometallic structures, and I only obtain reasonable results when trimming the dataset down to only $\text{M}_n$ clusters with a fixed $n$, e.g. $n=55$. If I add more structures with different $n$, the evaluated energies deviate drastically, and MD just explodes after a few steps.

  • Update: okay I found the v1.0-lrs-gems branch. The function calls to jraph have arguments such as n_pairs which are not in the function signature (I'm on jraph 0.0.6.dev0). Simply removing them seems just fine, and the sparse model works now (at least for non-periodic structures)! I'm experiencing a ~0.5x slowdown, though, for a completely local model w.r.t. the default model with similar hyperparameters.

  • Update: Apart from the neighborlist rebuild issue Impossible to make prediction on dataset with different size molecules using the ASE calculator general-molecular-simulations/so3lr#9, the mlffCalculatorSparse ASE calculator throws errors for cells of which atoms.cell.cellpar[:3] are smaller than r_cut, referring to glp. For a simple bcc/hcp metal, I have to use 5x5x5 supercells to calculate (while producing qualitatively incorrect results, is PBC not supported for training?)

[glp]           ** warning: total cutoff 4.50 does not fit within simulation cell (needs to be < 3.90)               
[glp]           ** this will yield incorrect results!               
[glp]           ** suggestion: N is 44, cl.id is 210 consider using different buffer_size_multiplier, so that cl.id is slightly bigger than N. Now it is 1)               
[glp]           ** warning: total cutoff 4.50 does not fit within simulation cell (needs to be < 3.90)               
[glp]           ** this will yield incorrect results!               
[glp]           ** suggestion: N is 44, cl.id is 210 consider using different buffer_size_multiplier, so that cl.id is slightly bigger than N. Now it is 1)   

I'm sorry for not patching these through PRs as I'm quite busy rn. I haven't got a model for production runs yet, but the speed is impressive. Fantastic work!

@RKS-labs RKS-labs changed the title --use_wandb False doesn't work; and other issues with datasets and training some issues with data loading and training Jan 8, 2025
@thorben-frank
Copy link
Owner

Hi, thank you for digging into this and collecting the different issues. As you have already figured out yourself, there is a new version on the way in v1.0 or v1.0-lrs-gems (the latter is by now merged into v1.0) but I have not fully completed that. And there is not proper docs at the moment. I hope I will be able to complete this within the next weeks.

Thanks, you are right, I will add use_wandb to config.

So are you using train_so3krates_sparse at the moment?

How comes that you assume PBC is not supported by training? I think the warning you describe comes from the mlffCaculatorSparse, which is used for the ASE calculator. During training, it is not used. But maybe I missunderstood?

How do you create the mlffCaculatorSparse? How does your config.yaml looks you are using during training?

Regarding the jraph and n_pairs argument. This is, because the new version supports long_range indices in addition to short-range indices (senders and receivers in jraph). n_pairs gives the number of global indices.

Best,
Thorben

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants