Skip to content

Commit e362707

Browse files
committed
Rebase and add example_2d_sharding.py
1 parent e795c68 commit e362707

File tree

1 file changed

+171
-0
lines changed

1 file changed

+171
-0
lines changed

example_2d_sharding.py

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# This software may be used and distributed according to the terms of the GNU General Public License version 3.
3+
4+
from typing import Tuple
5+
import os
6+
import sys
7+
import torch
8+
import fire
9+
import time
10+
import math
11+
import torch_xla.core.xla_model as xm
12+
import torch_xla.debug.metrics as met
13+
import torch_xla.debug.profiler as xp
14+
import json
15+
16+
from pathlib import Path
17+
18+
from llama import ModelArgs, Transformer, Tokenizer, Llama
19+
20+
# TODO(yeounoh) import packages for PyTorch/XLA GSPMD
21+
import numpy as np
22+
import torch_xla.experimental.xla_sharding as xs
23+
import torch_xla.experimental.pjrt as pjrt
24+
25+
# For xr.global_runtime_device_count()
26+
from torch_xla import runtime as xr
27+
28+
def init(
29+
tokenizer_path: str,
30+
max_seq_len: int,
31+
max_batch_size: int,
32+
dim: int = 4096,
33+
n_layers: int = 32,
34+
n_heads: int = 32,
35+
) -> Llama:
36+
start_time = time.time()
37+
# checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
38+
# TODO the checkpoint for large models seems to be sharded as well
39+
# assert world_size == len(
40+
# checkpoints
41+
# ), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {world_size}"
42+
# ckpt_path = checkpoints[rank]
43+
print("Loading")
44+
# checkpoint = torch.load(ckpt_path, map_location="cpu")
45+
# with open(Path(ckpt_dir) / "params.json", "r") as f:
46+
# params = json.loads(f.read())
47+
params = {"dim": dim,
48+
"n_layers": n_layers,
49+
"n_heads": n_heads,
50+
}
51+
model_args: ModelArgs = ModelArgs(
52+
max_seq_len=max_seq_len, max_batch_size=max_batch_size, **params
53+
)
54+
tokenizer = Tokenizer(model_path=tokenizer_path)
55+
model_args.vocab_size = tokenizer.n_words
56+
# torch.set_default_tensor_type(torch.cuda.HalfTensor) # TODO: this line puts the model to cuda device
57+
torch.set_default_tensor_type(torch.BFloat16Tensor)
58+
model = Transformer(model_args)
59+
device = xm.xla_device()
60+
model = model.to(device)
61+
62+
# for i in range(len(model.cache_kvs)):
63+
# model.cache_kvs[i] = tuple(t.to(device) for t in model.cache_kvs[i])
64+
# torch.set_default_tensor_type(torch.FloatTensor)
65+
66+
# model.load_state_dict(checkpoint, strict=False)
67+
68+
# num_devices = pjrt.global_device_count()
69+
num_devices = xr.global_runtime_device_count() # updated way to get device count
70+
device_ids = np.arange(num_devices)
71+
72+
x_dim = math.isqrt(num_devices) // 2
73+
yz_dim = 2 * math.isqrt(num_devices)
74+
75+
col_mesh = xs.Mesh(device_ids, (1, num_devices))
76+
row_mesh = xs.Mesh(device_ids, (num_devices, 1))
77+
78+
print(f'[WONJOO] device_ids={device_ids}')
79+
print(f'[WONJOO] x_dim={x_dim}')
80+
print(f'[WONJOO] yz_dim={yz_dim}')
81+
two_d_mesh = xs.Mesh(device_ids, (x_dim, yz_dim))
82+
two_d_mesh_transpose = xs.Mesh(device_ids, (yz_dim, x_dim))
83+
84+
for name, layer in model.named_modules():
85+
if 'tok_embeddings' in name:
86+
xs.mark_sharding(layer.weight, row_mesh, (0, 1))
87+
if 'attention.' in name:
88+
if 'wo' in name:
89+
xs.mark_sharding(layer.weight, two_d_mesh_transpose, (0, 1))
90+
else:
91+
xs.mark_sharding(layer.weight, two_d_mesh, (0, 1))
92+
if 'feed_forward.' in name:
93+
if 'w2' in name:
94+
xs.mark_sharding(layer.weight, two_d_mesh_transpose, (0, 1))
95+
else:
96+
xs.mark_sharding(layer.weight, two_d_mesh, (0, 1))
97+
if 'output' in name:
98+
xs.mark_sharding(layer.weight, col_mesh, (0, 1))
99+
100+
# TODO(yeounoh) shard cache_kvs before LLaMA init
101+
# col_mesh = xs.Mesh(device_ids, (1, 1, num_devices, 1))
102+
# for i in range(len(model.cache_kvs)):
103+
# for t in model.cache_kvs[i]:
104+
# xs.mark_sharding(t, col_mesh, (0,1,2,3))
105+
106+
generator = LLaMA(model, tokenizer)
107+
print(generator)
108+
print(f"Loaded in {time.time() - start_time:.2f} seconds")
109+
return generator
110+
111+
112+
def main(
113+
tokenizer_path: str,
114+
temperature: float = 0.8,
115+
top_p: float = 0.95,
116+
max_seq_len: int = 512,
117+
max_batch_size: int = 32,
118+
dim: int = 4096,
119+
n_layers: int = 32,
120+
n_heads: int = 32,
121+
):
122+
server = xp.start_server(9012, only_on_master=False)
123+
torch.manual_seed(1)
124+
generator = init(
125+
tokenizer_path, max_seq_len, max_batch_size, dim, n_layers, n_heads
126+
)
127+
128+
prompts = [
129+
# For these prompts, the expected answer is the natural continuation of the prompt
130+
"I believe the meaning of life is",
131+
#"Simply put, the theory of relativity states that ",
132+
#"Building a website can be done in 10 simple steps:\n",
133+
# Few shot prompts: https://huggingface.co/blog/few-shot-learning-gpt-neo-and-inference-api
134+
# """Tweet: "I hate it when my phone battery dies."
135+
#Sentiment: Negative
136+
####
137+
#Tweet: "My day has been 👍"
138+
#Sentiment: Positive
139+
####
140+
#Tweet: "This is the link to the article"
141+
#Sentiment: Neutral
142+
####
143+
#Tweet: "This new music video was incredibile"
144+
#Sentiment:""",
145+
# """Translate English to French:
146+
#
147+
#sea otter => loutre de mer
148+
#
149+
#peppermint => menthe poivrée
150+
#
151+
#plush girafe => girafe peluche
152+
#
153+
#cheese =>""",
154+
]
155+
with torch.no_grad():
156+
results = generator.generate(
157+
prompts, max_gen_len=1, temperature=temperature, top_p=top_p
158+
)
159+
with torch.no_grad():
160+
results = generator.generate(
161+
prompts, max_gen_len=256, temperature=temperature, top_p=top_p
162+
)
163+
if xm.is_master_ordinal(local=False):
164+
for result in results:
165+
print(result)
166+
print("\n==================================\n")
167+
168+
169+
if __name__ == "__main__":
170+
fire.Fire(main)
171+
# print(met.metrics_report())

0 commit comments

Comments
 (0)