Skip to content

Commit

Permalink
scan and apply_layers
Browse files Browse the repository at this point in the history
Add the lowering of scan to HLO While op.

Introduce apply_layers which can sequentially apply a bunch of layers
using scan underneath.

Beef up unit tests including linear layers and decoders.

add regression test for parameter_id_tensor_mapping

add test_apply_layers.py to test shell scripts

correctly import decoder model from examples
  • Loading branch information
tengyifei committed Nov 8, 2024
1 parent 81c4caa commit e6188d8
Show file tree
Hide file tree
Showing 12 changed files with 1,048 additions and 47 deletions.
8 changes: 4 additions & 4 deletions examples/decoder_only_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,16 @@
from torch import nn


# the default config is intentionally kept low to make it runable on a sigle tpu v2-8 core.
# the default config is intentionally kept low to make it runnable on a single tpu v2-8 core.
@dataclass
class DecoderOnlyConfig:
hidden_size: int = 1024
num_hidden_layers: int = 2
num_attention_heads: int = 8
num_key_value_heads: int = 4
intermediate_size = 32 * 1024
vocab_size = 3200
use_flash_attention = False
intermediate_size: int = 32 * 1024
vocab_size: int = 3200
use_flash_attention: bool = False


def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
Expand Down
1 change: 1 addition & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ function run_xla_op_tests2 {
run_test "$CDIR/pjrt/test_dtypes.py"
run_test "$CDIR/test_while_loop.py"
run_test "$CDIR/test_scan.py"
run_test "$CDIR/test_apply_layers.py"
run_test "$CDIR/test_autocast.py"
run_test "$CDIR/eager/test_eager.py"
run_test "$CDIR/eager/test_eager_with_xla_compile.py"
Expand Down
208 changes: 208 additions & 0 deletions test/test_apply_layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
import sys
import os
example_folder = os.path.dirname(os.path.dirname(os.path.abspath(
sys.argv[0]))) + "/examples"
sys.path.append(example_folder)
from decoder_only_model import DecoderOnlyConfig, DecoderOnlyModel # type:ignore

import unittest
from copy import deepcopy
from typing import Iterable

import torch
import torch.nn as nn

import torch_xla
from torch_xla.experimental.apply_layers import apply_layers

from test_utils import XlaTestCase # type:ignore


class ApplyLayersTest(XlaTestCase):

def setUp(self):
super().setUp()

self.device = torch_xla.device()

def assert_different_tensor(self, a: torch.Tensor, b: torch.Tensor):
assert a is not b, f"Expected {a} and {b} to be different tensors"
assert a.data is not b.data, f"Expected {a} and {b} to have different storage"

def assert_while_found_in_hlo(self, tensor: torch.Tensor):
hlo_text = torch_xla._XLAC._get_xla_tensors_hlo([tensor])
assert "while(" in hlo_text
assert "condition=" in hlo_text
assert "body=" in hlo_text

def test_empty_layers(self):
layers = []
input_data = torch.randn(64).to(self.device)
output = apply_layers(layers, input_data.clone())
super().compareResults(output, input_data, abs_err=0.0001, rel_err=0.001)

def test_linear_layers(self):
# Fix the random seed to avoid flakes.
with torch.random.fork_rng():
with torch_xla.xm.fork_rng():
torch.random.manual_seed(42)
torch_xla.xm.set_rng_state(42)
# We want to apply these layers sequentially
layers = [nn.Linear(64, 64).to(self.device) for _ in range(10)]
input_data = torch.randn(64).to(self.device)
torch_xla.sync(wait=True)

scan_layers = deepcopy(layers)
loop_layers = deepcopy(layers)
torch_xla.sync()

output = apply_layers(scan_layers, input_data.clone())
output.sum().backward()

# Test that the result is the same as for loop.
loop_output = input_data.clone()
for layer in loop_layers:
loop_output = layer(loop_output)
torch_xla.sync()

super().compareResults(loop_output, output, abs_err=0.0001, rel_err=0.001)
self.assert_different_tensor(loop_output, output)

loop_output.sum().backward()
torch_xla.sync()

# Test that the gradients are the same too.
for layer_scan, layer_loop in zip(scan_layers, loop_layers):
assert layer_scan.weight.grad is not None
assert layer_loop.weight.grad is not None
assert layer_scan.bias.grad is not None
assert layer_loop.bias.grad is not None
super().compareResults(
layer_scan.weight.grad,
layer_loop.weight.grad,
abs_err=0.0001,
rel_err=0.001)
super().compareResults(
layer_scan.bias.grad,
layer_loop.bias.grad,
abs_err=0.0001,
rel_err=0.001)
self.assert_different_tensor(layer_scan.weight.grad,
layer_loop.weight.grad)
self.assert_different_tensor(layer_scan.bias.grad, layer_loop.bias.grad)

def test_decoder_model(self):
# Define a decoder model that composes the decoder model in the example,
# but adds the ability to run the layers with the `scan` operator.
class DecoderOnlyModelWithScan(torch.nn.Module):

def __init__(self, **kwargs):
super(DecoderOnlyModelWithScan, self).__init__()
self.decoder = DecoderOnlyModel(**kwargs)

@property
def layers(self) -> Iterable[torch.nn.Module]:
return self.decoder.layers

def forward(
self,
input_ids: torch.Tensor,
) -> torch.Tensor:
return self.decoder.forward(input_ids)

def forward_scan(
self,
input_ids: torch.Tensor,
) -> torch.Tensor:
inputs_embeds = self.decoder.embed_tokens(input_ids)
# embed positions
assert isinstance(inputs_embeds, torch.Tensor)
# decoder layers
hidden_states = apply_layers(self.decoder.layers, inputs_embeds)
hidden_states = self.decoder.norm(hidden_states)
# [B, S, H] -> [B, S, V]
return self.decoder.output(hidden_states)

# Fix the random seed to avoid flakes.
with torch.random.fork_rng():
with torch_xla.xm.fork_rng():
torch.random.manual_seed(42)
torch_xla.xm.set_rng_state(42)

# Make it smaller for fast model run and comparisons.
config = DecoderOnlyConfig(
hidden_size=128, intermediate_size=8 * 128, vocab_size=256)
model = DecoderOnlyModelWithScan(config=config).to(self.device)
batch_size = 2
sequence_length = 8

# Generate random input_ids within the range of the vocabulary size
input_ids = torch.randint(0, config.vocab_size,
(batch_size, sequence_length)).to(self.device)

loop_model = deepcopy(model)
scan_model = deepcopy(model)
torch_xla.sync(wait=True)

# Run the loop-based model.
loop_output = loop_model(input_ids.clone())
loop_output.sum().backward()
torch_xla.sync()

# Run again, this time using `scan`
scan_output = scan_model.forward_scan(input_ids.clone())
scan_output.sum().backward()

# Before materializing the tensors, check that tensor HLO has `While` in it.
self.assert_while_found_in_hlo(scan_output)
for layer_scan in scan_model.layers:
for (name, param_scan) in layer_scan.named_parameters():
if param_scan.grad is not None:
self.assert_while_found_in_hlo(param_scan.grad)

torch_xla.sync()

# Compare results
super().compareResults(
scan_output, loop_output, abs_err=0.0001, rel_err=0.0001)

# Check gradients
checks = 0
for layer_scan, layer_loop in zip(scan_model.layers, loop_model.layers):
for (name,
param_scan), (name2,
param_loop) in zip(layer_scan.named_parameters(),
layer_loop.named_parameters()):
assert name == name2
# Either the parameter should have gradient in both, or it should not
# have gradient in both.
assert (param_scan.grad is not None) == (param_loop.grad is not None)
# Check gradients
if param_scan.grad is not None and param_loop.grad is not None:
# Check that they are not the same tensor
assert id(param_scan.grad) != id(param_loop.grad)
assert id(param_scan.grad.untyped_storage()) != id(
param_loop.grad.untyped_storage())
super().compareResults(
param_scan.grad, param_loop.grad, abs_err=0.0001, rel_err=0.0001)
checks = checks + 1
assert checks > 0

def test_heterogenous_layers(self):
layer1 = nn.Linear(128, 128).to(torch_xla.device())
layer2 = nn.Sequential(nn.Linear(128, 128).to(torch_xla.device()))
with self.assertRaisesRegex(ValueError, "mismatched set of parameters"):
apply_layers([layer1, layer2],
torch.zeros((128,), device=torch_xla.device()))

def test_mismatched_shapes(self):
layer1 = nn.Linear(128, 128).to(torch_xla.device())
layer2 = nn.Linear(128, 129).to(torch_xla.device())
with self.assertRaisesRegex(ValueError, "Shape mismatch"):
apply_layers([layer1, layer2],
torch.zeros((128,), device=torch_xla.device()))


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
24 changes: 24 additions & 0 deletions test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import itertools
import math
from numbers import Number
from functools import reduce
import numpy
import random
import re
Expand Down Expand Up @@ -2597,6 +2598,29 @@ def test_api(self):
mapping = ctx.parameter_id_tensor_mapping()
self.assertEqual(len(mapping), 2)

def test_get_parameters_scalar(self):
"""Scalar tensors parameters may be shared in the HLO graph if their
numerical values are equal. `parameter_id_tensor_mapping` needs to handle
that appropriately.
"""

device = torch_xla.device()
tensors = []
for i in range(10):
# Add three copies of the same value.
tensors.append(torch.tensor(i, device=device))
tensors.append(torch.tensor(i, device=device))
tensors.append(torch.tensor(i, device=device))
result = reduce(lambda a, b: a + b, tensors)
ctx = torch_xla._XLAC.lowering.LoweringContext()
ctx.build([result])
mapping = ctx.parameter_id_tensor_mapping()

import json
hlo_json = json.loads(ctx.hlo_json())
num_parameters = len(hlo_json["hostProgramShape"]["parameters"])
self.assertEqual(len(mapping), num_parameters)


class TestGeneric(test_utils.XlaTestCase):

Expand Down
Loading

0 comments on commit e6188d8

Please sign in to comment.