Skip to content

Commit

Permalink
not working state
Browse files Browse the repository at this point in the history
  • Loading branch information
NathanHB committed Feb 1, 2024
1 parent 0cf83ce commit d0e9be9
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 21 deletions.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ optimum = ["optimum==1.12.0"]
quantization = ["bitsandbytes>=0.41.0", "auto-gptq>=0.4.2"]
adapters = ["peft==0.3.0"]
nanotron = [
"nanotron@git+https://github.com/huggingface/nanotron@8c1a49588d0745a6404644a86547c2dd6a63640e",
"brrr@git+https://github.com/huggingface/brrr@e8a503e2ec08b34eed7522d331aec3bee8cdd29b",
"nanotron@git+https://github.com/huggingface/nanotron@main",
"brrr@git+https://github.com/huggingface/brrr@fix-lighteval",
"tensorboardX"
]

Expand Down
52 changes: 34 additions & 18 deletions src/lighteval/models/brrr_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,16 @@
from tqdm import tqdm
from transformers import AutoTokenizer, BatchEncoding

from lighteval.data import GenDataset, GenDistributedSampler, LoglikelihoodDataset, LoglikelihoodSingleTokenDataset
from lighteval.data import (
GenDistributedSampler,
GenerativeTaskDataset,
LoglikelihoodDataset,
LoglikelihoodSingleTokenDataset,
)
from lighteval.models.model_output import Batch, GenerateReturn, LoglikelihoodReturn, LoglikelihoodSingleTokenReturn
from lighteval.utils import as_list, find_executable_batch_size
from lighteval.tasks.requests import GreedyUntilRequest
from lighteval.utils import as_list
from lighteval.utils_parallelism import find_executable_batch_size


# from .brrr_generation import GenerationConfig, GenerationInputs, SamplerType, greedy_search_tokenized
Expand Down Expand Up @@ -1166,7 +1173,7 @@ def _loglikelihood_tokens(
@torch.inference_mode()
def greedy_until(
self,
requests: List[Tuple[str, dict]],
requests: List[GreedyUntilRequest],
task_names: Optional[List[str]] = None,
returns_logits=False,
disable_tqdm: bool = False,
Expand All @@ -1178,15 +1185,24 @@ def greedy_until(
# pull longest context sample from request
if task_names:
enc_inputs = [
(self.tok_encode(req[0]), self.homogeneize_ending_conditions(req[1]), task_name)
(
self.tok_encode(req.context),
self.homogeneize_ending_conditions((req.stop_sequence, req.generation_size)),
task_name,
)
for req, task_name in zip(requests, task_names)
]
else:
enc_inputs = [
(self.tok_encode(req[0]), self.homogeneize_ending_conditions(req[1]), None) for req in requests
(
self.tok_encode(req.context),
self.homogeneize_ending_conditions((req.stop_sequence, req.generation_size)),
None,
)
for req in requests
]

dataset = GenDataset(requests=enc_inputs)
dataset = GenerativeTaskDataset(requests=enc_inputs, dataset_splits=dataset_splits)
res = []

# Dataset is sorted in descending size.
Expand All @@ -1195,20 +1211,20 @@ def greedy_until(

total_length, subset_length = self._get_subsets(dataset, dataset_splits)

for s, subset_start in enumerate(
for s, _ in enumerate(
tqdm(
range(0, total_length, subset_length),
disable=disable_tqdm,
position=0,
dataset.splits_start_end_iterator(),
total=dataset_splits,
desc=f"greedy -- Node {dist.get_rank(self.parallel_context.world_pg)}",
position=0,
disable=disable_tqdm,
)
):
dataset.split_start = subset_start
dataset.split_end = min(subset_start + subset_length, total_length)

_, (context_enc, _, _) = dataset[0]
max_gen = max(d[1][1][1] for d in dataset)
max_input_length = min(len(context_enc) + max_gen, self.max_length)
# print(dataset[0])
(context_enc, _, _) = dataset[0]
# max_gen = max(d[1][1][1] for d in dataset)
# max_input_length = min(len(context_enc) + max_gen, self.max_length)
max_input_length = len(context_enc)
batch_size = self._get_batch_size(
override_bs=override_bs, max_input_length=max_input_length, starting_batch_size=starting_batch_size
)
Expand Down Expand Up @@ -1243,7 +1259,7 @@ def greedy_until(
rank=0,
)
iteration_start_time = time.time()
example_index, batch_data = zip(*all_batch)
batch_data = zip(*all_batch)
context = [c[0] for c in batch_data]
task_names = [c[2] for c in batch_data]
# we take the longest asked generation in the batch
Expand Down Expand Up @@ -1304,7 +1320,7 @@ def greedy_until(
generations = batch_generations.numpy(force=True)
input_ids = batch_input_ids.numpy(force=True)

batch_example_index = torch.tensor(example_index, device=self.device)
batch_example_index = torch.tensor(0, device=self.device)
batch_example_index = self.gather(batch_example_index)
batch_truncated = torch.tensor(batch_model.truncated, device=self.device)
batch_truncated = self.gather(batch_truncated)
Expand Down
2 changes: 1 addition & 1 deletion src/lighteval/models/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ class TGIModelConfig:
inference_server_auth: str


def create_model_config(args, accelerator: Accelerator): # noqa C901
def create_model_config(args, accelerator: "Accelerator"): # noqa C901
# Tests
if args.inference_server_address is not None and args.model_args is not None:
raise ValueError("You cannot both use an inference server and load a model from its checkpoint.")
Expand Down
1 change: 1 addition & 0 deletions src/main_brrr.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ def main(args):
lm=model,
max_samples=lighteval_config.tasks.max_samples,
evaluation_tracker=evaluation_tracker,
use_chat_template=False,
)

with htrack_block("Setting seeds and waiting for all processes"):
Expand Down

0 comments on commit d0e9be9

Please sign in to comment.