diff --git a/example_text_completion.py b/example_text_completion.py index fc084b019..077921430 100755 --- a/example_text_completion.py +++ b/example_text_completion.py @@ -25,17 +25,22 @@ def main( max_gen_len: int = 64, max_batch_size: int = 4, dynamo: bool = True, + spmd: bool = True, ): if not USE_CUDA: - server = xp.start_server(9012, only_on_master=False) + # server = xp.start_server(9012, only_on_master=False) + pass generator = Llama.build( ckpt_dir=ckpt_dir, tokenizer_path=tokenizer_path, max_seq_len=max_seq_len, max_batch_size=max_batch_size, dynamo=dynamo, + spmd=spmd, ) + print(f'[WONJOO] max_batch_size={max_batch_size}') + prompts = [ # For these prompts, the expected answer is the natural continuation of the prompt "I believe the meaning of life is", @@ -53,6 +58,13 @@ def main( # plush girafe => girafe peluche # cheese =>""", ] + + import time + print("About to start in 15 seconds") + server = xp.start_server(9012, only_on_master=False) + time.sleep(15) + print("Starting!") + for _ in range(2): with torch.no_grad(): results = generator.text_completion( @@ -66,6 +78,8 @@ def main( print(f"> {result['generation']}") print("\n==================================\n") + print("Finished!") + def _fn( idx, @@ -77,12 +91,13 @@ def _fn( max_gen_len: int = 64, max_batch_size: int = 4, dynamo: bool = True, + spmd: bool = True, ): if USE_CUDA: os.environ['WORLD_SIZE'] = torch.cuda.device_count() os.environ['RANK'] = idx os.environ['LOCAL_RANK'] = idx - main(ckpt_dir, tokenizer_path, temperature, top_p, max_seq_len, max_gen_len, max_batch_size, dynamo) + main(ckpt_dir, tokenizer_path, temperature, top_p, max_seq_len, max_gen_len, max_batch_size, dynamo, spmd) def mp_main( @@ -95,6 +110,7 @@ def mp_main( max_gen_len: int = 64, max_batch_size: int = 4, dynamo: bool = True, + spmd: bool = True, ): if mp: if USE_CUDA: @@ -103,9 +119,9 @@ def mp_main( else: kwargs = {} xmp.spawn(_fn, - args=(ckpt_dir, tokenizer_path, temperature, top_p, max_seq_len, max_gen_len, max_batch_size, dynamo), **kwargs) + args=(ckpt_dir, tokenizer_path, temperature, top_p, max_seq_len, max_gen_len, max_batch_size, dynamo, spmd), **kwargs) else: - main(ckpt_dir, tokenizer_path, temperature, top_p, max_seq_len, max_gen_len, max_batch_size, dynamo) + main(ckpt_dir, tokenizer_path, temperature, top_p, max_seq_len, max_gen_len, max_batch_size, dynamo, spmd) if __name__ == "__main__": diff --git a/llama/generation.py b/llama/generation.py index 043f188c7..2c00de05c 100755 --- a/llama/generation.py +++ b/llama/generation.py @@ -20,6 +20,9 @@ # Some how xla init will slow down the CUDA speed. if not USE_CUDA: import torch_xla.core.xla_model as xm + import torch_xla.experimental.xla_sharding as xs + from torch_xla import runtime as xr + import numpy as np Role = Literal["system", "user", "assistant"] @@ -60,6 +63,7 @@ def build( max_batch_size: int, model_parallel_size: Optional[int] = None, dynamo: bool = True, + spmd: bool = True, ) -> "Llama": # if not model_parallel_is_initialized(): # if model_parallel_size is None: @@ -106,6 +110,9 @@ def build( max_batch_size=max_batch_size, **params, ) + + model_args.print_values() + tokenizer = Tokenizer(model_path=tokenizer_path) model_args.vocab_size = tokenizer.n_words if USE_CUDA: @@ -118,14 +125,65 @@ def build( model = model.to(device) print(f"Loaded in {time.time() - start_time:.2f} seconds") - return Llama(model, tokenizer, device, dynamo) + return Llama(model, tokenizer, device, dynamo, spmd) - def __init__(self, model: Transformer, tokenizer: Tokenizer, device: torch.device, dynamo: bool = True): + def __init__(self, model: Transformer, tokenizer: Tokenizer, device: torch.device, dynamo: bool = True, spmd: bool = True): self.model = model self.tokenizer = tokenizer self.device = device - self._generate_one_token_fn = self._generate_one_token + + if spmd: + num_devices = xr.global_runtime_device_count() # updated way to get device count + # num_devices = 8 # hard-coded for v5-8 + device_ids = np.arange(num_devices) + x_dim = 2 # hard-coded for v5-8 + yz_dim = 4 # hard-coded for v5-8 + + # manually shard the kv cache + four_d_mesh = xs.Mesh(device_ids, (1, 1, x_dim, yz_dim)) + for layer in model.layers: + xs.mark_sharding(layer.attention.cache_k, four_d_mesh, (0, 1, 2, None)) + xs.mark_sharding(layer.attention.cache_v, four_d_mesh, (0, 1, 2, None)) + + col_mesh = xs.Mesh(device_ids, (1, num_devices)) + row_mesh = xs.Mesh(device_ids, (num_devices, 1)) + + for name, layer in model.named_modules(): + if 'tok_embeddings' in name: + xs.mark_sharding(layer.weight, row_mesh, (0, 1)) + if 'attention.' in name: + if 'wo' in name: + xs.mark_sharding(layer.weight, row_mesh, (0, 1)) + else: + xs.mark_sharding(layer.weight, col_mesh, (0, 1)) + if 'feed_forward.' in name: + if 'w2' in name: + xs.mark_sharding(layer.weight, row_mesh, (0, 1)) + else: + xs.mark_sharding(layer.weight, col_mesh, (0, 1)) + if 'output' in name: + xs.mark_sharding(layer.weight, col_mesh, (0, 1)) + + # Sharding strategy for 2D sharding + # two_d_mesh = xs.Mesh(device_ids, (x_dim, yz_dim)) + # two_d_mesh_transpose = xs.Mesh(device_ids, (yz_dim, x_dim)) + # for name, layer in model.named_modules(): + # if 'tok_embeddings' in name: + # xs.mark_sharding(layer.weight, row_mesh, (0, 1)) + # if 'attention.' in name: + # if 'wo' in name: + # xs.mark_sharding(layer.weight, two_d_mesh_transpose, (0, 1)) + # else: + # xs.mark_sharding(layer.weight, two_d_mesh, (0, 1)) + # if 'feed_forward.' in name: + # if 'w2' in name: + # xs.mark_sharding(layer.weight, two_d_mesh_transpose, (0, 1)) + # else: + # xs.mark_sharding(layer.weight, two_d_mesh, (0, 1)) + # if 'output' in name: + # xs.mark_sharding(layer.weight, col_mesh, (0, 1)) + if dynamo: if USE_CUDA: # Inductor errors out when compiles _generate_one_token_fn. diff --git a/llama/model.py b/llama/model.py index 88cc7ca1d..e9482c9eb 100755 --- a/llama/model.py +++ b/llama/model.py @@ -35,6 +35,15 @@ class ModelArgs: max_seq_len: int = 2048 quant: bool = False + def print_values(self): + print(f'[WONJOO] ModelArgs') + print(f'[WONJOO] dim={self.dim}') + print(f'[WONJOO] n_layers={self.n_layers}') + print(f'[WONJOO] n_heads={self.n_heads}') + print(f'[WONJOO] max_batch_size={self.max_batch_size}') + print(f'[WONJOO] max_seq_len={self.max_seq_len}') + print(f'[WONJOO] quant={self.quant}') + class RMSNorm(torch.nn.Module): def __init__(self, dim: int, eps: float = 1e-6):