Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: llava next hf implementation #170

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

Conversation

long8v
Copy link
Owner

@long8v long8v commented Apr 23, 2024

(24.04.23)
LLaVA NeXT reading.

Copy link
Owner Author

@long8v long8v left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(24.04.23)

  • text only data에 대한 학습 코드 반영 X
  • 멀티이미지 학습에 대한 코드 반영 X

}


class LlavaNextConfig(PretrainedConfig):
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Config 부터 구경

Comment on lines +64 to +71
>>> # Initializing a CLIP-vision config
>>> vision_config = CLIPVisionConfig()

>>> # Initializing a Llama config
>>> text_config = LlamaConfig()

>>> # Initializing a Llava-Next llava-hf/llava-v1.6-mistral-7b-hf style configuration
>>> configuration = LlavaNextConfig(vision_config, text_config)
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CLIPVisionConfig를 vision_config로 LlamaConfig를 text_config로 줌. nested 형태

Comment on lines +83 to +93
def __init__(
self,
vision_config=None,
text_config=None,
ignore_index=-100,
image_token_index=32000,
projector_hidden_act="gelu",
vision_feature_select_strategy="default",
vision_feature_layer=-2,
image_grid_pinpoints=None,
**kwargs,
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

추가적으로 받아주는건 image_token_index와 vision_feature_select_strategy / vision_feature_layer(몇번째 레이어 feautre 쓸건지) / image_grid_pinpoints

Comment on lines +107 to +112
image_grid_pinpoints = (
image_grid_pinpoints
if image_grid_pinpoints is not None
else [[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]]
)
self.image_grid_pinpoints = image_grid_pinpoints
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image_grid_pinpoint가 지정되지 않으면 쓰이는 기본 grid들은 [[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]]

Comment on lines +114 to +118
if isinstance(vision_config, dict):
vision_config["model_type"] = (
vision_config["model_type"] if "model_type" in vision_config else "clip_vision_model"
)
vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vision_config가 지정되지 않으면 model_type으로 CONFIG_MAPPING dict가 넣어줌.
vision_config를 input으로 받고 config 객체에 **vision_config로 넣어주는 식으로 하네.

Comment on lines +555 to +557
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
image_features, inputs_embeds, input_ids, attention_mask, labels
)
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

얘네는 image_feature를 input_embeds에 넣어주는 부분을 함수화 해놓음.

Comment on lines +558 to +559
if labels is None:
labels = torch.full_like(attention_mask, self.config.ignore_index).to(torch.long)
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

label이 none 일 경우 아래와 같이 처리


# In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
# generation with cache
elif past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1:
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

elif 는 cache를 위한 구문임.
text only 데이터를 학습하기 위한 코드는 없음.!

Comment on lines +594 to +603
outputs = self.language_model(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

language_model에 input_embeds 넣어줌.

Comment on lines +607 to +633
loss = None
if labels is not None:
# Shift so that tokens < n predict n
if attention_mask is not None:
shift_attention_mask = attention_mask[..., 1:]
shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
else:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
)

if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output

return LlavaNextCausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

label이 None이 아닌 경우 Loss를 구하고 output

Comment on lines +355 to +429
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration._merge_input_ids_with_image_features
def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels):
num_images, num_image_patches, embed_dim = image_features.shape
batch_size, sequence_length = input_ids.shape
left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
# 1. Create a mask to know where special image tokens are
special_image_token_mask = input_ids == self.config.image_token_index
num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
# Compute the maximum embed dimension
max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length
batch_indices, non_image_indices = torch.where(input_ids != self.config.image_token_index)

# 2. Compute the positions where text should be written
# Calculate new positions for text tokens in merged image-text sequence.
# `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens.
# `torch.cumsum` computes how each image token shifts subsequent text token positions.
# - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1
nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1]
if left_padding:
new_token_positions += nb_image_pad[:, None] # offset for left padding
text_to_overwrite = new_token_positions[batch_indices, non_image_indices]

# 3. Create the full embedding, already padded to the maximum position
final_embedding = torch.zeros(
batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
)
final_attention_mask = torch.zeros(
batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
)
if labels is not None:
final_labels = torch.full(
(batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
)
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually
# set the corresponding tensors into their correct target device.
target_device = inputs_embeds.device
batch_indices, non_image_indices, text_to_overwrite = (
batch_indices.to(target_device),
non_image_indices.to(target_device),
text_to_overwrite.to(target_device),
)
attention_mask = attention_mask.to(target_device)

# 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
if labels is not None:
final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]

# 5. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling
image_to_overwrite = torch.all(final_embedding == 0, dim=-1)
image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device)

if image_to_overwrite.sum() != image_features.shape[:-1].numel():
raise ValueError(
f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while"
f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation."
)

final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
final_attention_mask |= image_to_overwrite
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)

# 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id)
indices_to_mask = new_token_positions[batch_indices, pad_indices]

final_embedding[batch_indices, indices_to_mask] = 0

if labels is None:
final_labels = None

return final_embedding, final_attention_mask, final_labels, position_ids
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👀

Copy link
Owner Author

@long8v long8v left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.

@long8v long8v changed the title feat: llava next feat: llava next (hf) May 6, 2024
@long8v long8v changed the title feat: llava next (hf) feat: llava next hf implementation May 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant