diff --git a/.github/workflows/e2e_test.yml b/.github/workflows/e2e_test.yml index fc2a7e05..ef9a86cb 100644 --- a/.github/workflows/e2e_test.yml +++ b/.github/workflows/e2e_test.yml @@ -26,6 +26,7 @@ jobs: llama-3_1-8b-scan-offload-name: ${{ steps.run-llama-3_1-8b-scan-offload.outputs.name }} llama-3-8b-2d-name: ${{ steps.run-llama-3-8b-2d.outputs.name }} llama-3-8b-2-slice-name: ${{ steps.run-llama-3-8b-2-slice.outputs.name }} + llama-3-8b-sft-name: ${{ steps.run-llama-3-8b-sft.outputs.name }} llama-3-8b-ddp-fsdp-name: ${{ steps.run-llama-3-8b-ddp-fsdp.outputs.name }} mixtral-8x7b-name: ${{ steps.run-mixtral-8x7b.outputs.name }} artifact-dir: ${{ steps.artifacts.outputs.artifact_dir }} @@ -188,6 +189,25 @@ jobs: task.max_steps=15 \ dcn_mesh.fsdp=2 \ ici_mesh.fsdp=4 \ + profile_step=3 + + - name: Run Llama 3.0 8B SFT + id: run-llama-3-8b-sft + env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} + XLA_IR_DEBUG: 1 + XLA_HLO_DEBUG: 1 + run: | + name=$(e2e_testing/gen_name.py llama-3-8b-sft) + echo "name=$name" >> "$GITHUB_OUTPUT" + tp run ${{ steps.docker-url-option.outputs.value }} \ + --name $name \ + torchprime/torch_xla_models/train.py \ + --config-name sft_w_gsm8k \ + ici_mesh.fsdp=4 \ + task.max_steps=20 \ + task.global_batch_size=16 \ + task.convert_to_safetensors=False \ profile_start_step=3 - name: Run Llama 3.0 8B (ddp + fsdp) @@ -256,6 +276,7 @@ jobs: matrix.config.benchmark == 'llama-3_1-8b-scan-offload' && needs.tp-run.outputs.llama-3_1-8b-scan-offload-name || matrix.config.benchmark == 'llama-3-8b-2d' && needs.tp-run.outputs.llama-3-8b-2d-name || matrix.config.benchmark == 'mixtral-8x7b' && needs.tp-run.outputs.mixtral-8x7b-name || + matrix.config.benchmark == 'llama-3-8b-sft' && needs.tp-run.outputs.llama-3-8b-sft-name || matrix.config.benchmark == 'llama-3-8b-2-slice' && needs.tp-run.outputs.llama-3-8b-2-slice-name || matrix.config.benchmark == 'llama-3-8b-ddp-fsdp' && needs.tp-run.outputs.llama-3-8b-ddp-fsdp-name }} diff --git a/README.md b/README.md index 2f8c17cf..3dc9e924 100644 --- a/README.md +++ b/README.md @@ -130,6 +130,15 @@ python3 torchprime/torch_xla_models/train.py \ You may refer to the hydra docs for other ways to specify configs. +To fine-tune a pretrained model using the gsm8k (Grade School Math question-answer) dataset, run + +```sh +python3 torchprime/torch_xla_models/train.py --config-name sft_w_gsm8k +``` + +This uses the `sft_w_gsm8k.yaml` config which selects the SFT trainer and +dataset automatically. + ### Multi-VM distributed training `torchprime` uses [xpk][xpk] as the standard path for iterating on distributed diff --git a/e2e_testing/step_time_bounds.yaml b/e2e_testing/step_time_bounds.yaml index c58da512..016ce794 100644 --- a/e2e_testing/step_time_bounds.yaml +++ b/e2e_testing/step_time_bounds.yaml @@ -41,6 +41,13 @@ benchmarks: confidence_interval: 0.12888 average: 3.9587 sample_size: 416 + llama-3-8b-sft: + name: Llama 3.0 8B SFT + step_time_lower_bound: 0 # some random number, will be replaced by actual values later + step_time_upper_bound: 1 # some random number + confidence_interval: 0.5 # some random number + average: 0.5 # some random number + sample_size: 123 # some random number llama-3-8b-ddp-fsdp: name: Llama 3.0 8B (ddp + fsdp) step_time_lower_bound: 3.82985294 # copyed from llama-3-8b-2-slice diff --git a/e2e_testing/update_step_time.py b/e2e_testing/update_step_time.py index e425b940..0b819b23 100755 --- a/e2e_testing/update_step_time.py +++ b/e2e_testing/update_step_time.py @@ -73,6 +73,15 @@ def match_llama_3_8b_2_slice(row): ) +def match_llama_3_8b_sft(row): + config = json.loads(row.configs_framework) + return ( + row.run_id.startswith("llama-3-8b-sft") + and config["dcn_mesh"]["fsdp"] == 1 + and config["ici_mesh"]["tensor"] == 1 + ) + + def match_llama_3_8b_ddp_fsdp(row): config = json.loads(row.configs_framework) return ( @@ -89,6 +98,7 @@ def match_llama_3_8b_ddp_fsdp(row): "Llama 3.0 8B (2D sharding)": match_llama3_8b_2d, "Mixtral 8x7B": match_mixtral, "Llama 3.0 8B (2 Slice)": match_llama_3_8b_2_slice, + "Llama 3.0 8B SFT": match_llama_3_8b_sft, "Llama 3.0 8B (ddp + fsdp)": match_llama_3_8b_ddp_fsdp, } @@ -99,6 +109,7 @@ def match_llama_3_8b_ddp_fsdp(row): "Llama 3.0 8B (2D sharding)": "llama-3-8b-2d", "Mixtral 8x7B": "mixtral-8x7b", "Llama 3.0 8B (2 Slice)": "llama-3-8b-2-slice", + "Llama 3.0 8B SFT": "llama-3-8b-sft", "Llama 3.0 8B (ddp + fsdp)": "llama-3-8b-ddp-fsdp", } """Mapping from the benchmark name to the ID of the E2E test step used in GitHub Actions.""" diff --git a/torchprime/data/__init__.py b/torchprime/data/__init__.py index eb396910..b9e16168 100644 --- a/torchprime/data/__init__.py +++ b/torchprime/data/__init__.py @@ -5,7 +5,13 @@ from .dataset import make_train_dataset from .sft_dataset import make_sft_dataset +DATASET_BUILDERS = { + "train": make_train_dataset, + "sft": make_sft_dataset, +} + __all__ = [ + "DATASET_BUILDERS", "make_train_dataset", "make_sft_dataset", ] diff --git a/torchprime/data/sft_dataset.py b/torchprime/data/sft_dataset.py index fc9c8f5c..5c68bec4 100644 --- a/torchprime/data/sft_dataset.py +++ b/torchprime/data/sft_dataset.py @@ -52,15 +52,19 @@ def _tokenize_prompt_completion( Mapping with ``input_ids`` and ``labels`` suitable for training. """ - if "prompt" in example or "question" in example: + if "prompt" in example and "completion" in example: prompt = example.get("prompt", "") completion = example.get("completion", "") - elif "question" in example or "answer" in example: + elif "question" in example and "answer" in example: prompt = example.get("question", "") + prompt = f"Question:\n{prompt}\n\n\nAnswer:\n" # Add format for q-a pair completion = example.get("answer", "") elif "text" in example: prompt = "" completion = example["text"] + elif "completion" in example: + prompt = "" + completion = example["completion"] else: raise ValueError( "Invalid input format: must contain 'prompt'/'completion' or 'question'/'answer' or 'text' fields." diff --git a/torchprime/metrics/step_duration.py b/torchprime/metrics/step_duration.py index 460773e3..f37cdab3 100644 --- a/torchprime/metrics/step_duration.py +++ b/torchprime/metrics/step_duration.py @@ -3,12 +3,15 @@ """ import glob +import logging import os import statistics import sys from torchprime.metrics.xplane_pb2 import XSpace # type: ignore +logger = logging.getLogger(__name__) + def step_duration_from_latest_profile(profile_dir: str) -> float: profile_dir = os.path.abspath(profile_dir) @@ -66,9 +69,13 @@ def analyze_step_duration_from_pb(xspace: XSpace) -> float: # Confirm we have exactly one unique event name if len(unique_names) > 1: - raise ValueError(f"Ambiguous event names found in XSpace: {unique_names}") + logger.warning( + f"Multiple event names found in XSpace: {unique_names}.\n" + "Using the one with max graph nodes for duration calculation." + ) + + inferred_event_name = max(unique_names) - inferred_event_name = list(unique_names)[0] # Sort offsets to compute consecutive differences offsets.sort() diff --git a/torchprime/metrics/tests/test_step_duration.py b/torchprime/metrics/tests/test_step_duration.py index 074d4ca1..1346ddd2 100644 --- a/torchprime/metrics/tests/test_step_duration.py +++ b/torchprime/metrics/tests/test_step_duration.py @@ -88,8 +88,11 @@ def test_conflicting_step_names(): event.duration_ps = int(2e12) temp.write(xspace.SerializeToString()) temp.flush() - with pytest.raises(ValueError, match="Ambiguous"): - analyze_step_duration(temp.name) + # with pytest.raises(ValueError, match="Ambiguous"): + # analyze_step_duration(temp.name) + + # Temperarily allow multiple profile names, checkout issue #260 + assert analyze_step_duration(temp.name) == 1.0 def test_real_profile(): diff --git a/torchprime/torch_xla_models/configs/dataset/alpaca.yaml b/torchprime/torch_xla_models/configs/dataset/alpaca.yaml new file mode 100644 index 00000000..1f1c6741 --- /dev/null +++ b/torchprime/torch_xla_models/configs/dataset/alpaca.yaml @@ -0,0 +1,10 @@ +# Dataset configuration for supervised fine-tuning using the Alpaca dataset +hf_dataset_name: tatsu-lab/alpaca +hf_dataset_config_name: null +split: train +block_size: 8192 +cache_dir: /tmp/ +format: prompt_completion +compute_loss_on: completion +pack_samples: true +truncation: right diff --git a/torchprime/torch_xla_models/configs/dataset/gsm8k.yaml b/torchprime/torch_xla_models/configs/dataset/gsm8k.yaml new file mode 100644 index 00000000..cbff1825 --- /dev/null +++ b/torchprime/torch_xla_models/configs/dataset/gsm8k.yaml @@ -0,0 +1,10 @@ +# Dataset configuration for supervised fine-tuning using the GSM8k dataset +hf_dataset_name: gsm8k +hf_dataset_config_name: main +split: train +block_size: 256 +cache_dir: /tmp/ +format: prompt_completion +compute_loss_on: completion +pack_samples: false +truncation: drop diff --git a/torchprime/torch_xla_models/configs/model/llama-1b-random-for-test.yaml b/torchprime/torch_xla_models/configs/model/llama-1b-random-for-test.yaml index 68156cf6..2c82b810 100644 --- a/torchprime/torch_xla_models/configs/model/llama-1b-random-for-test.yaml +++ b/torchprime/torch_xla_models/configs/model/llama-1b-random-for-test.yaml @@ -1,5 +1,6 @@ model_id: llama-1b-random-for-test model_class: llama.LlamaForCausalLM # Used to import the model from this class +pretrained_model: hf-internal-testing/tiny-random-LlamaForCausalLM vocab_size: 32000 hidden_size: 16 intermediate_size: 64 diff --git a/torchprime/torch_xla_models/configs/model/llama-3-8b.yaml b/torchprime/torch_xla_models/configs/model/llama-3-8b.yaml index 51272561..47e5a5a7 100644 --- a/torchprime/torch_xla_models/configs/model/llama-3-8b.yaml +++ b/torchprime/torch_xla_models/configs/model/llama-3-8b.yaml @@ -5,6 +5,7 @@ defaults: model_id: llama-3-8b model_class: llama.LlamaForCausalLM # Used to import the model from this class +pretrained_model: null vocab_size: 128256 hidden_size: 4096 intermediate_size: 14336 diff --git a/torchprime/torch_xla_models/configs/sft_w_gsm8k.yaml b/torchprime/torch_xla_models/configs/sft_w_gsm8k.yaml new file mode 100644 index 00000000..a662ab9f --- /dev/null +++ b/torchprime/torch_xla_models/configs/sft_w_gsm8k.yaml @@ -0,0 +1,17 @@ +# Configuration for supervised fine-tuning using the GSM8k dataset +# Overrides the default dataset and task while reusing the default model + +defaults: + - default # Refers to configs/default.yaml + - override model: llama-3-8b + - override dataset: gsm8k + - override task: sft + - _self_ + +task: + # don't convert the checkpoint to safetensors to save space/time + convert_to_safetensors: False + +model: + # pretrained checkpoint to use for supervised fine-tuning + pretrained_model: meta-llama/Meta-Llama-3-8B \ No newline at end of file diff --git a/torchprime/torch_xla_models/configs/task/sft.yaml b/torchprime/torch_xla_models/configs/task/sft.yaml new file mode 100644 index 00000000..49aa7cd2 --- /dev/null +++ b/torchprime/torch_xla_models/configs/task/sft.yaml @@ -0,0 +1,14 @@ +# Task configuration for supervised fine-tuning +name: sft +global_batch_size: 64 +max_steps: 100 +export_checkpoint_path: export +convert_to_safetensors: True +max_grad_norm: 1.0 +max_grad_value: null +optimizer: + learning_rate: 4.e-5 + type: adafactor +lr_scheduler: + type: linear + warmup_steps: 10 diff --git a/torchprime/torch_xla_models/model/base_causal_lm.py b/torchprime/torch_xla_models/model/base_causal_lm.py index cd8906c6..5c9293db 100644 --- a/torchprime/torch_xla_models/model/base_causal_lm.py +++ b/torchprime/torch_xla_models/model/base_causal_lm.py @@ -7,6 +7,11 @@ import json import os +import shutil +import subprocess +from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path import torch import torch.nn as nn @@ -57,7 +62,13 @@ def load_safetensors_to_state_dict(model_dir: str) -> dict: return state_dict -def save_sharded_safetensors_by_layer(state_dict: dict, save_dir: str): +def save_sharded_safetensors_by_layer( + state_dict: dict[str, "torch.Tensor"], + save_dir: str | os.PathLike, + *, + max_workers: int = 24, + tmp_dir: str | os.PathLike | None = None, +): """Save a model state dict to sharded safetensors by layer prefix. This function saves the model's state dictionary into separate sharded files, @@ -67,21 +78,82 @@ def save_sharded_safetensors_by_layer(state_dict: dict, save_dir: str): Args: state_dict (dict): The model's state dictionary to be saved. save_dir (str): Directory where the sharded safetensors and index file will be saved. + max_workers: Parallel writer threads. 24 saturates v6e-4 dual NVMe; tune as needed. + tmp_dir: If given, write shards to this *local* directory first, then + `gsutil -m cp` the results to ``save_dir``. Handy when `save_dir` is a + Cloud-Storage mount and you want full NVMe speed. """ + save_dir = Path(save_dir) + if tmp_dir: + tmp_dir = Path(tmp_dir) + tmp_dir.mkdir(parents=True, exist_ok=True) + work_dir = tmp_dir + else: + save_dir.mkdir(parents=True, exist_ok=True) + work_dir = save_dir + work_dir.mkdir(parents=True, exist_ok=True) + + def _shard_key(param_name: str) -> str: + """Determine the shard key for a parameter based on its name.""" + parts = param_name.split(".") + if parts[:2] == ["model", "layers"] and parts[2].isdigit(): + return "_".join(parts[:3]) # model_layers_0, model_layers_1, etc. + elif parts[0] == "model": + return "model" + elif parts[0] == "lm_head": + return "lm_head" + else: + return "other" - os.makedirs(save_dir, exist_ok=True) - grouped = {} + grouped: dict[str, dict[str, torch.Tensor]] = defaultdict(dict) + sizes: dict[str, int] = {} for k, v in state_dict.items(): - prefix = k.split(".")[0] - grouped.setdefault(prefix, {})[k] = v - weight_map = {} - for prefix, group in grouped.items(): - shard_file = f"{prefix}.safetensors" - shard_path = os.path.join(save_dir, shard_file) - save_file(group, shard_path) - weight_map.update({k: shard_file for k in group}) - with open(os.path.join(save_dir, "model.safetensors.index.json"), "w") as f: - json.dump({"weight_map": weight_map}, f, indent=2) + p = _shard_key(k) + grouped[p][k] = v + sizes[p] = sizes.get(p, 0) + v.numel() * v.element_size() + + def _write_one(item: tuple[str, dict[str, "torch.Tensor"]]) -> dict[str, str]: + prefix, group = item + fname = f"{prefix}.safetensors" + save_file(group, str(work_dir / fname)) + # strip FSDP suffix for HF compatibility + return {k.replace("._orig_mod", ""): fname for k in group} + + # sort largest → smallest so threads finish together + items = sorted(grouped.items(), key=lambda kv: sizes[kv[0]], reverse=True) + + weight_map: dict[str, str] = {} + with ThreadPoolExecutor(max_workers=max_workers) as pool: + for mapping in pool.map(_write_one, items): + weight_map.update(mapping) + + # ---------- dump index -------------------------------------------- + (work_dir / "model.safetensors.index.json").write_text( + json.dumps({"weight_map": weight_map}, indent=2) + ) + + # ---------- optional gsutil sync --------------------------------- + if tmp_dir: + save_dir.mkdir(parents=True, exist_ok=True) + cmd = [ + "gsutil", + "-m", + "-q", + "cp", + "-n", # don't clobber if file exists + *(str(p) for p in work_dir.glob("*.safetensors")), + str(save_dir) + "/", + ] + cmd_idx = [ + "gsutil", + "-q", + "cp", + str(work_dir / "model.safetensors.index.json"), + str(save_dir) + "/", + ] + subprocess.check_call(cmd) + subprocess.check_call(cmd_idx) + shutil.rmtree(work_dir, ignore_errors=True) class BaseCausalLM(nn.Module): diff --git a/torchprime/torch_xla_models/model/model_utils.py b/torchprime/torch_xla_models/model/model_utils.py index e066ffd6..46d01d39 100644 --- a/torchprime/torch_xla_models/model/model_utils.py +++ b/torchprime/torch_xla_models/model/model_utils.py @@ -1,6 +1,7 @@ """Utility function(s) for model initialization.""" import importlib +import logging import re import sys from contextlib import contextmanager @@ -78,3 +79,41 @@ def extract_model_size_from_model_name(model_name: str) -> int | float: except ValueError: return -1 return -1 + + +def log_parameter_breakdown(model: torch.nn.Module, logger: logging.Logger) -> None: + """Logs the number of parameters in different components of the model. + + Args: + model: The PyTorch model. + logger: A logger instance to write the output to. + """ + total_params = sum(p.numel() for p in model.parameters()) + logger.info("Model total size: {} parameters".format(f"{total_params:,}")) + + param_groups = { + "mlp": 0, + "attention": 0, + "embedding": 0, + "lm_head": 0, + "norm": 0, + "other": 0, + } + + for name, param in model.named_parameters(): + if "mlp" in name: + param_groups["mlp"] += param.numel() + elif "self_attn" in name or "attention" in name: + param_groups["attention"] += param.numel() + elif "embed" in name: + param_groups["embedding"] += param.numel() + elif "lm_head" in name: + param_groups["lm_head"] += param.numel() + elif "norm" in name or "layernorm" in name: + param_groups["norm"] += param.numel() + else: + param_groups["other"] += param.numel() + + for k, v in param_groups.items(): + percentage = (v / total_params) * 100 + logger.info(" {:10s}: {} params ({:.2f}%)".format(k, f"{v:,}", percentage)) diff --git a/torchprime/torch_xla_models/tests/test_sft_trainer.py b/torchprime/torch_xla_models/tests/test_sft_trainer.py new file mode 100644 index 00000000..eb84b113 --- /dev/null +++ b/torchprime/torch_xla_models/tests/test_sft_trainer.py @@ -0,0 +1,142 @@ +"""Tests for the :class:`SFTTrainer` class.""" + +import numpy as np +import pytest +import torch +import torch.nn as nn +import torch_xla.core.xla_model as xm +from omegaconf import OmegaConf +from torch.utils.data import Dataset + +from torchprime.torch_xla_models.trainer.sft_trainer import SFTTrainer + + +class DummyModel(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(4, 2) + self.loaded = False + self.saved = False + + def forward(self, input_ids=None, attention_mask=None, **kwargs): + logits = self.linear(input_ids) + loss = logits.mean() + return logits, loss + + def from_pretrained(self, path): + self.loaded = True + + def export(self, path): + self.saved = True + + +class DummyDataset(Dataset): + def __init__(self): + self.device = xm.xla_device() + + def __getitem__(self, idx): + return { + "input_ids": torch.ones(4, device=self.device), + "attention_mask": torch.ones(4, device=self.device), + } + + def __len__(self): + return 8 + + +class FakeMesh: + def __init__(self): + self.device_ids = [0] + self.axis_names = ("data", "fsdp") + self.mesh_shape = (1, 1) + + def shape(self): + return {"data": 1, "fsdp": 1} + + def get_axis_name_idx(self, axis_name): + return self.axis_names.index(axis_name) + + def get_logical_mesh(self): + return np.array(self.device_ids).reshape(self.mesh_shape) + + +@pytest.fixture +def dummy_config(): + return OmegaConf.create( + { + "model": { + "remat": { + "activation_checkpoint_layers": [], + "optimization_barrier_layers": [], + "scan_layers": None, + "offload_tensors": [], + }, + "sharding": {"type": "spmd"}, + "pretrained_model": "dummy", + }, + "data": {"name": "dummy_dataset", "block_size": 4}, + "task": { + "name": "sft", + "global_batch_size": 4, + "max_steps": 1, + "max_grad_norm": None, + "max_grad_value": None, + "export_checkpoint_path": "dummy_export_path", + "optimizer": {"type": "adafactor", "learning_rate": 1e-3}, + "lr_scheduler": {"type": "constant", "warmup_steps": 0}, + }, + "run_name": None, + "output_dir": "/tmp/test_output", + "logging_steps": 1, + "profile_step": -1, + "profile_dir": "/tmp/profile", + "profile_duration": 5, + "ici_mesh": {"data": 1, "fsdp": 1, "tensor": 1}, + "dcn_mesh": {}, + } + ) + + +def test_load_and_save(monkeypatch, dummy_config): + from torchprime.torch_xla_models.model_rewriting import sharding_initialization + + # Patch mesh setup + monkeypatch.setattr( + sharding_initialization, "get_mesh", lambda *args, **kwargs: FakeMesh() + ) + monkeypatch.setattr( + sharding_initialization, + "shard_torch_xla_model_from_config", + lambda model, *args, **kwargs: model, + ) + + # Patch process index and count + monkeypatch.setattr("torch_xla.runtime.process_index", lambda: 0) + monkeypatch.setattr("torch_xla.runtime.process_count", lambda: 1) + + # Patch xm.save to track call + saved = {} + + def fake_save(state_dict, *args, **kwargs): + saved["called"] = True + saved["state_dict"] = state_dict + + monkeypatch.setattr( + "torchprime.torch_xla_models.trainer.sft_trainer.dist_cp.save", + fake_save, + ) + # Initialize + device = xm.xla_device() + model = DummyModel().to(device) + dataset = DummyDataset() + trainer = SFTTrainer(model, dummy_config, dataset) + + # from_pretrained should mark model as loaded + assert model.loaded is True + + # Train (1 step), should trigger save + trainer.train_loop() + + # Save should have occurred + assert "called" in saved and saved["called"] is True + assert isinstance(saved["state_dict"], dict) diff --git a/torchprime/torch_xla_models/tests/test_trainer.py b/torchprime/torch_xla_models/tests/test_trainer.py index a97d9b2a..dcb1ae90 100644 --- a/torchprime/torch_xla_models/tests/test_trainer.py +++ b/torchprime/torch_xla_models/tests/test_trainer.py @@ -19,7 +19,6 @@ from omegaconf import OmegaConf from torch.utils.data import Dataset -from torchprime.metrics.metrics import MetricsLogger from torchprime.torch_xla_models.trainer.base_trainer import Trainer @@ -151,7 +150,7 @@ def counting_train_step(self, batch): monkeypatch.setattr(Trainer, "train_step", counting_train_step) - trainer.train_loop(metrics_logger=MetricsLogger()) + trainer.train_loop() assert call_counter["steps"] == dummy_config.task.max_steps @@ -341,7 +340,7 @@ def fake_stop(): dataset = DummyDataset() trainer = Trainer(model, dummy_config, dataset) - trainer.train_loop(metrics_logger=MetricsLogger()) + trainer.train_loop() assert dummy_config.profile_start_step == 0 assert dummy_config.profile_end_step == 1 diff --git a/torchprime/torch_xla_models/train.py b/torchprime/torch_xla_models/train.py index 72f02ac3..e37c28b4 100644 --- a/torchprime/torch_xla_models/train.py +++ b/torchprime/torch_xla_models/train.py @@ -15,13 +15,14 @@ set_seed, ) -from torchprime.data.dataset import make_train_dataset +from torchprime.data import DATASET_BUILDERS, make_train_dataset from torchprime.metrics.metrics import MetricsLogger from torchprime.torch_xla_models.model.model_utils import ( initialize_model_class, + log_parameter_breakdown, set_default_dtype, ) -from torchprime.torch_xla_models.trainer.base_trainer import Trainer +from torchprime.torch_xla_models.trainer import TRAINERS, Trainer from torchprime.utils.retry import retry transformers.utils.check_min_version("4.39.3") @@ -63,12 +64,19 @@ def main(config: DictConfig): with set_default_dtype(model_dtype), torch_xla.device(): model = initialize_model_class(config.model) - n_params = sum([p.numel() for p in model.parameters()]) - logger.info(f"Training new model from scratch - Total size={n_params} params") + log_parameter_breakdown(model, logger) - # Downloading and loading a dataset from the hub. - data = retry(lambda: make_train_dataset(**config.dataset, tokenizer=tokenizer)) - trainer = Trainer( + # Select dataset builder and trainer based on the task name. + dataset_fn = DATASET_BUILDERS.get(config.task.name, make_train_dataset) + trainer_cls = TRAINERS.get(config.task.name, Trainer) + data = retry(lambda: dataset_fn(**config.dataset, tokenizer=tokenizer)) + + dataset_name = getattr(config.dataset, "hf_dataset_name", None) or getattr( + config.dataset, "file_dataset_path", "unknown" + ) + logger.info("Loaded dataset `%s`, size=%d (packed) samples", dataset_name, len(data)) + + trainer = trainer_cls( model=model, config=config, train_dataset=data, @@ -76,7 +84,8 @@ def main(config: DictConfig): # TODO(https://github.com/pytorch/xla/issues/8954): Remove `jax_env_context`. with jax_env_context(): - trainer.train_loop(metrics_logger) + trainer.train_loop() + trainer.finalize_training(metrics_logger) if __name__ == "__main__": diff --git a/torchprime/torch_xla_models/trainer/__init__.py b/torchprime/torch_xla_models/trainer/__init__.py new file mode 100644 index 00000000..4ed60df2 --- /dev/null +++ b/torchprime/torch_xla_models/trainer/__init__.py @@ -0,0 +1,15 @@ +"""Trainer module for Torch XLA models.""" + +from .base_trainer import Trainer +from .sft_trainer import SFTTrainer + +TRAINERS = { + "train": Trainer, + "sft": SFTTrainer, +} + +__all__ = [ + "TRAINERS", + "Trainer", + "SFTTrainer", +] diff --git a/torchprime/torch_xla_models/trainer/base_trainer.py b/torchprime/torch_xla_models/trainer/base_trainer.py index 1d060f29..19109cdf 100644 --- a/torchprime/torch_xla_models/trainer/base_trainer.py +++ b/torchprime/torch_xla_models/trainer/base_trainer.py @@ -201,13 +201,14 @@ def _log_to_tensorboard( self.summary_writer.add_scalar("train/grad_norm", grad_norm, step) self.summary_writer.flush() - def train_loop(self, metrics_logger) -> None: + def train_loop(self) -> None: self.model.train() self.model.zero_grad() # For now we assume that we will never train for more than one epoch max_step = self.config.task.max_steps train_loader = self._get_train_dataloader() + steps_per_epoch = len(train_loader) train_iterator = iter(train_loader) logger.info("Starting training") @@ -236,8 +237,8 @@ def step_closure( loss = loss.detach().item() grad_norm = grad_norm.detach().item() logger.info( - "Epoch: %d, step: %d, loss: %.4f, grad_norm: %.4f, lr: %.2e, trace time: %.2f ms", - epoch, + "Epoch: %.4f, step: %d, loss: %.4f, grad_norm: %.4f, lr: %.2e, trace time: %.2f ms", + step / steps_per_epoch, step, loss, grad_norm, @@ -278,6 +279,9 @@ def step_closure( xm.wait_device_ops() logger.info("Finished training run") + def finalize_training(self, metrics_logger) -> None: + """Finalize training by processing profiling output and logging metrics.""" + if self.config.profile_start_step >= 0 and self.config.profile_end_step >= 0: # Analyze the step duration from the latest profile step_duration = step_duration_from_latest_profile(self.config.profile_dir) diff --git a/torchprime/torch_xla_models/trainer/sft_trainer.py b/torchprime/torch_xla_models/trainer/sft_trainer.py new file mode 100644 index 00000000..471b55c7 --- /dev/null +++ b/torchprime/torch_xla_models/trainer/sft_trainer.py @@ -0,0 +1,155 @@ +"""Trainer for supervised fine-tuning (SFT) tasks.""" + +from __future__ import annotations + +import logging +import multiprocessing as mp +import tempfile +import time +from pathlib import Path + +import torch +import torch.distributed as dist +import torch.distributed.checkpoint as dist_cp +import torch_xla.core.xla_model as xm +import torch_xla.experimental.distributed_checkpoint as xc +import torch_xla.runtime as xr +from omegaconf import DictConfig +from torch import nn +from torch.distributed.checkpoint import FileSystemReader, FileSystemWriter + +from torchprime.torch_xla_models.model.base_causal_lm import ( + save_sharded_safetensors_by_layer, +) + +from .base_trainer import Trainer + +logger = logging.getLogger(__name__) + + +class SFTTrainer(Trainer): + """Trainer with pretrained weight loading and saving support.""" + + def __init__( + self, + model: nn.Module, + config: DictConfig, + train_dataset, + ) -> None: + """Initialize trainer and optionally load pretrained weights. + + Args: + model: Model instance to train. + config: Hydra configuration object. + train_dataset: Dataset used for training. + """ + + self.pretrained_model = getattr(config.model, "pretrained_model", None) + + if self.pretrained_model: + if xr.process_index() == 0: + logger.info("Loading model weights from %s", self.pretrained_model) + model.from_pretrained(self.pretrained_model) + xm.mark_step() + else: + logger.info( + "No pretrained model specified; training from scratch. \n\nIs this what you intended?\n" + ) + + super().__init__(model, config, train_dataset) + + def train_loop(self) -> None: + """Run the base training loop and export the model. + + Args: + metrics_logger: Instance used to record metrics during training. + """ + super().train_loop() + + t0 = time.perf_counter() + logger.info("[SAVING] Starting distributed checkpoint …") + self._maybe_save_model() + dt = time.perf_counter() - t0 + logger.info("[SAVING] Finished in %.2f s", dt) + + def _maybe_save_model(self) -> None: + """Save a sharded checkpoint with torch.distributed.checkpoint. + + Call **once** on all TPU ranks at the end of training. + + • All ranks write a sharded *Distributed Checkpoint* (fast) to + ``//`` + • Optionally, Rank-0 immediately reloads that checkpoint on CPU and emits + Hugging-Face-compatible `*.safetensors` shards + index. + """ + folder_name = getattr(self.config.task, "export_checkpoint_path", None) + if folder_name is None: + logger.info("Skipping model export, no export_checkpoint_path provided.") + return + + save_dir = Path(self.config.output_dir) / folder_name + save_dir.mkdir(parents=True, exist_ok=True) + + # Make sure pending device ops are flushed + xm.mark_step() + xm.wait_device_ops() + + # torch.distributed.checkpoint requires a torch.distributed process group + # even when used on TPUs + # Ensure a torch.distributed PG exists + if not dist.is_initialized(): + xr.use_spmd() + + dist.init_process_group("gloo", init_method="xla://") + + # -------------------------- 1 · fast distributed save ------------- + state_dict = {"model": self.model.state_dict()} + + dist_cp.save( + state_dict=state_dict, + storage_writer=FileSystemWriter( + str(save_dir), thread_count=max(2, min(8, mp.cpu_count())) + ), + planner=xc.SPMDSavePlanner(), + ) + logger.info("DCP checkpoint written to %s", save_dir) + + # -------------------------- 2 · CPU safetensor conversion (rank-0) - + convert_to_safetensors = getattr(self.config.task, "convert_to_safetensors", False) + + if convert_to_safetensors and xr.process_index() == 0: + logger.info("Rank-0: reloading checkpoint for safetensors export …") + + # build placeholder dict purely from names (no device copies) + reload_sd = { + "model": { + name: torch.empty(tensor.shape, dtype=tensor.dtype, device="cpu") + for name, tensor in state_dict["model"].items() + } + } + + dist_cp.load( + state_dict=reload_sd, + storage_reader=FileSystemReader(str(save_dir)), + planner=xc.SPMDLoadPlanner(), + ) + logger.info("Checkpoint fully materialised on CPU") + + cpu_state = { + k.replace("._orig_mod", ""): v for k, v in reload_sd["model"].items() + } + + try: + tmp_dir = tempfile.mkdtemp(dir="/mnt/localssd") + logger.info("Using local SSD for safetensors shards: %s", tmp_dir) + except (FileNotFoundError, PermissionError): + tmp_dir = tempfile.mkdtemp() + logger.info("Using default temp directory for safetensors shards: %s", tmp_dir) + + save_sharded_safetensors_by_layer(cpu_state, str(save_dir), tmp_dir=tmp_dir) + + logger.info("Safetensors shards + index written to %s", save_dir) + + # -------------------------- 3 · barrier so other ranks wait -------- + if xr.process_count() > 1: + xm.rendezvous("sft_save")