Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Token selection for qwen2-vl and qwen2.5-vl #36

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion src/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,35 @@ class ModelArguments:
default=16,
metadata={"help": "number of crops used in image encoder"}
)

uigraph_train: bool = field(
default=True,
metadata={"help": "Enable ui graph during training"}
)
uigraph_test: bool = field(
default=False,
metadata={"help": "Enable ui graph during inference"}
)
uigraph_diff: int = field(
default=1,
metadata={"help": "Pixel difference used for constructing ui graph"}
)
uigraph_rand: bool = field(
default=False,
metadata={"help": "Enable random graph construction"}
)
uimask_pre: bool = field(
default=True,
metadata={"help": "Prebuild patch selection mask in the preprocessor (not in model layers) for efficiency"}
)
uimask_ratio: float = field(
default=0.5,
metadata={"help": "Specify the percentage of patch tokens to skip per component"}
)
uimask_rand: bool = field(
default=False,
metadata={"help": "Enable random token selection instead of uniform selection"}
)


@dataclass
class DataArguments:
Expand Down
11 changes: 9 additions & 2 deletions src/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,15 +75,22 @@ def load_processor(model_args):
processor = Qwen2VLProcessor.from_pretrained(
model_name,
image_processor=image_processor, tokenizer=tokenizer,
min_pixels=256 * 28 * 28, max_pixels=1280 * 28 * 28
min_pixels=256 * 28 * 28, max_pixels=1280 * 28 * 28,
uigraph_train=model_args.uigraph_train, uigraph_test=model_args.uigraph_test,
uigraph_diff=model_args.uigraph_diff, uigraph_rand=model_args.uigraph_rand,
uimask_pre=model_args.uimask_pre, uimask_ratio=model_args.uimask_ratio, uimask_rand=model_args.uimask_rand
)
elif model_args.model_backbone == QWEN2_5_VL:
from src.vlm_backbone.qwen2_5_vl.processing_qwen2_5_vl import Qwen2_5_VLProcessor
from src.vlm_backbone.qwen2_5_vl.image_processing_qwen2_5_vl import Qwen2_5_VLImageProcessor
from src.vlm_backbone.qwen2_vl.tokenization_qwen2_fast import Qwen2TokenizerFast
image_processor = Qwen2_5_VLImageProcessor.from_pretrained(model_name)
tokenizer = Qwen2TokenizerFast.from_pretrained(model_name)
processor = Qwen2_5_VLProcessor.from_pretrained(model_name, image_processor=image_processor, tokenizer=tokenizer)
processor = Qwen2_5_VLProcessor.from_pretrained(model_name, image_processor=image_processor, tokenizer=tokenizer,
uigraph_train=model_args.uigraph_train, uigraph_test=model_args.uigraph_test,
uigraph_diff=model_args.uigraph_diff, uigraph_rand=model_args.uigraph_rand,
uimask_pre=model_args.uimask_pre, uimask_ratio=model_args.uimask_ratio, uimask_rand=model_args.uimask_rand
)
elif model_args.model_backbone == INTERN_VL:
from src.vlm_backbone.intern_vl.tokenization_internlm2_fast import InternLM2TokenizerFast
tokenizer = InternLM2TokenizerFast.from_pretrained(model_name)
Expand Down
45 changes: 44 additions & 1 deletion src/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,51 @@
import torch

import numpy as np
from src.logging import get_logger
logger = get_logger(__name__)

# Implement Union-Find operator for constructing ui patches
class UnionFind:
def __init__(self, size):
self.parent = np.arange(size)

def find(self, x):
if self.parent[x] != x:
self.parent[x] = self.find(self.parent[x]) # Path compression
return self.parent[x]

def union(self, x, y):
px = self.find(x)
py = self.find(y)
if px != py:
self.parent[py] = px

def get_select_mask(tensor, skip_ratio=0, rand=False):
# Use tensor operations for efficiency
retain_mask = (tensor == -1).clone()
unique_vals, counts = torch.unique(tensor, return_counts=True)

for i, (val, count) in enumerate(zip(unique_vals, counts)):
if val == -1:
continue
positions = (tensor == val).nonzero(as_tuple=True)[0]
num_positions = len(positions)

if num_positions == 1:
retain_mask[positions] = True
else:
num_to_skip = int(round(num_positions * skip_ratio))
num_to_retain = max(1, num_positions - num_to_skip)
if rand:
# rand means random select subset of selective tokens for layer-wise
perm = torch.randperm(num_positions, device=tensor.device)
positions_to_retain = positions[perm[:num_to_retain]]
else:
indices = torch.linspace(0, num_positions - 1, steps=num_to_retain).long()
positions_to_retain = positions[indices]

retain_mask[positions_to_retain] = True
return retain_mask

def print_rank(message):
"""If distributed is initialized, print the rank."""
if torch.distributed.is_initialized():
Expand Down
171 changes: 165 additions & 6 deletions src/vlm_backbone/qwen2_5_vl/image_processing_qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@
import math
from typing import Dict, List, Optional, Union

import PIL
import numpy as np
from sklearn.preprocessing import LabelEncoder
from skimage.segmentation import mark_boundaries

from transformers.feature_extraction_utils import BatchFeature
from transformers.image_processing_utils import BaseImageProcessor
Expand All @@ -48,7 +51,7 @@
validate_preprocess_arguments,
)
from transformers.utils import TensorType, is_vision_available, logging

from ...utils import UnionFind

if is_vision_available():
from PIL import Image
Expand Down Expand Up @@ -200,6 +203,79 @@ def __init__(
self.merge_size = merge_size
self.size = {"min_pixels": min_pixels, "max_pixels": max_pixels}
self.do_convert_rgb = do_convert_rgb

def rerank_values(self, arr):
mapping = {}
new_arr = np.empty_like(arr)
next_value = 0

for idx, x in enumerate(arr):
if x not in mapping:
mapping[x] = next_value
next_value += 1
new_arr[idx] = mapping[x]
return new_arr

def _build_uigraph(self, patches,
grid_t, grid_h, grid_w,
grid_h_half, grid_w_half,
uigraph_threshold,
channel):
num_patches = grid_t * grid_h_half * grid_w_half
uf = UnionFind(num_patches)

def idx(t, i, j):
return t * grid_h_half * grid_w_half + i * grid_w_half + j

# Compare adjacent patches based on the threshold
for t in range(grid_t):
for i in range(grid_h_half):
for j in range(grid_w_half):
current_idx = idx(t, i, j)
current_patch = patches[t, i, j, :, :, :, :,] # Shape: (channel, temporal_patch_size, patch_size, patch_size)

# Compare with right neighbor
if j + 1 < grid_w_half:
right_patch = patches[t, i, j + 1, :, :, :, :,]
# Compute the difference between the patches
diff = np.linalg.norm(current_patch - right_patch)
if diff < uigraph_threshold:
uf.union(current_idx, idx(t, i, j + 1))

# Compare with bottom neighbor
if i + 1 < grid_h_half:
bottom_patch = patches[t, i + 1, j, :, :, :, :,]
# Compute the difference between the patches
diff = np.linalg.norm(current_patch - bottom_patch)
if diff < uigraph_threshold:
uf.union(current_idx, idx(t, i + 1, j))

uigraph_assign_flat = np.array([uf.find(x) for x in range(num_patches)])
le = LabelEncoder()
uigraph_assign_flat = le.fit_transform(uigraph_assign_flat)
uigraph_assign = uigraph_assign_flat.reshape((grid_t, grid_h_half, grid_w_half))
return uigraph_assign

def _vis_uigraph(self, uigraph_assign, image_size, patch_size, image):
resized_height, resized_width = image_size[0]
uigraph_assign = uigraph_assign[0]

upscaled_uigraph_assign = np.repeat(np.repeat(uigraph_assign, patch_size, axis=0), patch_size, axis=1)
upscaled_uigraph_assign = upscaled_uigraph_assign[:resized_height, :resized_width]

if isinstance(image, PIL.Image.Image):
image = np.array(image)

if image.shape[0] in [1, 3]: # Assuming grayscale or RGB image
image = image.transpose(1, 2, 0)
elif image.shape[2] in [1, 3]:
pass
else:
raise ValueError("Unexpected image shape: {}".format(image.shape))

boundaries_image = mark_boundaries(image, upscaled_uigraph_assign, color=(1, 0.4, 0.4))
boundaries_image = (boundaries_image * 255).astype(np.uint8)
return Image.fromarray(boundaries_image)

def _preprocess(
self,
Expand All @@ -214,6 +290,9 @@ def _preprocess(
do_convert_rgb: bool = None,
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
uigraph_use: bool = False,
uigraph_diff: float = 0.0,
uigraph_rand: bool = False,
):
"""
Preprocess an image or batch of images. Copy of the `preprocess` method from `CLIPImageProcessor`.
Expand Down Expand Up @@ -249,6 +328,13 @@ def _preprocess(
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
uigraph_use (`bool`, *optional*, defaults to `False`):
Whether to build ui graph.
uigraph_diff (`float`, *optional*, defaults to `0.0`):
If build, this parameter sets the patch-wise difference threshold.
A larger threshold results in sparser components, while a smaller threshold leads to denser components.
uigraph_rand (`bool`, *optional*, defaults to `False`):
If build, whether to build it randomly for ablation studies.
"""
images = make_list_of_images(images)

Expand All @@ -270,6 +356,7 @@ def _preprocess(
height, width = get_image_size(images[0], channel_dim=input_data_format)
resized_height, resized_width = height, width
processed_images = []
processed_resize = [] # for visualization
for image in images:
if do_resize:
resized_height, resized_width = smart_resize(
Expand All @@ -293,6 +380,7 @@ def _preprocess(

image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
processed_images.append(image)
processed_resize.append((resized_height, resized_width))

patches = np.array(processed_images)
if data_format == ChannelDimension.LAST:
Expand All @@ -303,6 +391,12 @@ def _preprocess(
channel = patches.shape[1]
grid_t = patches.shape[0] // self.temporal_patch_size
grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size

# default grid as init. ui graph
grid_h_half = grid_h // self.merge_size
grid_w_half = grid_w // self.merge_size
uigraph_assign = np.arange(grid_t * grid_h_half * grid_w_half).reshape((grid_t, grid_h_half, grid_w_half))

patches = patches.reshape(
grid_t,
self.temporal_patch_size,
Expand All @@ -315,11 +409,20 @@ def _preprocess(
self.patch_size,
)
patches = patches.transpose(0, 3, 6, 4, 7, 2, 1, 5, 8)

# showui's ui graph construction
if uigraph_use:
uigraph_assign = self._build_uigraph(patches=patches,
grid_t=grid_t, grid_h=grid_h, grid_w=grid_w,
grid_h_half=grid_h_half, grid_w_half=grid_w_half,
uigraph_threshold=uigraph_diff,
channel=channel)

flatten_patches = patches.reshape(
grid_t * grid_h * grid_w, channel * self.temporal_patch_size * self.patch_size * self.patch_size
)

return flatten_patches, (grid_t, grid_h, grid_w)
return flatten_patches, (grid_t, grid_h, grid_w), uigraph_assign, processed_resize

def preprocess(
self,
Expand All @@ -337,6 +440,10 @@ def preprocess(
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
uigraph_use: bool = False,
uigraph_diff: float = 0.0,
uigraph_rand: bool = False,
vis_dir: str = None,
):
"""
Args:
Expand Down Expand Up @@ -385,7 +492,15 @@ def preprocess(
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.

uigraph_use (`bool`, *optional*, defaults to `False`):
Whether to build ui graph.
uigraph_diff (`float`, *optional*, defaults to `0.0`):
If build, this parameter sets the patch-wise difference threshold.
A larger threshold results in sparser components, while a smaller threshold leads to denser components.
uigraph_rand (`bool`, *optional*, defaults to `False`):
If build, whether to build it randomly for ablation studies.
vis_dir (`str`, *optional*, defaults to `None`):
If build, the path to store the image with ui graph visualization.
"""
do_resize = do_resize if do_resize is not None else self.do_resize
size = size if size is not None else self.size
Expand Down Expand Up @@ -420,8 +535,13 @@ def preprocess(

if images is not None:
pixel_values, vision_grid_thws = [], []

patch_assign_sep = [] # store the patch-wise assignment separately for each ui graph
patch_assign_len = [] # store the component number per ui graph
patch_assign_shared = [] # store the patch-wise assignment jointly with shared component idx

for image in images:
patches, image_grid_thw = self._preprocess(
patches, image_grid_thw, uigraph_assign, image_resize = self._preprocess(
image,
do_resize=do_resize,
resample=resample,
Expand All @@ -433,17 +553,56 @@ def preprocess(
data_format=data_format,
do_convert_rgb=do_convert_rgb,
input_data_format=input_data_format,
uigraph_use=uigraph_use,
uigraph_diff=uigraph_diff,
uigraph_rand=uigraph_rand,
)


# if use uigraph
if uigraph_use:
# if apply uigraph_rand
if uigraph_rand:
C = len(np.unique(uigraph_assign))
_, H, W = uigraph_assign.shape
uigraph_assign = np.random.randint(0, C + 1, size=(1, H, W))

# flat 2d graph to 1d
uigraph_assign_1d = uigraph_assign.flatten()
uigraph_assign_1d = self.rerank_values(uigraph_assign_1d)
uigraph_assign_len = len(np.unique(uigraph_assign_1d))

uigraph_assign_1d += sum(patch_assign_len) # shared component idx to distinguish different images
patch_assign_shared.extend(uigraph_assign_1d)
patch_assign_sep.extend(uigraph_assign_1d)
patch_assign_len.append(uigraph_assign_len)

pixel_values.extend(patches)
vision_grid_thws.append(image_grid_thw)

if vis_dir is not None:
image_vis = self._vis_uigraph(uigraph_assign, image_resize, self.patch_size*self.merge_size, image)
# pre_num = np.prod(uigraph_assign.shape).item()
# post_num = len(np.unique(uigraph_assign))
# img_size = f'{image_resize[0][0]}x{image_resize[0][1]}'
# image_vis.save(f'{vis_dir}/{img_size}_{pre_num}_{post_num}.png')
image_vis.save(f'{vis_dir}/demo.png')

pixel_values = np.array(pixel_values)
vision_grid_thws = np.array(vision_grid_thws)
data = {"pixel_values": pixel_values, "image_grid_thw": vision_grid_thws}
patch_assign_shared = np.array(patch_assign_shared)

data = {"pixel_values": pixel_values, "image_grid_thw": vision_grid_thws,
"patch_assign": patch_assign_shared,
"patch_assign_sep": patch_assign_sep,
"patch_assign_len": patch_assign_len
}

if videos is not None:
pixel_values, vision_grid_thws = [], []
for images in videos:
patches, video_grid_thw = self._preprocess(
# uigraph not support video yet
patches, video_grid_thw, _, _ = self._preprocess(
images,
do_resize=do_resize,
resample=resample,
Expand Down
Loading