From 0bec0db621c7685a0ea56c009c452cedd187fa7e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9mentine=20Fourrier?= <22726840+clefourrier@users.noreply.github.com> Date: Wed, 7 Feb 2024 19:21:24 +0100 Subject: [PATCH] Some extra review fixes for the nanotron PR (#20) * add management with env config * use markdown table generator * doc of the s3 cleaning function * fix name --- src/lighteval/main_nanotron.py | 12 ++++++----- src/lighteval/models/nanotron_model.py | 15 +++++++------- src/lighteval/utils.py | 28 ++++++++++++++++++++------ 3 files changed, 37 insertions(+), 18 deletions(-) diff --git a/src/lighteval/main_nanotron.py b/src/lighteval/main_nanotron.py index 95b897c2e..d8d87b863 100644 --- a/src/lighteval/main_nanotron.py +++ b/src/lighteval/main_nanotron.py @@ -8,6 +8,7 @@ from lighteval.evaluator import evaluate, make_results_table from lighteval.logging.evaluation_tracker import EvaluationTracker from lighteval.logging.hierarchical_logger import hlog, htrack, htrack_block +from lighteval.models.model_config import EnvConfig from lighteval.models.model_loader import ModelInfo from lighteval.models.nanotron_model import NanotronLightevalModel from lighteval.tasks.lighteval_task import LightevalTask, create_requests_from_tasks @@ -35,7 +36,7 @@ @htrack() def main( - local_config_path: str, + checkpoint_config_path: str, lighteval_config_path: Optional[str] = None, cache_dir: str = None, config_cls: Type = Config, @@ -45,16 +46,16 @@ def main( if cache_dir is None: cache_dir = CACHE_DIR - # env_config = EnvConfig(token=TOKEN, cache_dir=cache_dir) + env_config = EnvConfig(token=TOKEN, cache_dir=cache_dir) dist.initialize_torch_distributed() with htrack_block("get config"): - if not local_config_path.endswith(".yaml"): + if not checkpoint_config_path.endswith(".yaml"): raise ValueError("The checkpoint path should point to a YAML file") nanotron_config: config_cls = get_config_from_file( - local_config_path, + checkpoint_config_path, config_class=config_cls, model_config_class=model_config_cls, skip_unused_config_keys=True, @@ -91,7 +92,7 @@ def main( with htrack_block("Model loading"): # We need to load the model in the main process first to avoid downloading the model multiple times model = NanotronLightevalModel( - checkpoint_path=os.path.dirname(local_config_path), + checkpoint_path=os.path.dirname(checkpoint_config_path), model_args=nanotron_config.model, tokenizer=nanotron_config.tokenizer, parallel_context=parallel_context, @@ -101,6 +102,7 @@ def main( cache_dir=os.environ.get("HF_HOME", "/scratch"), debug_one_layer_model=False, model_class=model_cls, + env_config=env_config, ) model_info = ModelInfo(model_name=f"{nanotron_config.general.run}/{nanotron_config.general.step}") evaluation_tracker.general_config_logger.log_model_info(model_info) diff --git a/src/lighteval/models/nanotron_model.py b/src/lighteval/models/nanotron_model.py index e6a3223d7..38b1bd2a2 100644 --- a/src/lighteval/models/nanotron_model.py +++ b/src/lighteval/models/nanotron_model.py @@ -33,6 +33,7 @@ LoglikelihoodSingleTokenDataset, ) from lighteval.models.base_model import LightevalModel +from lighteval.models.model_config import EnvConfig from lighteval.models.model_output import Batch, GenerateReturn, LoglikelihoodReturn, LoglikelihoodSingleTokenReturn from lighteval.tasks.requests import ( GreedyUntilRequest, @@ -71,9 +72,9 @@ def __init__( add_special_tokens: Optional[bool] = True, dtype: Optional[Union[str, torch.dtype]] = None, trust_remote_code: bool = False, - cache_dir: str = "/scratch", debug_one_layer_model: bool = False, model_class: Optional[Type] = None, + env_config: EnvConfig = None, ): """Initializes a nanotron model for evaluation. Args: @@ -119,7 +120,7 @@ def __init__( self._add_special_tokens = add_special_tokens self._tokenizer = self._create_auto_tokenizer( pretrained=tokenizer.tokenizer_name_or_path, - cache_dir=cache_dir, + env_config=env_config, trust_remote_code=trust_remote_code, ) self._tokenizer.model_max_length = self.max_length @@ -206,7 +207,7 @@ def _create_auto_tokenizer( *, pretrained: str, tokenizer: Optional[str] = None, - cache_dir: str = "/scratch", + env_config: EnvConfig = None, trust_remote_code: bool = False, ) -> transformers.PreTrainedTokenizer: """Returns a pre-trained tokenizer from a pre-trained tokenizer configuration.""" @@ -214,16 +215,16 @@ def _create_auto_tokenizer( try: tokenizer = AutoTokenizer.from_pretrained( pretrained if tokenizer is None else tokenizer, - cache_dir=cache_dir, - token=os.getenv("HUGGING_FACE_HUB_TOKEN"), + cache_dir=env_config.cache_dir, + token=env_config.token, trust_remote_code=trust_remote_code, ) except RecursionError: tokenizer = AutoTokenizer.from_pretrained( pretrained if tokenizer is None else tokenizer, - cache_dir=cache_dir, + cache_dir=env_config.cache_dir, + token=env_config.token, unk_token="", - token=os.getenv("HUGGING_FACE_HUB_TOKEN"), trust_remote_code=trust_remote_code, ) tokenizer.pad_token = tokenizer.eos_token diff --git a/src/lighteval/utils.py b/src/lighteval/utils.py index a4f64e035..c2a9335d5 100644 --- a/src/lighteval/utils.py +++ b/src/lighteval/utils.py @@ -50,26 +50,42 @@ def rec(nest: dict, prefix: str, into: dict): return flat -def clean_s3_links(key, value): +def clean_s3_links(value: str) -> str: + """Cleans and formats s3 bucket links for better display in the result table (nanotron models) + + Args: + value (str): path to clean + + Returns: + str : cleaned path + """ s3_bucket, s3_prefix = str(value).replace("s3://", "").split("/", maxsplit=1) if not s3_prefix.endswith("/"): s3_prefix += "/" link_str = f"https://s3.console.aws.amazon.com/s3/buckets/{s3_bucket}?prefix={s3_prefix}" value = f' {value} ' - return key, value + return value def obj_to_markdown(obj, convert_s3_links: bool = True) -> str: """Convert a (potentially nested) dataclass object or a dict in a readable markdown string for logging""" + from pytablewriter import MarkdownTableWriter + if is_dataclass(obj): obj = asdict(obj) config_dict = flatten_dict(obj) - config_markdown = "| Key | Value |\n| --- | --- |\n" + + md_writer = MarkdownTableWriter() + md_writer.headers = ["Key", "Value"] + + values = [] for key, value in config_dict.items(): if convert_s3_links and "s3://" in str(value): - key, value = clean_s3_links(key, value) - config_markdown += f"| {key} | {value} |\n" - return config_markdown + value = clean_s3_links(value) + values.append([key, value]) + md_writer.value_matrix = values + + return md_writer.dumps() def sanitize_numpy(example_dict: dict) -> dict: