Skip to content

Commit

Permalink
Manually fix a few remaining lints (#238)
Browse files Browse the repository at this point in the history
  • Loading branch information
mergennachin authored Apr 17, 2024
1 parent 170dde6 commit 5450e3e
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 18 deletions.
14 changes: 11 additions & 3 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,9 @@ def encode_tokens(tokenizer, string, bos=True, device="cpu"):
return torch.tensor(tokens, dtype=torch.int, device=device)


B_INST, E_INST = "[INST]", "[/INST]"


def _main(
builder_args: BuilderArgs,
speculative_builder_args: BuilderArgs,
Expand All @@ -330,6 +333,7 @@ def _main(
# from tp import maybe_init_dist
# rank = maybe_init_dist()
use_tp = False
rank: Optional[int] = None
# if use_tp:
# if rank != 0:
# # only print on rank 0
Expand Down Expand Up @@ -417,8 +421,9 @@ def _main(
period_id = tokenizer.encode(".")[0]
done_generating = False

def callback(x):
nonlocal done_generating
def callback(
x, buffer=buffer, period_id=period_id, done_generating=done_generating
):
if done_generating:
return
buffer.append(tokenizer.decode([period_id] + x.tolist())[1:])
Expand All @@ -430,7 +435,10 @@ def callback(x):
# print(, end='', flush=True)

else:
callback = lambda x: x

def callback(x):
return x

t0 = time.perf_counter()
import contextlib

Expand Down
30 changes: 15 additions & 15 deletions quantized_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import torch
import torch.nn.functional as F
from torch.library import impl, impl_abstract
from torch.library import impl

torchchat_lib = torch.library.Library("torchchat", "DEF")

Expand All @@ -25,21 +25,21 @@ def embedding_int8(
) -> torch.Tensor:
indices = input
# embedding_byte_weight_checks(weight, weight_scales, weight_zero_points)
groupsize = weight.size(1) // (scales.size(1) if scales.dim() == 2 else 1)
# groupsize = weight.size(1) // (scales.size(1) if scales.dim() == 2 else 1)
# ET definition
if False:
weight_zero_points = None
weight = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(
weight,
weight_scales,
weight_zero_points,
weight_quant_min,
weight_quant_max,
weight.dtype,
groupsize,
weight_scales.dtype,
)
return torch.ops.aten.embedding.default(weight, indices)
# if False:
# weight_zero_points = None
# weight = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(
# weight,
# weight_scales,
# weight_zero_points,
# weight_quant_min,
# weight_quant_max,
# weight.dtype,
# groupsize,
# weight_scales.dtype,
# )
# return torch.ops.aten.embedding.default(weight, indices)

scales = scales.view(weight.shape[0], -1)
result_weights = F.embedding(indices, weight)
Expand Down

0 comments on commit 5450e3e

Please sign in to comment.