Skip to content

Commit

Permalink
add vila
Browse files Browse the repository at this point in the history
(cherry picked from commit 701d12e570c08f45a91b2700fd10dd349b5f683a)
  • Loading branch information
ZhangYuanhan-AI authored and Luodian committed Jul 9, 2024
1 parent d270223 commit e6844db
Show file tree
Hide file tree
Showing 5 changed files with 417 additions and 11 deletions.
1 change: 1 addition & 0 deletions lmms_eval/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
"llava_onevision": "Llava_OneVision",
"llava_hf": "LlavaHf",
"longva": "LongVA",
"vila": "VILA",
}

for model_name, model_class in AVAILABLE_MODELS.items():
Expand Down
16 changes: 11 additions & 5 deletions lmms_eval/models/llava_vid.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ def __init__(
mm_spatial_pool_mode: str = "average",
overwrite: bool = True,
video_decode_backend: str = "pyav",
delay_load: bool = False,
tie_weights: bool = True,
**kwargs,
) -> None:
super().__init__()
Expand Down Expand Up @@ -86,16 +88,19 @@ def __init__(
self.mm_spatial_pool_out_channels = int(mm_spatial_pool_out_channels)
self.mm_spatial_pool_mode = mm_spatial_pool_mode
self.max_frames_num = int(max_frames_num)
print(self.max_frames_num)
self.mm_resampler_location = mm_resampler_location
self.delay_load = delay_load
if self.overwrite == True:
overwrite_config = {}
overwrite_config["mm_resampler_type"] = self.mm_resampler_type
overwrite_config["mm_spatial_pool_stride"] = self.mm_spatial_pool_stride
overwrite_config["mm_spatial_pool_out_channels"] = self.mm_spatial_pool_out_channels
overwrite_config["mm_spatial_pool_mode"] = self.mm_spatial_pool_mode
overwrite_config["mm_resampler_location"] = "before"
overwrite_config["patchify_video_feature"] = False
overwrite_config["attn_implementation"] = attn_implementation
overwrite_config["mm_pooling_position"] = self.mm_resampler_location
overwrite_config["mm_newline_position"] = mm_newline_position
overwrite_config["add_faster_video"] = False
overwrite_config["delay_load"] = self.delay_load
# overwrite_config["attn_implementation"] = attn_implementation

cfg_pretrained = AutoConfig.from_pretrained(self.pretrained)

Expand Down Expand Up @@ -146,7 +151,8 @@ def __init__(

self._config = self._model.config
self.model.eval()
self.model.tie_weights()
if tie_weights:
self.model.tie_weights()
self.truncation = truncation
self.batch_size_per_gpu = int(batch_size)
self.conv_template = conv_template
Expand Down
Loading

0 comments on commit e6844db

Please sign in to comment.