Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error Running run_ray_serve_interleave with Llama3 8B #169

Open
ryanaoleary opened this issue Aug 7, 2024 · 0 comments
Open

Error Running run_ray_serve_interleave with Llama3 8B #169

ryanaoleary opened this issue Aug 7, 2024 · 0 comments

Comments

@ryanaoleary
Copy link

I'm receiving an error when attempting to run:

ray job submit -- python run_ray_serve_interleave.py  --tpu_chips=4 --num_hosts=1 --size=8B --model_name=llama-3 --batch_size=8 --max_cache_length=2048 --tokenizer_path=$tokenizer_path --checkpoint_path=$output_ckpt_dir --quantize_weights=True --quantize_type="int8_per_channel" --quantize_kv_cache=True --sharding_config="default_shardings/llama.yaml"

on a single-host v4 TPU of 2x2x1 topology. The error is:

ray.exceptions.ActorDiedError: The actor died because of an error raised in its creation task, [36mray::PyTorchRayWorker.__init__()[39m (pid=5137, ip=10.168.0.16, actor_id=243ec964a2f41eae1707d84404000000, repr=<jetstream_pt.ray_worker.PyTorchRayWorker object at 0x7abe4074add0>)
File "/home/ray/jetstream-pytorch/jetstream_pt/ray_worker.py", line 200, in __init__
 pt_model = model_exportable.Transformer(args, env)
File "/home/ray/jetstream-pytorch/jetstream_pt/third_party/llama/model_exportable.py", line 192, in __init__
 self.tok_embeddings = Embedding(
File "/home/ray/jetstream-pytorch/jetstream_pt/layers.py", line 57, in __init__
 table = torch.ones(
File "/home/ray/anaconda3/lib/python3.10/site-packages/torch/_prims_common/wrappers.py", line 252, in _fn
 result = fn(*args, **kwargs)
File "/home/ray/anaconda3/lib/python3.10/site-packages/torch/_refs/__init__.py", line 4774, in ones
 size = utils.extract_shape_from_varargs(size)
File "/home/ray/anaconda3/lib/python3.10/site-packages/torch/_prims_common/__init__.py", line 854, in extract_shape_from_varargs
 validate_shape(shape) # type: ignore[arg-type]
File "/home/ray/anaconda3/lib/python3.10/site-packages/torch/_prims_common/__init__.py", line 588, in validate_shape
 validate_dim_length(l)

This seems to be related to this logic in ray_worker.py that creates the pt_model:

env_data.model_type = "llama-2-" + param_size
env_data.num_layers = args.n_layers
env = JetEngineEnvironment(env_data)
pt_model = model_exportable.Transformer(args, env)

should the model_type for llama models be hardcoded to llama-2?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant