Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LoRA] introduce LoraBaseMixin to promote reusability. #8774

Merged
merged 59 commits into from
Jul 25, 2024

Conversation

sayakpaul
Copy link
Member

@sayakpaul sayakpaul commented Jul 3, 2024

What does this PR do?

It is basically a mirror of #8670. I had accidentally merged it but I have reverted it in #8773. Apologies for this.

Check #8774 (comment) as well.

I have made comments in line to address the questions brought up by @yiyixuxu.

@sayakpaul sayakpaul requested review from DN6 and yiyixuxu July 3, 2024 01:37
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@DN6
Copy link
Collaborator

DN6 commented Jul 3, 2024

Nice initiative 👍🏽 . A lot to unpack here, so perhaps it's best to start bit by bit. I just went over the pipeline related components here.

Regarding the LoraBaseMixin, at the moment I think it might be doing a bit too much.

There are quite a few methods in there that are making assumptions about the inheriting class using the method, which isn't really how a base class should behave. So loading methods related to specific model components are better left out e.g. load_lora_into_text_encoder. If this method is used across different pipelines with no changes, then it's better to create a utility function that does this and call it from the inheriting class. Or redefine the method in the inheriting class and use copied from.

I would assume that these are the methods that need to be defined for managing LoRAs across all pipelines?

class LoraBaseMixin:

    @classmethod
    def _optionally_disable_offloading(cls, _pipeline):
        raise NotImplementedError()

    @classmethod
    def _fetch_state_dict(
        cls,
        pretrained_model_name_or_path_or_dict,
        weight_name,
        use_safetensors,
        local_files_only,
        cache_dir,
        force_download,
        resume_download,
        proxies,
        token,
        revision,
        subfolder,
        user_agent,
        allow_pickle,
    ):
        raise NotImplementedError()

    @classmethod
    def _best_guess_weight_name(
        cls, pretrained_model_name_or_path_or_dict, file_extension=".safetensors", local_files_only=False
    ):
        return NotImplementedError()

    @classmethod
    def save_lora_weights(cls, **kwargs):
        raise NotImplementedError("`save_lora_weights()` not implemented.")

    @classmethod
    def lora_state_dict(cls, **kwargs):
        raise NotImplementedError("`lora_state_dict()` is not implemented.")

    def load_lora_weights(self, **kwargs):
        raise NotImplementedError("`load_lora_weights()` is not implemented.")

    def unload_lora_weights(self, **kwargs):
        raise NotImplementedError("`unload_lora_weights()` is not implemented.")

    def fuse_lora(self, **kwargs):
        raise NotImplementedError("`fuse_lora()` is not implemented.")

    def unfuse_lora(self, **kwargs):
        raise NotImplementedError("`unfuse_lora()` is not implemented.")

    def disable_lora(self):
        raise NotImplementedError("`disable_lora()` is not implemented.")

    def enable_lora(self):
        raise NotImplementedError("`unfuse_lora()` is not implemented.")

    def get_active_adapters(self):
        raise NotImplementedError("`delete_adapters()` is not implemented.")

    def delete_adapters(self, adapter_names):
        raise NotImplementedError("`delete_adapters()` is not implemented.")

    def set_lora_device(self, adapter_names):
        raise NotImplementedError("`delete_adapters()` is not implemented.")

    @staticmethod
    def pack_weights(layers, prefix):
        raise NotImplementedError()

    @staticmethod
    def write_lora_layers(
        state_dict: Dict[str, torch.Tensor],
        save_directory: str,
        is_main_process: bool,
        weight_name: str,
        save_function: Callable,
        safe_serialization: bool,
    ):
        raise NotImplementedError()

    @property
    def lora_scale(self) -> float:
        raise NotImplementedError()

Quite a few of these methods probably cannot be defined in the base class, such as load_lora_weights and unload_lora_weights, fuse_lora and unfuse_lora, since they deal with specific pipeline components
They might also require arguments specific to the pipeline type or pipeline components.

I think it might be better to define these methods in a pipeline specific class that inherits from the LoraBaseMixin. Or just as it's own Mixin class. I don't have a strong feeling about either approach. e.g. StableDiffusionLoraLoaderMixin could look like:

class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
    _lora_loadable_modules = ["unet", "text_encoder"]

    def load_lora_weights(
        self,
        pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
        adapter_name: Optional[str] = None,
        **kwargs,
    ):
        _load_lora_into_unet(**kwargs)
        _load_lora_into_text_encoder(**kwargs)

    def fuse_lora(self, components=["unet", "text_encoder"], **kwargs):
        for fuse_component in components:
            if fuse_component not in self._lora_loadable_modules:
                raise ValueError()

            model = getattr(self, fuse_component)
            # check if diffusers model
            if issubclass(model, ModelMixin):
                model.fuse_lora()
            # handle transformers models. 
            if issubclass(model, PretrainedModel):
                fuse_text_encoder()

I saw this comment about using the term "fuse_denoiser" in the fusing methods. I'm not so sure about that. I think if we want to fuse the LoRA in a specific component, it's better to pass in the actual name of the component used in pipeline, rather than track another attribute such as denoiser

I also think the constants and class attributes such as TEXT_ENCODER_NAME and is_unet_denoiser might not be needed if we use a single class attribute with a list of the names of the lora loadable components.

@sayakpaul
Copy link
Member Author

@DN6 as discussed over Slack, I have unified the PeftAdapterMixin class too so that we can have methods like fuse_lora(), delete_lora(), enable_lora(), etc. under one umbrella without having to define and copy-paste them for each model-specific loader mixins such as UNet2DConditionLoadersMixin.

One thing to note is that I had to still keep loaders/transformer_sd3.py to implement set_adapters() as this method varies from unet to transformer. This is because the block naming is different in these models. This is why you will also see set_adapters() in UNet2DConditionLoadersMixin.

We could have two additional classes under loaders/peft.py:

  • TransformerPeftAdapterMixin(PeftAdapterMixin)
  • UNet2DConditionPeftAdapterMixin(PeftAdapterMixin) to reimplement this method there and use them accordingly.

LMK.

@sayakpaul
Copy link
Member Author

@DN6 I think this is ready for another review now.

weights = [w if w is not None else 1.0 for w in weights]

# e.g. [{...}, 7] -> [{expanded dict...}, 7]
scale_expansion_fn = _SET_ADAPTER_SCALE_FN_MAPPING[self.__class__.__name__]
Copy link
Collaborator

@DN6 DN6 Jul 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's just add a check in case this is applied to a model that doesn't exist in the mapping. Edge case, because we would probably always verify, but better to be safe.

if `scale_expansion_fn` is not None:
 ....

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But scale_expansion_fn CANNOT be None no? We are directly indexing the dictionary here and not using get(). So, wrong indexing will anyway lead to an error. But LMK if I am missing something.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually on second thought, the check might be overkill. If we add to a model not in the mapping, we should error out.

Copy link
Collaborator

@DN6 DN6 Jul 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I was thinking is that SD3Transformer2DModel doesn't even need to be in the mapping. We use get to check if a scale_expansion_fn exists for a model class, and return None if it doesn't. Either approach works.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we add to a model not in the mapping, we should error out.

Yeah this already works. So, I would prefer that.

@sayakpaul sayakpaul requested a review from DN6 July 23, 2024 15:03
@sayakpaul
Copy link
Member Author

@DN6 anything else you would like me to address?

Copy link
Collaborator

@DN6 DN6 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM 👍🏽

@sayakpaul sayakpaul merged commit 527430d into main Jul 25, 2024
18 checks passed
@sayakpaul sayakpaul deleted the feat-lora-base-class branch July 25, 2024 16:11
@sayakpaul
Copy link
Member Author

Thanks for the massive help and guidance, Dhruv!

yiyixuxu added a commit that referenced this pull request Jul 25, 2024
yiyixuxu added a commit that referenced this pull request Jul 25, 2024
Revert "[LoRA] introduce LoraBaseMixin to promote reusability. (#8774)"

This reverts commit 527430d.
@yiyixuxu yiyixuxu restored the feat-lora-base-class branch July 25, 2024 19:12
@sayakpaul sayakpaul deleted the feat-lora-base-class branch July 26, 2024 01:40
@sayakpaul sayakpaul restored the feat-lora-base-class branch July 26, 2024 01:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants