Skip to content

Commit

Permalink
update (#58)
Browse files Browse the repository at this point in the history
Co-authored-by: Clémentine Fourrier <[email protected]>
  • Loading branch information
thomwolf and clefourrier authored Feb 27, 2024
1 parent 6a3e3b9 commit 589e6b0
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
13 changes: 10 additions & 3 deletions src/lighteval/logging/evaluation_tracker.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import json
import os
import re
Expand All @@ -22,7 +23,7 @@


if is_nanotron_available():
from nanotron.config import Config, get_config_from_dict
from nanotron.config import Config


class EnhancedJSONEncoder(json.JSONEncoder):
Expand Down Expand Up @@ -116,8 +117,14 @@ def save(

hlog(f"Saving results to {output_results_file} and {output_results_in_details_file}")

config_general = copy.deepcopy(self.general_config_logger)
config_general.config = (
config_general.config.as_dict() if is_dataclass(config_general.config) else config_general.config
)
config_general = asdict(config_general)

to_dump = {
"config_general": asdict(self.general_config_logger),
"config_general": config_general,
"results": self.metrics_logger.metric_aggregated,
"versions": self.versions_logger.versions,
"config_tasks": self.task_config_logger.tasks_configs,
Expand Down Expand Up @@ -485,7 +492,7 @@ def push_results_to_tensorboard( # noqa: C901
if not is_nanotron_available():
hlog_warn("You cannot push results to tensorboard with having nanotron installed. Skipping")
return
config: Config = get_config_from_dict(self.general_config_logger.config, config_class=Config)
config: Config = self.general_config_logger.config
lighteval_config = config.lighteval
try:
global_step = config.general.step
Expand Down
2 changes: 1 addition & 1 deletion src/lighteval/main_nanotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def main(
override_batch_size=None,
max_samples=lighteval_config.tasks.max_samples,
job_id=os.environ.get("SLURM_JOB_ID", None),
config=nanotron_config.as_dict(),
config=nanotron_config,
)

with htrack_block("Test all gather"):
Expand Down

0 comments on commit 589e6b0

Please sign in to comment.