Skip to content

Commit

Permalink
adapting the functions Callbacks
Browse files Browse the repository at this point in the history
  • Loading branch information
yunss-ML committed Aug 26, 2023
1 parent c4a717b commit 9c30ffc
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 72 deletions.
68 changes: 40 additions & 28 deletions core/Adapter.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,50 @@
import bitsandbytes as bnb
import loratorch as LoraT
import loralib as lora
import torch.nn as nn
from typing import dict , Optional , Union
from typing import List, Optional
import torch.nn as nn

class Adapters:
def __init__(self, layer_type: List[str]):
"""
Initialize an Adapters object with a list of supported layer types.
Args:
layer_type (List[str]): List of supported layer types.
"""
self.layer_type = layer_type

@staticmethod
def layer_type_check(layer: str) -> bool:
"""
Check if a given layer type is supported.
class Adapters(object):

def __init__(self, layerTyep:list, Method:func)-> nn.Module:
self.layer = layerTyep
Args:
layer (str): The layer type to check.
@staticmethod
def LayerType(self , layer):
layers = ["nn.Linear" , "nn.Embedding", "nn.Conv1d","nn.Conv2d"]
AdaptedLayer = []
for i in layer:
for j in layers:
if layer[i] == layers[j]:
AdaptedLayer.append(layer[i])
return f"{layers[i]} not support Please Visit \n Docs to list correct Layer support"
return AdaptedLayer
Returns:
bool: True if the layer type is supported, False otherwise.
"""
layers = ["nn.Linear", "nn.Embedding", "nn.Conv1d", "nn.Conv2d"]
return layer in layers

def __call__(self, fn):
if self.LayerType(self.layer):
def __fn():
print(f"Layers to adjusted Used AdapterLoRa: {[layer for layer in self.layer]}")
print("Adapter Applied:", fn.__name__)
"""
Decorator to apply an adapter function to specified layers.
Args:
fn (Callable): The adapter function to be applied.
Returns:
Callable: Decorated function with adapter applied.
"""
def decorated_fn():
if all(self.layer_type_check(layer) for layer in self.layer_type):
print(f"Layers to be adjusted using AdapterLoRa: {self.layer_type}")
print("Adapter Applied:", fn.__name__)
fn()
return __fn

else:
print("Some layer types are not supported.")
return decorated_fn


class Optimzer:
def __init__(self, Optimzer: nn.Module):
pass
class Optimizer:
def __init__(self, optimizer: nn.Module):
pass # You can add initialization logic here
4 changes: 2 additions & 2 deletions core/LayersAdaptes.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def Layer(model, new_layer):

return new_layer

@Adapters(layertyep)
@Adapters(LAYERS)
def LoRaLinear(method: str, model: nn.Module, Rank: Optional[int], threshold: Optional[int]):
"""
Replace a linear layer with a quantized layer using specified method.
Expand Down Expand Up @@ -73,7 +73,7 @@ def LoRaLinear(method: str, model: nn.Module, Rank: Optional[int], threshold: Op
else:
raise ValueError(f"Unsupported method or invalid method name: {method}")

@Adapters(layertyep)
@Adapters(LAYERS)
def LoRaEmbedding(
method: str,
model: nn.Module,
Expand Down
97 changes: 55 additions & 42 deletions core/utils.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,55 @@
import torch.nn as nn
from typing import Optional, Callable
from typing import Optional
from .Adapter import Adapters

def quantize_layer(method,layer, quantize_fn, quantize_fn_, Rank):
def quantize_layer(**kwargs):
"""
Apply the appropriate quantization function to the given layer.
Args:
layer (nn.Module): The layer to be quantized.
quantize_fn (Callable): The function to quantize a linear layer.
quantize_fn_ (Callable): The function to quantize an embedding layer.
Rank (int): The rank parameter for LoRA adaptation.
**kwargs: Dictionary containing the quantization parameters.
Returns:
nn.Module: The quantized layer.
"""
if isinstance(layer, nn.Linear) and quantize_fn is not None:
return quantize_fn(
method,
model,
Rank,
threshold
)
elif isinstance(layer, nn.Embedding) and quantize_fn_ is not None:
return quantize_fn_(
method,
model,
Rank,
lora_alpha,
scale_grad_by_freq,
padding_idx,
max_norm
)
else:
return layer
method = kwargs["method"]
model = kwargs["model"]
Rank = kwargs.get("Rank")
lora_alpha = kwargs.get("lora_alpha")
scale_grad_by_freq = kwargs.get("scale_grad_by_freq")
padding_idx = kwargs.get("padding_idx")
max_norm = kwargs.get("max_norm")
threshold = kwargs.get("threshold")

if isinstance(model, nn.Linear):
quantize_fn = kwargs["quantize_fn"]
if quantize_fn is not None:
return quantize_fn(
method=method,
model=model,
Rank=Rank,
**kwargs
)

elif isinstance(model, nn.Embedding):
quantize_fn_ = kwargs["quantize_fn_"]
if quantize_fn_ is not None:
return quantize_fn_(
method=method,
model=model,
Rank=Rank,
lora_alpha=lora_alpha,
scale_grad_by_freq=scale_grad_by_freq,
padding_idx=padding_idx,
max_norm=max_norm,
**kwargs
)

return model


def make_lora_replace(
model, method:str ,LayerType, quantize_fn=None, quantize_fn_=None, Rank=0, layers=None,
model, method: str, LayerType, quantize_fn=None, quantize_fn_=None, Rank=0, layers=None,
depth=1, path="", verbose=True
):
"""
Expand Down Expand Up @@ -64,26 +78,25 @@ def make_lora_replace(
if verbose:
print(f"Found linear layer to quantize: {path}", type(model))
if quantize_fn is not None:
return quantize_fn(
method,
model,
Rank,
threshold
)
return quantize_fn(
method=method,
model=model,
Rank=Rank
)

if LayerType[1] in AdaptersLayer and isinstance(model, nn.Embedding) and any(item in path for item in layers):
if verbose:
print(f"Found embedding layer to quantize: {path}", type(model))
if quantize_fn_ is not None:
return quantize_fn_(
method,
model,
Rank,
lora_alpha,
scale_grad_by_freq,
padding_idx,
max_norm
)
method=method,
model=model,
Rank=Rank,
lora_alpha=lora_alpha,
scale_grad_by_freq=scale_grad_by_freq,
padding_idx=padding_idx,
max_norm=max_norm
)

for key, module in model.named_children():
if isinstance(module, (nn.Linear, nn.Embedding)) and any(item in path for item in layers):
Expand All @@ -94,14 +107,14 @@ def make_lora_replace(
elif isinstance(module, (nn.ModuleList, nn.ModuleDict)):
for i, elem in enumerate(module):
layer = make_lora_replace(
elem, LayerType, quantize_fn, quantize_fn_, Rank, layers,
elem, method, LayerType, quantize_fn, quantize_fn_, Rank, layers,
depth + 1, f"{path}:{key}[{i}]", verbose=verbose
)
if layer is not None:
module[i] = layer
else:
layer = make_lora_replace(
module, LayerType, quantize_fn, quantize_fn_, Rank, layers,
module, method, LayerType, quantize_fn, quantize_fn_, Rank, layers,
depth + 1, f"{path}:{key}", verbose=verbose
)
if layer is not None:
Expand Down

0 comments on commit 9c30ffc

Please sign in to comment.