Skip to content

Commit

Permalink
Merge pull request #87 from ddlBoJack/dev-mzy
Browse files Browse the repository at this point in the history
Fix a bug on modality mask
  • Loading branch information
ddlBoJack authored May 21, 2024
2 parents ba10359 + 92f21a1 commit 386e768
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions src/slam_llm/models/slam_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,13 +367,14 @@ def forward(self,
inputs_embeds = self.llm.model.model.model.embed_tokens(input_ids)

if modality_mask is not None:
modality_unmask_start = (modality_mask == True).float().argmax(dim=1)
modality_mask_start_indices = (modality_mask == True).float().argmax(dim=1)
modality_lengths = torch.clamp(modality_mask.sum(dim=1), max=encoder_outs.shape[1]).tolist()

encoder_outs_pad = torch.zeros_like(inputs_embeds)
for i in range(encoder_outs.shape[0]):
encoder_outs_pad[
i, modality_unmask_start[i]:modality_unmask_start[i]+modality_mask[i].sum().item()
] = encoder_outs[i]
i, modality_mask_start_indices[i]:modality_mask_start_indices[i]+modality_lengths[i]
] = encoder_outs[i][:modality_lengths[i]]

inputs_embeds = encoder_outs_pad + inputs_embeds * (~modality_mask[:, :, None])

Expand Down

0 comments on commit 386e768

Please sign in to comment.