Skip to content

Commit

Permalink
add optional alpha to TSVM
Browse files Browse the repository at this point in the history
  • Loading branch information
tanganke committed Jan 18, 2025
1 parent 1777862 commit 60df3e1
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 2 deletions.
Original file line number Diff line number Diff line change
@@ -1,2 +1,8 @@
_target_: fusion_bench.method.TaskSingularVectorMerging
remove_keys: null

# alpha is a float or a list of floats
# example:
# alpha: 1
# alpha: [1, 0.5, 0.25]
alpha: 1
24 changes: 22 additions & 2 deletions fusion_bench/method/task_singular_vector/TSVM.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,20 @@
```
"""

from typing import List, Optional
from typing import List, Optional, Union, Iterable

import torch
from torch import Tensor, nn
from omegaconf import ListConfig

from fusion_bench import BaseAlgorithm
from fusion_bench.mixins import LightningFabricMixin
from fusion_bench.utils import timeit_context
from fusion_bench.utils.state_dict_arithmetic import state_dict_add, state_dict_sub
from fusion_bench.utils.state_dict_arithmetic import (
state_dict_add,
state_dict_sub,
state_dict_mul,
)
from fusion_bench.utils.type import StateDictType

from .utils import (
Expand All @@ -33,9 +38,11 @@ class TaskSingularVectorMerging(BaseAlgorithm, LightningFabricMixin):

def __init__(
self,
alpha: Union[float, Iterable[float]] = None,
remove_keys: Optional[List[str]] = None,
**kwargs,
):
self.alpha = alpha
self.remove_keys = remove_keys if remove_keys is not None else []
super().__init__(**kwargs)

Expand All @@ -50,13 +57,26 @@ def run(self, modelpool):

with timeit_context("Flattening out Checkpoints"):
task_vectors = [state_dict_sub(check, ptm_check) for check in ft_checks]
if isinstance(self.alpha, Iterable):
assert len(self.alpha) == len(
task_vectors
), "Alpha and task vectors must have the same length"
task_vectors = [
state_dict_mul(state_dict=tv, scalar=alpha)
for alpha, tv in zip(self.alpha, task_vectors)
]

new_merged_tv = TSVM_utils.compute_and_sum_svd_mem_reduction(
task_vectors,
exclude_keys=self.remove_keys,
accelerator=self.fabric.device,
)

# If alpha is a float, we need to scale the new merged task vector by alpha
if self.alpha is not None and isinstance(self.alpha, float):
print(f"Scaling new merged task vector by alpha: {self.alpha}")
new_merged_tv = state_dict_mul(state_dict=new_merged_tv, scalar=self.alpha)

pretrained_model.load_state_dict(
state_dict_add(new_merged_tv, pretrained_model.state_dict())
)
Expand Down

0 comments on commit 60df3e1

Please sign in to comment.