Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ray Disaggregated Serving MVP #106

Merged
merged 11 commits into from
May 29, 2024

Conversation

FanhaiLu1
Copy link
Collaborator

@FanhaiLu1 FanhaiLu1 commented May 29, 2024

This PR enable pytorch engine disaggregated serving on multiple TPU POD slices.

This PR delivered:

  1. Engine do prefill in one POD slice and do decode in another POD slice
  2. Transfer prefill result from one POD slice to another POD slice
  3. Load weight and sharding in multiple host on multiple POD slice
  4. Compute meaningful decode result

Result validation:

  1. Disaggregated serving, here is the logs (tpu-vm-1 as decode engine):

---- Do prefill in prefill engine pod_slice_name: tpu-vm-2
---- Transfer prefill result to decode engine pod_slice_name: tpu-vm-1
---- Do insert in decode engine pod_slice_name: tpu-vm-1

  1. Correct result compared with interleave serving

Command:
python /home/{user}/jetstream-pytorch/run_interactive_disaggregated.py --size=7b --batch_size=1 --is_disaggregated=True --num_hosts=8 --decode_pod_slice_name={user}-tpu-vm-2 --model_name=llama-2 --max_cache_length=2048 --quantize_weights=False --quantize_kv_cache=False --checkpoint_path=/home/{user}/data/llama-2-7b-chat-safetensor/model.safetensors --tokenizer_path=/home/{user}/data/tokenizer.model --sharding_config=/home/{user}/jetstream-pytorch/default_shardings/llama.yaml

Interleave result:


to find purpose and fulfillment.

I believe that everyone has a unique purpose and that it is up to each individual to discover and pursue theirs.

Disaggregated result:


to find purpose and fulfillment.

I believe that everyone has a unique purpose and that it is up to each individual to discover and pursue theirs.

Next Steps:

  1. Integrate with jetstream orchestrator
  2. Add readme for both interactive run and jetstream run
  3. Fix performance issue
  4. Test llama 70
    5: Support multiple prefill engine and multiple decode engine

jetstream_pt/ray_engine.py Outdated Show resolved Hide resolved
jetstream_pt/ray_engine.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@allenwang28 allenwang28 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

High level comment - it looks like the main difference for is_disaggregated within PyTorchRayEngine is whether or not prefill returns outputs.

If the prefill/decode/interleave functionality is essentially the same, then I guess it's an implementation detail for orchestrator to trigger the transfer. If so, then it possible to exclude is_disaggregated from the worker? That'd simplify the complexity

install_everything.sh Outdated Show resolved Hide resolved
jetstream_pt/ray_engine.py Outdated Show resolved Hide resolved
jetstream_pt/ray_engine.py Show resolved Hide resolved
@FanhaiLu1
Copy link
Collaborator Author

FanhaiLu1 commented May 29, 2024

High level comment - it looks like the main difference for is_disaggregated within PyTorchRayEngine is whether or not prefill returns outputs.

If the prefill/decode/interleave functionality is essentially the same, then I guess it's an implementation detail for orchestrator to trigger the transfer. If so, then it possible to exclude is_disaggregated from the worker? That'd simplify the complexity

Simplified the prefill call from engine side. On the worker side. Yes, they are same on insert and decode side. But I feel it's better to keep disaggregated and interleave for prefill. Several reasons:

  1. Worker doesn't need handle the logic of disaggregated or interleave. Worker can do both interleave and disaggregated prefill, let the engine chose what worker can do.
  2. There are large difference between disaggregated or interleave prefill. From interleave side, we directly save the prefill result in it's local HBM. But from disaggregated side, we need extract call to all gather the cache and return the prefill result. The another difference is that pre result cache is jax array in interleave, but it's np array disaggregated
  3. Worker already had pretty complex logic, it's better to keep some logic in engine (engine has simple code logic right now)

@allenwang28
Copy link
Collaborator

But I feel it's better to keep disaggregated and interleave for prefill. Several reasons:

I think that makes sense to me, thanks!

jetstream_pt/ray_engine.py Show resolved Hide resolved
@FanhaiLu1 FanhaiLu1 merged commit c360158 into AI-Hypercomputer:main May 29, 2024
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants