Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
Signed-off-by: Roger Wang <[email protected]>
  • Loading branch information
ywang96 committed Dec 30, 2024
1 parent 135fd5c commit 0a8dbe0
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions examples/offline_inference_vision_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def run_aria(question: str, modality: str):
tokenizer_mode="slow",
trust_remote_code=True,
dtype="bfloat16",
max_num_seqs=2,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)

prompt = (f"<|im_start|>user\n<fim_prefix><|img|><fim_suffix>\n{question}"
Expand Down Expand Up @@ -191,8 +192,10 @@ def run_llava_next(question: str, modality: str):

prompt = f"[INST] <image>\n{question} [/INST]"
llm = LLM(model="llava-hf/llava-v1.6-mistral-7b-hf",
max_model_len=8192,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
max_num_batched_tokens=32768,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
limit_mm_per_prompt={"image": 4},
enable_prefix_caching=False)
stop_token_ids = None
return llm, prompt, stop_token_ids

Expand Down Expand Up @@ -591,7 +594,7 @@ def main(args):

# We set temperature to 0.2 so that outputs can be different
# even when all prompts are identical when running batch inference.
sampling_params = SamplingParams(temperature=0.2,
sampling_params = SamplingParams(temperature=0,
max_tokens=64,
stop_token_ids=stop_token_ids)

Expand Down

0 comments on commit 0a8dbe0

Please sign in to comment.