Skip to content

Add SFT trainer and sft task #284

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 42 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
1cb25b5
ini
jialei777 Jun 5, 2025
fb44eb9
fix model loading
jialei777 Jun 6, 2025
64a965e
fix saving and laoding
jialei777 Jun 6, 2025
36dc58b
enable gsm8k
jialei777 Jun 6, 2025
d42bd7e
fix
jialei777 Jun 6, 2025
f40e181
add e2e test
jialei777 Jun 6, 2025
e209ae5
fix
jialei777 Jun 6, 2025
84fba0b
merge main
jialei777 Jun 6, 2025
b6659a6
format
jialei777 Jun 6, 2025
b682632
fix e2e test
jialei777 Jun 6, 2025
a441249
fic
jialei777 Jun 6, 2025
0ee6fb6
fic
jialei777 Jun 6, 2025
ae981bf
fix e2e test?
jialei777 Jun 6, 2025
cd11f01
fix e2e test?
jialei777 Jun 6, 2025
534b4c9
remove model savibng in the end
jialei777 Jun 7, 2025
a99f438
test
jialei777 Jun 7, 2025
cd467a5
?
jialei777 Jun 7, 2025
9a94268
?
jialei777 Jun 7, 2025
4b11837
update
jialei777 Jun 7, 2025
649da66
merge main
jialei777 Jun 9, 2025
2c5c04c
fix unittest
jialei777 Jun 9, 2025
542d5e8
update
jialei777 Jun 9, 2025
f8ebb2a
minor fix
jialei777 Jun 9, 2025
7a1637d
fix e2e test
jialei777 Jun 9, 2025
974d228
update
jialei777 Jun 9, 2025
539a74e
update
jialei777 Jun 9, 2025
a4211ed
a
jialei777 Jun 9, 2025
870ab16
a
jialei777 Jun 9, 2025
20dd088
remove rendezvous
jialei777 Jun 9, 2025
901ed44
remove @torch_xla.compile(full_graph=True)
jialei777 Jun 9, 2025
71040da
fix
jialei777 Jun 9, 2025
4ed5df2
a
jialei777 Jun 9, 2025
17a551e
add back torch_xla.compile(full_graph=True)
jialei777 Jun 10, 2025
2d6df51
remove save in e2e
jialei777 Jun 10, 2025
ecb5c8d
save but not profile
jialei777 Jun 10, 2025
da6d248
update saving
jialei777 Jun 10, 2025
abc9290
format
jialei777 Jun 10, 2025
e1919ae
magic
jialei777 Jun 10, 2025
4336c15
update
jialei777 Jun 10, 2025
7daa9f4
merge main
jialei777 Jun 10, 2025
3443898
update
jialei777 Jun 10, 2025
9ce7458
make saving faster
jialei777 Jun 11, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions .github/workflows/e2e_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ jobs:
llama-3_1-8b-scan-offload-name: ${{ steps.run-llama-3_1-8b-scan-offload.outputs.name }}
llama-3-8b-2d-name: ${{ steps.run-llama-3-8b-2d.outputs.name }}
llama-3-8b-2-slice-name: ${{ steps.run-llama-3-8b-2-slice.outputs.name }}
llama-3-8b-sft-name: ${{ steps.run-llama-3-8b-sft.outputs.name }}
mixtral-8x7b-name: ${{ steps.run-mixtral-8x7b.outputs.name }}
artifact-dir: ${{ steps.artifacts.outputs.artifact_dir }}
steps:
Expand Down Expand Up @@ -187,6 +188,25 @@ jobs:
task.max_steps=15 \
dcn_mesh.fsdp=2 \
ici_mesh.fsdp=4 \
profile_step=3

- name: Run Llama 3.0 8B SFT
id: run-llama-3-8b-sft
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
XLA_IR_DEBUG: 1
XLA_HLO_DEBUG: 1
run: |
name=$(e2e_testing/gen_name.py llama-3-8b-sft)
echo "name=$name" >> "$GITHUB_OUTPUT"
tp run ${{ steps.docker-url-option.outputs.value }} \
--name $name \
torchprime/torch_xla_models/train.py \
--config-name sft_w_gsm8k \
ici_mesh.fsdp=4 \
task.max_steps=20 \
task.global_batch_size=16 \
task.convert_to_safetensors=False \
profile_start_step=3

# Load reference step times
Expand Down Expand Up @@ -232,6 +252,7 @@ jobs:
matrix.config.benchmark == 'llama-3_1-8b-scan-offload' && needs.tp-run.outputs.llama-3_1-8b-scan-offload-name ||
matrix.config.benchmark == 'llama-3-8b-2d' && needs.tp-run.outputs.llama-3-8b-2d-name ||
matrix.config.benchmark == 'mixtral-8x7b' && needs.tp-run.outputs.mixtral-8x7b-name ||
matrix.config.benchmark == 'llama-3-8b-sft' && needs.tp-run.outputs.llama-3-8b-sft-name ||
matrix.config.benchmark == 'llama-3-8b-2-slice' && needs.tp-run.outputs.llama-3-8b-2-slice-name
}}
artifact_dir: ${{ needs.tp-run.outputs.artifact-dir }}
Expand Down
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,15 @@ python3 torchprime/torch_xla_models/train.py \

You may refer to the hydra docs for other ways to specify configs.

To fine-tune a pretrained model using the gsm8k (Grade School Math question-answer) dataset, run

```sh
python3 torchprime/torch_xla_models/train.py --config-name sft_w_gsm8k
```

This uses the `sft_w_gsm8k.yaml` config which selects the SFT trainer and
dataset automatically.

### Multi-VM distributed training

`torchprime` uses [xpk][xpk] as the standard path for iterating on distributed
Expand Down
7 changes: 7 additions & 0 deletions e2e_testing/step_time_bounds.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,13 @@ benchmarks:
confidence_interval: 0.12888
average: 3.9587
sample_size: 416
llama-3-8b-sft:
name: Llama 3.0 8B SFT
step_time_lower_bound: 0 # some random number, will be replaced by actual values later
step_time_upper_bound: 1 # some random number
confidence_interval: 0.5 # some random number
average: 0.5 # some random number
sample_size: 123 # some random number
metadata:
query_start: '2025-05-26T18:37:58.674556-07:00'
query_end: '2025-06-05T18:37:58-07:00'
Expand Down
11 changes: 11 additions & 0 deletions e2e_testing/update_step_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,23 @@ def match_llama_3_8b_2_slice(row):
)


def match_llama_3_8b_sft(row):
config = json.loads(row.configs_framework)
return (
row.run_id.startswith("llama-3-8b-sft")
and config["dcn_mesh"]["fsdp"] == 1
and config["ici_mesh"]["tensor"] == 1
)


BENCHMARKS = {
"Llama 3.0 8B": match_llama3_8b,
"Llama 3.1 8B (Splash Attention)": match_llama3_1_8b_sa,
"Llama 3.1 8B (Scan + Offload)": match_llama3_1_8b_scan_offload,
"Llama 3.0 8B (2D sharding)": match_llama3_8b_2d,
"Mixtral 8x7B": match_mixtral,
"Llama 3.0 8B (2 Slice)": match_llama_3_8b_2_slice,
"Llama 3.0 8B SFT": match_llama_3_8b_sft,
}

STEP_ID_MAPPING = {
Expand All @@ -89,6 +99,7 @@ def match_llama_3_8b_2_slice(row):
"Llama 3.0 8B (2D sharding)": "llama-3-8b-2d",
"Mixtral 8x7B": "mixtral-8x7b",
"Llama 3.0 8B (2 Slice)": "llama-3-8b-2-slice",
"Llama 3.0 8B SFT": "llama-3-8b-sft",
}
"""Mapping from the benchmark name to the ID of the E2E test step used in GitHub Actions."""

Expand Down
6 changes: 6 additions & 0 deletions torchprime/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@
from .dataset import make_train_dataset
from .sft_dataset import make_sft_dataset

DATASET_BUILDERS = {
"train": make_train_dataset,
"sft": make_sft_dataset,
}

__all__ = [
"DATASET_BUILDERS",
"make_train_dataset",
"make_sft_dataset",
]
8 changes: 6 additions & 2 deletions torchprime/data/sft_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,19 @@ def _tokenize_prompt_completion(
Mapping with ``input_ids`` and ``labels`` suitable for training.
"""

if "prompt" in example or "question" in example:
if "prompt" in example and "completion" in example:
prompt = example.get("prompt", "")
completion = example.get("completion", "")
elif "question" in example or "answer" in example:
elif "question" in example and "answer" in example:
prompt = example.get("question", "")
prompt = f"Question:\n{prompt}\n\n\nAnswer:\n" # Add format for q-a pair
completion = example.get("answer", "")
elif "text" in example:
prompt = ""
completion = example["text"]
elif "completion" in example:
prompt = ""
completion = example["completion"]
else:
raise ValueError(
"Invalid input format: must contain 'prompt'/'completion' or 'question'/'answer' or 'text' fields."
Expand Down
11 changes: 9 additions & 2 deletions torchprime/metrics/step_duration.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@
"""

import glob
import logging
import os
import statistics
import sys

from torchprime.metrics.xplane_pb2 import XSpace # type: ignore

logger = logging.getLogger(__name__)


def step_duration_from_latest_profile(profile_dir: str) -> float:
profile_dir = os.path.abspath(profile_dir)
Expand Down Expand Up @@ -66,9 +69,13 @@ def analyze_step_duration_from_pb(xspace: XSpace) -> float:

# Confirm we have exactly one unique event name
if len(unique_names) > 1:
raise ValueError(f"Ambiguous event names found in XSpace: {unique_names}")
logger.warning(
f"Multiple event names found in XSpace: {unique_names}.\n"
"Using the one with max graph nodes for duration calculation."
)

inferred_event_name = max(unique_names)

inferred_event_name = list(unique_names)[0]
# Sort offsets to compute consecutive differences
offsets.sort()

Expand Down
7 changes: 5 additions & 2 deletions torchprime/metrics/tests/test_step_duration.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,11 @@ def test_conflicting_step_names():
event.duration_ps = int(2e12)
temp.write(xspace.SerializeToString())
temp.flush()
with pytest.raises(ValueError, match="Ambiguous"):
analyze_step_duration(temp.name)
# with pytest.raises(ValueError, match="Ambiguous"):
# analyze_step_duration(temp.name)

# Temperarily allow multiple profile names, checkout issue #260
assert analyze_step_duration(temp.name) == 1.0


def test_real_profile():
Expand Down
10 changes: 10 additions & 0 deletions torchprime/torch_xla_models/configs/dataset/alpaca.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Dataset configuration for supervised fine-tuning using the Alpaca dataset
hf_dataset_name: tatsu-lab/alpaca
hf_dataset_config_name: null
split: train
block_size: 8192
cache_dir: /tmp/
format: prompt_completion
compute_loss_on: completion
pack_samples: true
truncation: right
10 changes: 10 additions & 0 deletions torchprime/torch_xla_models/configs/dataset/gsm8k.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Dataset configuration for supervised fine-tuning using the Alpaca dataset
hf_dataset_name: gsm8k
hf_dataset_config_name: main
split: train
block_size: 256
cache_dir: /tmp/
format: prompt_completion
compute_loss_on: completion
pack_samples: false
truncation: drop
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
model_id: llama-1b-random-for-test
model_class: llama.LlamaForCausalLM # Used to import the model from this class
pretrained_model: hf-internal-testing/tiny-random-LlamaForCausalLM
vocab_size: 32000
hidden_size: 16
intermediate_size: 64
Expand Down
1 change: 1 addition & 0 deletions torchprime/torch_xla_models/configs/model/llama-3-8b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ defaults:

model_id: llama-3-8b
model_class: llama.LlamaForCausalLM # Used to import the model from this class
pretrained_model: null
vocab_size: 128256
hidden_size: 4096
intermediate_size: 14336
Expand Down
80 changes: 80 additions & 0 deletions torchprime/torch_xla_models/configs/sft_w_gsm8k.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Configuration for supervised fine-tuning using the Alpaca dataset
# Overrides the default dataset and task while reusing the default model

defaults:
- model: llama-3-8b
- dataset: gsm8k
- task: sft
- _self_

task:
global_batch_size: 64
max_steps: 100

seed: 42
logging_steps: 5
torch_dtype: bfloat16

# set profile_start_step to a positive integer to enable profiling and start profiling
# at that step. If profile_end_step is not set, profiling will continue until for
# num_profile_steps (default 20) training steps or total step - 5 (to avoid issue #260)
# Also, try to make number of profile profile steps >= 10
profile_start_step: 3
profile_end_step: null

# The directory where profiling data will be stored. This might be overwritten
# when using tp run to launch the run using XPK
profile_dir: profile

# This might be overwritten when using tp run to launch the run using XPK
output_dir: outputs

# The name of the training run as it shows up on tensorboard.
# If unspecified, defaults to the current date and time.
run_name: null

# The virtual device mesh shape to use within a TPU slice. This is also called
# the "ICI mesh", since devices within a slice enjoy a faster network called
# "Inter-Chip Interconnect".
ici_mesh:
data: 1
fsdp: 4
tensor: 1
expert: 1

# Shape of the logical mesh where each element is a TPU slice. This is called
# "Data Center Network (DCN) mesh" because TPU slices are usually connected
# together with slower data center networking, with the faster ICI network
# used within a slice.
#
# As an example, to enable 2-way data parallelism across 2 TPU slices, you may
# specify `dcn_mesh.data=2`.
dcn_mesh:
data: 1
fsdp: 1
tensor: 1
expert: 1

# These are default values for model activation rematerialization configuration.
# They can be overridden on the command line or by importing one of the presets
# in the `model/remat` directory.
model:
pretrained_model: meta-llama/Meta-Llama-3-8B
remat:
# The class names of model layers whose intermediate activations should be
# recomputed during the backward pass (i.e. activation checkpointing).
activation_checkpoint_layers: []

# If not null, compile a module of type `HomogeneousSequential` located at the
# given path in the module tree using `torch_xla.experimental.scan_layers`.
scan_layers: null

# If specified, offload these tensors to host RAM during the forward pass and
# move them back during the backward pass.
#
# The tensors to be offloaded should be given a name by wrapping them with the
# `torchprime.torch_xla_models.offloading.offload_name` call. Then the same
# name could be specified here to offload that tensor.
#
# Currently in order to offload tensors, `scan_layers` must also be enabled.
offload_tensors: []
14 changes: 14 additions & 0 deletions torchprime/torch_xla_models/configs/task/sft.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Task configuration for supervised fine-tuning
name: sft
global_batch_size: 16
max_steps: 20
export_checkpoint_path: export
convert_to_safetensors: True
max_grad_norm: 1.0
max_grad_value: null
optimizer:
learning_rate: 4.e-5
type: adafactor
lr_scheduler:
type: linear
warmup_steps: 10
Loading
Loading