Skip to content

Commit

Permalink
Don't reshuffle eval data each "epoch" (#229)
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh authored Jul 11, 2023
1 parent 87f6a79 commit 952819b
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 33 deletions.
7 changes: 4 additions & 3 deletions olmo/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List

from torch.utils.data import DataLoader, DistributedSampler

Expand All @@ -15,14 +15,15 @@

def build_memmap_dataset(train_config: TrainConfig, data_config: DataConfig) -> MemMapDataset:
paths: List[str]
metadata: Optional[List[Dict[str, Any]]] = None
metadata: List[Dict[str, Any]] = []
if data_config.paths:
if data_config.datasets:
raise OlmoConfigurationError("DataConfig.paths is mutually exclusive with DataConfig.datasets")
paths = data_config.paths
for path in paths:
metadata.append({"path": str(path)})
elif data_config.datasets:
paths = []
metadata = []
for label in sorted(data_config.datasets.keys()):
label_paths = data_config.datasets[label]
paths.extend(label_paths)
Expand Down
4 changes: 1 addition & 3 deletions olmo/eval/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from ..config import EvaluatorConfig, EvaluatorType, TrainConfig
from ..exceptions import OlmoConfigurationError
from ..tokenizer import Tokenizer
from ..util import cycle_through_epochs, get_global_rank, get_world_size
from ..util import get_global_rank, get_world_size
from .downstream import ICLMetric, label_to_task_map
from .evaluator import Evaluator

Expand Down Expand Up @@ -59,7 +59,6 @@ def build_downstream_evaluator(
label=eval_cfg.label,
type=eval_cfg.type,
eval_loader=ds_eval_dataloader,
eval_batches=cycle_through_epochs(ds_eval_dataloader),
eval_metric=metric.to(device),
subset_num_batches=eval_cfg.subset_num_batches,
)
Expand Down Expand Up @@ -97,7 +96,6 @@ def make_metric():
label=eval_config.label,
type=eval_config.type,
eval_loader=eval_loader,
eval_batches=cycle_through_epochs(eval_loader),
eval_metric=eval_metric,
subset_num_batches=eval_config.subset_num_batches,
)
Expand Down
3 changes: 1 addition & 2 deletions olmo/eval/evaluator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Any, Dict, Iterator, Optional, Union
from typing import Any, Dict, Optional, Union

import torch
from torch.utils.data import DataLoader
Expand All @@ -16,7 +16,6 @@ class Evaluator:
label: str
type: EvaluatorType
eval_loader: DataLoader
eval_batches: Iterator[Dict[str, Any]]
eval_metric: Union[Metric, Dict[str, Metric]]
subset_num_batches: Optional[int] = None

Expand Down
22 changes: 14 additions & 8 deletions olmo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,17 +677,21 @@ def eval(self) -> Dict[str, Any]:
# Reset metrics.
evaluator.reset_metrics()

# Check how many batches to evaluate on.
num_eval_batches = evaluator.subset_num_batches
if num_eval_batches is None:
num_eval_batches = self.cfg.eval_subset_num_batches
if num_eval_batches <= 0:
num_eval_batches = max(1, len(evaluator.eval_loader))
else:
# Initialize data loader iterator.
eval_batches = iter(evaluator.eval_loader)

# Adjust how many batches to evaluate on.
num_eval_batches = (
evaluator.subset_num_batches
if evaluator.subset_num_batches is not None
else self.cfg.eval_subset_num_batches
)
if num_eval_batches > 0:
num_eval_batches = min(num_eval_batches, len(evaluator.eval_loader))
eval_batches = islice(eval_batches, num_eval_batches)

# Run model over batches.
for eval_step, eval_batch in enumerate(islice(evaluator.eval_batches, num_eval_batches)):
for eval_step, eval_batch in enumerate(eval_batches):
self.eval_step(eval_batch, evaluator)

# Log to console.
Expand All @@ -699,6 +703,8 @@ def eval(self) -> Dict[str, Any]:
eval_metrics.update(metrics)
self.log_metrics_to_console(f"{evaluator.label}", metrics)

del eval_batches

return eval_metrics

def fit(self):
Expand Down
13 changes: 1 addition & 12 deletions olmo/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import sys
import warnings
from datetime import datetime
from typing import Any, Dict, Generator, Optional, TypeVar, Union
from typing import Any, Dict, Optional, TypeVar, Union

import rich
import torch
Expand All @@ -13,7 +13,6 @@
from rich.highlighter import NullHighlighter
from rich.text import Text
from rich.traceback import Traceback
from torch.utils.data import DataLoader, DistributedSampler

from .config import LogFilterType
from .exceptions import OlmoCliError, OlmoError
Expand Down Expand Up @@ -333,16 +332,6 @@ def peak_gpu_memory(reset: bool = False) -> Optional[float]:
return peak_mb


def cycle_through_epochs(dataloader: DataLoader) -> Generator[Dict[str, Any], None, None]:
while True:
for batch in dataloader:
yield batch

if isinstance(dataloader.sampler, DistributedSampler):
epoch = dataloader.sampler.epoch + 1
dataloader.sampler.set_epoch(epoch)


def syncronize_flag(flag: bool, device: torch.device) -> bool:
if dist.is_available() and dist.is_initialized():
flag_tensor = torch.tensor(flag, device=device)
Expand Down
6 changes: 1 addition & 5 deletions tests/eval/downstream_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,6 @@ def test_piqa():
cfg, cfg.evaluators[1], tokenizer, torch.device("cpu"), is_unit_test=True
)
logits = torch.rand(4, 57, 50304)
first_batch = next(evaluator.eval_batches)
first_batch = next(iter(evaluator.eval_loader))
evaluator.reset_metrics()
evaluator.update_metrics(first_batch, logits.sum(), logits)


if __name__ == "__main__":
test_piqa()

0 comments on commit 952819b

Please sign in to comment.