diff --git a/src/lighteval/main_nanotron.py b/src/lighteval/main_nanotron.py index b52be757b..05c82a4a7 100644 --- a/src/lighteval/main_nanotron.py +++ b/src/lighteval/main_nanotron.py @@ -21,7 +21,7 @@ raise ImportError(NO_NANOTRON_ERROR_MSG) from nanotron import distributed as dist -from nanotron.config import Config, get_config_from_file +from nanotron.config import Config, LightEvalConfig, get_config_from_file from nanotron.logging import get_logger from nanotron.parallel.context import ParallelContext from nanotron.utils import local_ranks_zero_first @@ -63,10 +63,7 @@ def main( ) if lighteval_config_path: - lighteval_nanotron_config: config_cls = get_config_from_file( - lighteval_config_path, config_class=config_cls - ) - lighteval_config = lighteval_nanotron_config.lighteval + lighteval_config: config_cls = get_config_from_file(lighteval_config_path, config_class=LightEvalConfig) nanotron_config.lighteval = lighteval_config else: lighteval_config = nanotron_config.lighteval