Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

moe #162

Merged
merged 26 commits into from
Sep 5, 2024
Merged

moe #162

Changes from 1 commit
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
e036804
Add Mixtral LLM
archana-ramalingam May 21, 2024
3110119
Refactoring attention, moe and ffn blocks
archana-ramalingam May 22, 2024
d1691c3
Allow _optional_int_prop to handle missing hyperparameters
archana-ramalingam May 22, 2024
a865ac3
Fixing circular dep and imports
archana-ramalingam May 23, 2024
3496258
Fix multiple expert layer weight handling + other issues
archana-ramalingam May 29, 2024
2b32fba
Add ffn_moe layers and other fixes
archana-ramalingam Jun 13, 2024
15f2a22
Edit theta slicing
archana-ramalingam Jun 13, 2024
0f155c5
Fix ffn_moe theta parsing & wraping
archana-ramalingam Jun 14, 2024
4a8bb97
Extract tensor unmerging into a function
archana-ramalingam Jun 14, 2024
36eb868
Cleaning up debug statements
archana-ramalingam Aug 19, 2024
58890f9
Fix test failure
archana-ramalingam Aug 19, 2024
99186fd
Add rope_freq_base to llama
archana-ramalingam Aug 19, 2024
c66cbe5
Rebase and fixes
IanNod Aug 28, 2024
0bc76f6
Add missing grok layers
archana-ramalingam Aug 29, 2024
96de75d
adds a test for exporting moe block
dan-garvey Sep 3, 2024
f323792
actually add the test
dan-garvey Sep 3, 2024
2cd365b
some fixes
dan-garvey Sep 5, 2024
67f112f
moe moe moe
dan-garvey Sep 5, 2024
12b2a7a
refactor paged llama
dan-garvey Sep 5, 2024
77163aa
fix format
dan-garvey Sep 5, 2024
5fba3de
rope_freq
dan-garvey Sep 5, 2024
b315fa3
saver
dan-garvey Sep 5, 2024
a2df6a4
address rope freq
dan-garvey Sep 5, 2024
47b14b6
fix llama attn
dan-garvey Sep 5, 2024
911b3a3
Merge branch 'main' into moe-wip
dan-garvey Sep 5, 2024
6a28481
add tensor name
dan-garvey Sep 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fixing circular dep and imports
archana-ramalingam authored and IanNod committed Aug 28, 2024
commit a865ac31b5420c60d463e9a11b42439f1ed53396
3 changes: 3 additions & 0 deletions sharktank/sharktank/layers/__init__.py
Original file line number Diff line number Diff line change
@@ -12,5 +12,8 @@
from .norm import RMSNormLayer
from .rotary_embedding import RotaryEmbeddingLayer
from .token_embedding import TokenEmbeddingLayer
from .attention_block import AttentionBlock
from .ffn_block import FFN
from .mixture_of_experts_block import SparseMoeBlock

from . import configs
6 changes: 4 additions & 2 deletions sharktank/sharktank/layers/attention_block.py
Original file line number Diff line number Diff line change
@@ -11,8 +11,10 @@
import torch
import torch.nn.functional as F

from ...layers import *
from ...types import Theta
from .base import Theta, ThetaLayer
from .linear import LinearLayer
from .norm import RMSNormLayer
from .rotary_embedding import RotaryEmbeddingLayer

__all__ = [
"AttentionBlock",
8 changes: 1 addition & 7 deletions sharktank/sharktank/layers/base.py
Original file line number Diff line number Diff line change
@@ -16,14 +16,8 @@
from ..utils import debugging

__all__ = [
"LinearLayer",
"RotaryEmbeddingLayer",
"RMSNormLayer",
"BaseLayer",
"ThetaLayer",
"TokenEmbedding",
"AttentionBlock",
"SparseMoeBlock",
"FFN",
]


3 changes: 2 additions & 1 deletion sharktank/sharktank/layers/ffn_block.py
Original file line number Diff line number Diff line change
@@ -7,7 +7,8 @@
import torch
import torch.nn.functional as F

from .base import Theta, ThetaLayer, LinearLayer
from .base import Theta, ThetaLayer
from .linear import LinearLayer

__all__ = [
"FFN",
5 changes: 4 additions & 1 deletion sharktank/sharktank/layers/mixture_of_experts_block.py
Original file line number Diff line number Diff line change
@@ -10,7 +10,10 @@
import torch.nn as nn
import torch.nn.functional as F

from .base import Theta, ThetaLayer, LinearLayer, RMSNormLayer, FFN
from .base import Theta, ThetaLayer
from .linear import LinearLayer
from .norm import RMSNormLayer
from .ffn_block import FFN

__all__ = [
"SparseMoeBlock",
29 changes: 24 additions & 5 deletions sharktank/sharktank/models/mixtral/mixtral_ref.py
Original file line number Diff line number Diff line change
@@ -6,6 +6,7 @@

from typing import Optional

from dataclasses import dataclass
import math

import torch
@@ -16,9 +17,24 @@
from ...types import Theta

__all__ = [
"RefLlamaModelConfig",
"DirectCacheMixtralModelV1",
]


################################################################################
# Config
################################################################################


@dataclass
class RefLlamaModelConfig:
hp: configs.LlamaHParams

# Dtype to use for general FP activations not otherwise configured.
activation_dtype: torch.dtype = torch.float16


################################################################################
# Models
################################################################################
@@ -27,12 +43,15 @@
class DirectCacheMixtralModelV1(ThetaLayer):
"""Simple Mixtral Model with a direct lookup KV cache for batch-1 inference."""

def __init__(self, theta: Theta, hp: configs.LlamaHParams):
def __init__(self, theta: Theta, config: RefLlamaModelConfig):
super().__init__(theta)
hp = config.hp
self.config = config
self.hp = hp
self.activation_dtype = config.activation_dtype
self.add_module(
"token_embedding",
TokenEmbeddingLayer(theta("token_embd"), dtype=hp.activation_dtype),
TokenEmbeddingLayer(theta("token_embd"), dtype=config.activation_dtype),
)
self.add_module(
"attention_embedding",
@@ -52,7 +71,7 @@ def __init__(self, theta: Theta, hp: configs.LlamaHParams):
self.attn_blocks = nn.ModuleList()

for n in range(hp.block_count):
attn_blocks.append(
self.attn_blocks.append(
AttentionBlock(
theta("attn_blk", n),
embedding=self.attention_embedding,
@@ -62,7 +81,7 @@ def __init__(self, theta: Theta, hp: configs.LlamaHParams):
rms_epsilon=hp.attention_layer_norm_rms_epsilon,
)
)
attn_blocks.append(
self.attn_blocks.append(
SparseMoeBlock(
theta("moe_blk", n),
num_experts=hp.expert_count,
@@ -80,7 +99,7 @@ def create_cache(self, bs: int) -> list[torch.Tensor]:
self.hp.attention_head_count,
self.hp.rope_dimension_count,
),
dtype=self.hp.activation_dtype,
dtype=self.activation_dtype,
)
for _ in range(self.hp.block_count * 2)
]