Skip to content
This repository has been archived by the owner on Oct 25, 2024. It is now read-only.

Commit

Permalink
fix bloom
Browse files Browse the repository at this point in the history
Signed-off-by: changwangss <[email protected]>
  • Loading branch information
changwangss committed Jun 26, 2024
1 parent 8c0242e commit ca4aa2e
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@ def _reorder_cache(
This is required to match `past_key_values` with the correct beam_idx at every generation step.
"""
if self.config.model_type == "bloom":
return self._reorder_cache_bloom(past_key_values, beam_idx)

if self.config.model_type == "chatglm":
return tuple(
tuple(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -946,11 +946,7 @@ def collate_batch(batch):
)

last_ind.append(input_ids.shape[0] - 1)
if model_type in ["bloom"]:
attention_mask = torch.ones(len(input_ids) + 1)
attention_mask[0] = 0
else:
attention_mask = torch.ones(len(input_ids))
attention_mask = torch.ones(len(input_ids))
position_ids = torch.arange(len(input_ids))
input_ids_padded.append(input_ids)
attention_mask_padded.append(attention_mask)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -375,11 +375,7 @@ def get_example_inputs(model_config, batch_size=1, tokenizer=None, num_beams=4):
past_key_values = generate_dummy_past_key_values(config=model_config, input_bs=batch_size)

input_ids = input_ids[:, :512]
if model_type in ["bloom", "qwen"]:
attention_mask = torch.ones(input_ids.shape[0], input_ids.shape[1] + 1)
attention_mask[:,0] = 0
else:
attention_mask = torch.ones(input_ids.shape)
attention_mask = torch.ones(input_ids.shape)
position_ids = torch.arange(input_ids.shape[1]).repeat(batch_size, 1)

if model_type in MODEL_TYPES_REQUIRING_POSITION_IDS:
Expand Down

0 comments on commit ca4aa2e

Please sign in to comment.