diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7ec4856f..441ff70a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -37,4 +37,5 @@ repos: rev: 'v0.1.6' hooks: - id: ruff + args: ['--fix'] - id: ruff-format diff --git a/README.md b/README.md index c04a6611..edf0556a 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ It is still an early, internal version - it should be nice to use but don't expe In case of problems or question, feel free to open an issue! ## How to install and use -### Requirements +### Installation 0) Create your virtual environment using virtualenv or conda depending on your preferences. We require Python3.10 1) Clone the package using `git clone`, then `cd lighteval-harness`, `pip install -e .` Once the dependencies are installed, `cd src`. @@ -22,6 +22,12 @@ Optional: 2) Add your user token to the environment variable `HUGGING_FACE_HUB_TOKEN` if you want to push your results to the hub +For the linting: +```bash +pre-commit install +pre-commit run --config .pre-commit-config.yaml --all-files +``` + ### Usage - Launching on CPU diff --git a/pyproject.toml b/pyproject.toml index 0e67a947..56faf1b4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -82,8 +82,7 @@ optimum = ["optimum==1.12.0"] quantization = ["bitsandbytes>=0.41.0", "auto-gptq>=0.4.2"] adapters = ["peft==0.3.0"] nanotron = [ - "nanotron@git+https://github.com/huggingface/nanotron@8c1a49588d0745a6404644a86547c2dd6a63640e", - "brrr@git+https://github.com/huggingface/brrr@e8a503e2ec08b34eed7522d331aec3bee8cdd29b", + "nanotron@git+https://github.com/huggingface/nanotron", "tensorboardX" ] diff --git a/run_evals_accelerate.py b/run_evals_accelerate.py new file mode 100644 index 00000000..7002c874 --- /dev/null +++ b/run_evals_accelerate.py @@ -0,0 +1,92 @@ +import argparse + +from lighteval.main_accelerate import CACHE_DIR, main + + +def get_parser(): + parser = argparse.ArgumentParser() + group = parser.add_mutually_exclusive_group(required=True) + task_type_group = parser.add_mutually_exclusive_group(required=True) + + # Model type 1) Base model + weight_type_group = parser.add_mutually_exclusive_group() + weight_type_group.add_argument( + "--delta_weights", + action="store_true", + default=False, + help="set to True of your model should be merged with a base model, also need to provide the base model name", + ) + weight_type_group.add_argument( + "--adapter_weights", + action="store_true", + default=False, + help="set to True of your model has been trained with peft, also need to provide the base model name", + ) + parser.add_argument( + "--base_model", type=str, default=None, help="name of the base model to be used for delta or adapter weights" + ) + + task_type_group.add_argument("--model_args") + parser.add_argument("--model_dtype", type=str, default=None) + parser.add_argument( + "--multichoice_continuations_start_space", + action="store_true", + help="Whether to force multiple choice continuations to start with a space", + ) + parser.add_argument( + "--no_multichoice_continuations_start_space", + action="store_true", + help="Whether to force multiple choice continuations to not start with a space", + ) + parser.add_argument("--use_chat_template", default=False, action="store_true") + # Model type 2) TGI + task_type_group.add_argument("--inference_server_address", type=str) + parser.add_argument("--inference_server_auth", type=str, default=None) + # Model type 3) Inference endpoints + task_type_group.add_argument("--endpoint_model_name", type=str) + parser.add_argument("--accelerator", type=str, default=None) + parser.add_argument("--vendor", type=str, default=None) + parser.add_argument("--region", type=str, default=None) + parser.add_argument("--instance_size", type=str, default=None) + parser.add_argument("--instance_type", type=str, default=None) + parser.add_argument("--reuse_existing", default=False, action="store_true") + # Debug + parser.add_argument("--max_samples", type=int, default=None) + parser.add_argument("--job_id", type=str, help="Optional Job ID for future reference", default="") + # Saving + parser.add_argument("--push_results_to_hub", default=False, action="store_true") + parser.add_argument("--save_details", action="store_true") + parser.add_argument("--push_details_to_hub", default=False, action="store_true") + parser.add_argument( + "--public_run", default=False, action="store_true", help="Push results and details to a public repo" + ) + parser.add_argument("--cache_dir", type=str, default=CACHE_DIR) + parser.add_argument( + "--results_org", + type=str, + help="Hub organisation where you want to store the results. Your current token must have write access to it", + ) + # Common parameters + parser.add_argument("--output_dir", required=True) + parser.add_argument("--override_batch_size", type=int, default=-1) + parser.add_argument("--dataset_loading_processes", type=int, default=1) + parser.add_argument( + "--custom_tasks_file", + type=str, + default=None, + help="Path to a file with custom tasks (a TASK list of dict and potentially prompt formating functions)", + ) + group.add_argument( + "--tasks", + type=str, + default=None, + help="Id of a task, e.g. 'original|mmlu:abstract_algebra|5' or path to a texte file with a list of tasks", + ) + parser.add_argument("--num_fewshot_seeds", type=int, default=1, help="Number of trials the few shots") + return parser + + +if __name__ == "__main__": + parser = get_parser() + args, unknowns = parser.parse_known_args() + main(args) diff --git a/run_evals_nanotron.py b/run_evals_nanotron.py new file mode 100644 index 00000000..9b98d005 --- /dev/null +++ b/run_evals_nanotron.py @@ -0,0 +1,33 @@ +# flake8: noqa: C901 +import argparse + +from lighteval.main_nanotron import main + + +def get_parser(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--checkpoint-config-path", + type=str, + required=True, + help="Path to the brr checkpoint YAML or python config file, potentially on S3", + ) + parser.add_argument( + "--lighteval-override", + type=str, + help="Path to an optional YAML or python Lighteval config to override part of the checkpoint Lighteval config", + ) + parser.add_argument( + "--cache-dir", + type=str, + default="", + help="Cache directory", + ) + + return parser + + +if __name__ == "__main__": + parser = get_parser() + args, unknowns = parser.parse_known_args() + main(args.checkpoint_config_path, args.lighteval_override, args.cache_dir) diff --git a/src/lighteval/data.py b/src/lighteval/data.py index 969f4208..5af88ff3 100644 --- a/src/lighteval/data.py +++ b/src/lighteval/data.py @@ -198,6 +198,37 @@ def _sorting_criteria(self, request: GreedyUntilRequest | GreedyUntilWithLogitsR return -(len(toks) + gen_length) +class GenerativeTaskDatasetNanotron(DynamicBatchDataset): + def __getitem__(self, index) -> Request: + """ + Get an item from the dataset depending on the split we are currently in. + For instance, if we are in split 0, we will get the item at index 0, if + we are in split 1, we will get the item at index self.split_size, etc. + Used for dynamic batching. + + Args: + index (int): The index of the item. + + Returns: + Any: The item at the specified index. + """ + return index, self.sorted_data[index + self.split_start] + + def _sorting_criteria(self, request) -> int: + """ + Collate function for generating batches. + + Args: + x (Any): The input data. + + Returns: + Any: The collated data. + """ + toks = request.tokenized_context + gen_length = request.generation_size + return -(len(toks) + gen_length) + + class GenDistributedSampler(DistributedSampler): """A distributed sampler that copy the last element only when drop_last is False so we keep a small padding in the batches as our samples are sorted by length. diff --git a/src/lighteval/evaluator.py b/src/lighteval/evaluator.py index c547bc1a..0b09ff23 100644 --- a/src/lighteval/evaluator.py +++ b/src/lighteval/evaluator.py @@ -5,6 +5,8 @@ import copy from typing import Dict, Union +from pytablewriter import LatexTableWriter, MarkdownTableWriter + from lighteval.logging.evaluation_tracker import EvaluationTracker from lighteval.logging.hierarchical_logger import hlog from lighteval.models.base_model import BaseModel @@ -99,8 +101,6 @@ def evaluate( # noqa: C901 def make_results_table(result_dict): """Generate table of results.""" - from pytablewriter import LatexTableWriter, MarkdownTableWriter - md_writer = MarkdownTableWriter() latex_writer = LatexTableWriter() md_writer.headers = ["Task", "Version", "Metric", "Value", "", "Stderr"] diff --git a/src/lighteval/logging/evaluation_tracker.py b/src/lighteval/logging/evaluation_tracker.py index 0515af46..68ac95f2 100644 --- a/src/lighteval/logging/evaluation_tracker.py +++ b/src/lighteval/logging/evaluation_tracker.py @@ -18,13 +18,11 @@ TaskConfigLogger, VersionsLogger, ) -from lighteval.utils import is_nanotron_available +from lighteval.utils import is_nanotron_available, obj_to_markdown if is_nanotron_available(): - from brrr.config import BrrrConfig - from brrr.experiment_loggers import obj_to_markdown - from nanotron.config import get_config_from_dict + from nanotron.config import Config, get_config_from_dict class EnhancedJSONEncoder(json.JSONEncoder): @@ -104,81 +102,81 @@ def save( """ hlog("Saving experiment tracker") - try: - date_id = datetime.now().isoformat().replace(":", "-") - - output_dir_results = Path(output_dir) / "results" / self.general_config_logger.model_name - output_dir_details = Path(output_dir) / "details" / self.general_config_logger.model_name - output_dir_details_sub_folder = output_dir_details / date_id - output_dir_results.mkdir(parents=True, exist_ok=True) - output_dir_details_sub_folder.mkdir(parents=True, exist_ok=True) - - output_results_file = output_dir_results / f"results_{date_id}.json" - output_results_in_details_file = output_dir_details / f"results_{date_id}.json" - - hlog(f"Saving results to {output_results_file} and {output_results_in_details_file}") - - to_dump = { - "config_general": asdict(self.general_config_logger), - "results": self.metrics_logger.metric_aggregated, - "versions": self.versions_logger.versions, - "config_tasks": self.task_config_logger.tasks_configs, - "summary_tasks": self.details_logger.compiled_details, - "summary_general": asdict(self.details_logger.compiled_details_over_all_tasks), - } - dumped = json.dumps(to_dump, cls=EnhancedJSONEncoder, indent=2) - - with open(output_results_file, "w") as f: - f.write(dumped) - - with open(output_results_in_details_file, "w") as f: - f.write(dumped) - - for task_name, task_details in self.details_logger.details.items(): - output_file_details = output_dir_details_sub_folder / f"details_{task_name}_{date_id}.parquet" - # Create a dataset from the dictionary - try: - dataset = Dataset.from_list([asdict(detail) for detail in task_details]) - except Exception: - # We force cast to str to avoid formatting problems for nested objects - dataset = Dataset.from_list( - [{k: str(v) for k, v in asdict(detail).items()} for detail in task_details] - ) + # try: + date_id = datetime.now().isoformat().replace(":", "-") - # We don't keep 'id' around if it's there - column_names = dataset.column_names - if "id" in dataset.column_names: - column_names = [t for t in dataset.column_names if t != "id"] - - # Sort column names to make it easier later - dataset = dataset.select_columns(sorted(column_names)) - # Save the dataset to a Parquet file - dataset.to_parquet(output_file_details.as_posix()) - - if push_results_to_hub: - self.api.upload_folder( - repo_id=self.hub_results_repo if public else self.hub_private_results_repo, - folder_path=output_dir_results, - path_in_repo=self.general_config_logger.model_name, - repo_type="dataset", - commit_message=f"Updating model {self.general_config_logger.model_name}", - ) + output_dir_results = Path(output_dir) / "results" / self.general_config_logger.model_name + output_dir_details = Path(output_dir) / "details" / self.general_config_logger.model_name + output_dir_details_sub_folder = output_dir_details / date_id + output_dir_results.mkdir(parents=True, exist_ok=True) + output_dir_details_sub_folder.mkdir(parents=True, exist_ok=True) - if push_details_to_hub: - self.details_to_hub( - model_name=self.general_config_logger.model_name, - results_file_path=output_results_in_details_file, - details_folder_path=output_dir_details_sub_folder, - push_as_public=public, - ) + output_results_file = output_dir_results / f"results_{date_id}.json" + output_results_in_details_file = output_dir_details / f"results_{date_id}.json" + + hlog(f"Saving results to {output_results_file} and {output_results_in_details_file}") - if push_results_to_tensorboard: - self.push_results_to_tensorboard( - results=self.metrics_logger.metric_aggregated, details=self.details_logger.details + to_dump = { + "config_general": asdict(self.general_config_logger), + "results": self.metrics_logger.metric_aggregated, + "versions": self.versions_logger.versions, + "config_tasks": self.task_config_logger.tasks_configs, + "summary_tasks": self.details_logger.compiled_details, + "summary_general": asdict(self.details_logger.compiled_details_over_all_tasks), + } + dumped = json.dumps(to_dump, cls=EnhancedJSONEncoder, indent=2) + + with open(output_results_file, "w") as f: + f.write(dumped) + + with open(output_results_in_details_file, "w") as f: + f.write(dumped) + + for task_name, task_details in self.details_logger.details.items(): + output_file_details = output_dir_details_sub_folder / f"details_{task_name}_{date_id}.parquet" + # Create a dataset from the dictionary + try: + dataset = Dataset.from_list([asdict(detail) for detail in task_details]) + except Exception: + # We force cast to str to avoid formatting problems for nested objects + dataset = Dataset.from_list( + [{k: str(v) for k, v in asdict(detail).items()} for detail in task_details] ) - except Exception as e: - hlog("WARNING: Could not save results") - hlog(repr(e)) + + # We don't keep 'id' around if it's there + column_names = dataset.column_names + if "id" in dataset.column_names: + column_names = [t for t in dataset.column_names if t != "id"] + + # Sort column names to make it easier later + dataset = dataset.select_columns(sorted(column_names)) + # Save the dataset to a Parquet file + dataset.to_parquet(output_file_details.as_posix()) + + if push_results_to_hub: + self.api.upload_folder( + repo_id=self.hub_results_repo if public else self.hub_private_results_repo, + folder_path=output_dir_results, + path_in_repo=self.general_config_logger.model_name, + repo_type="dataset", + commit_message=f"Updating model {self.general_config_logger.model_name}", + ) + + if push_details_to_hub: + self.details_to_hub( + model_name=self.general_config_logger.model_name, + results_file_path=output_results_in_details_file, + details_folder_path=output_dir_details_sub_folder, + push_as_public=public, + ) + + if push_results_to_tensorboard: + self.push_results_to_tensorboard( + results=self.metrics_logger.metric_aggregated, details=self.details_logger.details + ) + # except Exception as e: + # hlog("WARNING: Could not save results") + # hlog(repr(e)) def generate_final_dict(self) -> dict: """Aggregates and returns all the logger's experiment information in a dictionary. @@ -487,7 +485,7 @@ def push_results_to_tensorboard( # noqa: C901 if not is_nanotron_available(): hlog_warn("You cannot push results to tensorboard with having nanotron installed. Skipping") return - config: BrrrConfig = get_config_from_dict(self.general_config_logger.config, config_class=BrrrConfig) + config: Config = get_config_from_dict(self.general_config_logger.config, config_class=Config) lighteval_config = config.lighteval try: global_step = config.general.step diff --git a/src/lighteval/logging/hierarchical_logger.py b/src/lighteval/logging/hierarchical_logger.py index b0d3a239..898fd822 100644 --- a/src/lighteval/logging/hierarchical_logger.py +++ b/src/lighteval/logging/hierarchical_logger.py @@ -6,12 +6,12 @@ from lighteval.utils import is_accelerate_available, is_nanotron_available -if is_accelerate_available(): - from accelerate.logging import get_logger +if is_nanotron_available(): + from nanotron.logging import get_logger logger = get_logger(__name__, log_level="INFO") -elif is_nanotron_available(): - from nanotron.logging import get_logger +elif is_accelerate_available(): + from accelerate.logging import get_logger logger = get_logger(__name__, log_level="INFO") else: diff --git a/src/lighteval/logging/info_loggers.py b/src/lighteval/logging/info_loggers.py index 38d4d7ab..194e65f5 100644 --- a/src/lighteval/logging/info_loggers.py +++ b/src/lighteval/logging/info_loggers.py @@ -19,7 +19,7 @@ if is_nanotron_available(): - from brrr.config import BrrrConfig + from nanotron.config import Config @dataclass(init=False) @@ -64,8 +64,8 @@ class GeneralConfigLogger: model_dtype: str = None model_size: str = None - # Nanotron/Brrr config - config: "BrrrConfig" = None + # Nanotron config + config: "Config" = None def __init__(self) -> None: """Stores the current lighteval commit for reproducibility, and starts the evaluation timer.""" @@ -79,7 +79,7 @@ def log_args_info( override_batch_size: Union[None, int], max_samples: Union[None, int], job_id: str, - config: "BrrrConfig" = None, + config: "Config" = None, ) -> None: """ Logs the information about the arguments passed to the method. @@ -91,7 +91,7 @@ def log_args_info( Else, the batch size is automatically inferred depending on what fits in memory. max_samples (Union[None, int]): maximum number of samples, if None, use all the samples available. job_id (str): job ID, used to retrieve logs. - config (optional): BrrrConfig + config (optional): Nanotron Config Returns: None diff --git a/src/main.py b/src/lighteval/main_accelerate.py similarity index 56% rename from src/main.py rename to src/lighteval/main_accelerate.py index 0fc663b3..b048fbd9 100644 --- a/src/main.py +++ b/src/lighteval/main_accelerate.py @@ -1,4 +1,3 @@ -import argparse import os import random import shutil @@ -32,89 +31,6 @@ accelerator = None -def get_parser(): - parser = argparse.ArgumentParser() - group = parser.add_mutually_exclusive_group(required=True) - task_type_group = parser.add_mutually_exclusive_group(required=True) - - # Model type 1) Base model - weight_type_group = parser.add_mutually_exclusive_group() - weight_type_group.add_argument( - "--delta_weights", - action="store_true", - default=False, - help="set to True of your model should be merged with a base model, also need to provide the base model name", - ) - weight_type_group.add_argument( - "--adapter_weights", - action="store_true", - default=False, - help="set to True of your model has been trained with peft, also need to provide the base model name", - ) - parser.add_argument( - "--base_model", type=str, default=None, help="name of the base model to be used for delta or adapter weights" - ) - - task_type_group.add_argument("--model_args") - parser.add_argument("--model_dtype", type=str, default=None) - parser.add_argument( - "--multichoice_continuations_start_space", - action="store_true", - help="Whether to force multiple choice continuations to start with a space", - ) - parser.add_argument( - "--no_multichoice_continuations_start_space", - action="store_true", - help="Whether to force multiple choice continuations to not start with a space", - ) - parser.add_argument("--use_chat_template", default=False, action="store_true") - # Model type 2) TGI - task_type_group.add_argument("--inference_server_address", type=str) - parser.add_argument("--inference_server_auth", type=str, default=None) - # Model type 3) Inference endpoints - task_type_group.add_argument("--endpoint_model_name", type=str) - parser.add_argument("--accelerator", type=str, default=None) - parser.add_argument("--vendor", type=str, default=None) - parser.add_argument("--region", type=str, default=None) - parser.add_argument("--instance_size", type=str, default=None) - parser.add_argument("--instance_type", type=str, default=None) - parser.add_argument("--reuse_existing", default=False, action="store_true") - # Debug - parser.add_argument("--max_samples", type=int, default=None) - parser.add_argument("--job_id", type=str, help="Optional Job ID for future reference", default="") - # Saving - parser.add_argument("--push_results_to_hub", default=False, action="store_true") - parser.add_argument("--save_details", action="store_true") - parser.add_argument("--push_details_to_hub", default=False, action="store_true") - parser.add_argument( - "--public_run", default=False, action="store_true", help="Push results and details to a public repo" - ) - parser.add_argument("--cache_dir", type=str, default=CACHE_DIR) - parser.add_argument( - "--results_org", - type=str, - help="Hub organisation where you want to store the results. Your current token must have write access to it", - ) - # Common parameters - parser.add_argument("--output_dir", required=True) - parser.add_argument("--override_batch_size", type=int, default=-1) - parser.add_argument("--dataset_loading_processes", type=int, default=1) - parser.add_argument( - "--custom_tasks_file", - type=str, - default=None, - help="Path to a file with custom tasks (a TASK list of dict and potentially prompt formating functions)", - ) - group.add_argument( - "--tasks", - type=str, - default=None, - help="Id of a task, e.g. 'original|mmlu:abstract_algebra|5' or path to a texte file with a list of tasks", - ) - parser.add_argument("--num_fewshot_seeds", type=int, default=1, help="Number of trials the few shots") - return parser - - @htrack() def main(args): env_config = EnvConfig(token=TOKEN, cache_dir=args.cache_dir) @@ -208,9 +124,3 @@ def main(args): model.cleanup() return final_dict - - -if __name__ == "__main__": - parser = get_parser() - args, unknowns = parser.parse_known_args() - main(args) diff --git a/src/main_brrr.py b/src/lighteval/main_nanotron.py similarity index 50% rename from src/main_brrr.py rename to src/lighteval/main_nanotron.py index bd257eac..d8d87b86 100644 --- a/src/main_brrr.py +++ b/src/lighteval/main_nanotron.py @@ -1,115 +1,75 @@ # flake8: noqa: C901 -import argparse import os import random +from typing import Optional, Type import numpy as np -import torch -from brrr.config import BrrrConfig -from brrr.s3_checkpoints import fs_copy -from brrr.utils import check_env -from nanotron import distributed as dist -from nanotron import logging -from nanotron.config import get_config_from_file -from nanotron.logging import get_logger, log_rank -from nanotron.parallel.context import ParallelContext -from nanotron.utils import local_ranks_zero_first from lighteval.evaluator import evaluate, make_results_table from lighteval.logging.evaluation_tracker import EvaluationTracker from lighteval.logging.hierarchical_logger import hlog, htrack, htrack_block -from lighteval.models.brrr_models import BRRRModel +from lighteval.models.model_config import EnvConfig from lighteval.models.model_loader import ModelInfo +from lighteval.models.nanotron_model import NanotronLightevalModel from lighteval.tasks.lighteval_task import LightevalTask, create_requests_from_tasks from lighteval.tasks.registry import Registry, get_custom_tasks, taskinfo_selector +from lighteval.utils import NO_NANOTRON_ERROR_MSG, is_nanotron_available +from lighteval.utils_parallelism import test_all_gather + + +if not is_nanotron_available(): + raise ImportError(NO_NANOTRON_ERROR_MSG) + +from nanotron import distributed as dist +from nanotron.config import Config, get_config_from_file +from nanotron.logging import get_logger +from nanotron.parallel.context import ParallelContext +from nanotron.utils import local_ranks_zero_first logger = get_logger(__name__) +SEED = 1234 TOKEN = os.getenv("HF_TOKEN") CACHE_DIR = os.getenv("HF_HOME", "/scratch") -def get_parser(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--checkpoint-config-path", - type=str, - required=True, - help="Path to the brr checkpoint YAML or python config file, potentially on S3", - ) - parser.add_argument( - "--lighteval-override", - type=str, - help="Path to an optional YAML or python Lighteval config to override part of the checkpoint Lighteval config", - ) - parser.add_argument( - "--tokenizer", - type=str, - help="Local or hub path of an optional tokenizer (if not indicated in the checkpoint)", - ) - parser.add_argument( - "--s5cmd-path", - type=str, - default="/admin/home/thomwolf/miniconda3/envs/b4r/bin/s5cmd", - help="Path to s5cmd install", - ) - parser.add_argument( - "--s5cmd-numworkers", - type=int, - default=64, - help="s5cmd num workers (optional)", - ) - parser.add_argument( - "--s5cmd-concurrency", - type=int, - default=10, - help="s5cmd concurrency (optional)", - ) - parser.add_argument( - "--cache-dir", - type=str, - default="", - help="Cache directory", - ) - - return parser - - @htrack() -def main(args): - cache_dir = args.cache_dir or CACHE_DIR - check_env() +def main( + checkpoint_config_path: str, + lighteval_config_path: Optional[str] = None, + cache_dir: str = None, + config_cls: Type = Config, + model_config_cls: Optional[Type] = None, + model_cls: Optional[Type] = None, +): + if cache_dir is None: + cache_dir = CACHE_DIR + + env_config = EnvConfig(token=TOKEN, cache_dir=cache_dir) dist.initialize_torch_distributed() with htrack_block("get config"): - if not args.checkpoint_config_path.endswith(".yaml"): + if not checkpoint_config_path.endswith(".yaml"): raise ValueError("The checkpoint path should point to a YAML file") - local_config_path = args.checkpoint_config_path - if args.checkpoint_config_path.startswith("s3:/"): - local_config_path = args.checkpoint_config_path.replace("s3:/", cache_dir) - with local_ranks_zero_first(): - if os.environ.get("LOCAL_RANK", None) == "0": - os.makedirs(os.path.dirname(local_config_path), exist_ok=True) - fs_copy(args.checkpoint_config_path, local_config_path) - - brrr_config: BrrrConfig = get_config_from_file(local_config_path, config_class=BrrrConfig) - - if args.lighteval_override: - local_override_path = args.lighteval_override.replace("s3:/", cache_dir) - if args.lighteval_override.startswith("s3:/"): - local_override_path = args.lighteval_override.replace("s3:/", cache_dir) - with local_ranks_zero_first(): - if os.environ.get("LOCAL_RANK", None) == "0": - os.makedirs(os.path.dirname(local_override_path), exist_ok=True) - fs_copy(args.lighteval_override, local_override_path) - lighteval_brrr_config: BrrrConfig = get_config_from_file(local_override_path, config_class=BrrrConfig) - lighteval_config = lighteval_brrr_config.lighteval - brrr_config.lighteval = lighteval_config + + nanotron_config: config_cls = get_config_from_file( + checkpoint_config_path, + config_class=config_cls, + model_config_class=model_config_cls, + skip_unused_config_keys=True, + skip_null_keys=True, + ) + + if lighteval_config_path: + lighteval_nanotron_config: config_cls = get_config_from_file( + lighteval_config_path, config_class=config_cls + ) + lighteval_config = lighteval_nanotron_config.lighteval + nanotron_config.lighteval = lighteval_config else: - local_override_path = "" - lighteval_config = brrr_config.lighteval + lighteval_config = nanotron_config.lighteval parallel_context = ParallelContext( tensor_parallel_size=lighteval_config.parallelism.tp, @@ -123,51 +83,28 @@ def main(args): override_batch_size=None, max_samples=lighteval_config.tasks.max_samples, job_id=os.environ.get("SLURM_JOB_ID", None), - config=brrr_config.as_dict(), + config=nanotron_config.as_dict(), ) with htrack_block("Test all gather"): - hlog("Test gather tensor") - # Do a first NCCL sync to warmup and try to avoid Timeout after model/data loading - log_rank( - f"[TEST] Running NCCL sync for ranks {list(range(parallel_context.world_pg.size()))}", - logger=logger, - level=logging.WARNING, - group=parallel_context.dp_pg, - rank=0, - ) - test_tensor = torch.tensor([dist.get_rank(parallel_context.world_pg)], device=torch.device("cuda")) - test_tensor_list = [torch.zeros_like(test_tensor) for _ in range(parallel_context.world_pg.size())] - dist.all_gather(test_tensor_list, test_tensor, group=parallel_context.world_pg, async_op=False) - dist.barrier() - log_rank( - f"[TEST] NCCL sync for ranks {[t.item() for t in test_tensor_list]}", - logger=logger, - level=logging.WARNING, - group=parallel_context.dp_pg, - rank=0, - ) - - del test_tensor_list - del test_tensor + test_all_gather(parallel_context=parallel_context) with htrack_block("Model loading"): # We need to load the model in the main process first to avoid downloading the model multiple times - model = BRRRModel( - checkpoint_path=brrr_config.s3_upload.upload_s3_path / str(brrr_config.general.step), - model_args=brrr_config.model, - tokenizer=brrr_config.tokenizer, + model = NanotronLightevalModel( + checkpoint_path=os.path.dirname(checkpoint_config_path), + model_args=nanotron_config.model, + tokenizer=nanotron_config.tokenizer, parallel_context=parallel_context, parallel_config=lighteval_config.parallelism, lighteval_config=lighteval_config, batch_size=lighteval_config.batch_size, cache_dir=os.environ.get("HF_HOME", "/scratch"), debug_one_layer_model=False, - s5cmd_path=args.s5cmd_path, - s5cmd_numworkers=args.s5cmd_numworkers, - s5cmd_concurrency=args.s5cmd_concurrency, + model_class=model_cls, + env_config=env_config, ) - model_info = ModelInfo(model_name=f"{brrr_config.general.run}/{brrr_config.general.step}") + model_info = ModelInfo(model_name=f"{nanotron_config.general.run}/{nanotron_config.general.step}") evaluation_tracker.general_config_logger.log_model_info(model_info) with htrack_block("Tasks loading"): @@ -195,12 +132,13 @@ def main(args): lm=model, max_samples=lighteval_config.tasks.max_samples, evaluation_tracker=evaluation_tracker, + use_chat_template=False, ) with htrack_block("Setting seeds and waiting for all processes"): - hlog(f"setting seed to {1234} for random and numpy") - random.seed(1234) - np.random.seed(1234) + hlog(f"setting seed to {SEED} for random and numpy") + random.seed(SEED) + np.random.seed(SEED) dist.barrier() with htrack_block("Evaluation"): @@ -234,9 +172,3 @@ def main(args): hlog(make_results_table(final_dict)) return final_dict - - -if __name__ == "__main__": - parser = get_parser() - args, unknowns = parser.parse_known_args() - main(args) diff --git a/src/lighteval/models/brrr_models.py b/src/lighteval/models/nanotron_model.py similarity index 86% rename from src/lighteval/models/brrr_models.py rename to src/lighteval/models/nanotron_model.py index 1fd7a85c..38b1bd2a 100644 --- a/src/lighteval/models/brrr_models.py +++ b/src/lighteval/models/nanotron_model.py @@ -1,17 +1,15 @@ -# flake8: noqa: C901 +# ruff: noqa: C901,E120 import os import time -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Type, Union import torch import torch.nn.functional as F import transformers -from brrr.config import LightEvalConfig, ModelArgs, ParallelismArgs -from brrr.s3_checkpoints import S3Mover, check_path_is_local from datasets.download.streaming_download_manager import xPath from nanotron import distributed as dist from nanotron import logging -from nanotron.config import TokenizerArgs +from nanotron.config import LightEvalConfig, ModelArgs, ParallelismArgs, TokenizerArgs from nanotron.generation.decode import decode_tokenized from nanotron.logging import human_format, log_rank from nanotron.models import build_model @@ -28,13 +26,24 @@ from tqdm import tqdm from transformers import AutoTokenizer, BatchEncoding -from lighteval.data import GenDataset, GenDistributedSampler, LoglikelihoodDataset, LoglikelihoodSingleTokenDataset +from lighteval.data import ( + GenDistributedSampler, + GenerativeTaskDatasetNanotron, + LoglikelihoodDataset, + LoglikelihoodSingleTokenDataset, +) +from lighteval.models.base_model import LightevalModel +from lighteval.models.model_config import EnvConfig from lighteval.models.model_output import Batch, GenerateReturn, LoglikelihoodReturn, LoglikelihoodSingleTokenReturn -from lighteval.utils import as_list, find_executable_batch_size +from lighteval.tasks.requests import ( + GreedyUntilRequest, + LoglikelihoodRequest, + LoglikelihoodRollingRequest, +) +from lighteval.utils import as_list +from lighteval.utils_parallelism import find_executable_batch_size -# from .brrr_generation import GenerationConfig, GenerationInputs, SamplerType, greedy_search_tokenized - os.environ["TOKENIZERS_PARALLELISM"] = "false" logger = logging.get_logger(__name__) @@ -44,7 +53,7 @@ # _DeviceMapping = NewType("DeviceMapping", Mapping[str, Union[int, str, torch.device]]) -class BRRRModel: +class NanotronLightevalModel(LightevalModel): # Default max sequence length setting for when no `max_length` is provided # or no max length config setting is found in the model or tokenizer. _DEFAULT_MAX_LENGTH: int = 2048 @@ -63,16 +72,13 @@ def __init__( add_special_tokens: Optional[bool] = True, dtype: Optional[Union[str, torch.dtype]] = None, trust_remote_code: bool = False, - cache_dir: str = "/scratch", debug_one_layer_model: bool = False, - s5cmd_numworkers: int = 64, - s5cmd_concurrency: int = 10, - s5cmd_path: str = "/admin/home/thomwolf/miniconda/envs/b4r/bin/s5cmd", + model_class: Optional[Type] = None, + env_config: EnvConfig = None, ): - """Initializes a brrr model for evaluation. + """Initializes a nanotron model for evaluation. Args: """ - super().__init__() self._batch_size = batch_size self._max_gen_toks = max_gen_toks @@ -112,14 +118,16 @@ def __init__( self.model_config.num_hidden_layers = 1 self._add_special_tokens = add_special_tokens - self.tokenizer = self._create_auto_tokenizer( + self._tokenizer = self._create_auto_tokenizer( pretrained=tokenizer.tokenizer_name_or_path, - cache_dir=cache_dir, + env_config=env_config, trust_remote_code=trust_remote_code, ) - self.tokenizer.model_max_length = self.max_length + self._tokenizer.model_max_length = self.max_length model_config_cls = self.model_config.__class__.__name__ + if model_class is not None: + CONFIG_TO_MODEL_CLASS[model_config_cls] = model_class if model_config_cls not in CONFIG_TO_MODEL_CLASS: raise ValueError( f"Unsupported model config {model_config_cls}. Only {CONFIG_TO_MODEL_CLASS.keys()} are supported" @@ -160,7 +168,7 @@ def __init__( ) # Mark some parameters as tied - # TODO @nouamane: this is only needed for training, can we just mark params as BRRRParameter instead? + # TODO @nouamane: this is only needed for training, can we just mark params as NanotronParameter instead? mark_tied_parameters(model=model, parallel_context=parallel_context, parallel_config=parallel_config) log_rank( @@ -179,49 +187,7 @@ def __init__( level=logging.WARNING, rank=0, ) - if check_path_is_local(checkpoint_path): - load_weights(model=model, parallel_context=parallel_context, root_folder=xPath(checkpoint_path)) - else: - local_path = str(checkpoint_path).replace("s3:/", cache_dir) - loaded = False - if float(os.environ.get("WORLD_SIZE", 1)) // float( - os.environ.get("LOCAL_WORLD_SIZE", 1) - ) == 1 and os.path.exists(local_path): - # If we have only one node and the local folder already exists, we can try to load from there - # Sometimes the checkpoints is already here but since it may be on some nodes nd not others we still need to download everywhere - # so we can only use this pathway when we have a single node (WORLD_SIZE == LOCAL_WORLD_SIZE) - try: - log_rank( - f"Testing loading checkpoint from {local_path}:", - logger=logger, - level=logging.WARNING, - rank=0, - ) - load_weights(model=model, parallel_context=parallel_context, root_folder=xPath(local_path)) - loaded = True - except ValueError: - loaded = False - if not loaded: - log_rank( - f"Downloading checkpoint from S3 in {local_path}. ", - logger=logger, - level=logging.WARNING, - rank=0, - ) - # Download checkpoint from S3 - s3_mover = S3Mover( - os.path.join(local_path, "model"), - os.path.join(checkpoint_path, "model"), - s5cmd_numworkers=s5cmd_numworkers, - s5cmd_concurrency=s5cmd_concurrency, - s5cmd_path=s5cmd_path, - dummy=bool(int(os.environ.get("LOCAL_RANK", None)) != 0), - ) - s3_mover.distributed_wait_for_completion(self.parallel_context.world_pg) - s3_mover.start_downloading() - s3_mover.distributed_wait_for_completion(self.parallel_context.world_pg) - load_weights(model=model, parallel_context=parallel_context, root_folder=xPath(local_path)) - + load_weights(model=model, parallel_context=parallel_context, root_folder=xPath(checkpoint_path)) model.eval() # We don't need the loss @@ -232,12 +198,16 @@ def __init__( self.multichoice_continuations_start_space = multichoice_continuations_start_space + @property + def tokenizer(self): + return self._tokenizer + def _create_auto_tokenizer( self, *, pretrained: str, tokenizer: Optional[str] = None, - cache_dir: str = "/scratch", + env_config: EnvConfig = None, trust_remote_code: bool = False, ) -> transformers.PreTrainedTokenizer: """Returns a pre-trained tokenizer from a pre-trained tokenizer configuration.""" @@ -245,16 +215,16 @@ def _create_auto_tokenizer( try: tokenizer = AutoTokenizer.from_pretrained( pretrained if tokenizer is None else tokenizer, - cache_dir=cache_dir, - token=os.getenv("HUGGING_FACE_HUB_TOKEN"), + cache_dir=env_config.cache_dir, + token=env_config.token, trust_remote_code=trust_remote_code, ) except RecursionError: tokenizer = AutoTokenizer.from_pretrained( pretrained if tokenizer is None else tokenizer, - cache_dir=cache_dir, + cache_dir=env_config.cache_dir, + token=env_config.token, unk_token="", - token=os.getenv("HUGGING_FACE_HUB_TOKEN"), trust_remote_code=trust_remote_code, ) tokenizer.pad_token = tokenizer.eos_token @@ -394,7 +364,7 @@ def homogeneize_ending_conditions(self, ending_condition: tuple | dict | list | """ max_tokens, stop_sequences = None, None # Filling with input values or default - if isinstance(ending_condition, tuple): + if isinstance(ending_condition, tuple) and len(ending_condition) == 2: stop_sequence_arg, max_gen_tokens_arg = ending_condition stop_sequences = as_list(stop_sequence_arg) max_tokens = max_gen_tokens_arg @@ -449,19 +419,16 @@ def loglikelihood_single_token( Returns: List[Tuple[float, bool]]: _description_ """ - tokenized_reqs = [] - - for context, continuations in tqdm( + for request in tqdm( requests, desc="Tokenizing", disable=bool(dist.get_rank(self.parallel_context.world_pg) != 0) ): - if context == "": - # end of text as context - context_enc = [self.eot_token_id] + if request.context == "": + request.tokenized_context = [self.tokenizer.eos_token_id] else: - context_enc = self.tok_encode(context) + request.tokenized_context = self.tok_encode(request.context) # Some models tokenizer want a space at the beginning and other not - continuations = [self._check_continuations_start_space(c) for c in continuations] + continuations = [self._check_continuations_start_space(c) for c in request.choices] # We must not accidentally prepend a continuation with a start of sentence token. continuations_enc = [self.tok_encode(c, add_special_tokens=False) for c in continuations] @@ -470,59 +437,48 @@ def loglikelihood_single_token( f"Trying to do single token multiple choice but one choice has several tokens: {continuations_enc}. " "If the additional pre-token is a space, try to set --no_multichoice_continuations_start_space " ) - - tokenized_reqs.append(((context, continuations), context_enc, continuations_enc)) + request.tokenized_continuation = continuations_enc return self._loglikelihood_single_token( - tokenized_reqs, + requests, override_bs=override_bs, disable_tqdm=bool(dist.get_rank(self.parallel_context.world_pg) != 0), ) - def loglikelihood(self, requests: List[Tuple[str, str]], override_bs=None) -> List[LoglikelihoodReturn]: + def loglikelihood(self, requests: List[LoglikelihoodRequest], override_bs=None) -> List[LoglikelihoodReturn]: """Tokenize the context and continuation and compute the log likelihood of those tokenized sequences. - - Args: - requests (List[Tuple[str, dict]]): _description_ - - Returns: - List[Tuple[float, bool]]: _description_ """ - tokenized_reqs = [] - - for req in tqdm(requests, desc="Tokenizing", disable=bool(dist.get_rank(self.parallel_context.world_pg) != 0)): - context, continuation = req.context, req.choice - if context == "": - context_enc, continuation_enc = [self.eot_token_id], self.tok_encode(continuation) + for request in tqdm( + requests, desc="Tokenizing", disable=bool(dist.get_rank(self.parallel_context.world_pg) != 0) + ): + if request.context == "": + request.tokenized_context = [self.tokenizer.eos_token_id] + request.tokenized_continuation = self.tok_encode(request.choice) else: - # DO NOT CHANGE THE FOLLOWING LINE! - # It is mandatory for compatibility with the harness!!! - context_enc, continuation_enc = self._encode_pair(context, continuation) - - tokenized_reqs.append(((context, continuation), context_enc, continuation_enc)) + # The following line is mandatory for compatibility with the harness + request.tokenized_context, request.tokenized_continuation = self.tok_encode_pair( + request.context, request.choice + ) return self._loglikelihood_tokens( - tokenized_reqs, + requests, override_bs=override_bs, disable_tqdm=bool(dist.get_rank(self.parallel_context.world_pg) != 0), ) - def loglikelihood_rolling(self, requests: List[Tuple[str, str]], override_bs=None) -> List[LoglikelihoodReturn]: + def loglikelihood_rolling( + self, requests: List[LoglikelihoodRollingRequest], override_bs=None + ) -> List[LoglikelihoodReturn]: """This function is used to compute the log likelihood of the context for perplexity metrics.""" - tokenized_reqs = [] - - for (context,) in tqdm( + for request in tqdm( requests, desc="Tokenizing", disable=bool(dist.get_rank(self.parallel_context.world_pg) != 0) ): # tuple of one elem - if isinstance(context, dict): # lm_eval.base.PerplexityTask passed the query as such - context = context["query"] - fake_context_enc, context_enc = [self.eot_token_id], self.tok_encode(context) - - tokenized_reqs.append((("", context), fake_context_enc, context_enc)) + request.tokenized_context = [self.tokenizer.eos_token_id] # Fake context + request.tokenized_continuation = self.tok_encode(request.context) results = self._loglikelihood_tokens( - tokenized_reqs, + requests, override_bs=override_bs, disable_tqdm=bool(dist.get_rank(self.parallel_context.world_pg) != 0), return_bool_score=False, @@ -714,7 +670,7 @@ def _loglikelihood_single_token( # automatic (variable) batch size detection for vectorization # pull longest context sample from request - _, context_enc, _ = dataset[0] + context_enc = dataset[0].tokenized_context max_context = len(context_enc[-self.max_length :]) batch_size = self._get_batch_size( override_bs=override_bs, max_input_length=max_context, starting_batch_size=starting_batch_size @@ -757,7 +713,7 @@ def _loglikelihood_single_token( rank=0, ) iteration_start_time = time.time() - inputs = [context_enc for _, context_enc, _ in batch_data] + inputs = [item.tokenized_context for item in batch_data] batch_model = self.prepare_batch( inputs, padding_length=max_context, max_context=max_context, full_attention_masks=True @@ -780,9 +736,9 @@ def _loglikelihood_single_token( batch_probs = [] batch_cont_tokens = [] - for i, ((context, _, cont_toks), logits, inplen) in enumerate( - zip(batch_data, out, batch_model.input_lengths) - ): + for i, (batch, logits, inplen) in enumerate(zip(batch_data, out, batch_model.input_lengths)): + context = batch.context + cont_toks = batch.tokenized_continuation # Get the last token logits = logits[inplen - 1] # [vocab] @@ -911,7 +867,7 @@ def _loglikelihood_single_token( # We are in a process which return no output (beginning/middle of the PP group) return [] - return dataset.ordered.get_original(res) + return dataset.get_original_order(res) @torch.inference_mode() def _loglikelihood_tokens( @@ -922,7 +878,7 @@ def _loglikelihood_tokens( dataset_splits: int = 1, return_bool_score: bool = True, ) -> List[LoglikelihoodReturn]: - dataset = LoglikelihoodDataset(requests=requests) + dataset = LoglikelihoodDataset(requests=requests, dataset_splits=dataset_splits) res = [] # Dataset is sorted in descending size. @@ -944,7 +900,9 @@ def _loglikelihood_tokens( # automatic (variable) batch size detection for vectorization # pull longest context sample from request - _, context_enc, continuation_enc = dataset[0] + context_enc = dataset[0].tokenized_context + continuation_enc = dataset[0].tokenized_continuation + max_context = len((context_enc + continuation_enc)[-(self.max_length + 1) :][:-1]) batch_size = self._get_batch_size( @@ -987,7 +945,7 @@ def _loglikelihood_tokens( ) iteration_start_time = time.time() inputs = [ - context_enc + continuation_enc[:-1] for _, context_enc, continuation_enc in batch_data + item.tokenized_context + item.tokenized_continuation[:-1] for item in batch_data ] # The last token doesn't need to be input in the model batch_model = self.prepare_batch( inputs, padding_length=max_context, max_context=max_context, full_attention_masks=True @@ -1010,17 +968,16 @@ def _loglikelihood_tokens( logits_sum = [] max_equals = [] batch_cont_tokens = [] - for (_, _, cont_toks), logits, inplen in zip(batch_data, multi_logits, batch_model.input_lengths): + for cur_request, cur_logits, inplen in zip(batch_data, multi_logits, batch_model.input_lengths): + cont_toks = torch.tensor( + cur_request.tokenized_continuation, dtype=torch.long, device=self.device + ) + contlen = cont_toks.shape[0] # We only look at the continuation tokens - contlen = len(cont_toks) if contlen > inplen: - # continuation is longer than the allowed context size, everything is a continuation - logits = logits.unsqueeze(0).to(self.device) # [1, seq, vocab] - cont_toks = ( - torch.tensor(cont_toks, dtype=torch.long, device=self.device)[:inplen] - .unsqueeze(0) - .to(self.device) - ) # [1, seq] + # Continuation is longer than the input size, we are in rolling mode (only continuation) + cur_logits = cur_logits.unsqueeze(0).to(self.device) # [1, seq, vocab] + cont_toks = cont_toks[:inplen].unsqueeze(0).to(self.device) # [1, seq] else: # if contlen == 1: # top_k = torch.topk(logits, 20)[1].tolist() @@ -1031,28 +988,21 @@ def _loglikelihood_tokens( # f"Not all the solutions are in the top 20 most likely tokens on rank {dist.get_rank(self.parallel_context.world_pg)} " # f"top_tokens: {top_toks_str}\ncont_tokens: {cont_toks_str}") - logits = ( - logits[inplen - contlen : inplen] - .unsqueeze(0) - .to(self.device) # Here we remove the last one with our [...:inplen] - ) # [1, contlen, vocab] - cont_toks = ( - torch.tensor(cont_toks, dtype=torch.long, device=self.device) - .unsqueeze(0) - .to(self.device) - ) # [1, contlen] + cur_logits = ( + cur_logits[inplen - contlen : inplen].unsqueeze(0).to(self.device) + ) # [1, seq, voc] + cont_toks = cont_toks.unsqueeze(0).to(self.device) # [1, seq] # Check if per-token argmax is exactly equal to continuation - greedy_tokens = logits.argmax(dim=-1).to(self.device) # [1, contlen] + greedy_tokens = cur_logits.argmax(dim=-1).to(self.device) # Sometimes the continuation is longer than allowed by the model, we only look at the first tokens max_equal = (greedy_tokens == cont_toks).all().squeeze(0).to(self.device) # Obtain log-probs at the corresponding continuation token indices - # last_token_slice = logits[:, -1, :].squeeze(0).tolist() - logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [1, contlen] + cur_logits = torch.gather(cur_logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [1, seq] # Answer: (log prob, is-exact-match) - logits_sum.append(logits.sum()) + logits_sum.append(cur_logits.sum()) max_equals.append(max_equal) batch_cont_tokens.append(cont_toks) @@ -1151,13 +1101,12 @@ def _loglikelihood_tokens( # We are in a process which return no output (beginning/middle of the PP group) return [] - return dataset.ordered.get_original(res) + return dataset.get_original_order(res) @torch.inference_mode() def greedy_until( self, - requests: List[Tuple[str, dict]], - task_names: Optional[List[str]] = None, + requests: List[GreedyUntilRequest], returns_logits=False, disable_tqdm: bool = False, override_bs=None, @@ -1166,17 +1115,11 @@ def greedy_until( """Greedy generation until a stop token is generated.""" # automatic (variable) batch size detection for vectorization # pull longest context sample from request - if task_names: - enc_inputs = [ - (self.tok_encode(req[0]), self.homogeneize_ending_conditions(req[1]), task_name) - for req, task_name in zip(requests, task_names) - ] - else: - enc_inputs = [ - (self.tok_encode(req[0]), self.homogeneize_ending_conditions(req[1]), None) for req in requests - ] + for request in requests: + request.stop_sequence = request.stop_sequence + (self.tokenizer.eos_token,) + request.tokenized_context = self.tok_encode(request.context) - dataset = GenDataset(requests=enc_inputs) + dataset = GenerativeTaskDatasetNanotron(requests=requests, dataset_splits=dataset_splits) res = [] # Dataset is sorted in descending size. @@ -1196,8 +1139,8 @@ def greedy_until( dataset.split_start = subset_start dataset.split_end = min(subset_start + subset_length, total_length) - _, (context_enc, _, _) = dataset[0] - max_gen = max(d[1][1][1] for d in dataset) + context_enc = dataset[0][1].tokenized_context + max_gen = max(item[1].generation_size for item in dataset) max_input_length = min(len(context_enc) + max_gen, self.max_length) batch_size = self._get_batch_size( override_bs=override_bs, max_input_length=max_input_length, starting_batch_size=starting_batch_size @@ -1234,11 +1177,10 @@ def greedy_until( ) iteration_start_time = time.time() example_index, batch_data = zip(*all_batch) - context = [c[0] for c in batch_data] - task_names = [c[2] for c in batch_data] + context = [c.tokenized_context for c in batch_data] # we take the longest asked generation in the batch # Multiple request may have different max generation length - max_tokens = max(d[1][1] for d in batch_data) + max_tokens = max(d.generation_size for d in batch_data) # d[1][1] if max_tokens <= 0: raise ValueError("Greedy generation requires a positive value for max generation but we got -1") @@ -1307,7 +1249,7 @@ def greedy_until( ): # Ensure the generated responses do not contain the stop sequences. decoded_response = self.tokenizer.decode(generation, skip_special_tokens=False) - stop_terms = dataset[example_index][1][1][0] + stop_terms = dataset[example_index][1].stop_sequence for stop_term in stop_terms: decoded_response = decoded_response.split(stop_term)[0] # partial caching @@ -1350,7 +1292,7 @@ def greedy_until( # We are in a process which return no output (beginning/middle of the PP group) return [] - return dataset.ordered.get_original(res) + return dataset.get_original_order(res) class MultiTokenEOSCriteria(transformers.StoppingCriteria): diff --git a/src/lighteval/utils.py b/src/lighteval/utils.py index 246510fe..c2a9335d 100644 --- a/src/lighteval/utils.py +++ b/src/lighteval/utils.py @@ -12,11 +12,82 @@ # See the License for the specific language governing permissions and # limitations under the License. import importlib +from dataclasses import asdict, is_dataclass from typing import Any, Union import numpy as np +def flatten_dict(nested: dict, sep="/") -> dict: + """Flatten dictionary, list, tuple and concatenate nested keys with separator.""" + + def clean_markdown(v: str) -> str: + return v.replace("|", "_").replace("\n", "_") if isinstance(v, str) else v # Need this for markdown + + def rec(nest: dict, prefix: str, into: dict): + for k, v in sorted(nest.items()): + # if sep in k: + # raise ValueError(f"separator '{sep}' not allowed to be in key '{k}'") + if isinstance(v, dict): + rec(v, prefix + k + sep, into) + elif isinstance(v, (list, tuple)): + for i, vv in enumerate(v): + if isinstance(vv, dict): + rec(vv, prefix + k + sep + str(i) + sep, into) + else: + vv = ( + vv.replace("|", "_").replace("\n", "_") if isinstance(vv, str) else vv + ) # Need this for markdown + into[prefix + k + sep + str(i)] = vv.tolist() if isinstance(vv, np.ndarray) else vv + elif isinstance(v, np.ndarray): + into[prefix + k + sep + str(i)] = v.tolist() + else: + v = clean_markdown(v) + into[prefix + k] = v + + flat = {} + rec(nested, "", flat) + return flat + + +def clean_s3_links(value: str) -> str: + """Cleans and formats s3 bucket links for better display in the result table (nanotron models) + + Args: + value (str): path to clean + + Returns: + str : cleaned path + """ + s3_bucket, s3_prefix = str(value).replace("s3://", "").split("/", maxsplit=1) + if not s3_prefix.endswith("/"): + s3_prefix += "/" + link_str = f"https://s3.console.aws.amazon.com/s3/buckets/{s3_bucket}?prefix={s3_prefix}" + value = f' {value} ' + return value + + +def obj_to_markdown(obj, convert_s3_links: bool = True) -> str: + """Convert a (potentially nested) dataclass object or a dict in a readable markdown string for logging""" + from pytablewriter import MarkdownTableWriter + + if is_dataclass(obj): + obj = asdict(obj) + config_dict = flatten_dict(obj) + + md_writer = MarkdownTableWriter() + md_writer.headers = ["Key", "Value"] + + values = [] + for key, value in config_dict.items(): + if convert_s3_links and "s3://" in str(value): + value = clean_s3_links(value) + values.append([key, value]) + md_writer.value_matrix = values + + return md_writer.dumps() + + def sanitize_numpy(example_dict: dict) -> dict: """ Sanitizes a dictionary by converting any numpy generic types to their corresponding Python types. @@ -92,7 +163,7 @@ def is_nanotron_available() -> bool: return importlib.util.find_spec("nanotron") is not None -NO_NANOTRON_ERROR_MSG = "YYou requested the use of nanotron for this evaluation, but it is not available in your current environement. Please install it using pip." +NO_NANOTRON_ERROR_MSG = "You requested the use of nanotron for this evaluation, but it is not available in your current environement. Please install it using pip." def is_optimum_available() -> bool: diff --git a/tasks_examples/custom_tasks/lighteval_config_override_template.yaml b/tasks_examples/custom_tasks/lighteval_config_override_template.yaml index da9258ff..6544a88a 100644 --- a/tasks_examples/custom_tasks/lighteval_config_override_template.yaml +++ b/tasks_examples/custom_tasks/lighteval_config_override_template.yaml @@ -12,7 +12,7 @@ lighteval: push_results_to_tensorboard: true tensorboard_metric_prefix: e parallelism: - dp: 8 + dp: 1 pp: 1 pp_engine: 1f1b recompute_granularity: null @@ -20,7 +20,7 @@ lighteval: tp_linear_async_communication: false tp_mode: ALL_REDUCE tasks: - custom_tasks_file: /fsx/thomwolf/github/lighteval-harness/tasks_examples/custom_evaluation_tasks.py + custom_tasks_file: /fsx/thomwolf/github/lighteval/tasks_examples/custom_tasks/custom_evaluation_tasks.py dataset_loading_processes: 8 max_samples: 1000 multichoice_continuations_start_space: null diff --git a/tests/test_main.py b/tests/test_main.py index d87d37db..0a1030f0 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -2,9 +2,10 @@ import os import pytest -from main import get_parser, main # noqa: E402 from pytest import approx +from lighteval.main_accelerate import main # noqa: E402 +from run_evals_accelerate import get_parser from tests.reference_scores.reference_task_scores import RESULTS_FULL, RESULTS_LITE # noqa: E402 from tests.reference_scores.reference_tasks import ( # noqa: E402 HELM_SUBSET,