You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hello! Can I know if I need to do any changes to the scripts before using the checkpoints for evaluation?
I downloaded the vmoe_b16_imagenet21k_randaug_strong_ft_cifar10 checkpoint files (both .index and .data-00000-of-00001 ) and named them ckpt_1.index and ckpt_1.data-00000-of-00001, respectively. Also, for running the script on a single partition I changed the "num_expert_partitions" of config dict to 1.
With the above changes, when I try to run the script on google colab using below command: python vmoe/main.py --workdir=./vmoe/saved_checkpoints --config=vmoe/configs/vmoe_paper/vmoe_b16_imagenet21k_randaug_strong_ft_cifar10.py
(where saved_checkpoints is the directory where the checkpoint files are present)
I get the following error: ValueError: Missing field step in state dict while restoring an instance of TrainState
Any help is appreciated!!
Complete error stack:
Traceback (most recent call last):
File "vmoe/main.py", line 77, in <module>
app.run(main)
File "/usr/local/lib/python3.7/dist-packages/absl/app.py", line 312, in run
_run_main(main, args)
File "/usr/local/lib/python3.7/dist-packages/absl/app.py", line 258, in _run_main
sys.exit(main(argv))
File "vmoe/main.py", line 70, in main
trainer.train_and_evaluate(config=FLAGS.config, workdir=FLAGS.workdir)
File "./vmoe/train/trainer.py", line 527, in train_and_evaluate
return _train_and_evaluate(config, workdir, mesh)
File "./vmoe/train/trainer.py", line 575, in _train_and_evaluate
thread_pool=ThreadPool())
File "./vmoe/train/trainer.py", line 291, in restore_or_create_train_state
thread_pool=thread_pool)
File "./vmoe/checkpoints/partitioned.py", line 83, in restore_checkpoint
'index': tree if tree is not None else axis_resources,
File "./vmoe/checkpoints/base.py", line 137, in restore_checkpoint
return serialization.from_bytes(tree, checkpoint_contents)
File "./vmoe/checkpoints/serialization.py", line 71, in from_bytes
return from_state_dict(target, state_dict)
File "/usr/local/lib/python3.7/dist-packages/flax/serialization.py", line 65, in from_state_dict
return ty_from_state_dict(target, state)
File "/usr/local/lib/python3.7/dist-packages/flax/serialization.py", line 128, in _restore_dict
for key, value in xs.items()}
File "/usr/local/lib/python3.7/dist-packages/flax/serialization.py", line 128, in <dictcomp>
for key, value in xs.items()}
File "/usr/local/lib/python3.7/dist-packages/flax/serialization.py", line 65, in from_state_dict
return ty_from_state_dict(target, state)
File "/usr/local/lib/python3.7/dist-packages/flax/struct.py", line 146, in from_state_dict
raise ValueError(f'Missing field {name} in state dict while restoring'
ValueError: Missing field step in state dict while restoring an instance of TrainState
The text was updated successfully, but these errors were encountered:
The checkpoints that we released do not contain the training state, only the model parameters. You can use them using the function initialize_from_vmoe_release:
Thanks a lot @jpuigcerver ! I am able to load the checkpoint now, using the mentioned function.
(I also had to change the "keep" key in config dict of respective checkpoint to load the entire checkpoint).
Hello! Can I know if I need to do any changes to the scripts before using the checkpoints for evaluation?
I downloaded the
vmoe_b16_imagenet21k_randaug_strong_ft_cifar10
checkpoint files (both .index and .data-00000-of-00001 ) and named themckpt_1.index
andckpt_1.data-00000-of-00001
, respectively. Also, for running the script on a single partition I changed the "num_expert_partitions" of config dict to 1.With the above changes, when I try to run the script on google colab using below command:
python vmoe/main.py --workdir=./vmoe/saved_checkpoints --config=vmoe/configs/vmoe_paper/vmoe_b16_imagenet21k_randaug_strong_ft_cifar10.py
(where
saved_checkpoints
is the directory where the checkpoint files are present)I get the following error:
ValueError: Missing field step in state dict while restoring an instance of TrainState
Any help is appreciated!!
Complete error stack:
The text was updated successfully, but these errors were encountered: