-
Notifications
You must be signed in to change notification settings - Fork 15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Ray Disaggregated Serving MVP #106
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
High level comment - it looks like the main difference for is_disaggregated
within PyTorchRayEngine
is whether or not prefill returns outputs.
If the prefill/decode/interleave functionality is essentially the same, then I guess it's an implementation detail for orchestrator to trigger the transfer. If so, then it possible to exclude is_disaggregated
from the worker? That'd simplify the complexity
Simplified the prefill call from engine side. On the worker side. Yes, they are same on insert and decode side. But I feel it's better to keep disaggregated and interleave for prefill. Several reasons:
|
I think that makes sense to me, thanks! |
This PR enable pytorch engine disaggregated serving on multiple TPU POD slices.
This PR delivered:
Result validation:
Command:
python /home/{user}/jetstream-pytorch/run_interactive_disaggregated.py --size=7b --batch_size=1 --is_disaggregated=True --num_hosts=8 --decode_pod_slice_name={user}-tpu-vm-2 --model_name=llama-2 --max_cache_length=2048 --quantize_weights=False --quantize_kv_cache=False --checkpoint_path=/home/{user}/data/llama-2-7b-chat-safetensor/model.safetensors --tokenizer_path=/home/{user}/data/tokenizer.model --sharding_config=/home/{user}/jetstream-pytorch/default_shardings/llama.yaml
Interleave result:
Disaggregated result:
Next Steps:
5: Support multiple prefill engine and multiple decode engine