Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't reshuffle eval data each "epoch" #229

Merged
merged 4 commits into from
Jul 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions olmo/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List

from torch.utils.data import DataLoader, DistributedSampler

Expand All @@ -15,14 +15,15 @@

def build_memmap_dataset(train_config: TrainConfig, data_config: DataConfig) -> MemMapDataset:
paths: List[str]
metadata: Optional[List[Dict[str, Any]]] = None
metadata: List[Dict[str, Any]] = []
if data_config.paths:
if data_config.datasets:
raise OlmoConfigurationError("DataConfig.paths is mutually exclusive with DataConfig.datasets")
paths = data_config.paths
for path in paths:
metadata.append({"path": str(path)})
elif data_config.datasets:
paths = []
metadata = []
for label in sorted(data_config.datasets.keys()):
label_paths = data_config.datasets[label]
paths.extend(label_paths)
Expand Down
4 changes: 1 addition & 3 deletions olmo/eval/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from ..config import EvaluatorConfig, EvaluatorType, TrainConfig
from ..exceptions import OlmoConfigurationError
from ..tokenizer import Tokenizer
from ..util import cycle_through_epochs, get_global_rank, get_world_size
from ..util import get_global_rank, get_world_size
from .downstream import ICLMetric, label_to_task_map
from .evaluator import Evaluator

Expand Down Expand Up @@ -59,7 +59,6 @@ def build_downstream_evaluator(
label=eval_cfg.label,
type=eval_cfg.type,
eval_loader=ds_eval_dataloader,
eval_batches=cycle_through_epochs(ds_eval_dataloader),
eval_metric=metric.to(device),
subset_num_batches=eval_cfg.subset_num_batches,
)
Expand Down Expand Up @@ -97,7 +96,6 @@ def make_metric():
label=eval_config.label,
type=eval_config.type,
eval_loader=eval_loader,
eval_batches=cycle_through_epochs(eval_loader),
eval_metric=eval_metric,
subset_num_batches=eval_config.subset_num_batches,
)
Expand Down
3 changes: 1 addition & 2 deletions olmo/eval/evaluator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Any, Dict, Iterator, Optional, Union
from typing import Any, Dict, Optional, Union

import torch
from torch.utils.data import DataLoader
Expand All @@ -16,7 +16,6 @@ class Evaluator:
label: str
type: EvaluatorType
eval_loader: DataLoader
eval_batches: Iterator[Dict[str, Any]]
eval_metric: Union[Metric, Dict[str, Metric]]
subset_num_batches: Optional[int] = None

Expand Down
22 changes: 14 additions & 8 deletions olmo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,17 +677,21 @@ def eval(self) -> Dict[str, Any]:
# Reset metrics.
evaluator.reset_metrics()

# Check how many batches to evaluate on.
num_eval_batches = evaluator.subset_num_batches
if num_eval_batches is None:
num_eval_batches = self.cfg.eval_subset_num_batches
if num_eval_batches <= 0:
num_eval_batches = max(1, len(evaluator.eval_loader))
else:
# Initialize data loader iterator.
eval_batches = iter(evaluator.eval_loader)

# Adjust how many batches to evaluate on.
num_eval_batches = (
evaluator.subset_num_batches
if evaluator.subset_num_batches is not None
else self.cfg.eval_subset_num_batches
)
if num_eval_batches > 0:
num_eval_batches = min(num_eval_batches, len(evaluator.eval_loader))
eval_batches = islice(eval_batches, num_eval_batches)

# Run model over batches.
for eval_step, eval_batch in enumerate(islice(evaluator.eval_batches, num_eval_batches)):
for eval_step, eval_batch in enumerate(eval_batches):
self.eval_step(eval_batch, evaluator)

# Log to console.
Expand All @@ -699,6 +703,8 @@ def eval(self) -> Dict[str, Any]:
eval_metrics.update(metrics)
self.log_metrics_to_console(f"{evaluator.label}", metrics)

del eval_batches

return eval_metrics

def fit(self):
Expand Down
13 changes: 1 addition & 12 deletions olmo/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import sys
import warnings
from datetime import datetime
from typing import Any, Dict, Generator, Optional, TypeVar, Union
from typing import Any, Dict, Optional, TypeVar, Union

import rich
import torch
Expand All @@ -13,7 +13,6 @@
from rich.highlighter import NullHighlighter
from rich.text import Text
from rich.traceback import Traceback
from torch.utils.data import DataLoader, DistributedSampler

from .config import LogFilterType
from .exceptions import OlmoCliError, OlmoError
Expand Down Expand Up @@ -333,16 +332,6 @@ def peak_gpu_memory(reset: bool = False) -> Optional[float]:
return peak_mb


def cycle_through_epochs(dataloader: DataLoader) -> Generator[Dict[str, Any], None, None]:
while True:
for batch in dataloader:
yield batch

if isinstance(dataloader.sampler, DistributedSampler):
epoch = dataloader.sampler.epoch + 1
dataloader.sampler.set_epoch(epoch)


def syncronize_flag(flag: bool, device: torch.device) -> bool:
if dist.is_available() and dist.is_initialized():
flag_tensor = torch.tensor(flag, device=device)
Expand Down
6 changes: 1 addition & 5 deletions tests/eval/downstream_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,6 @@ def test_piqa():
cfg, cfg.evaluators[1], tokenizer, torch.device("cpu"), is_unit_test=True
)
logits = torch.rand(4, 57, 50304)
first_batch = next(evaluator.eval_batches)
first_batch = next(iter(evaluator.eval_loader))
evaluator.reset_metrics()
evaluator.update_metrics(first_batch, logits.sum(), logits)


if __name__ == "__main__":
test_piqa()