Skip to content

Commit

Permalink
update context sampler
Browse files Browse the repository at this point in the history
  • Loading branch information
pufanyi committed Mar 30, 2024
1 parent c9f759e commit 9979b03
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 22 deletions.
2 changes: 1 addition & 1 deletion lmms_eval/api/instance.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass, field
from typing import Literal, Tuple
from typing import Literal, Tuple, Iterable, Callable


@dataclass
Expand Down
91 changes: 70 additions & 21 deletions lmms_eval/api/samplers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,65 @@
import datasets
from typing import Callable


class LazyLoadedImages(object):
def __init__(self, data_frame, index):
self.data_frame: datasets.Dataset = data_frame
self.index = index

def get_images(self, doc_to_visual):
return doc_to_visual(self.data_frame[self.index])


class Context(object):
def __init__(self, task, few_shot_delimiter: str = "\n\n", target_delimiter: str = "\n"):
self.task = task
self.config = task._config

self.doc_to_visual = self.task.doc_to_visual
self.doc_to_text = self.task.doc_to_text
self.doc_to_target = self.task.doc_to_target
self.doc_to_choice = self.task.doc_to_choice

self.target_delimiter = target_delimiter
self.few_shot_delimiter = few_shot_delimiter

self.contexts = []

def get_question(self, doc, model_specific_prompt_kwargs=None):
return self.doc_to_text(doc, model_specific_prompt_kwargs) if (self.doc_to_choice is None or type(self.doc_to_text(doc)) is str) else self.doc_to_choice(doc)[self.doc_to_text(doc)]

def get_target(self, doc):
return (
str(self.doc_to_target(doc)[0])
if type(self.doc_to_target(doc)) is list
else self.doc_to_target(doc) if (self.config.doc_to_choice is None or type(self.doc_to_target(doc)) is str) else str(self.doc_to_choice(doc)[self.doc_to_target(doc)])
)

def add_in_context_example(self, doc, model_specific_prompt_kwargs=None, data_frame=None, index=None):
question = self.get_question(doc, model_specific_prompt_kwargs)
if data_frame and index:
visual = LazyLoadedImages(data_frame, index)
else:
visual = None
target = self.doc_to_target(doc)
if visual:
self.contexts.append(visual)
self.contexts.append(question)
self.contexts.append(self.target_delimiter)
self.contexts.append(target)
self.contexts.append(self.few_shot_delimiter)

def add_question(self, doc, model_specific_prompt_kwargs=None, data_frame=None, index=None):
question = self.doc_to_text(doc, model_specific_prompt_kwargs)
if data_frame and index:
visual = LazyLoadedImages(data_frame, index)
else:
visual = None
if visual:
self.contexts.append(visual)
self.contexts.append(question)
self.contexts.append(self.target_delimiter)


class FewShotDataset(object):
Expand All @@ -22,7 +83,7 @@ def get_dataset(self) -> datasets.Dataset:

def sample(self, n, rnd):
indices = rnd.sample(range(len(self.get_dataset())), n)
return self.get_dataset().select(indices)
return indices, self.get_dataset().select(indices)

def __getitem__(self, item):
return self.get_dataset()[item]
Expand All @@ -47,33 +108,21 @@ def __init__(self, docs: FewShotDataset, task, fewshot_indices=None, rnd=None) -
if fewshot_indices: # subset few-shot docs from
self.docs.fewshot_indices = fewshot_indices

def get_context(self, doc, num_fewshot):
def get_context(self, doc, num_fewshot, model_specific_prompt_kwargs=None):
# draw an extra fewshot sample if using same split as evaluating on
n_samples = num_fewshot + 1 if self.docs.same_as_eval else num_fewshot

# draw `n_samples` docs from fewshot_docs
fewshotex = self.sample(n_samples)
indices, fewshotex = self.sample(n_samples)

# get rid of the doc that's the one we're evaluating, if it's in the fewshot
# TODO: should we just stop people from using fewshot from same split as evaluating?
selected_docs = [x for x in fewshotex if x != doc][:num_fewshot]

labeled_examples = (
self.fewshot_delimiter.join(
[
# TODO: is separating doc_to_text and doc_to_target by one space always desired?
(self.doc_to_text(doc) if (self.config.doc_to_choice is None or type(self.doc_to_text(doc)) is str) else self.doc_to_choice(doc)[self.doc_to_text(doc)])
+ self.target_delimiter
+ (
str(self.doc_to_target(doc)[0])
if type(self.doc_to_target(doc)) is list
else self.doc_to_target(doc) if (self.config.doc_to_choice is None or type(self.doc_to_target(doc)) is str) else str(self.doc_to_choice(doc)[self.doc_to_target(doc)])
)
for doc in selected_docs
]
)
+ self.fewshot_delimiter
)
selected_docs = [(idx, x) for idx, x in zip(indices, fewshotex) if x != doc][:num_fewshot]

labeled_examples = Context(self.task, self.fewshot_delimiter, self.target_delimiter)

for idx, doc in selected_docs:
labeled_examples.add_in_context_example(doc, model_specific_prompt_kwargs, self.docs, idx)

return labeled_examples

Expand Down
8 changes: 8 additions & 0 deletions lmms_eval/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,14 @@ def _collate(x):
split = split[0]
visuals = [doc_to_visual[0](self.task_dict[task][split][ids]) for ids in doc_id]
visuals = self.flatten(visuals)
############### for debugging ###################
# TODO: remove this block
if len(visuals) > 1:
for i in range(len(visuals)):
path = f"./logs/llava/{i}.png"
visuals[i].save(path)
pass
#################################################
# we assume all gen kwargs in the batch are the same
# this is safe to assume because the `grouper` object ensures it.
gen_kwargs = all_gen_kwargs[0]
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ dependencies = [
"tiktoken",
"pre-commit",
"pydantic",
"antlr4-python3-runtime==4.11",
]

[tool.setuptools.packages.find]
Expand Down

0 comments on commit 9979b03

Please sign in to comment.