Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AssertionError: ClassiferFreeGuidancePlugin requires embedding #71

Open
gg4u opened this issue Jun 18, 2023 · 0 comments
Open

AssertionError: ClassiferFreeGuidancePlugin requires embedding #71

gg4u opened this issue Jun 18, 2023 · 0 comments

Comments

@gg4u
Copy link

gg4u commented Jun 18, 2023

Hi, I test the example you gave for conditioning on text, but got error:

# Train model with audio waveforms
audio_wave = torch.randn(1, 2, 2**18) # [batch, in_channels, length]
loss = model(
    audio_wave,
    text=['The audio description'], # Text conditioning, one element per batch
    embedding_mask_proba=0.1 # Probability of masking text with learned embedding (Classifier-Free Guidance Mask)
)
loss.backward()

# Turn noise into new audio sample with diffusion
noise = torch.randn(1, 2, 2**18)
sample = model.sample(
    noise,
    text=['The audio description'],
    embedding_scale=5.0, # Higher for more text importance, suggested range: 1-15 (Classifier-Free Guidance Scale)
    num_steps=2 # Higher for better quality, suggested num_steps: 10-100
)

Error:

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
Cell In[49], line 3
      1 # Train model with audio waveforms
      2 audio_wave = torch.randn(1, 2, 2**18) # [batch, in_channels, length]
----> 3 loss = model(
      4     audio_wave,
      5     text=['The audio description'], # Text conditioning, one element per batch
      6     embedding_mask_proba=0.1 # Probability of masking text with learned embedding (Classifier-Free Guidance Mask)
      7 )
      8 loss.backward()
     10 # Turn noise into new audio sample with diffusion

File /data0/home/h21/luas6629/dummy/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File /data0/home/h21/luas6629/dummy/lib/python3.10/site-packages/audio_diffusion_pytorch/models.py:40, in DiffusionModel.forward(self, *args, **kwargs)
     39 def forward(self, *args, **kwargs) -> Tensor:
---> 40     return self.diffusion(*args, **kwargs)

File /data0/home/h21/luas6629/dummy/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File /data0/home/h21/luas6629/dummy/lib/python3.10/site-packages/audio_diffusion_pytorch/diffusion.py:93, in VDiffusion.forward(self, x, **kwargs)
     91 v_target = alphas * noise - betas * x
     92 # Predict velocity and return loss
---> 93 v_pred = self.net(x_noisy, sigmas, **kwargs)
     94 return F.mse_loss(v_pred, v_target)

File /data0/home/h21/luas6629/dummy/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File /data0/home/h21/luas6629/dummy/lib/python3.10/site-packages/a_unet/blocks.py:63, in Module.<locals>.Module.forward(self, *args, **kwargs)
     62 def forward(self, *args, **kwargs):
---> 63     return forward_fn(*args, **kwargs)

File /data0/home/h21/luas6629/dummy/lib/python3.10/site-packages/a_unet/blocks.py:594, in TimeConditioningPlugin.<locals>.Net.<locals>.forward(x, time, features, **kwargs)
    592 # Merge time features with features if provided
    593 features = features + time_features if exists(features) else time_features
--> 594 return net(x, features=features, **kwargs)

File /data0/home/h21/luas6629/dummy/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File /data0/home/h21/luas6629/dummy/lib/python3.10/site-packages/a_unet/blocks.py:63, in Module.<locals>.Module.forward(self, *args, **kwargs)
     62 def forward(self, *args, **kwargs):
---> 63     return forward_fn(*args, **kwargs)

File /data0/home/h21/luas6629/dummy/lib/python3.10/site-packages/a_unet/blocks.py:534, in ClassifierFreeGuidancePlugin.<locals>.Net.<locals>.forward(x, embedding, embedding_scale, embedding_mask_proba, **kwargs)
    526 def forward(
    527     x: Tensor,
    528     embedding: Optional[Tensor] = None,
   (...)
    531     **kwargs,
    532 ):
    533     msg = "ClassiferFreeGuidancePlugin requires embedding"
--> 534     assert exists(embedding), msg
    535     b, device = embedding.shape[0], embedding.device
    536     embedding_mask = fixed_embedding(embedding)

AssertionError: ClassiferFreeGuidancePlugin requires embedding

Is it about dependencies ?
What dependencies am I supposed to install ?

P.s. can you please show two simple colab examples:

  • to train on own wav files
  • use pretrained networks and finetune on own wav files

I am trying to understand how to condition on text to validate research idea in bioacoustics, but not have a strong foundations to well understand yet your code, so a tutorial would be really helpful.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant