-
Notifications
You must be signed in to change notification settings - Fork 14
[Docs] add integrate_with_tunix.md #202
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Thank you very much for the proposal! Generally look good to me. There is one thing I'd like to pointed out, we have 2 modes to run Tunix: Multi Controller Jax or Pathways. In order to comply with Pathways single process requirement, we need to run sglang backend in the main process. Can you describe what changes are necessary in the proposal? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you!
@wang2yn84 Currently we have relatively little understanding of Pathways, and we're not quite clear on how it implements single controller + distributed computing, as well as how to use and integrate it. It would be great if there were some more detailed documentation available. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you!
|
||
nnx.update(self._model, current_state) | ||
else: | ||
nnx.update(self._model, updated_weights) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tunix implementation of the training model might be different from SGLang's version. So the weight name mapping, or potentially transposing or other operations might be necessary. Really depends on the model implementation on both sides.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it! Thanks for your suggestion. cc @JamesBrianD
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@wang2yn84 Currently I find out model(like llama3) implements methods like to_hf_mappings, lora_to_hf_mappings
, etc. These methods are used to update_params
. Different inference engines have different weight names, will Tunix follow this solution to implement every mapping for different framework every model? cc @JamesBrianD
Another question: Not all models implement all methods. For example, llama3 does not implement to_hf_hook_fns
, but model.to_hf_hook_fns()
will be called when initializing VllmSampler
in rl/rollout/vllm_rollout.py. This confuses me. Could you explain it for me?
self._sampler = vllm_sampler.VllmSampler(
tokenizer=tokenizer,
config=vllm_sampler.VllmConfig(
max_model_len=cache_config_or_size,
mesh=mesh,
model_version=model_version,
hbm_utilization=hbm_utilization,
init_with_random_weights=init_with_random_weights,
tpu_backend_type=tpu_backend_type,
mapping_config=vllm_sampler.MappingConfig(
to_hf_mappings=model.to_hf_mappings(),
to_hf_transpose_keys=model.to_hf_transpose_keys(),
to_hf_hook_fns=model.to_hf_hook_fns(),
lora_to_hf_mappings=model.lora_to_hf_mappings(),
lora_config=lora_config,
),
),
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@wang2yn84 Thank you for your suggestion. We will consider adding key_mappings and transpose_keys (following the example of utils.transfer_state_with_mappings) to handle the mapping and transformation between states.
Yea completely understand! Can you try https://arxiv.org/abs/2203.12533? I believe this paper can give you a high level view of the system. Let me know if it helps. |
Based on paper you provided and our internal discussion, here are changes we need to do. Change1: Currently Engine will create subprocess after initialization, in order to meet the single process requirement in Pathways, we will rewrite the related parts in Engine. Change2: We get TPUs through Other discussions:
Note: Multi Controller Jax: Currently sglang-jax can run different hosts, so we think we support to run multi-controller Tunix. |
can you explain why needs single process metained in "In order to comply with Pathways single process requirement, we need to run sglang backend in the main process", as we all know ,sglang run with multi process, can sglang run using a single process ? @wang2yn84 |
According to our communication in Google Chat, here are the conclusions about discussions which may influence the integration. D1: How to be aware of Pathway environment? Q2: How to get devices like 'jax.devices()' in Pathways? D3: What are the usage of sampling seed? cc @jimoosciuc |
f3c6eb2
to
545d112
Compare
545d112
to
f167dfb
Compare
No description provided.