-
Notifications
You must be signed in to change notification settings - Fork 149
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
(WIP) Support targeting the embedding layer for LoRA #501
base: main
Are you sure you want to change the base?
Conversation
This function was a bit hard to understand as there were multiple list comprehensions with almost same looping logic. So, merged all of them into a single for loop so for improved clarity.
It's the same as the one used in outer for loop which can cause confusion
@tgaddair I am pushing a partially done commit that supports embedding layer loras.
|
server/lorax_server/adapters/lora.py
Outdated
@@ -40,14 +41,20 @@ def map_weights_for_model( | |||
adapter_weight_names = set() | |||
module_map = {} | |||
for weight_name in weight_names: | |||
lora_a_name = f"base_model.model.{weight_name}.lora_A.weight" | |||
lora_b_name = f"base_model.model.{weight_name}.lora_B.weight" | |||
if EMBED_TOKENS in weight_name: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We might need to make this embed_tokens
name a property of the model rather than a constant, as I imagine it will vary from one architecture to the next.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense. I will make the change.
server/lorax_server/adapters/lora.py
Outdated
module_map[weight_name] = { | ||
"lora_A": (adapter_weights[lora_a_name], lora_a_name), | ||
"lora_B": (adapter_weights[lora_b_name], lora_b_name), | ||
"lora_A": (adapter_weights.pop(lora_a_name), lora_a_name), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure I understand the purpose of using pop
here, as it doesn't look like the adapter_weights
are used below (unless it's used from the caller). In general, it's good to avoid modifying input objects unless it's clear that the function does that from the name, etc.
In this case, I would suggest cloning the adapter_weights
dict at the top to avoid modifying the input, and then returning the modified adapter_weights
if the caller needs to check which elements haven't been popped.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good. I just realized that I may not even need to change this part since adapter_weight_names
captures which weights were consumed, and I can use it in the caller to figure out if all weights were consumed.
server/lorax_server/adapters/lora.py
Outdated
batch_indices = [adapter_to_segment[idx] for idx in meta.adapter_indices.tolist()] | ||
batch_indices = [idx if idx in rank_indices else -1 for idx in batch_indices] | ||
batch_indices = [idx if idx in set(indices) else -1 for idx in batch_indices] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would rather keep the separate variable as the call to set(indices)
each iteration of the loop is unnecessary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, yes. I will revert back to the original code. The reason to change it was to not have rank_indices
variable inside the for loop since the for loop itself loops over another rank_indices
variable. Maybe, I can rename the rank_indices
here.
server/lorax_server/utils/adapter.py
Outdated
|
||
# note(ajinkya): adapter weights are consumed during above mapping but if some are not then we may not be | ||
# supporting all the weights in the adapter which should be an error but for now just logging it | ||
if len(adapter_weights) > 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Per above comment, would return the modified adapter weights as unused_adapter_weights
or similar rather than relying on the input to be modified.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good.
server/lorax_server/utils/layers.py
Outdated
|
||
return result | ||
|
||
# def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why was this commented out? Was it raising an error?
I believe an all-reduce should be correct here, as the TensorParallelEmbedding implementation is row parallel.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, it wasn't raising an error. I deliberately left it out since I didn't fully understand whether it would work. In TensorParallelEmbedding
we're sharding the embedding weight matrix but we don't do that for linear layers so I wasn't sure that the weights would be sharded for TensorParallelAdapterRowEmbedding
.
server/lorax_server/adapters/lora.py
Outdated
if adapter_idx not in adapter_weights: | ||
continue | ||
rank_indices[adapter_weights[adapter_idx].lora_a_r].append(segment_idx) | ||
adapter_to_segment[adapter_idx] = segment_idx |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Definitely looks cleaner. I believe we have a few unit tests to verify this is working correctly, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I saw some test cases, but I am planning to add missing ones as well. I anyway need to add test cases to make sure that ours and HF implementation match.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will take a proper look at the test cases.
1. Make embedding weight name a property of the model 2. Do not pop the adapter weight names 3. Uncomment collect_lora method
What does this PR do?