Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding support for GPT-J #1

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/model_dec_scales.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[{"attn_input_scale": 0.031619094488188976, "q_output_scale": 0.1687992125984252, "k_output_scale": 0.1347194881889764, "v_output_scale": 0.02297613188976378, "out_input_scale": 0.01796259842519685, "fc1_input_scale": 0.031619094488188976, "fc2_input_scale": 0.007831723671259843}, {"attn_input_scale": 0.011095903051181102, "q_output_scale": 0.14013287401574803, "k_output_scale": 0.14160925196850394, "v_output_scale": 0.046475147637795276, "out_input_scale": 0.03595595472440945, "fc1_input_scale": 0.011095903051181102, "fc2_input_scale": 0.00655142716535433}, {"attn_input_scale": 0.00863527312992126, "q_output_scale": 0.18553149606299213, "k_output_scale": 0.156373031496063, "v_output_scale": 0.03021961122047244, "out_input_scale": 0.030096579724409447, "fc1_input_scale": 0.00863527312992126, "fc2_input_scale": 0.029789000984251968}, {"attn_input_scale": 0.019192913385826772, "q_output_scale": 0.2233021653543307, "k_output_scale": 0.15563484251968504, "v_output_scale": 0.03804749015748032, "out_input_scale": 0.03168061023622047, "fc1_input_scale": 0.019192913385826772, "fc2_input_scale": 0.03337229330708662}, {"attn_input_scale": 0.01287217027559055, "q_output_scale": 0.13041338582677164, "k_output_scale": 0.1392716535433071, "v_output_scale": 0.062100147637795276, "out_input_scale": 0.05361097440944882, "fc1_input_scale": 0.01287217027559055, "fc2_input_scale": 0.002772053395669291}, {"attn_input_scale": 0.016901451771653545, "q_output_scale": 0.17691929133858267, "k_output_scale": 0.17704232283464566, "v_output_scale": 0.025298351377952756, "out_input_scale": 0.024913877952755906, "fc1_input_scale": 0.016901451771653545, "fc2_input_scale": 0.00285279281496063}, {"attn_input_scale": 0.016378567913385825, "q_output_scale": 0.13188976377952755, "k_output_scale": 0.15243602362204725, "v_output_scale": 0.02449864665354331, "out_input_scale": 0.020100270669291338, "fc1_input_scale": 0.016378567913385825, "fc2_input_scale": 0.0020415538877952754}, {"attn_input_scale": 0.014563853346456693, "q_output_scale": 0.15526574803149606, "k_output_scale": 0.1625246062992126, "v_output_scale": 0.02995816929133858, "out_input_scale": 0.02109990157480315, "fc1_input_scale": 0.014563853346456693, "fc2_input_scale": 0.002793199434055118}, {"attn_input_scale": 0.016701525590551183, "q_output_scale": 0.15255905511811024, "k_output_scale": 0.18061023622047245, "v_output_scale": 0.021345964566929134, "out_input_scale": 0.01842396653543307, "fc1_input_scale": 0.016701525590551183, "fc2_input_scale": 0.00299312561515748}, {"attn_input_scale": 0.017685777559055118, "q_output_scale": 0.16289370078740156, "k_output_scale": 0.18393208661417323, "v_output_scale": 0.02875861220472441, "out_input_scale": 0.026113435039370077, "fc1_input_scale": 0.017685777559055118, "fc2_input_scale": 0.0021876537893700788}, {"attn_input_scale": 0.01819328248031496, "q_output_scale": 0.1875, "k_output_scale": 0.17285925196850394, "v_output_scale": 0.03186515748031496, "out_input_scale": 0.0296505905511811, "fc1_input_scale": 0.01819328248031496, "fc2_input_scale": 0.001685915969488189}, {"attn_input_scale": 0.014271653543307087, "q_output_scale": 0.14480807086614172, "k_output_scale": 0.16510826771653545, "v_output_scale": 0.023622047244094488, "out_input_scale": 0.01714751476377953, "fc1_input_scale": 0.014271653543307087, "fc2_input_scale": 0.0016195943036417322}, {"attn_input_scale": 0.01624015748031496, "q_output_scale": 0.1733513779527559, "k_output_scale": 0.18713090551181102, "v_output_scale": 0.04856668307086614, "out_input_scale": 0.029389148622047244, "fc1_input_scale": 0.01624015748031496, "fc2_input_scale": 0.0015542338213582678}, {"attn_input_scale": 0.016670767716535435, "q_output_scale": 0.1546505905511811, "k_output_scale": 0.18639271653543307, "v_output_scale": 0.03380290354330709, "out_input_scale": 0.03257258858267716, "fc1_input_scale": 0.016670767716535435, "fc2_input_scale": 0.002921998031496063}, {"attn_input_scale": 0.014686884842519685, "q_output_scale": 0.16203248031496062, "k_output_scale": 0.1969734251968504, "v_output_scale": 0.03071173720472441, "out_input_scale": 0.02066929133858268, "fc1_input_scale": 0.014686884842519685, "fc2_input_scale": 0.0026105745570866143}, {"attn_input_scale": 0.016670767716535435, "q_output_scale": 0.1592027559055118, "k_output_scale": 0.18011811023622049, "v_output_scale": 0.028420275590551183, "out_input_scale": 0.014148622047244094, "fc1_input_scale": 0.016670767716535435, "fc2_input_scale": 0.005417230561023622}, {"attn_input_scale": 0.017854945866141732, "q_output_scale": 0.17568897637795275, "k_output_scale": 0.19672736220472442, "v_output_scale": 0.023452878937007874, "out_input_scale": 0.02251476377952756, "fc1_input_scale": 0.017854945866141732, "fc2_input_scale": 0.0013398898868110236}, {"attn_input_scale": 0.015286663385826772, "q_output_scale": 0.1671998031496063, "k_output_scale": 0.14271653543307086, "v_output_scale": 0.019239050196850394, "out_input_scale": 0.017593503937007874, "fc1_input_scale": 0.015286663385826772, "fc2_input_scale": 0.0022145669291338582}, {"attn_input_scale": 0.016070989173228346, "q_output_scale": 0.15514271653543307, "k_output_scale": 0.15231299212598426, "v_output_scale": 0.019408218503937008, "out_input_scale": 0.016424704724409447, "fc1_input_scale": 0.016070989173228346, "fc2_input_scale": 0.006243848425196851}, {"attn_input_scale": 0.017009104330708662, "q_output_scale": 0.1422244094488189, "k_output_scale": 0.16117125984251968, "v_output_scale": 0.025221456692913386, "out_input_scale": 0.019500492125984252, "fc1_input_scale": 0.017009104330708662, "fc2_input_scale": 0.01803949311023622}]
97 changes: 97 additions & 0 deletions tests/test_gptj.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import torch
from torch_int.models.gptj import Int8GPTJForCausalLM, Int8GPTJBlock, Int8GPTJMLP, Int8GPTJAttention, Int8GPTJModel
from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoderLayer, OPTForCausalLM
from transformers.models.gptj.modeling_gptj import GPTJModel, GPTJConfig, GPTJForCausalLM
from transformers import AutoModelForCausalLM, AutoTokenizer
from icecream import ic
from torch_int.nn.linear import W8A8BFP32OFP32Linear, W8A8B8O8Linear, W8A8B8O8LinearGELU
# from transformers import GPTJTok
from datasets import load_dataset
from tqdm import tqdm
import json
import copy

class Evaluator:
def __init__(self, dataset, tokenizer, device):
self.dataset = dataset
self.tokenizer = tokenizer
self.device = device

# tokenize the dataset
def tokenize_function(examples):
example = self.tokenizer(examples['text'])
return example
self.dataset = self.dataset.map(tokenize_function, batched=True)
self.dataset.set_format(type='torch', columns=['input_ids'])

@torch.no_grad()
def evaluate2(self, model):
model.eval()
# The task is to predict the last token of the input.
total, hit = 0, 0
idx = 0
pbar = tqdm(self.dataset, desc='Evaluating')
for batch in pbar:
input_ids = batch['input_ids'].to(self.device).unsqueeze(0)
label = input_ids[:, -1]
outputs = model(input_ids.cuda())
idx += 1
last_token_logits = outputs.logits[:, -2, :]
pred = last_token_logits.argmax(dim=-1)
total += label.size(0)
hit += (pred == label).sum().item()
pbar.set_postfix({'acc': hit / total})
acc = hit / total
return acc

@torch.no_grad()
def evaluate(self, modelX, model):
model.eval()
# The task is to predict the last token of the input.
idx = 0
total, hit = 0, 0
hit2 = 0
pbar = tqdm(self.dataset, desc='Evaluating')
for batch in pbar:
input_ids = batch['input_ids'].to(self.device).unsqueeze(0)
label = input_ids[:, -1]
outputs = model(input_ids.to('cuda'))
outputs2 = modelX(input_ids.to('cuda'))
model.transformer.d.clear()
modelX.transformer.d.clear()
idx += 1
last_token_logits = outputs.logits[:, -2, :]
last_token_logits = outputs2.logits[:, -2, :]
pred = last_token_logits.argmax(dim=-1)
pred2 = last_token_logits.argmax(dim=-1)
total += label.size(0)
hit += (pred == label).sum().item()
hit2 += (pred == label).sum().item()
pbar.set_postfix({'acc': hit / total, 'accX': hit2 / total})
acc = hit / total
return acc

MP = "/home/iman/fgg/smoothquant/SF/codegen-350M-multiX.pt"
@torch.no_grad()
def test_opt():
dataset = load_dataset('lambada', split='validation[:1000]')
dataset = dataset.shuffle(seed=42)
checkpoint = "moyix/codegen-350M-multi-gptj"
# checkpoint = "Salesforce/codegen-350M-multi"
config = GPTJConfig.from_pretrained('moyix/codegen-350M-multi-gptj')
model = GPTJForCausalLM.from_pretrained(checkpoint, device_map = 'auto', torch_dtype = 'auto').cuda()
tokenizer = AutoTokenizer.from_pretrained('Salesforce/codegen-350M-multi')
evaluator = Evaluator(dataset, tokenizer, 'cuda')
dlsj = "./tests/model_dec_scales.json"
decoder_layer_scales = []
with open(dlsj, 'r') as fp:
decoder_layer_scales = json.load(fp)
# these layers will not be quantized
layers_to_keep = list(range(13))
int8_model = Int8GPTJForCausalLM.from_float(model, decoder_layer_scales, k = layers_to_keep)
acc = evaluator.evaluate2(int8_model.to('cuda'))
ic(acc)


if __name__ == '__main__':
test_opt()
57 changes: 57 additions & 0 deletions tests/test_gptj_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import torch
from transformers.models.gptj.modeling_gptj import GPTJAttention, GPTJConfig
from torch_int.models.gptj import Int8GPTJAttention
from torch_int.nn.linear import W8A8BFP32OFP32Linear, W8A8B8O8Linear, W8A8B8O8LinearGELU
from typing import Tuple
from icecream import ic
from functools import partial

def store_act(module, x, y, act_dict, name):
# print(f"{name}: {y.mean()}")
if isinstance(x, tuple):
x = x[0]
if isinstance(y, tuple):
y = y[0]
act_dict[name] = (x, y)


@torch.no_grad()
def test_gptj_attention():
B, L, D, H = 1, 32, 128, 1
x = torch.randn(B, L, D)
x_scale = x.abs().max() / 127
config = GPTJConfig()
config.n_embd = D
config.n_head = H
config.rotary_dim = None
attn = GPTJAttention(config)
attn.eval()
act_dict = {}
for name, module in attn.named_modules():
if isinstance(module, torch.nn.Linear):
module.register_forward_hook(
partial(store_act, act_dict=act_dict, name=name))
y = attn(x)
y = y[0]

q_output_scale = act_dict['q_proj'][1].abs().max() / 127
k_output_scale = act_dict['k_proj'][1].abs().max() / 127
v_output_scale = act_dict['v_proj'][1].abs().max() / 127
out_input_scale = act_dict['out_proj'][0].abs().max() / 127
int8_attn = Int8GPTJAttention.from_float(
attn, x_scale, q_output_scale, k_output_scale, v_output_scale, out_input_scale).cuda()
int8_attn.eval()
q_act_dict = {}
for name, module in int8_attn.named_modules():
if isinstance(module, (W8A8BFP32OFP32Linear, W8A8B8O8Linear, W8A8B8O8LinearGELU)):
module.register_forward_hook(
partial(store_act, act_dict=q_act_dict, name=name))
q_x = (x / x_scale).round().to(torch.int8)
y_hat = int8_attn(q_x.cuda())[0].cpu()

r2 = (y - y_hat).pow(2).mean() / y.pow(2).mean()
ic(r2)


if __name__ == '__main__':
test_gptj_attention()
81 changes: 81 additions & 0 deletions tests/test_gptj_block.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import torch
from transformers.models.gptj.modeling_gptj import GPTJBlock, GPTJConfig
from torch_int.models.gptj import Int8GPTJBlock
from torch_int.nn.linear import W8A8BFP32OFP32Linear, W8A8B8O8Linear, W8A8B8O8LinearGELU
from typing import Tuple
from icecream import ic
from functools import partial
import matplotlib.pyplot as plt

def store_act(module, x, y, act_dict, name):
# print(f"{name}: {y.mean()}")
if isinstance(x, tuple):
x = x[0]
if isinstance(y, tuple):
y = y[0]
act_dict[name] = (x, y)


@torch.no_grad()
def test_gptj_block():
config : GPTJConfig = GPTJConfig.from_pretrained('Salesforce/codegen-350M-mono')
B, L, D, H = 1, 256, config.n_embd, config.n_head
x = torch.randn(B, L, D)
blk = GPTJBlock(config)
blk.eval()
act_dict = {}
for name, module in blk.named_modules():
if isinstance(module, torch.nn.Linear):
module.register_forward_hook(
partial(store_act, act_dict=act_dict, name=name))
if isinstance(module, torch.nn.LayerNorm):
module.register_forward_hook(
partial(store_act, act_dict=act_dict, name=name))

y = blk(x)
y = y[0].cpu()
print(act_dict.keys())
# exit(0)
ln1_input_scale = act_dict['ln_1'][1].abs().max() / 127
attn_input_scale = act_dict['attn.q_proj'][0].abs().max() / 127
q_output_scale = act_dict['attn.q_proj'][1].abs().max() / 127
k_output_scale = act_dict['attn.k_proj'][1].abs().max() / 127
v_output_scale = act_dict['attn.v_proj'][1].abs().max() / 127
out_input_scale = act_dict['attn.out_proj'][0].abs().max() / 127
fc1_input_scale = act_dict['mlp.fc_in'][0].abs().max() / 127
fc2_input_scale = act_dict['mlp.fc_out'][0].abs().max() / 127
int8_blk = Int8GPTJBlock.from_float(
blk, attn_input_scale, q_output_scale, k_output_scale, v_output_scale, out_input_scale, fc1_input_scale, fc2_input_scale).cuda()
int8_blk.eval()
q_act_dict = {}

y_hat = int8_blk(x.cuda())[0].cpu()
# rd = blk.dbgi
# md = int8_blk.dbgi
# RN = 256
# ra = rd['atto'].cpu().flatten()[:RN]
# ma = md['attoX'].cpu().flatten()[:RN]
# rf = rd['ffn'].cpu().flatten()[:RN]
# mf = md['ffnX'].cpu().flatten()[:RN]
# rr = rd['resi'].cpu().flatten()[:RN]
# mr = md['resiX'].cpu().flatten()[:RN]
#
# plt.plot(ra.flatten())
# print(f"MAX: a:{ra.abs().max()} f:{rf.abs().max()} r:{rr.abs().max()+0.0000001}")
# plt.plot(ma - ra, color='r')
# plt.savefig("Xa.jpg", dpi=300)
# plt.cla()
# # plt.plot(rf)
# plt.plot(mf - rf, color='r')
# plt.savefig("Xf.jpg", dpi=300)
# plt.cla()
# # plt.plot(rr.flatten())
# plt.plot(mr - rr, color='r')
# plt.savefig("Xr.jpg", dpi=300)

r2 = (y - y_hat).pow(2).mean() / y.pow(2).mean()
ic(r2)


if __name__ == '__main__':
test_gptj_block()
54 changes: 54 additions & 0 deletions tests/test_gptj_mlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import torch
from transformers.models.gptj.modeling_gptj import GPTJMLP, GPTJConfig
from torch_int.models.gptj import Int8GPTJMLP
from torch_int.nn.linear import W8A8BFP32OFP32Linear, W8A8B8O8Linear, W8A8B8O8LinearGELU
from typing import Tuple
from icecream import ic
from functools import partial
from torch_int.nn.fused import LayerNormQ
from torch.nn import LayerNorm

def store_act(module, x, y, act_dict, name):
# print(f"{name}: {y.mean()}")
if isinstance(x, tuple):
x = x[0]
if isinstance(y, tuple):
y = y[0]
act_dict[name] = (x, y)


@torch.no_grad()
def test_gptj_mlp():
B, L, D, H = 1, 16, 32, 1
x = torch.randn(B, L, D)*40
x = torch.clamp(x, -127, 127)
x_scale = x.abs().max() / 127
config = GPTJConfig()
config.n_embd = D
config.n_head = H
intermediate_size = 4*D
config.rotary_dim = None
mlp = GPTJMLP(intermediate_size, config)
mlp.eval()
act_dict = {}
for name, module in mlp.named_modules():
if isinstance(module, torch.nn.Linear):
module.register_forward_hook(
partial(store_act, act_dict=act_dict, name=name))
y = mlp(x)
y = y[0]

fc_in_scale = act_dict['fc_in'][0].abs().max() / 127
fc_out_scale = act_dict['fc_out'][0].abs().max() / 127
int8_mlp = Int8GPTJMLP.from_float(
mlp, fc_in_scale, fc_out_scale).cuda()
int8_mlp.eval()
q_x = x.round().to(torch.int8)
y_hat = int8_mlp(q_x.cuda()).cpu()
print(y_hat.shape)
r2 = (y - y_hat).pow(2).mean() / y.pow(2).mean()
ic(r2)


if __name__ == '__main__':
test_gptj_mlp()
61 changes: 61 additions & 0 deletions tests/test_gptj_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import torch
from transformers.models.gptj.modeling_gptj import GPTJModel, GPTJConfig
from torch_int.models.gptj import Int8GPTJModel
from icecream import ic
from functools import partial


def store_act(module, x, y, act_dict, name):
if isinstance(x, tuple):
x = x[0]
if isinstance(y, tuple):
y = y[0]
act_dict[name] = (x, y)


@torch.no_grad()
def test_gptj_model_layer():
config = GPTJConfig.from_pretrained('Salesforce/codegen-350M-mono')

B, L, D, H = 1, 256, config.n_embd, config.n_head

x = torch.randint(0, config.vocab_size, (B, L))
model = GPTJModel(config)
model.eval()
act_dict = {}
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
module.register_forward_hook(
partial(store_act, act_dict=act_dict, name=name))
y = model(x)[0].cuda()
decoder_layer_scales = []
for idx in range(config.n_layer):
scale_dict = {}
scale_dict["attn_input_scale"] = act_dict[f"h.{idx}.attn.q_proj"][0].abs(
).max() / 127
scale_dict["q_output_scale"] = act_dict[f"h.{idx}.attn.q_proj"][1].abs(
).max() / 127
scale_dict["k_output_scale"] = act_dict[f"h.{idx}.attn.k_proj"][1].abs(
).max() / 127
scale_dict["v_output_scale"] = act_dict[f"h.{idx}.attn.v_proj"][1].abs(
).max() / 127
scale_dict["out_input_scale"] = act_dict[f"h.{idx}.attn.out_proj"][0].abs(
).max() / 127
scale_dict["fc1_input_scale"] = act_dict[f"h.{idx}.mlp.fc_in"][0].abs(
).max() / 127
scale_dict["fc2_input_scale"] = act_dict[f"h.{idx}.mlp.fc_out"][0].abs(
).max() / 127
decoder_layer_scales.append(scale_dict)

int8_model = Int8GPTJModel.from_float(model, decoder_layer_scales).cuda()
int8_model.eval()

y_hat = int8_model(x.cuda())[0]

# # ic(y_hat)
r2 = (y - y_hat).pow(2).mean() / y.pow(2).mean()
ic(r2)


if __name__ == '__main__':
test_gptj_model_layer()
21 changes: 20 additions & 1 deletion tests/test_linear_kernels.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch
from torch_int._CUDA import linear_a8_w8_b32_o32, linear_relu_a8_w8_b8_o8, linear_a8_w8_b8_o8, linear_a8_w8_b32_o32_with_scaling, linear_a8_w8_bfp32_ofp32
from torch_int._CUDA import linear_a8_w8_b32_o32, linear_relu_a8_w8_b8_o8, linear_a8_w8_b8_o8, linear_a8_w8_b32_o32_with_scaling, linear_a8_w8_bfp32_ofp32, linear_gelu_a8_w8_b8_o8
from icecream import ic


Expand Down Expand Up @@ -85,6 +85,23 @@ def test_quant_linear_relu_a8_w8_b8_o8():
ic(torch.allclose(y_gt.float(), y.float().cpu(), atol=1))


@torch.no_grad()
def test_quant_linear_gelu_a8_w8_b8_o8():
B, M, N = 128, 512, 1024
weight = torch.randint(-128, 127, (N, M), dtype=torch.int8)
bias = torch.randint(-128, 127, (N,), dtype=torch.int8)
x = torch.randint(-128, 127, (B, M), dtype=torch.int8)
alpha, beta = 0.001, 0.01
linear = torch.nn.Linear(M, N, bias=True)
linear.weight.data = weight.float() * alpha
linear.bias.data = bias.float() * beta
y_gt = linear(x.float())
y_gt = y_gt.clamp(0, 127).round().long()
y = linear_gelu_a8_w8_b8_o8(x.cuda(), weight.cuda(),
bias.cuda(), alpha, beta).cpu().long()
ic(torch.allclose(y_gt.float(), y.float().cpu(), atol=1))


if __name__ == '__main__':
print('test_quant_linear_a8_w8_b32_o32')
test_quant_linear_a8_w8_b32_o32()
Expand All @@ -96,3 +113,5 @@ def test_quant_linear_relu_a8_w8_b8_o8():
test_quant_linear_a8_w8_b8_o8()
print('test_quant_linear_relu_a8_w8_b8_o8')
test_quant_linear_relu_a8_w8_b8_o8()
print('test_quant_linear_gelu_a8_w8_b8_o8')
test_quant_linear_gelu_a8_w8_b8_o8()
Loading