|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# This software may be used and distributed according to the terms of the GNU General Public License version 3. |
| 3 | + |
| 4 | +from typing import Tuple |
| 5 | +import os |
| 6 | +import sys |
| 7 | +import torch |
| 8 | +import fire |
| 9 | +import time |
| 10 | +import math |
| 11 | +import torch_xla.core.xla_model as xm |
| 12 | +import torch_xla.debug.metrics as met |
| 13 | +import torch_xla.debug.profiler as xp |
| 14 | +import json |
| 15 | + |
| 16 | +from pathlib import Path |
| 17 | + |
| 18 | +from llama import ModelArgs, Transformer, Tokenizer, Llama |
| 19 | + |
| 20 | +# TODO(yeounoh) import packages for PyTorch/XLA GSPMD |
| 21 | +import numpy as np |
| 22 | +import torch_xla.experimental.xla_sharding as xs |
| 23 | +import torch_xla.experimental.pjrt as pjrt |
| 24 | + |
| 25 | +# For xr.global_runtime_device_count() |
| 26 | +from torch_xla import runtime as xr |
| 27 | + |
| 28 | +def init( |
| 29 | + tokenizer_path: str, |
| 30 | + max_seq_len: int, |
| 31 | + max_batch_size: int, |
| 32 | + dim: int = 4096, |
| 33 | + n_layers: int = 32, |
| 34 | + n_heads: int = 32, |
| 35 | +) -> Llama: |
| 36 | + start_time = time.time() |
| 37 | + # checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) |
| 38 | + # TODO the checkpoint for large models seems to be sharded as well |
| 39 | + # assert world_size == len( |
| 40 | + # checkpoints |
| 41 | + # ), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {world_size}" |
| 42 | + # ckpt_path = checkpoints[rank] |
| 43 | + print("Loading") |
| 44 | + # checkpoint = torch.load(ckpt_path, map_location="cpu") |
| 45 | + # with open(Path(ckpt_dir) / "params.json", "r") as f: |
| 46 | + # params = json.loads(f.read()) |
| 47 | + params = {"dim": dim, |
| 48 | + "n_layers": n_layers, |
| 49 | + "n_heads": n_heads, |
| 50 | + } |
| 51 | + model_args: ModelArgs = ModelArgs( |
| 52 | + max_seq_len=max_seq_len, max_batch_size=max_batch_size, **params |
| 53 | + ) |
| 54 | + tokenizer = Tokenizer(model_path=tokenizer_path) |
| 55 | + model_args.vocab_size = tokenizer.n_words |
| 56 | + # torch.set_default_tensor_type(torch.cuda.HalfTensor) # TODO: this line puts the model to cuda device |
| 57 | + torch.set_default_tensor_type(torch.BFloat16Tensor) |
| 58 | + model = Transformer(model_args) |
| 59 | + device = xm.xla_device() |
| 60 | + model = model.to(device) |
| 61 | + |
| 62 | + # for i in range(len(model.cache_kvs)): |
| 63 | + # model.cache_kvs[i] = tuple(t.to(device) for t in model.cache_kvs[i]) |
| 64 | + # torch.set_default_tensor_type(torch.FloatTensor) |
| 65 | + |
| 66 | + # model.load_state_dict(checkpoint, strict=False) |
| 67 | + |
| 68 | + # num_devices = pjrt.global_device_count() |
| 69 | + num_devices = xr.global_runtime_device_count() # updated way to get device count |
| 70 | + device_ids = np.arange(num_devices) |
| 71 | + |
| 72 | + x_dim = math.isqrt(num_devices) // 2 |
| 73 | + yz_dim = 2 * math.isqrt(num_devices) |
| 74 | + |
| 75 | + col_mesh = xs.Mesh(device_ids, (1, num_devices)) |
| 76 | + row_mesh = xs.Mesh(device_ids, (num_devices, 1)) |
| 77 | + |
| 78 | + print(f'[WONJOO] device_ids={device_ids}') |
| 79 | + print(f'[WONJOO] x_dim={x_dim}') |
| 80 | + print(f'[WONJOO] yz_dim={yz_dim}') |
| 81 | + two_d_mesh = xs.Mesh(device_ids, (x_dim, yz_dim)) |
| 82 | + two_d_mesh_transpose = xs.Mesh(device_ids, (yz_dim, x_dim)) |
| 83 | + |
| 84 | + for name, layer in model.named_modules(): |
| 85 | + if 'tok_embeddings' in name: |
| 86 | + xs.mark_sharding(layer.weight, row_mesh, (0, 1)) |
| 87 | + if 'attention.' in name: |
| 88 | + if 'wo' in name: |
| 89 | + xs.mark_sharding(layer.weight, two_d_mesh_transpose, (0, 1)) |
| 90 | + else: |
| 91 | + xs.mark_sharding(layer.weight, two_d_mesh, (0, 1)) |
| 92 | + if 'feed_forward.' in name: |
| 93 | + if 'w2' in name: |
| 94 | + xs.mark_sharding(layer.weight, two_d_mesh_transpose, (0, 1)) |
| 95 | + else: |
| 96 | + xs.mark_sharding(layer.weight, two_d_mesh, (0, 1)) |
| 97 | + if 'output' in name: |
| 98 | + xs.mark_sharding(layer.weight, col_mesh, (0, 1)) |
| 99 | + |
| 100 | + # TODO(yeounoh) shard cache_kvs before LLaMA init |
| 101 | + # col_mesh = xs.Mesh(device_ids, (1, 1, num_devices, 1)) |
| 102 | + # for i in range(len(model.cache_kvs)): |
| 103 | + # for t in model.cache_kvs[i]: |
| 104 | + # xs.mark_sharding(t, col_mesh, (0,1,2,3)) |
| 105 | + |
| 106 | + generator = LLaMA(model, tokenizer) |
| 107 | + print(generator) |
| 108 | + print(f"Loaded in {time.time() - start_time:.2f} seconds") |
| 109 | + return generator |
| 110 | + |
| 111 | + |
| 112 | +def main( |
| 113 | + tokenizer_path: str, |
| 114 | + temperature: float = 0.8, |
| 115 | + top_p: float = 0.95, |
| 116 | + max_seq_len: int = 512, |
| 117 | + max_batch_size: int = 32, |
| 118 | + dim: int = 4096, |
| 119 | + n_layers: int = 32, |
| 120 | + n_heads: int = 32, |
| 121 | +): |
| 122 | + server = xp.start_server(9012, only_on_master=False) |
| 123 | + torch.manual_seed(1) |
| 124 | + generator = init( |
| 125 | + tokenizer_path, max_seq_len, max_batch_size, dim, n_layers, n_heads |
| 126 | + ) |
| 127 | + |
| 128 | + prompts = [ |
| 129 | + # For these prompts, the expected answer is the natural continuation of the prompt |
| 130 | + "I believe the meaning of life is", |
| 131 | + #"Simply put, the theory of relativity states that ", |
| 132 | + #"Building a website can be done in 10 simple steps:\n", |
| 133 | + # Few shot prompts: https://huggingface.co/blog/few-shot-learning-gpt-neo-and-inference-api |
| 134 | +# """Tweet: "I hate it when my phone battery dies." |
| 135 | +#Sentiment: Negative |
| 136 | +#### |
| 137 | +#Tweet: "My day has been 👍" |
| 138 | +#Sentiment: Positive |
| 139 | +#### |
| 140 | +#Tweet: "This is the link to the article" |
| 141 | +#Sentiment: Neutral |
| 142 | +#### |
| 143 | +#Tweet: "This new music video was incredibile" |
| 144 | +#Sentiment:""", |
| 145 | +# """Translate English to French: |
| 146 | +# |
| 147 | +#sea otter => loutre de mer |
| 148 | +# |
| 149 | +#peppermint => menthe poivrée |
| 150 | +# |
| 151 | +#plush girafe => girafe peluche |
| 152 | +# |
| 153 | +#cheese =>""", |
| 154 | + ] |
| 155 | + with torch.no_grad(): |
| 156 | + results = generator.generate( |
| 157 | + prompts, max_gen_len=1, temperature=temperature, top_p=top_p |
| 158 | + ) |
| 159 | + with torch.no_grad(): |
| 160 | + results = generator.generate( |
| 161 | + prompts, max_gen_len=256, temperature=temperature, top_p=top_p |
| 162 | + ) |
| 163 | + if xm.is_master_ordinal(local=False): |
| 164 | + for result in results: |
| 165 | + print(result) |
| 166 | + print("\n==================================\n") |
| 167 | + |
| 168 | + |
| 169 | +if __name__ == "__main__": |
| 170 | + fire.Fire(main) |
| 171 | + # print(met.metrics_report()) |
0 commit comments