Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(WIP) Support targeting the embedding layer for LoRA #501

Open
wants to merge 13 commits into
base: main
Choose a base branch
from

Conversation

ajtejankar
Copy link
Contributor

@ajtejankar ajtejankar commented Jun 6, 2024

What does this PR do?

  1. Re-organize the code in BatchLoraWeights.load. This function was a bit hard to understand as there were multiple list comprehensions with almost same looping logic. So, merged all of them into two loops for improved clarity. @tgaddair Can you confirm if this looks good? I can revert back to the original code in case this change can cause problems.
  2. (WIP) Support embedding layer as a target module. This is mostly done except multi-gpu inference.

This function was a bit hard to understand as there were multiple list
comprehensions with almost same looping logic. So, merged all of them
into a single for loop so for improved clarity.
@ajtejankar ajtejankar requested a review from tgaddair June 6, 2024 21:07
@ajtejankar ajtejankar self-assigned this Jun 8, 2024
@ajtejankar
Copy link
Contributor Author

@tgaddair I am pushing a partially done commit that supports embedding layer loras.

  • Similar to HF implementation, lora_A is used for embedding lookup while lora_B is multiplied
  • Prevents lora_A transpose when in BGMV mode
  • Contains two implementations to replace two kernels: SGMV and BGMV
  • Both are implemented with for loops. How can we optimize them?
  • Cannot handle multi-GPU. I will need some help understanding sharding in LoRAX as I found it confusing. :(
  • Tested crudely by comparing with generation from HF, but need to add a proper test case.

@ajtejankar ajtejankar linked an issue Jun 8, 2024 that may be closed by this pull request
4 tasks
@@ -40,14 +41,20 @@ def map_weights_for_model(
adapter_weight_names = set()
module_map = {}
for weight_name in weight_names:
lora_a_name = f"base_model.model.{weight_name}.lora_A.weight"
lora_b_name = f"base_model.model.{weight_name}.lora_B.weight"
if EMBED_TOKENS in weight_name:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might need to make this embed_tokens name a property of the model rather than a constant, as I imagine it will vary from one architecture to the next.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. I will make the change.

module_map[weight_name] = {
"lora_A": (adapter_weights[lora_a_name], lora_a_name),
"lora_B": (adapter_weights[lora_b_name], lora_b_name),
"lora_A": (adapter_weights.pop(lora_a_name), lora_a_name),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I understand the purpose of using pop here, as it doesn't look like the adapter_weights are used below (unless it's used from the caller). In general, it's good to avoid modifying input objects unless it's clear that the function does that from the name, etc.

In this case, I would suggest cloning the adapter_weights dict at the top to avoid modifying the input, and then returning the modified adapter_weights if the caller needs to check which elements haven't been popped.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. I just realized that I may not even need to change this part since adapter_weight_names captures which weights were consumed, and I can use it in the caller to figure out if all weights were consumed.

batch_indices = [adapter_to_segment[idx] for idx in meta.adapter_indices.tolist()]
batch_indices = [idx if idx in rank_indices else -1 for idx in batch_indices]
batch_indices = [idx if idx in set(indices) else -1 for idx in batch_indices]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would rather keep the separate variable as the call to set(indices) each iteration of the loop is unnecessary.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, yes. I will revert back to the original code. The reason to change it was to not have rank_indices variable inside the for loop since the for loop itself loops over another rank_indices variable. Maybe, I can rename the rank_indices here.


# note(ajinkya): adapter weights are consumed during above mapping but if some are not then we may not be
# supporting all the weights in the adapter which should be an error but for now just logging it
if len(adapter_weights) > 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per above comment, would return the modified adapter weights as unused_adapter_weights or similar rather than relying on the input to be modified.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good.


return result

# def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why was this commented out? Was it raising an error?

I believe an all-reduce should be correct here, as the TensorParallelEmbedding implementation is row parallel.

Copy link
Contributor Author

@ajtejankar ajtejankar Jun 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it wasn't raising an error. I deliberately left it out since I didn't fully understand whether it would work. In TensorParallelEmbedding we're sharding the embedding weight matrix but we don't do that for linear layers so I wasn't sure that the weights would be sharded for TensorParallelAdapterRowEmbedding.

if adapter_idx not in adapter_weights:
continue
rank_indices[adapter_weights[adapter_idx].lora_a_r].append(segment_idx)
adapter_to_segment[adapter_idx] = segment_idx
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely looks cleaner. I believe we have a few unit tests to verify this is working correctly, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I saw some test cases, but I am planning to add missing ones as well. I anyway need to add test cases to make sure that ours and HF implementation match.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will take a proper look at the test cases.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Supporting LmHead and Embedding Layers for Adapters
2 participants