Skip to content

Commit

Permalink
fix: refactor call to not use accelerate log
Browse files Browse the repository at this point in the history
  • Loading branch information
NanoCode012 committed Jan 10, 2025
1 parent 9f54b8a commit 0eaa579
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 37 deletions.
2 changes: 1 addition & 1 deletion src/axolotl/utils/data/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,14 @@
from axolotl.utils.data.shared import load_dataset_w_config
from axolotl.utils.data.utils import (
deduplicate_and_log_datasets,
drop_long_seq_in_dataset,
md5,
retry_on_request_exceptions,
)
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_local_main_process, zero_first
from axolotl.utils.trainer import (
calculate_total_num_steps,
drop_long_seq_in_dataset,
process_datasets_for_packing,
)

Expand Down
42 changes: 41 additions & 1 deletion src/axolotl/utils/data/utils.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
"""data handling helpers"""

import functools
import hashlib
import logging
import time
from enum import Enum

import huggingface_hub
import numpy as np
import requests
from datasets import Dataset

LOG = logging.getLogger("axolotl")
from axolotl.utils.dict import DictDefault
from axolotl.utils.samplers.utils import get_dataset_lengths

LOG = logging.getLogger(__name__)


class RetryStrategy(Enum):
Expand Down Expand Up @@ -150,3 +155,38 @@ def deduplicate_and_log_datasets(
)

return train_dataset, eval_dataset, dataset


def drop_long_seq(sample, sequence_len=2048, min_sequence_len=None):
min_sequence_len = min_sequence_len or 2

return (
len(sample["input_ids"]) <= sequence_len
and len(sample["input_ids"]) >= min_sequence_len
)


def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault):
drop_long = functools.partial(
drop_long_seq,
sequence_len=cfg.sequence_len,
min_sequence_len=cfg.min_sample_len,
)

min_input_len = np.min(get_dataset_lengths(dataset))
LOG.debug(f"min_input_len: {min_input_len}")
max_input_len = np.max(get_dataset_lengths(dataset))
LOG.debug(f"max_input_len: {max_input_len}")

prior_len = len(dataset)
dataset = dataset.filter(
drop_long,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Dropping Long Sequences",
)
dropped = prior_len - len(dataset)
if dropped:
LOG.warning(f"Dropped {dropped} long samples from dataset")

return dataset
36 changes: 1 addition & 35 deletions src/axolotl/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from transformers.utils import is_torch_bf16_gpu_available

from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
from axolotl.utils.data.utils import drop_long_seq
from axolotl.utils.distributed import reduce_and_broadcast
from axolotl.utils.environment import check_cuda_p2p_ib_support
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
Expand Down Expand Up @@ -172,41 +173,6 @@ def add_length(sample):
return sample


def drop_long_seq(sample, sequence_len=2048, min_sequence_len=None):
min_sequence_len = min_sequence_len or 2

return (
len(sample["input_ids"]) <= sequence_len
and len(sample["input_ids"]) >= min_sequence_len
)


def drop_long_seq_in_dataset(dataset, cfg):
drop_long = partial(
drop_long_seq,
sequence_len=cfg.sequence_len,
min_sequence_len=cfg.min_sample_len,
)

min_input_len = np.min(get_dataset_lengths(dataset))
LOG.debug(f"min_input_len: {min_input_len}", main_process_only=True)
max_input_len = np.max(get_dataset_lengths(dataset))
LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True)

prior_len = len(dataset)
dataset = dataset.filter(
drop_long,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Dropping Long Sequences",
)
dropped = prior_len - len(dataset)
if dropped:
LOG.warning(f"Dropped {dropped} long samples from dataset")

return dataset


def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
if cfg.model_config_type == "mamba":
LOG.info("dropping attention_mask column")
Expand Down

0 comments on commit 0eaa579

Please sign in to comment.