diff --git a/README.md b/README.md index dde1f8e68..48a37ba72 100644 --- a/README.md +++ b/README.md @@ -10,3 +10,27 @@ ``` pip install ai2-olmo ``` + +## Fine-tuning + +To fine-tune an OLMo model you'll first need to prepare your dataset by tokenizing it and saving the tokens IDs to a flat numpy memory-mapped array. See [`scripts/prepare_tulu_data.py`](./scripts/prepare_tulu_data.py) for an example with the Tulu V2 dataset, which can be easily modified for other datasets. + +Next, prepare your training config. There are many examples in the [`configs/`](./configs) directory that you can use as a starting point. The most important thing is to make sure the model parameters (the `model` field in the config) match up with the checkpoint you're starting from. To be safe you can always start from the config that comes with the model checkpoint. At a minimum you'll need to make the following changes to the config or provide the corresponding overrides from the command line: + +- Update `load_path` to point to the checkpoint you want to start from. +- Set `reset_trainer_state` to `true`. +- Update `data.paths` to point to the `token_ids.npy` file you generated. +- Optionally update `data.label_mask_paths` to point to the `label_mask.npy` file you generated, unless you don't need special masking for the loss. +- Update `evaluators` to add/remove in-loop evaluations. + +Once you're satisfied with your training config, you can launch the training job via `torchrun`. For example: + +``` +torchrun --nproc_per_node=8 scripts/train.py {path_to_train_config} \ + --data.paths=[{path_to_data}/input_ids.npy] \ + --data.label_mask_paths=[{path_to_data}/label_mask.npy] \ + --load_path={path_to_checkpoint} \ + --reset_trainer_state +``` + +Note: passing CLI overrides like `--reset_trainer_state` is only necessary if you didn't update those fields in your config. diff --git a/configs/mcli/mitchish-instruct.yml b/configs/mcli/mitchish-instruct.yml index 66c8d3bd7..59ee0c270 100644 --- a/configs/mcli/mitchish-instruct.yml +++ b/configs/mcli/mitchish-instruct.yml @@ -1,17 +1,21 @@ -run_name: olmo-7b-instruct +name: olmo-7b-instruct image: mosaicml/pytorch:2.1.0_cu121-python3.10-ubuntu20.04 -gpu_num: 64 -#gpu_num: 8 -#cluster: r12z3 -cluster: r7z2 -gpu_type: a100_40gb +compute: + #cluster: r12z3 + cluster: r7z2 + gpus: 64 + gpu_type: a100_40gb integrations: - integration_type: git_repo git_repo: allenai/LLM - git_branch: epwalsh/tulu-fine-tune + git_branch: main pip_install: -e . ssh_clone: true command: |- + checkpoint=s3://olmo-checkpoints/ai2-llm/olmo-medium/wd2gxrza/step556000-unsharded + learning_rate=2e-6 + run_name=mitchish-mcli-2.5T-instruct-${learning_rate}-5ep-v2 + # NOTE: For some reason getting S3 and R2 authentication working both from the command line and # from Python proved to be challenging, maybe because Mosaic's server are in Australia. # In the end I had to use separate methods to get everything working: @@ -34,7 +38,6 @@ command: |- # Prepare environment including AWS config files for both S3 and R2 access. mkdir -p /root/.cache/torch mkdir /root/checkpoint-unsharded - mkdir /root/data mkdir /root/.aws touch /root/.aws/credentials /root/.aws/config echo '[s3]' >> /root/.aws/credentials @@ -54,7 +57,6 @@ command: |- export LOG_FILTER_TYPE=local_rank0_only # Download checkpoint (everything except optimizer state). - checkpoint=s3://olmo-checkpoints/ai2-llm/olmo-medium/wd2gxrza/step556000-unsharded echo "Downloading checkpoint '${checkpoint}'..." # Download config. @@ -72,15 +74,14 @@ command: |- --endpoint-url=https://a198dc34621661a1a66a02d6eb7c4dc3.r2.cloudflarestorage.com \ "${checkpoint}/model.pt" /root/checkpoint-unsharded/ + # Download optimizer state. + #aws s3 cp --profile=r2 --region=auto \ + # --endpoint-url=https://a198dc34621661a1a66a02d6eb7c4dc3.r2.cloudflarestorage.com \ + # "${checkpoint}/optim.pt" /root/checkpoint-unsharded/ + # Now remove the aws configs so it doesn't mess with data loading / uploading checkpoints to/from S3. rm -rf /root/.aws - # Download data (it's small enough so might as well). - echo "Downloading data..." - aws s3 cp \ - s3://ai2-llm/preprocessed/tulu-v2-fine-tune/gpt-neox-20b-pii-special/data.npy \ - /root/data/data.npy - torchrun \ --master_addr "$MASTER_ADDR" \ --master_port "$MASTER_PORT" \ @@ -88,15 +89,15 @@ command: |- --node_rank "$NODE_RANK" \ --nproc_per_node 8 \ scripts/train.py configs/mitchish-instruct.yaml \ - --run_name=mitchish-mcli-2.5T-instruct-2e-6 \ - --optimizer.learning_rate=2e-6 \ + --run_name=${run_name} \ + --optimizer.learning_rate=${learning_rate} \ + --scheduler.grad_clip_warmup_steps=400 \ --save_overwrite \ - --time_limit=169200 \ - --data.paths=[/root/data/data.npy] \ - --save_interval_unsharded=10000 \ + --save_interval_unsharded=100000 \ --load_path=/root/checkpoint-unsharded \ - --reset_optimizer_state \ --reset_trainer_state \ + --reset_optimizer_state \ --compile=null \ - --activation_checkpointing=fine_grained \ - --fsdp.wrapping_strategy=size_based + --activation_checkpointing=whole_layer \ + --fsdp.wrapping_strategy=size_based \ + --max_duration=5ep diff --git a/configs/mitchish-instruct.yaml b/configs/mitchish-instruct.yaml index a21f1ade6..ad247e7bc 100644 --- a/configs/mitchish-instruct.yaml +++ b/configs/mitchish-instruct.yaml @@ -43,15 +43,15 @@ compile: optimizer: name: adamw learning_rate: 2e-5 - weight_decay: 0.0 + weight_decay: 0.1 betas: - 0.9 - - 0.999 + - 0.95 metrics_log_interval: 10 scheduler: name: linear_with_warmup - t_warmup: 100 + t_warmup: 200 alpha_f: 0.001 tokenizer: @@ -91,42 +91,6 @@ eval_interval: ${save_interval} eval_subset_num_batches: -1 device_eval_batch_size: ${device_train_microbatch_size} evaluators: - - label: all-small-ppl-validation - data: - num_workers: 0 - drop_last: true - # pin_memory: true - # prefetch_factor: 1 - # persistent_workers: false - # timeout: 0 - datasets: - 4chan-validation: - - s3://ai2-llm/eval-data/perplexity/v2_small_gptneox20b/4chan/val.npy - c4_100_domains-validation: - - s3://ai2-llm/eval-data/perplexity/v2_small_gptneox20b/c4_100_domains/val.npy - c4_en-validation: - - s3://ai2-llm/eval-data/perplexity/v2_small_gptneox20b/c4_en/val.npy - gab-validation: - - s3://ai2-llm/eval-data/perplexity/v2_small_gptneox20b/gab/val.npy - ice-validation: - - s3://ai2-llm/eval-data/perplexity/v2_small_gptneox20b/ice/val.npy - m2d2_s2orc-validation: - - s3://ai2-llm/eval-data/perplexity/v2_small_gptneox20b/m2d2_s2orc/val.npy - m2d2_wiki-validation: - - s3://ai2-llm/eval-data/perplexity/v2_small_gptneox20b/m2d2_wiki/val.npy - manosphere-validation: - - s3://ai2-llm/eval-data/perplexity/v2_small_gptneox20b/manosphere/val.npy - mc4_en-validation: - - s3://ai2-llm/eval-data/perplexity/v2_small_gptneox20b/mc4_en/val.npy - pile-validation: - - s3://ai2-llm/eval-data/perplexity/v2_small_gptneox20b/pile/val.npy - ptb-validation: - - s3://ai2-llm/eval-data/perplexity/v2_small_gptneox20b/ptb/val.npy - twitterAEE-validation: - - s3://ai2-llm/eval-data/perplexity/v2_small_gptneox20b/twitterAEE/val.npy - wikitext_103-validation: - - s3://ai2-llm/eval-data/perplexity/v2_small_gptneox20b/wikitext_103/val.npy - ########################## # Downstream evaluations # ########################## @@ -179,4 +143,6 @@ data: timeout: 0 generate_attention_mask: true paths: - - s3://ai2-llm/preprocessed/tulu-v2-fine-tune/gpt-neox-20b-pii-special/data.npy + - s3://ai2-llm/preprocessed/tulu-v2-fine-tune/gpt-neox-20b-pii-special/input_ids.npy + label_mask_paths: + - s3://ai2-llm/preprocessed/tulu-v2-fine-tune/gpt-neox-20b-pii-special/label_mask.npy diff --git a/olmo/config.py b/olmo/config.py index f3768d18d..17c463f04 100644 --- a/olmo/config.py +++ b/olmo/config.py @@ -515,6 +515,7 @@ class PaddingDirection(StrEnum): class DataConfig(BaseConfig): paths: Optional[List[str]] = None datasets: Optional[Dict[str, List[str]]] = None + label_mask_paths: Optional[List[str]] = None pad_direction: PaddingDirection = PaddingDirection.right generate_attention_mask: bool = False num_workers: int = 0 diff --git a/olmo/data/__init__.py b/olmo/data/__init__.py index bc08ff863..52421b57a 100644 --- a/olmo/data/__init__.py +++ b/olmo/data/__init__.py @@ -1,8 +1,9 @@ from pathlib import Path -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional, cast from torch.utils.data import DataLoader, DistributedSampler +from ..aliases import PathOrStr from ..config import DataConfig, TrainConfig from ..exceptions import OlmoConfigurationError from ..torch_util import barrier, get_global_rank, get_world_size @@ -39,6 +40,7 @@ def build_memmap_dataset( include_instance_metadata=include_instance_metadata, pad_token_id=train_config.model.pad_token_id, generate_attention_mask=data_config.generate_attention_mask, + label_mask_paths=cast(Optional[List[PathOrStr]], data_config.label_mask_paths), ) diff --git a/olmo/data/collator.py b/olmo/data/collator.py index 2d81d271e..d86a0b9af 100644 --- a/olmo/data/collator.py +++ b/olmo/data/collator.py @@ -26,6 +26,7 @@ def __call__(self, items: Union[List[Dict[str, Any]], List[torch.Tensor]]) -> Di all_input_ids = [] all_attention_mask = [] all_attention_bias = [] + all_label_mask = [] all_indices = [] all_metadata = [] for x in items: @@ -78,6 +79,19 @@ def __call__(self, items: Union[List[Dict[str, Any]], List[torch.Tensor]]) -> Di ) ) + # Pad label mask. + label_mask = x.get("label_mask") if isinstance(x, dict) else None + if label_mask is not None: + if not isinstance(label_mask, torch.Tensor): + label_mask = torch.tensor(label_mask) + all_label_mask.append( + F.pad( + label_mask.to(dtype=torch.bool), + pad_shape, + value=False, + ) + ) + # Indices. index = x.get("index") if isinstance(x, dict) else None if index is not None: @@ -93,8 +107,11 @@ def __call__(self, items: Union[List[Dict[str, Any]], List[torch.Tensor]]) -> Di out["attention_mask"] = torch.stack(all_attention_mask) if all_attention_bias: out["attention_bias"] = torch.stack(all_attention_bias) + if all_label_mask: + out["label_mask"] = torch.stack(all_label_mask) if all_indices: out["index"] = torch.stack(all_indices) if all_metadata: out["metadata"] = all_metadata + return out diff --git a/olmo/data/memmap_dataset.py b/olmo/data/memmap_dataset.py index 69f2d85b9..5af73c277 100644 --- a/olmo/data/memmap_dataset.py +++ b/olmo/data/memmap_dataset.py @@ -35,6 +35,10 @@ class MemMapDataset(Dataset[Dict[str, Any]]): with the same number of items as there are paths. :param include_instance_metadata: If ``True`` (the default), each instance returned from `__getitem__` will include the metadata from its source. + :param generate_attention_mask: If ``True``, each instance returned from ``__getitem__`` will include an + attention mask generated by masking each padding token. + :param pad_token_id: The ID of the padding token. Required if ``generate_attention_mask`` is ``True``. + :param label_mask_paths: Optional paths to ``np.bool_`` memory-mapped arrays of label masks. """ def __init__( @@ -46,21 +50,30 @@ def __init__( include_instance_metadata: bool = True, generate_attention_mask: bool = False, pad_token_id: Optional[int] = None, + label_mask_paths: Optional[List[PathOrStr]] = None, ): if not paths: raise ValueError("At least one path is required") + + if generate_attention_mask and not pad_token_id: + raise ValueError("'pad_token_id' is required for 'generate_attention_mask'") + + if label_mask_paths and len(label_mask_paths) != len(paths): + raise ValueError("There must be the same number of 'label_mask_paths' as there are 'paths'") + if isinstance(metadata, list): if len(metadata) != len(paths): raise ValueError("'metadata' should have the same length as the number of file paths") else: metadata = [metadata or {}] * len(paths) + self._memmap_paths = paths self._metadata = metadata + self._label_mask_paths = label_mask_paths self._chunk_size = chunk_size self._mmap_offsets: Optional[List[Tuple[int, int]]] = None self._num_instances: Optional[int] = None self.dtype = memmap_dtype - self._item_size = self.dtype(0).itemsize self._include_instance_metadata = include_instance_metadata self._generate_attention_mask = generate_attention_mask self._pad_token_id = pad_token_id @@ -89,34 +102,57 @@ def offsets(self) -> List[Tuple[int, int]]: import concurrent.futures self._mmap_offsets = [] - with concurrent.futures.ThreadPoolExecutor() as executor: - futures = [] - for path in self._memmap_paths: - future = executor.submit(self._get_file_length, path) - futures.append(future) - path_to_length: Dict[PathOrStr, int] = {} - for future in concurrent.futures.as_completed(futures): + path_to_length: Dict[PathOrStr, int] = {} + path_to_mask_path: Dict[PathOrStr, PathOrStr] = {} + mask_path_to_length: Dict[PathOrStr, int] = {} + + with concurrent.futures.ThreadPoolExecutor() as executor: + path_futures = [] + mask_path_futures = [] + for i, path in enumerate(self._memmap_paths): + path_futures.append(executor.submit(self._get_file_length, path)) + if self._label_mask_paths is not None: + mask_path = self._label_mask_paths[i] + path_to_mask_path[path] = mask_path + mask_path_futures.append(executor.submit(self._get_file_length, mask_path, np.bool_)) + + for future in concurrent.futures.as_completed(path_futures): path, length = future.result() path_to_length[path] = length + for future in concurrent.futures.as_completed(mask_path_futures): + path, length = future.result() + mask_path_to_length[path] = length + start_offset = 0 for path in self._memmap_paths: length = path_to_length[path] + if mask_path_to_length: + mask_path = path_to_mask_path[path] + if length != mask_path_to_length[mask_path]: + raise ValueError(f"masking file '{mask_path}' should be the same size as '{path}'") end_offset = start_offset + length self._mmap_offsets.append((start_offset, end_offset)) start_offset += length return self._mmap_offsets - def _read_chunk_from_memmap(self, path: PathOrStr, index: int) -> torch.Tensor: - bytes_start = index * self._item_size * self._chunk_size - num_bytes = self._item_size * self._chunk_size + def _read_chunk_from_memmap(self, path: PathOrStr, index: int, dtype=None) -> torch.Tensor: + dtype = dtype or self.dtype + item_size = dtype(0).itemsize + bytes_start = index * item_size * self._chunk_size + num_bytes = item_size * self._chunk_size buffer = get_bytes_range(path, bytes_start, num_bytes) - array = np.frombuffer(buffer, dtype=self.dtype) - return torch.tensor(array.astype(np.int_), dtype=torch.long) + array = np.frombuffer(buffer, dtype=dtype) + if dtype == np.bool_: + return torch.tensor(array) + else: + return torch.tensor(array.astype(np.int_), dtype=torch.long) - def _get_file_length(self, path) -> Tuple[PathOrStr, int]: - return path, file_size(path) // (self._item_size * self._chunk_size) + def _get_file_length(self, path, dtype=None) -> Tuple[PathOrStr, int]: + dtype = dtype or self.dtype + item_size = dtype(0).itemsize + return path, file_size(path) // (item_size * self._chunk_size) def __len__(self) -> int: if self._num_instances is None: @@ -141,8 +177,14 @@ def __getitem__(self, index: int) -> Dict[str, Any]: # Read the data from file. input_ids = self._read_chunk_from_memmap(self._memmap_paths[memmap_index], memmap_local_index) - out: Dict[str, Any] = {"input_ids": input_ids} + + if self._label_mask_paths is not None: + label_mask = self._read_chunk_from_memmap( + self._label_mask_paths[memmap_index], memmap_local_index, dtype=np.bool_ + ) + out["label_mask"] = label_mask + if self._include_instance_metadata: metadata = self._metadata[memmap_index] out["metadata"] = deepcopy(metadata) diff --git a/olmo/tokenizer.py b/olmo/tokenizer.py index b6b934839..a833d3c21 100644 --- a/olmo/tokenizer.py +++ b/olmo/tokenizer.py @@ -44,6 +44,14 @@ def __init__( def vocab_size(self) -> int: return self.base_tokenizer.get_vocab_size() + @property + def eos_token(self) -> str: + return self.decode([self.eos_token_id], skip_special_tokens=False) + + @property + def pad_token(self) -> str: + return self.decode([self.pad_token_id], skip_special_tokens=False) + @classmethod def from_train_config(cls, config: TrainConfig) -> Tokenizer: tokenizer_identifier = config.tokenizer.identifier diff --git a/olmo/train.py b/olmo/train.py index d9710f453..f459ad88d 100644 --- a/olmo/train.py +++ b/olmo/train.py @@ -517,9 +517,15 @@ def remove_checkpoint(self, idx: int = 0, checkpoint_type: CheckpointType = Chec def get_labels(self, batch: Dict[str, Any]) -> torch.Tensor: # Labels are just input IDs shifted to the left (first item is ignored). - labels, attention_mask = batch["input_ids"], batch.get("attention_mask") + labels, label_mask, attention_mask = ( + batch["input_ids"].clone(), + batch.get("label_mask"), + batch.get("attention_mask"), + ) + if label_mask is not None: + labels.masked_fill_(~label_mask, -100) if attention_mask is not None: - labels = labels.masked_fill(attention_mask == 0.0, -100) + labels.masked_fill_(attention_mask == 0.0, -100) return labels[..., 1:].contiguous() def model_forward( diff --git a/scripts/prepare_tulu_data.py b/scripts/prepare_tulu_data.py new file mode 100644 index 000000000..4eba35945 --- /dev/null +++ b/scripts/prepare_tulu_data.py @@ -0,0 +1,131 @@ +""" +Script for preparing the Tulu V2 data for fine-tuning an OLMo model. +""" + +import logging +from argparse import ArgumentParser +from functools import partial +from pathlib import Path + +import datasets as ds +import numpy as np +from rich.progress import track + +from olmo.tokenizer import Tokenizer +from olmo.util import prepare_cli_environment + +log = logging.getLogger(__name__) + + +def main(opts) -> None: + tokenizer: Tokenizer + if Path(opts.tokenizer).is_file(): + tokenizer = Tokenizer.from_file(opts.tokenizer, eos_token_id=opts.eos, pad_token_id=opts.pad) + else: + tokenizer = Tokenizer.from_pretrained(opts.tokenizer, eos_token_id=opts.eos, pad_token_id=opts.pad) + + dataset = ds.load_dataset("allenai/tulu-v2-sft-mixture", split="train") + + log.info("Tokenizing dataset...") + dataset = dataset.map( + partial(preprocess, tokenizer=tokenizer, max_seq_len=opts.seq_len), + batched=False, + remove_columns=["dataset", "id", "messages"], + num_proc=opts.num_proc, # type: ignore + ) + + log.info("Filtering dataset...") + n = len(dataset) # type: ignore + dataset = dataset.filter(filter, batched=False, num_proc=opts.num_proc) # type: ignore + log.info(f"Filtered out {n - len(dataset):,d} examples") + + log.info("Counting tokens...") + total_tokens = 0 + for ex in track(dataset): + assert len(ex["input_ids"]) == opts.seq_len # type: ignore + total_tokens += len(ex["input_ids"]) # type: ignore + log.info(f"Total tokens: {total_tokens:,d}") + + log.info(f"Saving results to '{opts.output_dir}'...") + output_dir = Path(opts.output_dir) + output_dir.mkdir(exist_ok=True, parents=True) + + input_ids_file = np.memmap( + str(output_dir / "input_ids.npy"), dtype=np.uint16, mode="w+", shape=(total_tokens,) + ) + label_mask_file = np.memmap( + str(output_dir / "label_mask.npy"), dtype=np.bool_, mode="w+", shape=(total_tokens,) + ) + offset = 0 + for ex in track(dataset): + ex_len = len(ex["input_ids"]) # type: ignore + input_ids_file[offset : offset + ex_len] = ex["input_ids"] # type: ignore + label_mask_file[offset : offset + ex_len] = ex["label_mask"] # type: ignore + offset += ex_len + input_ids_file.flush() + label_mask_file.flush() + + log.info("Done!") + + +def filter(example): + return example["n_labels"] > 0 + + +def preprocess(example, tokenizer: Tokenizer, max_seq_len: int): + input_ids = [tokenizer.eos_token_id] + label_mask = [False] + + for msg in example["messages"]: + role_tokens = tokenizer.encode(f"<|{msg['role']}|>\n", add_special_tokens=False) + label_mask += [False] * len(role_tokens) + input_ids += role_tokens + + if msg["role"] == "assistant": + content_tokens = tokenizer.encode( + msg["content"].strip() + tokenizer.eos_token + "\n", add_special_tokens=False + ) + label_mask += [True] * len(content_tokens) + # mask out the last '\n' + assert content_tokens[-2] == tokenizer.eos_token_id + label_mask[-1] = False + else: + content_tokens = tokenizer.encode(msg["content"].strip() + "\n", add_special_tokens=False) + label_mask += [False] * len(content_tokens) + input_ids += content_tokens + + input_ids = input_ids[:max_seq_len] + label_mask = label_mask[:max_seq_len] + + if len(input_ids) < max_seq_len: + pad_len = max_seq_len - len(input_ids) + input_ids += [tokenizer.pad_token_id] * pad_len + label_mask += [False] * pad_len + + assert len(input_ids) == len(label_mask) + n_labels = sum(label_mask) + + return {"input_ids": input_ids, "label_mask": label_mask, "n_labels": n_labels} + + +def get_parser() -> ArgumentParser: + parser = ArgumentParser(description="Prepare Tulu V2 dataset") + parser.add_argument("output_dir", type=str, help="""Directory to save the results to.""") + parser.add_argument( + "-t", + "--tokenizer", + type=str, + help="""Tokenizer path or identifier.""", + default="tokenizers/allenai_eleuther-ai-gpt-neox-20b-pii-special.json", + ) + parser.add_argument("-s", "--seq-len", type=int, help="""Max sequence length.""", default=2048) + parser.add_argument("--eos", type=int, help="""EOS token ID.""", default=50279) + parser.add_argument("--pad", type=int, help="""PAD token ID.""", default=1) + parser.add_argument("-j", "--num-proc", type=int, help="""Number of workers.""", default=8) + return parser + + +if __name__ == "__main__": + prepare_cli_environment() + opts = get_parser().parse_args() + main(opts) diff --git a/tests/data/collator_test.py b/tests/data/collator_test.py index 2279570d5..e94451313 100644 --- a/tests/data/collator_test.py +++ b/tests/data/collator_test.py @@ -92,3 +92,40 @@ def test_collate_with_attention_bias(train_config, pad_direction): ] ) ).all() + + +@pytest.mark.parametrize( + "pad_direction", + [pytest.param(PaddingDirection.right, id="pad-right"), pytest.param(PaddingDirection.left, id="pad-left")], +) +def test_collate_with_label_mask(train_config, pad_direction): + train_config.data.pad_direction = pad_direction + collator = DataCollator.from_train_config(train_config) + + inputs = [ + { + "input_ids": torch.tensor([0, 1, 2, 3]), + "label_mask": torch.tensor([True, False, True, True]), + }, + { + "input_ids": torch.tensor([4, 5, 6]), + "label_mask": torch.tensor([True, True, False]), + }, + ] + batch = collator(inputs) # type: ignore + assert batch["label_mask"] is not None + assert batch["label_mask"].shape == (2, 4) + if pad_direction == "right": + assert ( + batch["label_mask"] + == torch.tensor( + [[True, False, True, True], [True, True, False, False]], + ) + ).all() + else: + assert ( + batch["label_mask"] + == torch.tensor( + [[True, False, True, True], [False, True, True, False]], + ) + ).all() diff --git a/tests/data/memmap_dataset_test.py b/tests/data/memmap_dataset_test.py index 85a3fb0cc..e267043ee 100644 --- a/tests/data/memmap_dataset_test.py +++ b/tests/data/memmap_dataset_test.py @@ -22,6 +22,40 @@ def test_mmap_dataset(tmp_path: Path): assert ds[7]["input_ids"].tolist() == [28, 29, 30, 31] +def test_mmap_dataset_with_label_mask(tmp_path: Path): + mmap1 = np.memmap(tmp_path / "mmap1.npy", mode="w+", dtype=np.uint16, shape=(16,)) + mmap1[:] = np.array(list(range(16)), dtype=np.uint16) + mmap1.flush() + + mask1 = [True] * 16 + mask1[1] = False + mask_mmap1 = np.memmap(tmp_path / "mask_mmap1.npy", mode="w+", dtype=np.bool_, shape=(16,)) + mask_mmap1[:] = np.array(mask1, dtype=np.bool_) + mask_mmap1.flush() + + mmap2 = np.memmap(tmp_path / "mmap2.npy", mode="w+", dtype=np.uint16, shape=(16,)) + mmap2[:] = np.array(list(range(16, 32)), dtype=np.uint16) + mmap2.flush() + + mask2 = [True] * 16 + mask2[-1] = False + mask_mmap2 = np.memmap(tmp_path / "mask_mmap2.npy", mode="w+", dtype=np.bool_, shape=(16,)) + mask_mmap2[:] = np.array(mask2, dtype=np.bool_) + mask_mmap2.flush() + + ds = MemMapDataset( + tmp_path / "mmap1.npy", + tmp_path / "mmap2.npy", + chunk_size=4, + label_mask_paths=[tmp_path / "mask_mmap1.npy", tmp_path / "mask_mmap2.npy"], + ) + assert ds[0]["input_ids"].tolist() == [0, 1, 2, 3] + assert ds[0]["label_mask"].tolist() == [True, False, True, True] + assert ds[1]["input_ids"].tolist() == [4, 5, 6, 7] + assert ds[7]["input_ids"].tolist() == [28, 29, 30, 31] + assert ds[7]["label_mask"].tolist() == [True, True, True, False] + + def test_mmap_dataset_with_metadata(tokenizer: Tokenizer, tmp_path: Path, lorem_ipsum_docs: List[str]): chunk_size = 24