Skip to content

Commit

Permalink
Merge pull request #65 from tanganke/develop
Browse files Browse the repository at this point in the history
add DARE Ties Merging
  • Loading branch information
tanganke authored Jan 10, 2025
2 parents 5b6f07e + 4e7303d commit 955f0d3
Show file tree
Hide file tree
Showing 6 changed files with 158 additions and 16 deletions.
15 changes: 15 additions & 0 deletions config/method/dare/ties_merging.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
_target_: fusion_bench.method.dare.DareTiesMerging

# === DARE parameters ===
sparsity_ratio: 0.5
only_on_linear_weights: false
rescale: true

# === Ties merging parameters ===
# Scaling factor $\lambda$
scaling_factor: 0.5
threshold: 20
# List of keys to remove from the state dict, default is empty
remove_keys: []
# Function to merge the models, default is sum. Options are 'sum', 'mean', and 'max'
merge_func: sum
4 changes: 2 additions & 2 deletions fusion_bench/method/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
"weighted_average": ["WeightedAverageAlgorithm", "WeightedAverageForLLama"],
"task_arithmetic": ["TaskArithmeticAlgorithm"],
"ties_merging": ["TiesMergingAlgorithm"],
"dare": ["DareSimpleAverage", "DareTaskArithmetic"],
"dare": ["DareSimpleAverage", "DareTaskArithmetic", "DareTiesMerging"],
"fisher_merging": [
"FisherMergingForCLIPVisionModel",
"FisherMergingAlgorithmForGPT2",
Expand Down Expand Up @@ -110,7 +110,7 @@
ConcreteTaskArithmeticAlgorithmForCLIP,
ConcreteTaskWiseAdaMergingForCLIP,
)
from .dare import DareSimpleAverage, DareTaskArithmetic
from .dare import DareSimpleAverage, DareTaskArithmetic, DareTiesMerging
from .dawe import DataAdaptiveWeightEnsemblingForCLIP
from .depth_upscaling import DepthUpscalingAlgorithm, DepthUpscalingForLlama
from .dummy import DummyAlgorithm
Expand Down
1 change: 1 addition & 0 deletions fusion_bench/method/dare/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# flake8: noqa F401
from .simple_average import DareSimpleAverage
from .task_arithmetic import DareTaskArithmetic
from .ties_merging import DareTiesMerging
21 changes: 14 additions & 7 deletions fusion_bench/method/dare/task_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,21 +33,28 @@ def __init__(
self.rescale = rescale
super().__init__(**kwargs)

def _load_task_vector(
self,
modelpool: BaseModelPool,
model_name: str,
pretrained_model: nn.Module,
):
finetuned_model = modelpool.load_model(model_name)
task_vector = module_sub_(finetuned_model, pretrained_model)
return task_vector

@torch.no_grad()
def run(self, modelpool: BaseModelPool):
assert (
self.sparsity_ratio >= 0 and self.sparsity_ratio <= 1
), "Sparsity ratio must be between 0 and 1"
pretrained_model = modelpool.load_pretrained_model()
finetuned_models = {
model_name: modelpool.load_model(model_name)
for model_name in modelpool.model_names
}

# load task vectors
task_vectors = {
model_name: module_sub_(finetuned_models[model_name], pretrained_model)
for model_name in finetuned_models
model_name: self._load_task_vector(modelpool, model_name, pretrained_model)
for model_name in modelpool.model_names
}
del finetuned_models

# drop and rescale task vectors
for model_name, tv in task_vectors.items():
Expand Down
100 changes: 100 additions & 0 deletions fusion_bench/method/dare/ties_merging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from typing import Literal

import torch
from torch import Tensor, nn

from fusion_bench import BaseAlgorithm, BaseModelPool
from fusion_bench.method.ties_merging.ties_merging_utils import ties_merging
from fusion_bench.utils.parameters import state_dict_to_vector, vector_to_state_dict
from fusion_bench.utils.state_dict_arithmetic import state_dict_sum

from .utils import (
module_random_drop_,
module_sub_,
param_random_drop_,
trainable_state_dict,
)


class DareTiesMerging(BaseAlgorithm):
def __init__(
self,
# DARE parameters
sparsity_ratio: float,
only_on_linear_weights: bool,
rescale: bool,
# Ties merging parameters
scaling_factor: float,
threshold: int,
remove_keys: list[str],
merge_func: Literal["sum", "mean", "max"],
**kwargs,
):
self.sparsity_ratio = sparsity_ratio
self.only_on_linear_weights = only_on_linear_weights
self.rescale = rescale
self.scaling_factor = scaling_factor
self.threshold = threshold
self.remove_keys = remove_keys
self.merge_func = merge_func
super().__init__(**kwargs)

@torch.no_grad()
def _load_task_vector(
self,
modelpool: BaseModelPool,
model_name: str,
pretrained_model: nn.Module,
):
finetuned_model = modelpool.load_model(model_name)
task_vector = module_sub_(finetuned_model, pretrained_model)
return task_vector

def run(self, modelpool: BaseModelPool):
assert (
self.sparsity_ratio >= 0 and self.sparsity_ratio <= 1
), "Sparsity ratio must be between 0 and 1"
pretrained_model = modelpool.load_pretrained_model()

# load task vectors
task_vectors = {
model_name: self._load_task_vector(modelpool, model_name, pretrained_model)
for model_name in modelpool.model_names
}

# drop and rescale task vectors
for model_name, tv in task_vectors.items():
if self.only_on_linear_weights:
for module_name, module in tv.named_modules():
if isinstance(module, nn.Linear):
print(f"pruning model: `{model_name}`, layer: {module_name}.")
param_random_drop_(
module.weight, self.sparsity_ratio, rescale=self.rescale
)
else:
print(f"pruning model: `{model_name}`")
module_random_drop_(tv, self.sparsity_ratio, rescale=self.rescale)

ptm_check = pretrained_model.state_dict()
flat_ptm = state_dict_to_vector(ptm_check, self.remove_keys)
tv_flat_checks = torch.vstack(
[
state_dict_to_vector(check.state_dict(), self.remove_keys)
for check in task_vectors.values()
]
)
del task_vectors

# Perform TIES Merging
merged_tv = ties_merging(
tv_flat_checks,
reset_thresh=self.threshold,
merge_func=self.merge_func,
)
merged_check = flat_ptm + self.scaling_factor * merged_tv
merged_state_dict = vector_to_state_dict(
merged_check, ptm_check, remove_keys=self.remove_keys
)

pretrained_model.load_state_dict(merged_state_dict)
return pretrained_model
33 changes: 26 additions & 7 deletions fusion_bench/utils/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,52 +45,71 @@ def trainable_state_dict(


def state_dict_to_vector(
state_dict: StateDictType,
state_dict: Union[StateDictType, nn.Module],
remove_keys: Optional[List[str]] = None,
):
"""
Convert a state dictionary to a vector.
Args:
state_dict (dict): The state dictionary to convert.
state_dict (Union[dict[str, torch.Tensor], nn.Module]): The state dictionary to convert.
remove_keys (list, optional): List of keys to remove from the state dictionary. Defaults to [].
Returns:
torch.Tensor: The converted vector.
"""
remove_keys = remove_keys if remove_keys is not None else []
shared_state_dict = copy.deepcopy(state_dict)

if isinstance(state_dict, nn.Module):
shared_state_dict = state_dict.state_dict()
else:
shared_state_dict = copy.copy(state_dict)

# remove the keys to be removed
for key in remove_keys:
if key in shared_state_dict:
del shared_state_dict[key]

# sort the reference dict
sorted_shared_state_dict = OrderedDict(sorted(shared_state_dict.items()))
return nn.utils.parameters_to_vector(

vector = nn.utils.parameters_to_vector(
[value.reshape(-1) for key, value in sorted_shared_state_dict.items()]
)
return vector


def vector_to_state_dict(
vector: torch.Tensor,
state_dict: StateDictType,
state_dict: Union[StateDictType, nn.Module],
remove_keys: Optional[List[str]] = None,
):
"""
Convert a vector to a state dictionary.
Args:
vector (torch.Tensor): The vector to convert.
state_dict (dict): The reference state dictionary to define the order of the vector.
state_dict (Union[dict[str, torch.Tensor], nn.Module]): The reference state dictionary to define the order of the vector.
remove_keys (list, optional): List of keys to remove from the reference state dictionary. Defaults to [].
Returns:
dict: The converted state dictionary.
"""
remove_keys = remove_keys if remove_keys is not None else []

# create a reference dict to define the order of the vector
reference_dict = copy.deepcopy(state_dict)
if isinstance(state_dict, nn.Module):
reference_dict = state_dict.state_dict()
else:
# shallow copy the state_dict
reference_dict = copy.copy(state_dict)

# remove the keys to be removed
for key in remove_keys:
if key in reference_dict:
del reference_dict[key]

# sort the reference dict
sorted_reference_dict = OrderedDict(sorted(reference_dict.items()))

# create a shared state dict using the reference dict
Expand Down

0 comments on commit 955f0d3

Please sign in to comment.