Skip to content

Commit

Permalink
add prompt parameter, remove prints, add quantization
Browse files Browse the repository at this point in the history
  • Loading branch information
kelleyl committed Jun 26, 2024
1 parent 939cf47 commit b15b28e
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 87 deletions.
114 changes: 66 additions & 48 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,26 @@
import logging
from typing import Union
import torch
from PIL import Image
import requests
from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration
import numpy as np
import cv2
from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration, BitsAndBytesConfig

from clams import ClamsApp, Restifier
from mmif import Mmif, View, Annotation, Document, AnnotationTypes, DocumentTypes

# For an NLP tool we need to import the LAPPS vocabulary items
from lapps.discriminators import Uri
from mmif import Mmif, View, Document, AnnotationTypes, DocumentTypes
from mmif.utils import video_document_helper as vdh

class InstructblipCaptioner(ClamsApp):

def __init__(self):
self.model = InstructBlipForConditionalGeneration.from_pretrained("Salesforce/instructblip-vicuna-7b")
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float8,
)
self.model = InstructBlipForConditionalGeneration.from_pretrained(
"Salesforce/instructblip-vicuna-7b",
quantization_config=quantization_config,
device_map="auto"
)
self.processor = InstructBlipProcessor.from_pretrained("Salesforce/instructblip-vicuna-7b")
self.device = "cpu" if torch.cuda.is_available() else "cpu"
self.model.to(self.device)
super().__init__()

def _appmetadata(self):
Expand All @@ -30,39 +30,61 @@ def _appmetadata(self):
# When using the ``metadata.py`` leave this do-nothing "pass" method here.
pass

def get_prompt(self, label: str, prompt_map: dict, default_prompt: str) -> str:
prompt = prompt_map.get(label, default_prompt)
if prompt == "-":
return None
prompt = f"[INST] <image>\n{prompt}\n[/INST]"
return prompt

def _annotate(self, mmif: Union[str, dict, Mmif], **parameters) -> Mmif:
print ("called annotate")
label_map = parameters.get('promptMap')
default_prompt = parameters.get('defaultPrompt')
frame_interval = parameters.get('frameInterval', 10) # Default to every 10th frame if not specified

video_doc: Document = mmif.get_documents_by_type(DocumentTypes.VideoDocument)[0]
input_view: View = mmif.get_views_for_document(video_doc.properties.id)[0]

new_view: View = mmif.new_view()
self.sign_view(new_view, parameters)
print("Starting annotation process for timeframes.")
for timeframe in input_view.get_annotations(AnnotationTypes.TimeFrame):
try:
print(f"Processing timeframe: {timeframe.id}")
if "representatives" in timeframe.properties and timeframe.properties["representatives"]:
representative_id = timeframe.get("representatives")[0]
print(f"Found representative: {representative_id}")
representative: AnnotationTypes.TimePoint = input_view.get_annotation_by_id(representative_id)
frame_index = vdh.convert(representative.get("timePoint"), "milliseconds",
"frame", vdh.get_framerate(video_doc))
print(f"Frame index for representative: {frame_index}")

timeframes = input_view.get_annotations(AnnotationTypes.TimeFrame)

if timeframes:
for timeframe in timeframes:
label = timeframe.get_property('label')
prompt = self.get_prompt(label, label_map, default_prompt)
if not prompt:
continue

representatives = timeframe.get("representatives") if "representatives" in timeframe.properties else None
if representatives:
image = vdh.extract_representative_frame(mmif, timeframe)
else:
start_frame = timeframe.get_property("start")
end_frame= timeframe.get_property("end")
if end_frame - start_frame < 30:
continue
print(f"Calculating frame index from start {start_frame}ms and end {end_frame}ms")
frame_index = (start_frame + end_frame) // 2
print(f"Frame index calculated: {frame_index}")

image: Image.Image = vdh.extract_frames_as_images(video_doc, [frame_index], as_PIL=True)[0]

prompt = "Describe this frame from a television program."
print(f"Using prompt: '{prompt}'")
image = vdh.extract_mid_frame(mmif, timeframe)

inputs = self.processor(images=image, text=prompt, return_tensors="pt")
outputs = self.model.generate(
**inputs,
do_sample=False,
num_beams=5,
max_length=256,
min_length=1,
repetition_penalty=1.5,
length_penalty=1.0,
temperature=1,
)
generated_text = self.processor.batch_decode(outputs, skip_special_tokens=True)[0].strip()

text_document = new_view.new_textdocument(generated_text)
alignment = new_view.new_annotation(AnnotationTypes.Alignment)
alignment.add_property("source", timeframe.id)
alignment.add_property("target", text_document.id)
else:
total_frames = vdh.get_frame_count(video_doc)
for frame_number in range(0, total_frames, frame_interval):
image = vdh.extract_frames_as_images(video_doc, [frame_number], as_PIL=True)[0]
prompt = default_prompt
inputs = self.processor(images=image, text=prompt, return_tensors="pt").to(self.device)
print("Inputs prepared for the model.")
outputs = self.model.generate(
**inputs,
do_sample=False,
Expand All @@ -74,20 +96,16 @@ def _annotate(self, mmif: Union[str, dict, Mmif], **parameters) -> Mmif:
temperature=1,
)
generated_text = self.processor.batch_decode(outputs, skip_special_tokens=True)[0].strip()
print (generated_text)

text_document = new_view.new_textdocument(generated_text)
self.create_alignment(new_view, timeframe.id, text_document.id)
timepoint = new_view.new_annotation(AnnotationTypes.TimePoint)
timepoint.add_property("timePoint", frame_number)
alignment = new_view.new_annotation(AnnotationTypes.Alignment)
alignment.add_property("source", timepoint.id)
alignment.add_property("target", text_document.id)

except Exception as e:
self.logger.error(f"Error processing timeframe: {e}")
continue
return mmif

def create_alignment(self, view, source_id, target_id):
alignment = view.new_annotation(AnnotationTypes.Alignment)
alignment.properties.source = source_id
alignment.properties.target = target_id

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--port", action="store", default="5000", help="set port to listen")
Expand Down
21 changes: 16 additions & 5 deletions metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
DO NOT CHANGE the name of the file
"""

from mmif import DocumentTypes
from mmif import DocumentTypes, AnnotationTypes

from clams.app import ClamsApp
from clams.appmetadata import AppMetadata
Expand Down Expand Up @@ -33,14 +33,25 @@ def appmetadata() -> AppMetadata:
)
# and then add I/O specifications: an app must have at least one input and one output
metadata.add_input(DocumentTypes.VideoDocument)
metadata.add_input(AnnotationTypes.TimeFrame)
metadata.add_output(AnnotationTypes.Alignment)
metadata.add_output(DocumentTypes.TextDocument)

# (optional) and finally add runtime parameter specifications
# metadata.add_parameter(name='a_param', description='example parameter description',
# type='boolean', default='false')
# metadta.add_parameter(more...)
metadata.add_parameter(
name='defaultPrompt', type='string', default='What is shown in this video frame?',
description='default prompt to use for timeframes not specified in the promptMap. If set to `-`, '
'timeframes not specified in the promptMap will be skipped.'
)
metadata.add_parameter(
name='promptMap', type='map', default=[],
description=('mapping of labels of the input timeframe annotations to new prompts. Must be formatted as '
'\"IN_LABEL:PROMPT\" (with a colon). To pass multiple mappings, use this parameter multiple '
'times. By default, any timeframe labels not mapped to a prompt will be used with the default'
'prompt. In order to skip timeframes with a particular label, pass `-` as the prompt value.'
'in order to skip all timeframes not specified in the promptMap, set the defaultPrompt'
'parameter to `-`'))

# CHANGE this line and make sure return the compiled `metadata` instance
return metadata


Expand Down
37 changes: 3 additions & 34 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,39 +1,8 @@
accelerate
aniso8601
attrs
bitsandbytes
blinker
certifi
charset-normalizer
clams-python==1.1.3
click
deepdiff
clams-python==1.2.4
ffmpeg-python
filelock
Flask
Flask-RESTful
fsspec
future
gunicorn
huggingface-hub
lxml
MarkupSafe
mmif-python==1.0.10
mpmath
numpy
mmif-python==1.0.16
opencv-python
ordered-set
packaging
pillow
platformdirs
pydantic
pytz
PyYAML
referencing
regex
safetensors
sympy
tokenizers
torch
tqdm
transformers
Pillow

0 comments on commit b15b28e

Please sign in to comment.