From 952819b63c4f0c40c2108691ee7a105def016bc4 Mon Sep 17 00:00:00 2001 From: Pete Date: Tue, 11 Jul 2023 15:46:37 -0700 Subject: [PATCH] Don't reshuffle eval data each "epoch" (#229) --- olmo/data/__init__.py | 7 ++++--- olmo/eval/__init__.py | 4 +--- olmo/eval/evaluator.py | 3 +-- olmo/train.py | 22 ++++++++++++++-------- olmo/util.py | 13 +------------ tests/eval/downstream_test.py | 6 +----- 6 files changed, 22 insertions(+), 33 deletions(-) diff --git a/olmo/data/__init__.py b/olmo/data/__init__.py index abf18e568..df367c6a9 100644 --- a/olmo/data/__init__.py +++ b/olmo/data/__init__.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List from torch.utils.data import DataLoader, DistributedSampler @@ -15,14 +15,15 @@ def build_memmap_dataset(train_config: TrainConfig, data_config: DataConfig) -> MemMapDataset: paths: List[str] - metadata: Optional[List[Dict[str, Any]]] = None + metadata: List[Dict[str, Any]] = [] if data_config.paths: if data_config.datasets: raise OlmoConfigurationError("DataConfig.paths is mutually exclusive with DataConfig.datasets") paths = data_config.paths + for path in paths: + metadata.append({"path": str(path)}) elif data_config.datasets: paths = [] - metadata = [] for label in sorted(data_config.datasets.keys()): label_paths = data_config.datasets[label] paths.extend(label_paths) diff --git a/olmo/eval/__init__.py b/olmo/eval/__init__.py index 96dc0e5eb..3520f8c31 100644 --- a/olmo/eval/__init__.py +++ b/olmo/eval/__init__.py @@ -7,7 +7,7 @@ from ..config import EvaluatorConfig, EvaluatorType, TrainConfig from ..exceptions import OlmoConfigurationError from ..tokenizer import Tokenizer -from ..util import cycle_through_epochs, get_global_rank, get_world_size +from ..util import get_global_rank, get_world_size from .downstream import ICLMetric, label_to_task_map from .evaluator import Evaluator @@ -59,7 +59,6 @@ def build_downstream_evaluator( label=eval_cfg.label, type=eval_cfg.type, eval_loader=ds_eval_dataloader, - eval_batches=cycle_through_epochs(ds_eval_dataloader), eval_metric=metric.to(device), subset_num_batches=eval_cfg.subset_num_batches, ) @@ -97,7 +96,6 @@ def make_metric(): label=eval_config.label, type=eval_config.type, eval_loader=eval_loader, - eval_batches=cycle_through_epochs(eval_loader), eval_metric=eval_metric, subset_num_batches=eval_config.subset_num_batches, ) diff --git a/olmo/eval/evaluator.py b/olmo/eval/evaluator.py index e7fe8c17c..ddc85a603 100644 --- a/olmo/eval/evaluator.py +++ b/olmo/eval/evaluator.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Any, Dict, Iterator, Optional, Union +from typing import Any, Dict, Optional, Union import torch from torch.utils.data import DataLoader @@ -16,7 +16,6 @@ class Evaluator: label: str type: EvaluatorType eval_loader: DataLoader - eval_batches: Iterator[Dict[str, Any]] eval_metric: Union[Metric, Dict[str, Metric]] subset_num_batches: Optional[int] = None diff --git a/olmo/train.py b/olmo/train.py index 1431d38d2..e7eb93455 100644 --- a/olmo/train.py +++ b/olmo/train.py @@ -677,17 +677,21 @@ def eval(self) -> Dict[str, Any]: # Reset metrics. evaluator.reset_metrics() - # Check how many batches to evaluate on. - num_eval_batches = evaluator.subset_num_batches - if num_eval_batches is None: - num_eval_batches = self.cfg.eval_subset_num_batches - if num_eval_batches <= 0: - num_eval_batches = max(1, len(evaluator.eval_loader)) - else: + # Initialize data loader iterator. + eval_batches = iter(evaluator.eval_loader) + + # Adjust how many batches to evaluate on. + num_eval_batches = ( + evaluator.subset_num_batches + if evaluator.subset_num_batches is not None + else self.cfg.eval_subset_num_batches + ) + if num_eval_batches > 0: num_eval_batches = min(num_eval_batches, len(evaluator.eval_loader)) + eval_batches = islice(eval_batches, num_eval_batches) # Run model over batches. - for eval_step, eval_batch in enumerate(islice(evaluator.eval_batches, num_eval_batches)): + for eval_step, eval_batch in enumerate(eval_batches): self.eval_step(eval_batch, evaluator) # Log to console. @@ -699,6 +703,8 @@ def eval(self) -> Dict[str, Any]: eval_metrics.update(metrics) self.log_metrics_to_console(f"{evaluator.label}", metrics) + del eval_batches + return eval_metrics def fit(self): diff --git a/olmo/util.py b/olmo/util.py index b4ce920e0..3f33e66c9 100644 --- a/olmo/util.py +++ b/olmo/util.py @@ -4,7 +4,7 @@ import sys import warnings from datetime import datetime -from typing import Any, Dict, Generator, Optional, TypeVar, Union +from typing import Any, Dict, Optional, TypeVar, Union import rich import torch @@ -13,7 +13,6 @@ from rich.highlighter import NullHighlighter from rich.text import Text from rich.traceback import Traceback -from torch.utils.data import DataLoader, DistributedSampler from .config import LogFilterType from .exceptions import OlmoCliError, OlmoError @@ -333,16 +332,6 @@ def peak_gpu_memory(reset: bool = False) -> Optional[float]: return peak_mb -def cycle_through_epochs(dataloader: DataLoader) -> Generator[Dict[str, Any], None, None]: - while True: - for batch in dataloader: - yield batch - - if isinstance(dataloader.sampler, DistributedSampler): - epoch = dataloader.sampler.epoch + 1 - dataloader.sampler.set_epoch(epoch) - - def syncronize_flag(flag: bool, device: torch.device) -> bool: if dist.is_available() and dist.is_initialized(): flag_tensor = torch.tensor(flag, device=device) diff --git a/tests/eval/downstream_test.py b/tests/eval/downstream_test.py index 60e23d3e5..c88f84c19 100644 --- a/tests/eval/downstream_test.py +++ b/tests/eval/downstream_test.py @@ -14,10 +14,6 @@ def test_piqa(): cfg, cfg.evaluators[1], tokenizer, torch.device("cpu"), is_unit_test=True ) logits = torch.rand(4, 57, 50304) - first_batch = next(evaluator.eval_batches) + first_batch = next(iter(evaluator.eval_loader)) evaluator.reset_metrics() evaluator.update_metrics(first_batch, logits.sum(), logits) - - -if __name__ == "__main__": - test_piqa()