Skip to content

Commit

Permalink
Merge branch 'main' into iree-packages-rename
Browse files Browse the repository at this point in the history
  • Loading branch information
ScottTodd authored Nov 8, 2024
2 parents ff68837 + 661a800 commit fb7fd3f
Show file tree
Hide file tree
Showing 14 changed files with 205 additions and 138 deletions.
71 changes: 40 additions & 31 deletions sharktank/sharktank/evaluate/perplexity_vmfb.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ def __init__(
self.iree_hip_target = iree_hip_target
self.iree_hal_target_backends = iree_hal_target_backends
self.kv_cache_type = kv_cache_type
self.activation_dtype = torch.float32
self.attention_dtype = torch.float32
self.activation_dtype = torch.float16
self.attention_dtype = torch.float16
self.tensor_parallelism_size = tensor_parallelism_size
self.attention_kernel = attention_kernel

Expand Down Expand Up @@ -166,6 +166,8 @@ def load_model(self, weight_path, tokenizer, vmfb_path):
external_weight_path=self.weight_path_str,
)

self.haldevice = self.runner.config.device

@timeit
def get_prompts(self):
test_prompts = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")[
Expand All @@ -189,40 +191,19 @@ def get_prompts(self):

def prefill_vmfb(self, token_batch, i):

logger.debug(f"Prefill:")

logger.debug("Input:")
logger.debug(f"{self.generator.tokenizer.decode(token_batch)}")

token_batch, seq_lens_batch = self.generator.tokenizer.pad_tokens(
token_ids=token_batch.tolist(),
pad_to_multiple_of=self.generator.model.cache.pad_sequence_stride,
)

logger.debug(f"{token_batch}")

token_batch = torch.tensor(token_batch, device=self.torch_device)
self.seq_lens_batch = torch.tensor(seq_lens_batch, device=self.torch_device)

self.batch = self.generator.begin_eval_batch(
token_batch=token_batch,
seq_lens_batch=self.seq_lens_batch,
bs=self.bs,
)

seq_block_ids = self.batch.pad_block_ids()
prefill_logits = self.runner.ctx.modules.module[f"prefill_bs{self.bs}"](
token_batch,
self.seq_lens_batch,
self.batch.seq_lens,
seq_block_ids,
self.batch.cache_state[0].to(torch.float16),
self.cache_state,
)

prefill_logits = torch.tensor(prefill_logits[:, :, :])

tokens = torch.tensor(
self.generator.model.extract_tokens_from_logits(
prefill_logits, seq_lens_batch
prefill_logits, self.batch.seq_lens
)
).unsqueeze(1)
self.batch.add_result_token(tokens)
Expand All @@ -237,17 +218,17 @@ def decode_vmfb(self, token_batch, i):
logger.debug(f"{self.generator.tokenizer.decode(token_batch)}")
logger.debug(f"{token_batch.tolist()}")

start_positions = self.seq_lens_batch.clone()
self.seq_lens_batch.add_(1)
start_positions = self.batch.seq_lens.clone()
self.batch.seq_lens.add_(1)
self.batch.allocate_seq_block_ids()
seq_block_ids = self.batch.pad_block_ids()

decode_logits = self.runner.ctx.modules.module[f"decode_bs{self.bs}"](
token_batch,
self.seq_lens_batch,
self.batch.seq_lens,
start_positions,
seq_block_ids,
self.batch.cache_state[0].to(torch.float16),
self.cache_state,
)

decode_logits = torch.tensor(decode_logits[:, :, :])
Expand Down Expand Up @@ -287,6 +268,7 @@ def get_logits(self):
start = 0
for i in tqdm(
range(start, self.max_prompt_length - 1),
mininterval=300,
desc="eval: Calculating logits",
):
logger.debug(f"Iteration: {i}")
Expand All @@ -295,8 +277,35 @@ def get_logits(self):

token_batch = self.token_ids[:, : i + 1]

logger.debug(f"Prefill:")

logger.debug("Input:")
logger.debug(f"{self.generator.tokenizer.decode(token_batch)}")

token_batch, seq_lens_batch = self.generator.tokenizer.pad_tokens(
token_ids=token_batch.tolist(),
pad_to_multiple_of=self.generator.model.cache.pad_sequence_stride,
)

logger.debug(f"{token_batch}")

token_batch = torch.tensor(token_batch, device=self.torch_device)
self.seq_lens_batch = torch.tensor(
seq_lens_batch, device=self.torch_device
)

self.batch = self.generator.begin_eval_batch(
token_batch=token_batch,
seq_lens_batch=self.seq_lens_batch,
bs=self.bs,
)

self.cache_state = ireert.asdevicearray(
self.haldevice, self.batch.cache_state[0].to("cpu").numpy()
)

prefill_logits = self.prefill_vmfb(token_batch, i)
self.out_logits = prefill_logits[:, 0:1, :]
self.out_logits = prefill_logits[:, -1:, :]

is_first_token = False

Expand Down
11 changes: 10 additions & 1 deletion sharktank/sharktank/ops/default_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,16 @@ def matmul_default(lhs, rhs, *, transpose_rhs: bool) -> Tensor:
rhs = unbox_tensor(rhs)
if transpose_rhs:
rhs = rhs.mT
return torch.matmul(lhs, rhs.to(lhs.dtype))

rhs = rhs.to(lhs.dtype)

if len(lhs.shape) > 2 and len(rhs.shape) < 3:
bdims = lhs.shape[:-1]
lhs = torch.flatten(lhs, 0, -2)
mm = torch.matmul(lhs, rhs)
return torch.unflatten(mm, 0, bdims)

return torch.matmul(lhs, rhs)


# Scaled dot product attention
Expand Down
2 changes: 1 addition & 1 deletion sharktank/sharktank/utils/load_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ def __init__(
self.tokenizer = tokenizer
if model.cache.is_paged:
self.shared_cache_state = model.cache.paged.allocate(page_cache_size)
self.free_pages = list(range(1, page_cache_size))
else:
self.shared_cache_state = None
self.free_pages = list(range(1, 8192))
self.end_token = end_token

@property
Expand Down
202 changes: 101 additions & 101 deletions sharktank/tests/evaluate/baseline_perplexity_scores.json
Original file line number Diff line number Diff line change
Expand Up @@ -212,107 +212,107 @@
},
"llama3_8B_f16_decomposed_vmfb": {
"perplexities": [
21194.505859,
19049.068359,
14214.751953,
15752.748047,
8948.568359,
9867.280273,
16664.880859,
10607.53125,
9715.395508,
14289.220703,
25121.929688,
8545.292969,
21990.28125,
8150.422363,
4658.82666,
13440.376953,
11978.756836,
9100.139648,
7168.022949,
14279.970703,
19406.207031,
13816.291016,
14942.27832,
20922.1875,
17307.214844,
10634.068359,
10968.188477,
11322.012695,
7898.733887,
7532.914062,
10352.375,
16628.289062,
5661.084473,
6998.464355,
7167.906738,
7252.662598,
7832.401367,
5824.921875,
12029.311523,
13104.125,
6688.567871,
7917.172852,
13455.291992,
7466.178223,
8360.422852,
5765.317383,
21530.652344,
13371.045898,
41826.242188,
13620.586914,
13886.725586,
13105.150391,
27155.019531,
8066.837402,
6860.444824,
9858.532227,
7352.963867,
15839.926758,
4746.95459,
8539.133789,
12957.833008,
10096.874023,
6436.333496,
6488.447754,
12649.62793,
9575.267578,
2897.279785,
12649.941406,
14139.443359,
12061.751953,
10646.621094,
15703.19043,
13080.764648,
9124.349609,
14409.989258,
10726.665039,
6444.680664,
10168.352539,
5474.356934,
10729.345703,
4240.486328,
11856.861328,
6184.834473,
16671.128906,
9840.30957,
39691.976562,
21551.833984,
6072.709961,
18333.572266,
6635.820801,
8460.941406,
14243.955078,
34157.90625,
9565.474609,
5573.206055,
9139.364258,
6077.837402,
13941.31543,
10590.963867,
12113.441406
6.651368,
22.059452,
15.392176,
17.418619,
15.206824,
7.907998,
8.829535,
22.355659,
8.29262,
20.958277,
7.167404,
14.592677,
9.060788,
7.274667,
16.238981,
6.666115,
6.535679,
7.086256,
10.676177,
8.979206,
10.597121,
42.038162,
11.70071,
65.731316,
47.42622,
20.109543,
18.897541,
13.781085,
9.99165,
5.955308,
10.175659,
23.628405,
14.306578,
9.719462,
5.594786,
14.198979,
5.711433,
17.381332,
9.058512,
8.286205,
8.016202,
18.4515,
11.600831,
3.945074,
13.000222,
10.373363,
12.237907,
21.408463,
37.858665,
25.794065,
15.489001,
14.004895,
7.625473,
10.993184,
14.698832,
11.062652,
5.855446,
15.625135,
8.052419,
14.365479,
5.927001,
6.931933,
2.3014,
15.769623,
40.843319,
8.022024,
12.544907,
10.090073,
9.304819,
10.679907,
8.136175,
21.540607,
3.736973,
15.381804,
24.21562,
14.385005,
17.791706,
16.498833,
8.753955,
12.941816,
12.887664,
13.725715,
13.994792,
10.769128,
14.734674,
26.970015,
17.811842,
9.847188,
15.124973,
15.623392,
29.147844,
12.309229,
32.15152,
33.225769,
14.426914,
17.496277,
14.7356,
15.503921,
12.336852,
16.469248
],
"mean_perplexity": 12191.57833
"mean_perplexity": 14.991893
}
}
2 changes: 1 addition & 1 deletion sharktank/tests/evaluate/perplexity_vmfb_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
class PerplexityTest(unittest.TestCase):
def setUp(self):
self.current_perplexity_all = {}
self.delta = 10
self.delta = 5e-1
self.tensor_parallelism_size = 8
with open(self.baseline_perplexity_scores, "r") as f:
self.baseline_perplexity = json.load(f)
Expand Down
9 changes: 9 additions & 0 deletions sharktank/tests/ops/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,15 @@ def testTorchImplTransposedPrimitiveRHS(self):
ops.custom_impls.matmul_mmtfp_tensor_tensor,
)

def testTorchImplImplicitBatch(self):
ops._registry._test_enable_last_op_dispatch(True)
t1 = torch.rand(4, 32, 16, dtype=torch.float32)
t2 = torch.rand(48, 16, dtype=torch.float16)
t2_pt = DefaultPrimitiveTensor(data=t2)
result = ops.matmul(t1, t2_pt.T)
expected = torch.matmul(t1, t2.T.to(torch.float32))
torch.testing.assert_close(result, expected)

def testTorchImplTransposedQuantizedRHS_BlockScaledLayout(self):
ops._registry._test_enable_last_op_dispatch(True)
a_dtype = torch.float32
Expand Down
2 changes: 1 addition & 1 deletion sharktank/version_info.json
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
{
"package-version": "3.0.0.dev"
"package-version": "2.9.0.dev"
}
Loading

0 comments on commit fb7fd3f

Please sign in to comment.