Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Weird bug for BroadcastPositionBiases #4

Open
alfaevc opened this issue Jan 29, 2024 · 0 comments
Open

Weird bug for BroadcastPositionBiases #4

alfaevc opened this issue Jan 29, 2024 · 0 comments

Comments

@alfaevc
Copy link

alfaevc commented Jan 29, 2024

Seems that there is a dimensionality issue here when I run train_videogpt.py ?

jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
File "/gpfs/data/oermannlab/users/qp2040/viper_rl/scripts/train_videogpt.py", line 272, in
main()
File "/gpfs/data/oermannlab/users/qp2040/viper_rl/scripts/train_videogpt.py", line 79, in main
visualize(sampler, ae, iteration, state, test_loader)
File "/gpfs/data/oermannlab/users/qp2040/viper_rl/scripts/train_videogpt.py", line 225, in visualize
samples = sampler(variables, batch).copy()
File "/gpfs/data/oermannlab/users/qp2040/viper_rl/viper_rl/videogpt/sampler.py", line 90, in call
_, cache = self._model_step(
File "/gpfs/data/oermannlab/users/qp2040/viper_rl/viper_rl/videogpt/sampler.py", line 38, in fn
logits, cache = self.model.apply(
File "/gpfs/data/oermannlab/users/qp2040/viper_rl/viper_rl/videogpt/models/videogpt.py", line 38, in call
return self.model(
File "/gpfs/data/oermannlab/users/qp2040/viper_rl/viper_rl/videogpt/models/transformer.py", line 31, in call
position_bias = BroadcastPositionBiases(shape=self.shape)(x)
File "/gpfs/data/oermannlab/users/qp2040/viper_rl/viper_rl/videogpt/models/transformer.py", line 188, in call
embs = [
File "/gpfs/data/oermannlab/users/qp2040/viper_rl/viper_rl/videogpt/models/transformer.py", line 189, in
self.param(f'd
{i}', nn.initializers.normal(stddev=0.02),
flax.errors.ScopeParamShapeError: Initializer expected to generate shape (16, 85) but got shape (8, 85) instead for parameter "d_1" in "/model/BroadcastPositionBiases_0". (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.ScopeParamShapeError)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant