From 589e6b0d8b61dbd962b2ce288769bb569bbbf025 Mon Sep 17 00:00:00 2001 From: Thomas Wolf Date: Tue, 27 Feb 2024 11:02:37 +0100 Subject: [PATCH] update (#58) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Clémentine Fourrier <22726840+clefourrier@users.noreply.github.com> --- src/lighteval/logging/evaluation_tracker.py | 13 ++++++++++--- src/lighteval/main_nanotron.py | 2 +- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/src/lighteval/logging/evaluation_tracker.py b/src/lighteval/logging/evaluation_tracker.py index 68ac95f23..f8b7540d2 100644 --- a/src/lighteval/logging/evaluation_tracker.py +++ b/src/lighteval/logging/evaluation_tracker.py @@ -1,3 +1,4 @@ +import copy import json import os import re @@ -22,7 +23,7 @@ if is_nanotron_available(): - from nanotron.config import Config, get_config_from_dict + from nanotron.config import Config class EnhancedJSONEncoder(json.JSONEncoder): @@ -116,8 +117,14 @@ def save( hlog(f"Saving results to {output_results_file} and {output_results_in_details_file}") + config_general = copy.deepcopy(self.general_config_logger) + config_general.config = ( + config_general.config.as_dict() if is_dataclass(config_general.config) else config_general.config + ) + config_general = asdict(config_general) + to_dump = { - "config_general": asdict(self.general_config_logger), + "config_general": config_general, "results": self.metrics_logger.metric_aggregated, "versions": self.versions_logger.versions, "config_tasks": self.task_config_logger.tasks_configs, @@ -485,7 +492,7 @@ def push_results_to_tensorboard( # noqa: C901 if not is_nanotron_available(): hlog_warn("You cannot push results to tensorboard with having nanotron installed. Skipping") return - config: Config = get_config_from_dict(self.general_config_logger.config, config_class=Config) + config: Config = self.general_config_logger.config lighteval_config = config.lighteval try: global_step = config.general.step diff --git a/src/lighteval/main_nanotron.py b/src/lighteval/main_nanotron.py index 8b71640f8..f657f1280 100644 --- a/src/lighteval/main_nanotron.py +++ b/src/lighteval/main_nanotron.py @@ -80,7 +80,7 @@ def main( override_batch_size=None, max_samples=lighteval_config.tasks.max_samples, job_id=os.environ.get("SLURM_JOB_ID", None), - config=nanotron_config.as_dict(), + config=nanotron_config, ) with htrack_block("Test all gather"):