From 051f85089848e1527dedc7530f4d509b3950301c Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Mon, 12 Feb 2024 21:43:33 -0800 Subject: [PATCH] Fixed TIES merging to calculate sign before applying weights --- .../lorax_server/utils/merges/strategies.py | 23 +++++++++++-------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/server/lorax_server/utils/merges/strategies.py b/server/lorax_server/utils/merges/strategies.py index 99dd5d81a..938a7b3a9 100644 --- a/server/lorax_server/utils/merges/strategies.py +++ b/server/lorax_server/utils/merges/strategies.py @@ -1,7 +1,7 @@ from abc import ABC from collections import defaultdict import copy -from typing import TYPE_CHECKING, Dict, List, Tuple, Type +from typing import TYPE_CHECKING, Dict, List, Tuple, Type, Union import torch from peft import LoraConfig @@ -17,8 +17,11 @@ from lorax_server.utils.adapter import ModuleMap -def _apply_weights(tensors: List[torch.Tensor], w: torch.Tensor) -> torch.Tensor: - t = torch.stack(tensors, dim=0) +def _apply_weights(tensors: Union[torch.Tensor, List[torch.Tensor]], w: torch.Tensor) -> torch.Tensor: + if isinstance(tensors, torch.Tensor): + t = tensors + else: + t = torch.stack(tensors, dim=0) # element-wise weighting of each task tensor # need to unsqueeze weights to match task tensor dimensions @@ -50,11 +53,12 @@ def __init__(self, density: float, majority_sign_method: str = "total", **kwargs def merge(self, task_tensors: List[torch.Tensor], weights: torch.Tensor) -> torch.Tensor: # sparsify task_tensors = [prune(tensor, self.density, method="magnitude") for tensor in task_tensors] + task_tensors = torch.stack(task_tensors, dim=0) + + # elect sign before applying weights + majority_sign_mask = calculate_majority_sign_mask(task_tensors, method=self.majority_sign_method) weighted_task_tensors = _apply_weights(task_tensors, weights) - # elect sign - majority_sign_mask = calculate_majority_sign_mask(weighted_task_tensors, method=self.majority_sign_method) - # disjoint merge return disjoint_merge(weighted_task_tensors, majority_sign_mask) @@ -78,11 +82,12 @@ def __init__(self, density: float, majority_sign_method: str = "total", **kwargs def merge(self, task_tensors: List[torch.Tensor], weights: torch.Tensor) -> torch.Tensor: # sparsify task_tensors = [prune(tensor, self.density, method="random", rescale=True) for tensor in task_tensors] + task_tensors = torch.stack(task_tensors, dim=0) + + # elect sign before applying weights + majority_sign_mask = calculate_majority_sign_mask(task_tensors, method=self.majority_sign_method) weighted_task_tensors = _apply_weights(task_tensors, weights) - # elect sign - majority_sign_mask = calculate_majority_sign_mask(weighted_task_tensors, method=self.majority_sign_method) - # disjoint merge mixed_task_tensors = disjoint_merge(weighted_task_tensors, majority_sign_mask) return mixed_task_tensors