From 8307bf308b2ad8beaf69cd2ee5d6967b8aea00e1 Mon Sep 17 00:00:00 2001 From: Wonjoo Lee Date: Fri, 28 Jul 2023 20:54:53 +0000 Subject: [PATCH 1/5] Rebase and add example_2d_sharding.py --- example_2d_sharding.py | 171 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 171 insertions(+) create mode 100644 example_2d_sharding.py diff --git a/example_2d_sharding.py b/example_2d_sharding.py new file mode 100644 index 000000000..a4cb7e689 --- /dev/null +++ b/example_2d_sharding.py @@ -0,0 +1,171 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the GNU General Public License version 3. + +from typing import Tuple +import os +import sys +import torch +import fire +import time +import math +import torch_xla.core.xla_model as xm +import torch_xla.debug.metrics as met +import torch_xla.debug.profiler as xp +import json + +from pathlib import Path + +from llama import ModelArgs, Transformer, Tokenizer, Llama + +# TODO(yeounoh) import packages for PyTorch/XLA GSPMD +import numpy as np +import torch_xla.experimental.xla_sharding as xs +import torch_xla.experimental.pjrt as pjrt + +# For xr.global_runtime_device_count() +from torch_xla import runtime as xr + +def init( + tokenizer_path: str, + max_seq_len: int, + max_batch_size: int, + dim: int = 4096, + n_layers: int = 32, + n_heads: int = 32, +) -> Llama: + start_time = time.time() + # checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) + # TODO the checkpoint for large models seems to be sharded as well + # assert world_size == len( + # checkpoints + # ), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {world_size}" + # ckpt_path = checkpoints[rank] + print("Loading") + # checkpoint = torch.load(ckpt_path, map_location="cpu") + # with open(Path(ckpt_dir) / "params.json", "r") as f: + # params = json.loads(f.read()) + params = {"dim": dim, + "n_layers": n_layers, + "n_heads": n_heads, + } + model_args: ModelArgs = ModelArgs( + max_seq_len=max_seq_len, max_batch_size=max_batch_size, **params + ) + tokenizer = Tokenizer(model_path=tokenizer_path) + model_args.vocab_size = tokenizer.n_words + # torch.set_default_tensor_type(torch.cuda.HalfTensor) # TODO: this line puts the model to cuda device + torch.set_default_tensor_type(torch.BFloat16Tensor) + model = Transformer(model_args) + device = xm.xla_device() + model = model.to(device) + + # for i in range(len(model.cache_kvs)): + # model.cache_kvs[i] = tuple(t.to(device) for t in model.cache_kvs[i]) + # torch.set_default_tensor_type(torch.FloatTensor) + + # model.load_state_dict(checkpoint, strict=False) + + # num_devices = pjrt.global_device_count() + num_devices = xr.global_runtime_device_count() # updated way to get device count + device_ids = np.arange(num_devices) + + x_dim = math.isqrt(num_devices) // 2 + yz_dim = 2 * math.isqrt(num_devices) + + col_mesh = xs.Mesh(device_ids, (1, num_devices)) + row_mesh = xs.Mesh(device_ids, (num_devices, 1)) + + print(f'[WONJOO] device_ids={device_ids}') + print(f'[WONJOO] x_dim={x_dim}') + print(f'[WONJOO] yz_dim={yz_dim}') + two_d_mesh = xs.Mesh(device_ids, (x_dim, yz_dim)) + two_d_mesh_transpose = xs.Mesh(device_ids, (yz_dim, x_dim)) + + for name, layer in model.named_modules(): + if 'tok_embeddings' in name: + xs.mark_sharding(layer.weight, row_mesh, (0, 1)) + if 'attention.' in name: + if 'wo' in name: + xs.mark_sharding(layer.weight, two_d_mesh_transpose, (0, 1)) + else: + xs.mark_sharding(layer.weight, two_d_mesh, (0, 1)) + if 'feed_forward.' in name: + if 'w2' in name: + xs.mark_sharding(layer.weight, two_d_mesh_transpose, (0, 1)) + else: + xs.mark_sharding(layer.weight, two_d_mesh, (0, 1)) + if 'output' in name: + xs.mark_sharding(layer.weight, col_mesh, (0, 1)) + + # TODO(yeounoh) shard cache_kvs before LLaMA init + # col_mesh = xs.Mesh(device_ids, (1, 1, num_devices, 1)) + # for i in range(len(model.cache_kvs)): + # for t in model.cache_kvs[i]: + # xs.mark_sharding(t, col_mesh, (0,1,2,3)) + + generator = LLaMA(model, tokenizer, device, True) + print(generator) + print(f"Loaded in {time.time() - start_time:.2f} seconds") + return generator + + +def main( + tokenizer_path: str, + temperature: float = 0.8, + top_p: float = 0.95, + max_seq_len: int = 512, + max_batch_size: int = 32, + dim: int = 4096, + n_layers: int = 32, + n_heads: int = 32, +): + server = xp.start_server(9012, only_on_master=False) + torch.manual_seed(1) + generator = init( + tokenizer_path, max_seq_len, max_batch_size, dim, n_layers, n_heads + ) + + prompts = [ + # For these prompts, the expected answer is the natural continuation of the prompt + "I believe the meaning of life is", + #"Simply put, the theory of relativity states that ", + #"Building a website can be done in 10 simple steps:\n", + # Few shot prompts: https://huggingface.co/blog/few-shot-learning-gpt-neo-and-inference-api +# """Tweet: "I hate it when my phone battery dies." +#Sentiment: Negative +#### +#Tweet: "My day has been 👍" +#Sentiment: Positive +#### +#Tweet: "This is the link to the article" +#Sentiment: Neutral +#### +#Tweet: "This new music video was incredibile" +#Sentiment:""", +# """Translate English to French: +# +#sea otter => loutre de mer +# +#peppermint => menthe poivrée +# +#plush girafe => girafe peluche +# +#cheese =>""", + ] + with torch.no_grad(): + results = generator.generate( + prompts, max_gen_len=1, temperature=temperature, top_p=top_p + ) + with torch.no_grad(): + results = generator.generate( + prompts, max_gen_len=256, temperature=temperature, top_p=top_p + ) + if xm.is_master_ordinal(local=False): + for result in results: + print(result) + print("\n==================================\n") + + +if __name__ == "__main__": + fire.Fire(main) + # print(met.metrics_report()) From 908d15d8fdd844fe601b4466c11893ec85e0dd2c Mon Sep 17 00:00:00 2001 From: Wonjoo Lee Date: Mon, 31 Jul 2023 23:32:54 +0000 Subject: [PATCH 2/5] Update prompt generation and hard-code sharding dims --- example_2d_sharding.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/example_2d_sharding.py b/example_2d_sharding.py index a4cb7e689..2cc1669fc 100644 --- a/example_2d_sharding.py +++ b/example_2d_sharding.py @@ -69,8 +69,11 @@ def init( num_devices = xr.global_runtime_device_count() # updated way to get device count device_ids = np.arange(num_devices) - x_dim = math.isqrt(num_devices) // 2 - yz_dim = 2 * math.isqrt(num_devices) + # x_dim = math.isqrt(num_devices) // 2 + # yz_dim = 2 * math.isqrt(num_devices) + + x_dim = 2 # hard-coded for v5 + yz_dim = 4 # hard-coded for v5 col_mesh = xs.Mesh(device_ids, (1, num_devices)) row_mesh = xs.Mesh(device_ids, (num_devices, 1)) @@ -103,7 +106,7 @@ def init( # for t in model.cache_kvs[i]: # xs.mark_sharding(t, col_mesh, (0,1,2,3)) - generator = LLaMA(model, tokenizer, device, True) + generator = Llama(model, tokenizer, device, True) print(generator) print(f"Loaded in {time.time() - start_time:.2f} seconds") return generator @@ -152,13 +155,14 @@ def main( # #cheese =>""", ] + prompt_tokens = [generator.tokenizer.encode(x, bos=True, eos=False) for x in prompts] with torch.no_grad(): results = generator.generate( - prompts, max_gen_len=1, temperature=temperature, top_p=top_p + prompt_tokens=prompt_tokens, max_gen_len=1, temperature=temperature, top_p=top_p ) with torch.no_grad(): results = generator.generate( - prompts, max_gen_len=256, temperature=temperature, top_p=top_p + prompt_tokens=prompt_tokens, max_gen_len=256, temperature=temperature, top_p=top_p ) if xm.is_master_ordinal(local=False): for result in results: From e41a4793eeaaea79c4314e6e1a646a6972298a2f Mon Sep 17 00:00:00 2001 From: Wonjoo Lee Date: Tue, 1 Aug 2023 01:39:56 +0000 Subject: [PATCH 3/5] Remove example_2d_sharding.py and enable SPMD --- example_2d_sharding.py | 175 ------------------------------------- example_text_completion.py | 10 ++- llama/generation.py | 37 +++++++- 3 files changed, 41 insertions(+), 181 deletions(-) delete mode 100644 example_2d_sharding.py diff --git a/example_2d_sharding.py b/example_2d_sharding.py deleted file mode 100644 index 2cc1669fc..000000000 --- a/example_2d_sharding.py +++ /dev/null @@ -1,175 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed according to the terms of the GNU General Public License version 3. - -from typing import Tuple -import os -import sys -import torch -import fire -import time -import math -import torch_xla.core.xla_model as xm -import torch_xla.debug.metrics as met -import torch_xla.debug.profiler as xp -import json - -from pathlib import Path - -from llama import ModelArgs, Transformer, Tokenizer, Llama - -# TODO(yeounoh) import packages for PyTorch/XLA GSPMD -import numpy as np -import torch_xla.experimental.xla_sharding as xs -import torch_xla.experimental.pjrt as pjrt - -# For xr.global_runtime_device_count() -from torch_xla import runtime as xr - -def init( - tokenizer_path: str, - max_seq_len: int, - max_batch_size: int, - dim: int = 4096, - n_layers: int = 32, - n_heads: int = 32, -) -> Llama: - start_time = time.time() - # checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) - # TODO the checkpoint for large models seems to be sharded as well - # assert world_size == len( - # checkpoints - # ), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {world_size}" - # ckpt_path = checkpoints[rank] - print("Loading") - # checkpoint = torch.load(ckpt_path, map_location="cpu") - # with open(Path(ckpt_dir) / "params.json", "r") as f: - # params = json.loads(f.read()) - params = {"dim": dim, - "n_layers": n_layers, - "n_heads": n_heads, - } - model_args: ModelArgs = ModelArgs( - max_seq_len=max_seq_len, max_batch_size=max_batch_size, **params - ) - tokenizer = Tokenizer(model_path=tokenizer_path) - model_args.vocab_size = tokenizer.n_words - # torch.set_default_tensor_type(torch.cuda.HalfTensor) # TODO: this line puts the model to cuda device - torch.set_default_tensor_type(torch.BFloat16Tensor) - model = Transformer(model_args) - device = xm.xla_device() - model = model.to(device) - - # for i in range(len(model.cache_kvs)): - # model.cache_kvs[i] = tuple(t.to(device) for t in model.cache_kvs[i]) - # torch.set_default_tensor_type(torch.FloatTensor) - - # model.load_state_dict(checkpoint, strict=False) - - # num_devices = pjrt.global_device_count() - num_devices = xr.global_runtime_device_count() # updated way to get device count - device_ids = np.arange(num_devices) - - # x_dim = math.isqrt(num_devices) // 2 - # yz_dim = 2 * math.isqrt(num_devices) - - x_dim = 2 # hard-coded for v5 - yz_dim = 4 # hard-coded for v5 - - col_mesh = xs.Mesh(device_ids, (1, num_devices)) - row_mesh = xs.Mesh(device_ids, (num_devices, 1)) - - print(f'[WONJOO] device_ids={device_ids}') - print(f'[WONJOO] x_dim={x_dim}') - print(f'[WONJOO] yz_dim={yz_dim}') - two_d_mesh = xs.Mesh(device_ids, (x_dim, yz_dim)) - two_d_mesh_transpose = xs.Mesh(device_ids, (yz_dim, x_dim)) - - for name, layer in model.named_modules(): - if 'tok_embeddings' in name: - xs.mark_sharding(layer.weight, row_mesh, (0, 1)) - if 'attention.' in name: - if 'wo' in name: - xs.mark_sharding(layer.weight, two_d_mesh_transpose, (0, 1)) - else: - xs.mark_sharding(layer.weight, two_d_mesh, (0, 1)) - if 'feed_forward.' in name: - if 'w2' in name: - xs.mark_sharding(layer.weight, two_d_mesh_transpose, (0, 1)) - else: - xs.mark_sharding(layer.weight, two_d_mesh, (0, 1)) - if 'output' in name: - xs.mark_sharding(layer.weight, col_mesh, (0, 1)) - - # TODO(yeounoh) shard cache_kvs before LLaMA init - # col_mesh = xs.Mesh(device_ids, (1, 1, num_devices, 1)) - # for i in range(len(model.cache_kvs)): - # for t in model.cache_kvs[i]: - # xs.mark_sharding(t, col_mesh, (0,1,2,3)) - - generator = Llama(model, tokenizer, device, True) - print(generator) - print(f"Loaded in {time.time() - start_time:.2f} seconds") - return generator - - -def main( - tokenizer_path: str, - temperature: float = 0.8, - top_p: float = 0.95, - max_seq_len: int = 512, - max_batch_size: int = 32, - dim: int = 4096, - n_layers: int = 32, - n_heads: int = 32, -): - server = xp.start_server(9012, only_on_master=False) - torch.manual_seed(1) - generator = init( - tokenizer_path, max_seq_len, max_batch_size, dim, n_layers, n_heads - ) - - prompts = [ - # For these prompts, the expected answer is the natural continuation of the prompt - "I believe the meaning of life is", - #"Simply put, the theory of relativity states that ", - #"Building a website can be done in 10 simple steps:\n", - # Few shot prompts: https://huggingface.co/blog/few-shot-learning-gpt-neo-and-inference-api -# """Tweet: "I hate it when my phone battery dies." -#Sentiment: Negative -#### -#Tweet: "My day has been 👍" -#Sentiment: Positive -#### -#Tweet: "This is the link to the article" -#Sentiment: Neutral -#### -#Tweet: "This new music video was incredibile" -#Sentiment:""", -# """Translate English to French: -# -#sea otter => loutre de mer -# -#peppermint => menthe poivrée -# -#plush girafe => girafe peluche -# -#cheese =>""", - ] - prompt_tokens = [generator.tokenizer.encode(x, bos=True, eos=False) for x in prompts] - with torch.no_grad(): - results = generator.generate( - prompt_tokens=prompt_tokens, max_gen_len=1, temperature=temperature, top_p=top_p - ) - with torch.no_grad(): - results = generator.generate( - prompt_tokens=prompt_tokens, max_gen_len=256, temperature=temperature, top_p=top_p - ) - if xm.is_master_ordinal(local=False): - for result in results: - print(result) - print("\n==================================\n") - - -if __name__ == "__main__": - fire.Fire(main) - # print(met.metrics_report()) diff --git a/example_text_completion.py b/example_text_completion.py index fc084b019..da89d443e 100755 --- a/example_text_completion.py +++ b/example_text_completion.py @@ -25,6 +25,7 @@ def main( max_gen_len: int = 64, max_batch_size: int = 4, dynamo: bool = True, + spmd: bool = True, ): if not USE_CUDA: server = xp.start_server(9012, only_on_master=False) @@ -34,6 +35,7 @@ def main( max_seq_len=max_seq_len, max_batch_size=max_batch_size, dynamo=dynamo, + spmd=spmd, ) prompts = [ @@ -77,12 +79,13 @@ def _fn( max_gen_len: int = 64, max_batch_size: int = 4, dynamo: bool = True, + spmd: bool = True, ): if USE_CUDA: os.environ['WORLD_SIZE'] = torch.cuda.device_count() os.environ['RANK'] = idx os.environ['LOCAL_RANK'] = idx - main(ckpt_dir, tokenizer_path, temperature, top_p, max_seq_len, max_gen_len, max_batch_size, dynamo) + main(ckpt_dir, tokenizer_path, temperature, top_p, max_seq_len, max_gen_len, max_batch_size, dynamo, spmd) def mp_main( @@ -95,6 +98,7 @@ def mp_main( max_gen_len: int = 64, max_batch_size: int = 4, dynamo: bool = True, + spmd: bool = True, ): if mp: if USE_CUDA: @@ -103,9 +107,9 @@ def mp_main( else: kwargs = {} xmp.spawn(_fn, - args=(ckpt_dir, tokenizer_path, temperature, top_p, max_seq_len, max_gen_len, max_batch_size, dynamo), **kwargs) + args=(ckpt_dir, tokenizer_path, temperature, top_p, max_seq_len, max_gen_len, max_batch_size, dynamo, spmd), **kwargs) else: - main(ckpt_dir, tokenizer_path, temperature, top_p, max_seq_len, max_gen_len, max_batch_size, dynamo) + main(ckpt_dir, tokenizer_path, temperature, top_p, max_seq_len, max_gen_len, max_batch_size, dynamo, spmd) if __name__ == "__main__": diff --git a/llama/generation.py b/llama/generation.py index 043f188c7..0da3f4f47 100755 --- a/llama/generation.py +++ b/llama/generation.py @@ -20,6 +20,9 @@ # Some how xla init will slow down the CUDA speed. if not USE_CUDA: import torch_xla.core.xla_model as xm + import torch_xla.experimental.xla_sharding as xs + from torch_xla import runtime as xr + import numpy as np Role = Literal["system", "user", "assistant"] @@ -60,6 +63,7 @@ def build( max_batch_size: int, model_parallel_size: Optional[int] = None, dynamo: bool = True, + spmd: bool = True, ) -> "Llama": # if not model_parallel_is_initialized(): # if model_parallel_size is None: @@ -118,14 +122,41 @@ def build( model = model.to(device) print(f"Loaded in {time.time() - start_time:.2f} seconds") - return Llama(model, tokenizer, device, dynamo) + return Llama(model, tokenizer, device, dynamo, spmd) - def __init__(self, model: Transformer, tokenizer: Tokenizer, device: torch.device, dynamo: bool = True): + def __init__(self, model: Transformer, tokenizer: Tokenizer, device: torch.device, dynamo: bool = True, spmd: bool = True): self.model = model self.tokenizer = tokenizer self.device = device - self._generate_one_token_fn = self._generate_one_token + + if spmd: + num_devices = xr.global_runtime_device_count() # updated way to get device count + device_ids = np.arange(num_devices) + x_dim = 2 # hard-coded for v5 + yz_dim = 4 # hard-coded for v5 + + col_mesh = xs.Mesh(device_ids, (1, num_devices)) + row_mesh = xs.Mesh(device_ids, (num_devices, 1)) + two_d_mesh = xs.Mesh(device_ids, (x_dim, yz_dim)) + two_d_mesh_transpose = xs.Mesh(device_ids, (yz_dim, x_dim)) + + for name, layer in model.named_modules(): + if 'tok_embeddings' in name: + xs.mark_sharding(layer.weight, row_mesh, (0, 1)) + if 'attention.' in name: + if 'wo' in name: + xs.mark_sharding(layer.weight, two_d_mesh_transpose, (0, 1)) + else: + xs.mark_sharding(layer.weight, two_d_mesh, (0, 1)) + if 'feed_forward.' in name: + if 'w2' in name: + xs.mark_sharding(layer.weight, two_d_mesh_transpose, (0, 1)) + else: + xs.mark_sharding(layer.weight, two_d_mesh, (0, 1)) + if 'output' in name: + xs.mark_sharding(layer.weight, col_mesh, (0, 1)) + if dynamo: if USE_CUDA: # Inductor errors out when compiles _generate_one_token_fn. From da5c70d88b88879e0bddb541e28a118e9057b07e Mon Sep 17 00:00:00 2001 From: Wonjoo Lee Date: Wed, 2 Aug 2023 20:18:38 +0000 Subject: [PATCH 4/5] Manually shard the KV cache --- llama/generation.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/llama/generation.py b/llama/generation.py index 0da3f4f47..fadd4fa44 100755 --- a/llama/generation.py +++ b/llama/generation.py @@ -136,6 +136,12 @@ def __init__(self, model: Transformer, tokenizer: Tokenizer, device: torch.devic x_dim = 2 # hard-coded for v5 yz_dim = 4 # hard-coded for v5 + # manually shard the kv cache + four_d_mesh = xs.Mesh(device_ids, (1, 1, x_dim, yz_dim)) + for layer in model.layers: + xs.mark_sharding(layer.attention.cache_k, four_d_mesh, (0, 1, 2, None)) + xs.mark_sharding(layer.attention.cache_v, four_d_mesh, (0, 1, 2, None)) + col_mesh = xs.Mesh(device_ids, (1, num_devices)) row_mesh = xs.Mesh(device_ids, (num_devices, 1)) two_d_mesh = xs.Mesh(device_ids, (x_dim, yz_dim)) From f45f5b1b1861f0844c612299736b7043830d2ad7 Mon Sep 17 00:00:00 2001 From: Wonjoo Lee Date: Mon, 7 Aug 2023 22:40:04 +0000 Subject: [PATCH 5/5] Update to latest --- example_text_completion.py | 14 +++++++++++++- llama/generation.py | 37 +++++++++++++++++++++++++++++-------- llama/model.py | 9 +++++++++ 3 files changed, 51 insertions(+), 9 deletions(-) diff --git a/example_text_completion.py b/example_text_completion.py index da89d443e..077921430 100755 --- a/example_text_completion.py +++ b/example_text_completion.py @@ -28,7 +28,8 @@ def main( spmd: bool = True, ): if not USE_CUDA: - server = xp.start_server(9012, only_on_master=False) + # server = xp.start_server(9012, only_on_master=False) + pass generator = Llama.build( ckpt_dir=ckpt_dir, tokenizer_path=tokenizer_path, @@ -38,6 +39,8 @@ def main( spmd=spmd, ) + print(f'[WONJOO] max_batch_size={max_batch_size}') + prompts = [ # For these prompts, the expected answer is the natural continuation of the prompt "I believe the meaning of life is", @@ -55,6 +58,13 @@ def main( # plush girafe => girafe peluche # cheese =>""", ] + + import time + print("About to start in 15 seconds") + server = xp.start_server(9012, only_on_master=False) + time.sleep(15) + print("Starting!") + for _ in range(2): with torch.no_grad(): results = generator.text_completion( @@ -68,6 +78,8 @@ def main( print(f"> {result['generation']}") print("\n==================================\n") + print("Finished!") + def _fn( idx, diff --git a/llama/generation.py b/llama/generation.py index fadd4fa44..2c00de05c 100755 --- a/llama/generation.py +++ b/llama/generation.py @@ -110,6 +110,9 @@ def build( max_batch_size=max_batch_size, **params, ) + + model_args.print_values() + tokenizer = Tokenizer(model_path=tokenizer_path) model_args.vocab_size = tokenizer.n_words if USE_CUDA: @@ -132,9 +135,10 @@ def __init__(self, model: Transformer, tokenizer: Tokenizer, device: torch.devic if spmd: num_devices = xr.global_runtime_device_count() # updated way to get device count + # num_devices = 8 # hard-coded for v5-8 device_ids = np.arange(num_devices) - x_dim = 2 # hard-coded for v5 - yz_dim = 4 # hard-coded for v5 + x_dim = 2 # hard-coded for v5-8 + yz_dim = 4 # hard-coded for v5-8 # manually shard the kv cache four_d_mesh = xs.Mesh(device_ids, (1, 1, x_dim, yz_dim)) @@ -144,25 +148,42 @@ def __init__(self, model: Transformer, tokenizer: Tokenizer, device: torch.devic col_mesh = xs.Mesh(device_ids, (1, num_devices)) row_mesh = xs.Mesh(device_ids, (num_devices, 1)) - two_d_mesh = xs.Mesh(device_ids, (x_dim, yz_dim)) - two_d_mesh_transpose = xs.Mesh(device_ids, (yz_dim, x_dim)) for name, layer in model.named_modules(): if 'tok_embeddings' in name: xs.mark_sharding(layer.weight, row_mesh, (0, 1)) if 'attention.' in name: if 'wo' in name: - xs.mark_sharding(layer.weight, two_d_mesh_transpose, (0, 1)) + xs.mark_sharding(layer.weight, row_mesh, (0, 1)) else: - xs.mark_sharding(layer.weight, two_d_mesh, (0, 1)) + xs.mark_sharding(layer.weight, col_mesh, (0, 1)) if 'feed_forward.' in name: if 'w2' in name: - xs.mark_sharding(layer.weight, two_d_mesh_transpose, (0, 1)) + xs.mark_sharding(layer.weight, row_mesh, (0, 1)) else: - xs.mark_sharding(layer.weight, two_d_mesh, (0, 1)) + xs.mark_sharding(layer.weight, col_mesh, (0, 1)) if 'output' in name: xs.mark_sharding(layer.weight, col_mesh, (0, 1)) + # Sharding strategy for 2D sharding + # two_d_mesh = xs.Mesh(device_ids, (x_dim, yz_dim)) + # two_d_mesh_transpose = xs.Mesh(device_ids, (yz_dim, x_dim)) + # for name, layer in model.named_modules(): + # if 'tok_embeddings' in name: + # xs.mark_sharding(layer.weight, row_mesh, (0, 1)) + # if 'attention.' in name: + # if 'wo' in name: + # xs.mark_sharding(layer.weight, two_d_mesh_transpose, (0, 1)) + # else: + # xs.mark_sharding(layer.weight, two_d_mesh, (0, 1)) + # if 'feed_forward.' in name: + # if 'w2' in name: + # xs.mark_sharding(layer.weight, two_d_mesh_transpose, (0, 1)) + # else: + # xs.mark_sharding(layer.weight, two_d_mesh, (0, 1)) + # if 'output' in name: + # xs.mark_sharding(layer.weight, col_mesh, (0, 1)) + if dynamo: if USE_CUDA: # Inductor errors out when compiles _generate_one_token_fn. diff --git a/llama/model.py b/llama/model.py index 88cc7ca1d..e9482c9eb 100755 --- a/llama/model.py +++ b/llama/model.py @@ -35,6 +35,15 @@ class ModelArgs: max_seq_len: int = 2048 quant: bool = False + def print_values(self): + print(f'[WONJOO] ModelArgs') + print(f'[WONJOO] dim={self.dim}') + print(f'[WONJOO] n_layers={self.n_layers}') + print(f'[WONJOO] n_heads={self.n_heads}') + print(f'[WONJOO] max_batch_size={self.max_batch_size}') + print(f'[WONJOO] max_seq_len={self.max_seq_len}') + print(f'[WONJOO] quant={self.quant}') + class RMSNorm(torch.nn.Module): def __init__(self, dim: int, eps: float = 1e-6):