Skip to content

Commit

Permalink
Fix llm example
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Feb 6, 2024
1 parent 7551c9b commit 432bf3f
Show file tree
Hide file tree
Showing 7 changed files with 104 additions and 340 deletions.
17 changes: 3 additions & 14 deletions src/brevitas_examples/llm/llm_quant/bias_corr.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,10 @@
import torch

from brevitas.graph.calibrate import bias_correction_mode
from brevitas_examples.llm.llm_quant.run_utils import apply_layer_ptq_fn


@torch.no_grad()
def bias_corr_iter(curr_layer, inps, outs, cached_values):
curr_layer = curr_layer.cuda()
with bias_correction_mode(curr_layer):
for j in range(len(inps)):
inp = inps[j].unsqueeze(0).cuda()
curr_out = curr_layer(inp, **cached_values)[0]
outs[j] = curr_out
curr_layer.cpu()
return outs


@torch.no_grad()
def apply_bias_correction(model, dataloader):
apply_layer_ptq_fn(model, dataloader, inference_fn=bias_corr_iter)
with bias_correction_mode(curr_layer):
for inps in dataloader:
model(**inps)
22 changes: 2 additions & 20 deletions src/brevitas_examples/llm/llm_quant/calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,28 +7,10 @@
from tqdm import tqdm

from brevitas.graph.calibrate import calibration_mode
from brevitas_examples.llm.llm_quant.run_utils import apply_layer_ptq_fn
from brevitas_examples.optimum.utils import offload_model
from brevitas_examples.optimum.utils import remove_hooks


@torch.no_grad()
def calibration_iter(curr_layer, inps, outs, cached_values):
curr_layer = curr_layer.cuda()
with calibration_mode(curr_layer):
for j in range(len(inps)):
inp = inps[j].unsqueeze(0).cuda()
curr_out = curr_layer(inp, **cached_values)[0]
outs[j] = curr_out
curr_layer.cpu()
return outs


@torch.no_grad()
def apply_calibration(model, dataloader, forward_call):
model = offload_model(model)
def apply_calibration(model, dataloader):
with calibration_mode(model):
for inps in tqdm(dataloader):
forward_call(model, inps)
# Remove all accelerate hooks
remove_hooks(model)
model(**inps)
24 changes: 3 additions & 21 deletions src/brevitas_examples/llm/llm_quant/equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,35 +3,19 @@
# SPDX-License-Identifier: BSD-3-Clause
"""

from accelerate.hooks import remove_hook_from_module
import torch
from tqdm import tqdm

from brevitas.graph.equalize import activation_equalization_mode
from brevitas.graph.equalize import EqualizeGraph
from brevitas_examples.optimum.utils import offload_model
from brevitas_examples.optimum.utils import remove_hooks


@torch.no_grad()
def activation_equalization_iter(curr_layer, inps, outs, cached_values, alpha):
curr_layer = curr_layer.cuda()
with activation_equalization_mode(curr_layer, alpha, add_mul_node=True, layerwise=True):
for j in range(len(inps)):
inp = inps[j].unsqueeze(0).cuda()
curr_out = curr_layer(inp, **cached_values)[0]
outs[j] = curr_out
curr_layer.cpu()
return outs


@torch.no_grad()
def apply_act_equalization(model, act_equalization_type, dataloader, forward_call, alpha=0.5):
model = offload_model(model)
def apply_act_equalization(model, act_equalization_type, dataloader, alpha=0.5):
if act_equalization_type == 'layerwise':
with activation_equalization_mode(model, alpha, add_mul_node=True, layerwise=True):
for inps in tqdm(dataloader):
forward_call(model, inps)
model(**inps)

elif act_equalization_type == 'fx':
assert model is not None, "FX Model is required to perform FX SmoothQuant"
Expand All @@ -41,12 +25,10 @@ def apply_act_equalization(model, act_equalization_type, dataloader, forward_cal
layerwise=False,
co_optimize_act_weights=True):
for inps in tqdm(dataloader):
forward_call(model, inps)
model(**inps)

else:
raise RuntimeError(f"{act_equalization_type} not supported.")
# Remove all accelerate hooks
remove_hooks(model)


@torch.no_grad()
Expand Down
86 changes: 12 additions & 74 deletions src/brevitas_examples/llm/llm_quant/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,91 +20,29 @@
from torch import nn
from tqdm import tqdm

from brevitas_examples.llm.llm_quant.run_utils import apply_layer_inference_fn
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
from brevitas_examples.llm.llm_quant.run_utils import InputCatcherException


def eval_inference_fn(curr_layer, inps, outs, cached_values):
curr_layer.cuda()
for j in range(len(inps)):
outs[j] = curr_layer(inps[j].unsqueeze(0).cuda(), **cached_values)[0]
curr_layer.cpu()


@torch.no_grad()
def model_eval(model, valenc, seqlen):

nsamples = valenc.numel() // seqlen

def eval_input_capture_fn(model, data):
for i in range(nsamples):
batch = data[:, (i * seqlen):((i + 1) * seqlen)].cuda()
try:
model(batch)
except InputCatcherException:
pass

inps = apply_layer_inference_fn(
model,
valenc,
input_capture_fn=eval_input_capture_fn,
inference_fn=eval_inference_fn,
)

model_impl = get_model_impl(model)
use_cache = model.config.use_cache
model.config.use_cache = False

if hasattr(model_impl, 'norm') and model_impl.norm is not None:
model_impl.norm = model_impl.norm.cuda()
if hasattr(model_impl, 'final_layer_norm') and model_impl.final_layer_norm is not None:
model_impl.final_layer_norm = model_impl.final_layer_norm.cuda()
if hasattr(model_impl, 'project_out') and model_impl.project_out is not None:
model_impl.project_out = model_impl.project_out.cuda()
if hasattr(model, 'lm_head'):
model.lm_head = model.lm_head.cuda()

valenc = valenc.cuda()
nlls = []
def create_validation_dataloader(data, seqlen):
nsamples = data['input_ids'].numel() // seqlen
val_dataloader = []
for i in tqdm(range(nsamples)):
hidden_states = inps[i].unsqueeze(0)
if hasattr(model_impl, 'norm') and model_impl.norm is not None:
hidden_states = model_impl.norm(hidden_states)
if hasattr(model_impl, 'final_layer_norm') and model_impl.final_layer_norm is not None:
hidden_states = model_impl.final_layer_norm(hidden_states)
if hasattr(model_impl, 'project_out') and model_impl.project_out is not None:
hidden_states = model_impl.project_out(hidden_states)
lm_logits = hidden_states
if hasattr(model, 'lm_head'):
lm_logits = model.lm_head(lm_logits)
shift_logits = lm_logits[:, :-1, :].contiguous()
shift_labels = valenc[:, (i * seqlen):((i + 1) * seqlen)][:, 1:]
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
neg_log_likelihood = loss.float() * seqlen
nlls.append(neg_log_likelihood)

ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * seqlen))
model.config.use_cache = use_cache
return ppl
batch = data['input_ids'][:, (i * seqlen):((i + 1) * seqlen)].cuda()
attention_mask = torch.ones_like(batch)
val_dataloader.append({'input_ids': batch, 'attention_mask': attention_mask})
return val_dataloader


@torch.no_grad()
def model_eval_accelerate(model, valenc, seqlen, forward_call):
def model_eval(model, valenc, seqlen):

nsamples = valenc['input_ids'].numel() // seqlen
nsamples = len(valenc)
use_cache = model.config.use_cache
model.config.use_cache = False
with torch.no_grad():
nlls = []
for i in tqdm(range(nsamples)):
batch = valenc['input_ids'][:, (i * seqlen):((i + 1) * seqlen)].cuda()
attention_mask = torch.ones_like(batch)
lm_logits = forward_call(model, {
'input_ids': batch, 'attention_mask': attention_mask})['logits']
for inps in valenc:
lm_logits = model(**inps)['logits']
shift_logits = lm_logits[:, :-1, :].contiguous()
shift_labels = (valenc['input_ids'][:, (i * seqlen):((i + 1) * seqlen)][:, 1:]).cuda()
shift_labels = inps['input_ids'][:, 1:].cuda()
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
neg_log_likelihood = loss.float() * seqlen
Expand Down
29 changes: 2 additions & 27 deletions src/brevitas_examples/llm/llm_quant/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,37 +3,14 @@
# SPDX-License-Identifier: BSD-3-Clause
"""

from accelerate.hooks import remove_hook_from_module
import torch
from tqdm import tqdm

from brevitas.graph.gptq import gptq_mode
from brevitas_examples.llm.llm_quant.run_utils import apply_layer_ptq_fn
from brevitas_examples.optimum.utils import offload_model
from brevitas_examples.optimum.utils import remove_hooks


@torch.no_grad()
def gptq_iter(curr_layer, inps, outs, cached_values, act_order):
curr_layer = curr_layer.cuda()
with gptq_mode(curr_layer, use_quant_activations=False, act_order=act_order) as gptq:
gptq_layer = gptq.model
for _ in range(gptq.num_layers):
for j in range(len(inps)):
curr_inp = inps[j].unsqueeze(0).cuda()
gptq_layer(curr_inp, **cached_values)
gptq.update()
for j in range(len(inps)):
inp = inps[j].unsqueeze(0).cuda()
curr_out = curr_layer(inp, **cached_values)[0]
outs[j] = curr_out
curr_layer.cpu()
return outs


@torch.no_grad()
def apply_gptq(model, dataloader, forward_call, act_order=True, group_of_parallel_layers=None):
model = offload_model(model)
def apply_gptq(model, dataloader, act_order=True, group_of_parallel_layers=None):
with gptq_mode(model,
use_quant_activations=False,
group_of_parallel_layers=group_of_parallel_layers,
Expand All @@ -42,7 +19,5 @@ def apply_gptq(model, dataloader, forward_call, act_order=True, group_of_paralle
gptq_model = gptq.model
for _ in tqdm(range(gptq.num_layers)):
for inps in dataloader:
forward_call(gptq_model, inps)
gptq_model(**inps)
gptq.update()
# Remove all accelerate hooks
remove_hooks(model)
Loading

0 comments on commit 432bf3f

Please sign in to comment.