Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

a logic error in _preprocess function of Qwen2VLImageProcessor Class #37064

Open
4 tasks
InsaneGe opened this issue Mar 28, 2025 · 1 comment
Open
4 tasks

a logic error in _preprocess function of Qwen2VLImageProcessor Class #37064

InsaneGe opened this issue Mar 28, 2025 · 1 comment
Labels

Comments

@InsaneGe
Copy link

System Info

in the _preprocess function of Qwen2VLImageProcessor Class(

repeats = np.repeat(patches[-1][np.newaxis], temporal_patch_size - 1, axis=0)
), it writes down as follows:

if patches.shape[0] % temporal_patch_size != 0:
    repeats = np.repeat(patches[-1][np.newaxis], temporal_patch_size - 1, axis=0)
    patches = np.concatenate([patches, repeats], axis=0)
grid_t = patches.shape[0] // temporal_patch_size

it should repeat temporal_patch_size - (patches.shape[0] % temporal_patch_size) instead of temporal_patch_size - 1, to make sure patches.shape[0] can be divisible by temporal_patch_size.

if patches.shape[0] % temporal_patch_size != 0:
    repeats = np.repeat(patches[-1][np.newaxis], temporal_patch_size- (patches.shape[0] % temporal_patch_size), axis=0)
    patches = np.concatenate([patches, repeats], axis=0)

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

nothing

Expected behavior

nothing

@InsaneGe InsaneGe added the bug label Mar 28, 2025
@zucchini-nlp
Copy link
Member

@InsaneGe hey, I agree it is not very generalizable and seems that the prev PR was merged to fix only certain cases for num_frames. Feel free to make a PR :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants