-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: llava next hf implementation #170
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(24.04.23)
- text only data에 대한 학습 코드 반영 X
- 멀티이미지 학습에 대한 코드 반영 X
} | ||
|
||
|
||
class LlavaNextConfig(PretrainedConfig): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Config 부터 구경
>>> # Initializing a CLIP-vision config | ||
>>> vision_config = CLIPVisionConfig() | ||
|
||
>>> # Initializing a Llama config | ||
>>> text_config = LlamaConfig() | ||
|
||
>>> # Initializing a Llava-Next llava-hf/llava-v1.6-mistral-7b-hf style configuration | ||
>>> configuration = LlavaNextConfig(vision_config, text_config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CLIPVisionConfig를 vision_config로 LlamaConfig를 text_config로 줌. nested 형태
def __init__( | ||
self, | ||
vision_config=None, | ||
text_config=None, | ||
ignore_index=-100, | ||
image_token_index=32000, | ||
projector_hidden_act="gelu", | ||
vision_feature_select_strategy="default", | ||
vision_feature_layer=-2, | ||
image_grid_pinpoints=None, | ||
**kwargs, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
추가적으로 받아주는건 image_token_index와 vision_feature_select_strategy / vision_feature_layer(몇번째 레이어 feautre 쓸건지) / image_grid_pinpoints
image_grid_pinpoints = ( | ||
image_grid_pinpoints | ||
if image_grid_pinpoints is not None | ||
else [[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]] | ||
) | ||
self.image_grid_pinpoints = image_grid_pinpoints |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
image_grid_pinpoint가 지정되지 않으면 쓰이는 기본 grid들은 [[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]]
if isinstance(vision_config, dict): | ||
vision_config["model_type"] = ( | ||
vision_config["model_type"] if "model_type" in vision_config else "clip_vision_model" | ||
) | ||
vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
vision_config가 지정되지 않으면 model_type으로 CONFIG_MAPPING dict가 넣어줌.
vision_config를 input으로 받고 config 객체에 **vision_config로 넣어주는 식으로 하네.
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features( | ||
image_features, inputs_embeds, input_ids, attention_mask, labels | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
얘네는 image_feature를 input_embeds에 넣어주는 부분을 함수화 해놓음.
if labels is None: | ||
labels = torch.full_like(attention_mask, self.config.ignore_index).to(torch.long) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
label이 none 일 경우 아래와 같이 처리
|
||
# In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of | ||
# generation with cache | ||
elif past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
elif 는 cache를 위한 구문임.
text only 데이터를 학습하기 위한 코드는 없음.!
outputs = self.language_model( | ||
attention_mask=attention_mask, | ||
position_ids=position_ids, | ||
past_key_values=past_key_values, | ||
inputs_embeds=inputs_embeds, | ||
use_cache=use_cache, | ||
output_attentions=output_attentions, | ||
output_hidden_states=output_hidden_states, | ||
return_dict=return_dict, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
language_model에 input_embeds 넣어줌.
loss = None | ||
if labels is not None: | ||
# Shift so that tokens < n predict n | ||
if attention_mask is not None: | ||
shift_attention_mask = attention_mask[..., 1:] | ||
shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous() | ||
shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous() | ||
else: | ||
shift_logits = logits[..., :-1, :].contiguous() | ||
shift_labels = labels[..., 1:].contiguous() | ||
# Flatten the tokens | ||
loss_fct = nn.CrossEntropyLoss() | ||
loss = loss_fct( | ||
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device) | ||
) | ||
|
||
if not return_dict: | ||
output = (logits,) + outputs[1:] | ||
return (loss,) + output if loss is not None else output | ||
|
||
return LlavaNextCausalLMOutputWithPast( | ||
loss=loss, | ||
logits=logits, | ||
past_key_values=outputs.past_key_values, | ||
hidden_states=outputs.hidden_states, | ||
attentions=outputs.attentions, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
label이 None이 아닌 경우 Loss를 구하고 output
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration._merge_input_ids_with_image_features | ||
def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels): | ||
num_images, num_image_patches, embed_dim = image_features.shape | ||
batch_size, sequence_length = input_ids.shape | ||
left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id)) | ||
# 1. Create a mask to know where special image tokens are | ||
special_image_token_mask = input_ids == self.config.image_token_index | ||
num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1) | ||
# Compute the maximum embed dimension | ||
max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length | ||
batch_indices, non_image_indices = torch.where(input_ids != self.config.image_token_index) | ||
|
||
# 2. Compute the positions where text should be written | ||
# Calculate new positions for text tokens in merged image-text sequence. | ||
# `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens. | ||
# `torch.cumsum` computes how each image token shifts subsequent text token positions. | ||
# - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one. | ||
new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1 | ||
nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1] | ||
if left_padding: | ||
new_token_positions += nb_image_pad[:, None] # offset for left padding | ||
text_to_overwrite = new_token_positions[batch_indices, non_image_indices] | ||
|
||
# 3. Create the full embedding, already padded to the maximum position | ||
final_embedding = torch.zeros( | ||
batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device | ||
) | ||
final_attention_mask = torch.zeros( | ||
batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device | ||
) | ||
if labels is not None: | ||
final_labels = torch.full( | ||
(batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device | ||
) | ||
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually | ||
# set the corresponding tensors into their correct target device. | ||
target_device = inputs_embeds.device | ||
batch_indices, non_image_indices, text_to_overwrite = ( | ||
batch_indices.to(target_device), | ||
non_image_indices.to(target_device), | ||
text_to_overwrite.to(target_device), | ||
) | ||
attention_mask = attention_mask.to(target_device) | ||
|
||
# 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"] | ||
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features | ||
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices] | ||
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices] | ||
if labels is not None: | ||
final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices] | ||
|
||
# 5. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling | ||
image_to_overwrite = torch.all(final_embedding == 0, dim=-1) | ||
image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device) | ||
|
||
if image_to_overwrite.sum() != image_features.shape[:-1].numel(): | ||
raise ValueError( | ||
f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while" | ||
f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation." | ||
) | ||
|
||
final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device) | ||
final_attention_mask |= image_to_overwrite | ||
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1) | ||
|
||
# 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens. | ||
batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id) | ||
indices_to_mask = new_token_positions[batch_indices, pad_indices] | ||
|
||
final_embedding[batch_indices, indices_to_mask] = 0 | ||
|
||
if labels is None: | ||
final_labels = None | ||
|
||
return final_embedding, final_attention_mask, final_labels, position_ids |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👀
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
.
(24.04.23)
LLaVA NeXT reading.