Skip to content

Commit

Permalink
Some extra review fixes for the nanotron PR (#20)
Browse files Browse the repository at this point in the history
* add management with env config

* use markdown table generator

* doc of the s3 cleaning function

* fix name
  • Loading branch information
clefourrier authored Feb 7, 2024
1 parent aab3f81 commit 0bec0db
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 18 deletions.
12 changes: 7 additions & 5 deletions src/lighteval/main_nanotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from lighteval.evaluator import evaluate, make_results_table
from lighteval.logging.evaluation_tracker import EvaluationTracker
from lighteval.logging.hierarchical_logger import hlog, htrack, htrack_block
from lighteval.models.model_config import EnvConfig
from lighteval.models.model_loader import ModelInfo
from lighteval.models.nanotron_model import NanotronLightevalModel
from lighteval.tasks.lighteval_task import LightevalTask, create_requests_from_tasks
Expand Down Expand Up @@ -35,7 +36,7 @@

@htrack()
def main(
local_config_path: str,
checkpoint_config_path: str,
lighteval_config_path: Optional[str] = None,
cache_dir: str = None,
config_cls: Type = Config,
Expand All @@ -45,16 +46,16 @@ def main(
if cache_dir is None:
cache_dir = CACHE_DIR

# env_config = EnvConfig(token=TOKEN, cache_dir=cache_dir)
env_config = EnvConfig(token=TOKEN, cache_dir=cache_dir)

dist.initialize_torch_distributed()

with htrack_block("get config"):
if not local_config_path.endswith(".yaml"):
if not checkpoint_config_path.endswith(".yaml"):
raise ValueError("The checkpoint path should point to a YAML file")

nanotron_config: config_cls = get_config_from_file(
local_config_path,
checkpoint_config_path,
config_class=config_cls,
model_config_class=model_config_cls,
skip_unused_config_keys=True,
Expand Down Expand Up @@ -91,7 +92,7 @@ def main(
with htrack_block("Model loading"):
# We need to load the model in the main process first to avoid downloading the model multiple times
model = NanotronLightevalModel(
checkpoint_path=os.path.dirname(local_config_path),
checkpoint_path=os.path.dirname(checkpoint_config_path),
model_args=nanotron_config.model,
tokenizer=nanotron_config.tokenizer,
parallel_context=parallel_context,
Expand All @@ -101,6 +102,7 @@ def main(
cache_dir=os.environ.get("HF_HOME", "/scratch"),
debug_one_layer_model=False,
model_class=model_cls,
env_config=env_config,
)
model_info = ModelInfo(model_name=f"{nanotron_config.general.run}/{nanotron_config.general.step}")
evaluation_tracker.general_config_logger.log_model_info(model_info)
Expand Down
15 changes: 8 additions & 7 deletions src/lighteval/models/nanotron_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
LoglikelihoodSingleTokenDataset,
)
from lighteval.models.base_model import LightevalModel
from lighteval.models.model_config import EnvConfig
from lighteval.models.model_output import Batch, GenerateReturn, LoglikelihoodReturn, LoglikelihoodSingleTokenReturn
from lighteval.tasks.requests import (
GreedyUntilRequest,
Expand Down Expand Up @@ -71,9 +72,9 @@ def __init__(
add_special_tokens: Optional[bool] = True,
dtype: Optional[Union[str, torch.dtype]] = None,
trust_remote_code: bool = False,
cache_dir: str = "/scratch",
debug_one_layer_model: bool = False,
model_class: Optional[Type] = None,
env_config: EnvConfig = None,
):
"""Initializes a nanotron model for evaluation.
Args:
Expand Down Expand Up @@ -119,7 +120,7 @@ def __init__(
self._add_special_tokens = add_special_tokens
self._tokenizer = self._create_auto_tokenizer(
pretrained=tokenizer.tokenizer_name_or_path,
cache_dir=cache_dir,
env_config=env_config,
trust_remote_code=trust_remote_code,
)
self._tokenizer.model_max_length = self.max_length
Expand Down Expand Up @@ -206,24 +207,24 @@ def _create_auto_tokenizer(
*,
pretrained: str,
tokenizer: Optional[str] = None,
cache_dir: str = "/scratch",
env_config: EnvConfig = None,
trust_remote_code: bool = False,
) -> transformers.PreTrainedTokenizer:
"""Returns a pre-trained tokenizer from a pre-trained tokenizer configuration."""

try:
tokenizer = AutoTokenizer.from_pretrained(
pretrained if tokenizer is None else tokenizer,
cache_dir=cache_dir,
token=os.getenv("HUGGING_FACE_HUB_TOKEN"),
cache_dir=env_config.cache_dir,
token=env_config.token,
trust_remote_code=trust_remote_code,
)
except RecursionError:
tokenizer = AutoTokenizer.from_pretrained(
pretrained if tokenizer is None else tokenizer,
cache_dir=cache_dir,
cache_dir=env_config.cache_dir,
token=env_config.token,
unk_token="<unk>",
token=os.getenv("HUGGING_FACE_HUB_TOKEN"),
trust_remote_code=trust_remote_code,
)
tokenizer.pad_token = tokenizer.eos_token
Expand Down
28 changes: 22 additions & 6 deletions src/lighteval/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,26 +50,42 @@ def rec(nest: dict, prefix: str, into: dict):
return flat


def clean_s3_links(key, value):
def clean_s3_links(value: str) -> str:
"""Cleans and formats s3 bucket links for better display in the result table (nanotron models)
Args:
value (str): path to clean
Returns:
str : cleaned path
"""
s3_bucket, s3_prefix = str(value).replace("s3://", "").split("/", maxsplit=1)
if not s3_prefix.endswith("/"):
s3_prefix += "/"
link_str = f"https://s3.console.aws.amazon.com/s3/buckets/{s3_bucket}?prefix={s3_prefix}"
value = f'<a href="{link_str}" target="_blank"> {value} </a>'
return key, value
return value


def obj_to_markdown(obj, convert_s3_links: bool = True) -> str:
"""Convert a (potentially nested) dataclass object or a dict in a readable markdown string for logging"""
from pytablewriter import MarkdownTableWriter

if is_dataclass(obj):
obj = asdict(obj)
config_dict = flatten_dict(obj)
config_markdown = "| Key | Value |\n| --- | --- |\n"

md_writer = MarkdownTableWriter()
md_writer.headers = ["Key", "Value"]

values = []
for key, value in config_dict.items():
if convert_s3_links and "s3://" in str(value):
key, value = clean_s3_links(key, value)
config_markdown += f"| {key} | {value} |\n"
return config_markdown
value = clean_s3_links(value)
values.append([key, value])
md_writer.value_matrix = values

return md_writer.dumps()


def sanitize_numpy(example_dict: dict) -> dict:
Expand Down

0 comments on commit 0bec0db

Please sign in to comment.