Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

moe #162

Merged
merged 26 commits into from
Sep 5, 2024
Merged

moe #162

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
e036804
Add Mixtral LLM
archana-ramalingam May 21, 2024
3110119
Refactoring attention, moe and ffn blocks
archana-ramalingam May 22, 2024
d1691c3
Allow _optional_int_prop to handle missing hyperparameters
archana-ramalingam May 22, 2024
a865ac3
Fixing circular dep and imports
archana-ramalingam May 23, 2024
3496258
Fix multiple expert layer weight handling + other issues
archana-ramalingam May 29, 2024
2b32fba
Add ffn_moe layers and other fixes
archana-ramalingam Jun 13, 2024
15f2a22
Edit theta slicing
archana-ramalingam Jun 13, 2024
0f155c5
Fix ffn_moe theta parsing & wraping
archana-ramalingam Jun 14, 2024
4a8bb97
Extract tensor unmerging into a function
archana-ramalingam Jun 14, 2024
36eb868
Cleaning up debug statements
archana-ramalingam Aug 19, 2024
58890f9
Fix test failure
archana-ramalingam Aug 19, 2024
99186fd
Add rope_freq_base to llama
archana-ramalingam Aug 19, 2024
c66cbe5
Rebase and fixes
IanNod Aug 28, 2024
0bc76f6
Add missing grok layers
archana-ramalingam Aug 29, 2024
96de75d
adds a test for exporting moe block
dan-garvey Sep 3, 2024
f323792
actually add the test
dan-garvey Sep 3, 2024
2cd365b
some fixes
dan-garvey Sep 5, 2024
67f112f
moe moe moe
dan-garvey Sep 5, 2024
12b2a7a
refactor paged llama
dan-garvey Sep 5, 2024
77163aa
fix format
dan-garvey Sep 5, 2024
5fba3de
rope_freq
dan-garvey Sep 5, 2024
b315fa3
saver
dan-garvey Sep 5, 2024
a2df6a4
address rope freq
dan-garvey Sep 5, 2024
47b14b6
fix llama attn
dan-garvey Sep 5, 2024
911b3a3
Merge branch 'main' into moe-wip
dan-garvey Sep 5, 2024
6a28481
add tensor name
dan-garvey Sep 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion sharktank/sharktank/examples/export_paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

# TODO: Should be using a base class with the protocol supported.
from ..models.llama.llama import LlamaModelConfig, PagedLlamaModelV1
from ..models.mixtral.mixtral import *


def main():
Expand Down Expand Up @@ -52,7 +53,10 @@ def main():
llama_config = LlamaModelConfig(hp)
llama_config.static_tables = False # Rely on the compiler for hoisting tables.
llama_config.kv_cache_type = "direct" if args.bs == [1] else "paged"
model = PagedLlamaModelV1(dataset.root_theta, llama_config)
if llama_config.hp.expert_count:
model = PagedMixtralModelV1(dataset.root_theta, llama_config)
else:
model = PagedLlamaModelV1(dataset.root_theta, llama_config)

def generate_params_json(hp, prefill_bs: list[int], decode_bs: list[int]):
return {
Expand Down
7 changes: 6 additions & 1 deletion sharktank/sharktank/examples/paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from ..types import *

# TODO: Should be using a base class with the protocol supported.
from ..models.mixtral.mixtral import *
from ..models.llama.llama import *
from ..utils.debugging import trace_tensor
from ..utils.tokenizer import InferenceTokenizer, load_tokenizer
Expand Down Expand Up @@ -236,7 +237,11 @@ def main():
activation_dtype=activation_dtype,
attention_dtype=activation_dtype,
)
model = PagedLlamaModelV1(dataset.root_theta, config)

if config.hp.expert_count:
model = PagedMixtralModelV1(dataset.root_theta, config)
else:
model = PagedLlamaModelV1(dataset.root_theta, config)
if args.save_intermediates_path:
archana-ramalingam marked this conversation as resolved.
Show resolved Hide resolved
from ..utils.patching import SaveModuleResultTensorsPatch

Expand Down
144 changes: 144 additions & 0 deletions sharktank/sharktank/examples/validate_direct_mixtral_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# Copyright 2024 Advanced Micro Devices, Inc
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import sys

import torch

from sharktank.layers import *
from sharktank.types import *
from sharktank.models.mixtral.mixtral import *


def main(args: list[str]):
from ..utils import cli

torch.no_grad().__enter__()

parser = cli.create_parser()
cli.add_input_dataset_options(parser)
args = cli.parse(parser)

dataset = cli.get_input_dataset(args)
hp = configs.LlamaHParams.from_gguf_props(dataset.properties)
llama_config = LlamaModelConfig(hp)
llama_config.kv_cache_type = "direct"
llama_config.activation_dtype = torch.float16
model = PagedMixtralModelV1(dataset.root_theta, llama_config)

# bs ("batch size") == 1
cache_state = model.cache.allocate(bs=1)

start_index = 0
tokens = torch.tensor(
[
[
1,
1059,
31871,
1217,
322,
266,
3682,
6075,
31902,
13,
31849,
31871,
0,
0,
0,
0,
]
+ 48 * [0],
]
)
assert tokens.shape[1] % model.cache.block_seq_stride == 0
seq_block_ids = torch.tensor(
[
[127, 0, 0, 0],
]
)

# Important: Do not use a sequence length of 0 for empty batch slots
# as it will cause softmax to nan due to a mask of all -inf. This then
# propagates and causes badness.
seq_lens = torch.tensor([12])

attention_mask = model.attention_mask(
model.input_mask(seq_lens, tokens.shape[1]),
)

print(f"Step {start_index}")
logits = model.prefill(
tokens,
attention_mask=attention_mask,
seq_block_ids=seq_block_ids,
cache_state=cache_state,
)
# TODO: Normalize the output of extract_tokens_from_logits into tensor [bs, 1].
tokens = torch.tensor(model.extract_tokens_from_logits(logits, seq_lens)).unsqueeze(
1
)
print(f" : tokens = {tokens}")

# Decode a step.
print("Decoding...")
print(tokens.shape, tokens)
start_positions = torch.tensor([12])
seq_lens = seq_lens + 1
decode_attention_mask = model.decode_attention_mask(
model.input_mask(
seq_lens,
seq_block_ids.shape[1] * model.cache.block_seq_stride,
),
)
logits = model.decode(
tokens,
attention_mask=decode_attention_mask,
start_positions=start_positions,
seq_block_ids=seq_block_ids,
cache_state=cache_state,
)
tokens = torch.tensor(model.extract_tokens_from_logits(logits, [1])).unsqueeze(1)
print(f" : tokens = {tokens}")

def save_prefill_module(model):
from iree.compiler.extras.fx_importer import FxImporter
from iree.compiler.ir import AsmState

importer = FxImporter()

print("Generating FX graph")

class InferenceModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.add_module("prefill", model)

def forward(self, tokens, attention_mask, seq_block_ids, *cache_state):
return self.prefill.prefill(
tokens,
attention_mask=attention_mask,
seq_block_ids=seq_block_ids,
cache_state=list(cache_state),
)

infmod = InferenceModule()
prog = torch.export.export(
infmod, (tokens, attention_mask, seq_block_ids) + tuple(cache_state)
)

print(f"FX prog:", prog)
importer.import_program(prog, func_name="prefill")
output_file = "/tmp/prefill.mlirbc"
print("Saving to:", output_file)
with open(output_file, "wb") as f:
importer.module_op.write_bytecode(f)


if __name__ == "__main__":
sys.exit(main(sys.argv[1:]))
48 changes: 48 additions & 0 deletions sharktank/sharktank/examples/validate_mixtral_ref_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright 2024 Advanced Micro Devices, Inc
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import sys

import torch

from sharktank.layers import *
from sharktank.types import *
from sharktank.models.mixtral.mixtral_ref import *


def main(args: list[str]):
from ..utils import cli

torch.no_grad().__enter__()

parser = cli.create_parser()
cli.add_input_dataset_options(parser)
args = cli.parse(parser)

dataset = cli.get_input_dataset(args)
hp = configs.LlamaHParams.from_gguf_props(dataset.properties)
ref_llama_config = RefLlamaModelConfig(hp)
ref_llama_config.activation_dtype = torch.float16
model = DirectCacheMixtralModelV1(dataset.root_theta, ref_llama_config)

kv_cache = model.create_cache(bs=1)
start_index = 0
next_tokens = [1, 1059, 31871, 1217, 322, 266, 3682, 6075, 31902, 13, 31849, 31871]
print(f"Step {start_index}")
tokens = model.forward(
torch.tensor([next_tokens]), start_index=start_index, local_kv_cache=kv_cache
)
print(f" : tokens = {tokens}")

# Decode a step.
print("Decoding...")
print(tokens.shape, tokens)
decode_token = model.forward(tokens, start_index=12, local_kv_cache=kv_cache)
print(f" : decode tokens = {decode_token}")


if __name__ == "__main__":
sys.exit(main(sys.argv[1:]))
5 changes: 5 additions & 0 deletions sharktank/sharktank/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,10 @@
from .norm import RMSNormLayer
from .rotary_embedding import RotaryEmbeddingLayer
from .token_embedding import TokenEmbeddingLayer
from .llama_attention_block import LlamaAttentionBlock
from .paged_llama_attention_block import PagedLlamaAttentionBlock
from .ffn_block import FFN
from .ffn_moe_block import FFNMOE
from .mixture_of_experts_block import SparseMoeBlock

from . import configs
5 changes: 1 addition & 4 deletions sharktank/sharktank/layers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,8 @@
from ..utils import debugging

__all__ = [
"LinearLayer",
"RotaryEmbeddingLayer",
"RMSNormLayer",
"BaseLayer",
"ThetaLayer",
"TokenEmbedding",
]


Expand Down
32 changes: 26 additions & 6 deletions sharktank/sharktank/layers/configs/llm_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@

import torch

__all__ = [
"LlamaHParams",
]
__all__ = ["LlamaHParams"]


@dataclass
Expand All @@ -36,14 +34,21 @@ class LlamaHParams:
block_count: int
feed_forward_length: int
rope_dimension_count: int
rope_freq_base: float
attention_head_count: int
attn_head_dim: int
attention_layer_norm_rms_epsilon: float
attention_head_count_kv: int
expert_count: int
expert_used_count: int

@staticmethod
def from_gguf_props(p: dict[str, Any]):
default_expert_count = 0
default_expert_used_count = 0
default_rope_freq_base = 10000.0
attention_head_count = _int_prop(p, "llama.attention.head_count")

return LlamaHParams(
context_length=_int_prop(p, "llama.context_length"),
embedding_length=_int_prop(p, "llama.embedding_length"),
Expand All @@ -58,6 +63,15 @@ def from_gguf_props(p: dict[str, Any]):
attention_head_count_kv=_optional_int_prop(
p, "llama.attention.head_count_kv", attention_head_count
),
rope_freq_base=_optional_float_prop(
p, "llama.rope.freq_base", default_rope_freq_base
),
expert_count=_optional_int_prop(
p, "llama.expert_count", default_expert_count
),
expert_used_count=_optional_int_prop(
p, "llama.expert_used_count", default_expert_used_count
),
)


Expand All @@ -79,10 +93,16 @@ def _int_prop(p: dict[str, Any], name: str) -> int:
raise KeyError(f"Property '{name}' not found (among keys {p.keys()})")


def _optional_float_prop(p: dict[str, Any], name: str, default_value: float) -> float:
value = p.get(name, default_value)
try:
return float(value)
except ValueError as e:
raise ValueError(f"Property '{name}' expected to be a float and was not") from e


def _optional_int_prop(p: dict[str, Any], name: str, default_value: int) -> int:
value = p[name]
if value is None:
return default_value
value = p.get(name, default_value)
try:
return int(value)
except ValueError as e:
Expand Down
38 changes: 38 additions & 0 deletions sharktank/sharktank/layers/ffn_block.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright 2024 Advanced Micro Devices, Inc
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from typing import Optional

import torch
import torch.nn.functional as F

from .base import Theta, ThetaLayer
from .linear import LinearLayer

__all__ = [
"FFN",
]


class FFN(ThetaLayer):
def __init__(
self,
theta: Theta,
):
super().__init__(theta)

self.add_module("ffn_gate", LinearLayer(theta("ffn_gate")))
self.add_module("ffn_up", LinearLayer(theta("ffn_up")))
self.add_module("ffn_down", LinearLayer(theta("ffn_down")))

def forward(
self,
h: torch.Tensor,
):
ffn_gate = F.silu(self.ffn_gate(h))
ffn_up = self.ffn_up(h)
ffn_down = self.ffn_down(ffn_gate * ffn_up)
return ffn_down
Loading
Loading