Skip to content

Commit

Permalink
Add import/run tooling based on the default deepseek implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
rsuderman committed Feb 4, 2025
1 parent f652668 commit 59b0b47
Show file tree
Hide file tree
Showing 2 changed files with 250 additions and 0 deletions.
45 changes: 45 additions & 0 deletions sharktank/sharktank/models/deepseek/run_deepseek.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright 2025 Advanced Micro Devices, Inc
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception


from sharktank.types.theta import Dataset
from sharktank.models.deepseek.deepseek import PagedDeepseekModelV1
from sharktank.models.deepseek.toy_deepseek import generate
from sharktank.models.llama.llama import LlamaModelConfig, LlamaHParams

import argparse
import math
import numpy
import torch

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--ids-path", type=str, required=True)
parser.add_argument("--irpa-path", type=str, required=True)
parser.add_argument("--results-path", type=str)

args = parser.parse_args()

dataset = Dataset.load(args.irpa_path)
properties = dataset.properties
theta = dataset.root_theta

config = LlamaModelConfig(
hp=LlamaHParams(**properties["hparams"]),
block_seq_stride=8,
activation_dtype=torch.float32,
attention_dtype=torch.float32,
)

model = PagedDeepseekModelV1(theta=theta, config=config)
x = torch.from_numpy(numpy.load(args.ids_path))
results = model.prefill(tokens=x)

if args.results_path:
expected = torch.from_numpy(numpy.load(args.results_path))
diff = expected - results
sqdiff = math.sqrt(torch.sum(diff * diff) / diff.numel())
print(f"Squared error {sqdiff}")
205 changes: 205 additions & 0 deletions sharktank/sharktank/tools/import_deepseek.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
import argparse
import dataclasses
import json
import logging
import torch

from safetensors.torch import save_file, safe_open
from sharktank.layers.configs.llm_configs import LlamaHParams
from sharktank.types import Dataset, Theta
from sharktank.types.tensors import DefaultPrimitiveTensor
from typing import Literal


@dataclasses.dataclass
class ModelArgs:
"""
Data class for defining model arguments and hyperparameters.
Attributes:
max_batch_size (int): Maximum batch size.
max_seq_len (int): Maximum sequence length.
dtype (Literal["bf16", "fp8"]): Data type for computations.
vocab_size (int): Vocabulary size.
dim (int): Model dimension.
inter_dim (int): Intermediate dimension for MLP layers.
moe_inter_dim (int): Intermediate dimension for MoE layers.
n_layers (int): Number of transformer layers.
n_dense_layers (int): Number of dense layers in the model.
n_heads (int): Number of attention heads.
n_routed_experts (int): Number of routed experts for MoE layers.
n_shared_experts (int): Number of shared experts for MoE layers.
n_activated_experts (int): Number of activated experts in MoE layers.
n_expert_groups (int): Number of expert groups.
n_limited_groups (int): Number of limited groups for MoE routing.
score_func (Literal["softmax", "sigmoid"]): Scoring function for MoE routing.
route_scale (float): Scaling factor for routing scores.
q_lora_rank (int): LoRA rank for query projections.
kv_lora_rank (int): LoRA rank for key-value projections.
qk_nope_head_dim (int): Dimension for query-key projections without positional embeddings.
qk_rope_head_dim (int): Dimension for query-key projections with rotary embeddings.
v_head_dim (int): Dimension for value projections.
original_seq_len (int): Original sequence length.
rope_theta (float): Base for rotary positional encoding.
rope_factor (float): Scaling factor for extended sequence lengths.
beta_fast (int): Fast beta correction factor.
beta_slow (int): Slow beta correction factor.
mscale (float): Scaling factor for extended attention.
"""

max_batch_size: int = 8
max_seq_len: int = 4096 * 4
dtype: Literal["bf16", "fp8"] = "bf16"
vocab_size: int = 102400
dim: int = 2048
inter_dim: int = 10944
moe_inter_dim: int = 1408
n_layers: int = 27
n_dense_layers: int = 1
n_heads: int = 16
# moe
n_routed_experts: int = 64
n_shared_experts: int = 2
n_activated_experts: int = 6
n_expert_groups: int = 1
n_limited_groups: int = 1
score_func: Literal["softmax", "sigmoid"] = "softmax"
route_scale: float = 1.0
# mla
q_lora_rank: int = 0
kv_lora_rank: int = 512
qk_nope_head_dim: int = 128
qk_rope_head_dim: int = 64
v_head_dim: int = 128
# yarn
original_seq_len: int = 4096
rope_theta: float = 10000.0
rope_factor: float = 40
beta_fast: int = 32
beta_slow: int = 1
mscale: float = 1.0


if __name__ == "__main__":
logger = logging.getLogger(__name__)

parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, required=True)
parser.add_argument("--safetensors", type=str, required=True)
parser.add_argument("--irpa-path", type=str, required=True)
parser.add_argument("--json-path", type=str, required=True)
args = parser.parse_args()

config = json.load(open(args.config, "r"))
modelargs = ModelArgs(**config)
hp = LlamaHParams(
model_arch="deepseek_v3",
context_length=modelargs.original_seq_len,
embedding_length=modelargs.vocab_size,
block_count=modelargs.n_layers,
feed_forward_length=modelargs.inter_dim,
attention_head_count=modelargs.n_heads,
attn_head_dim=modelargs.dim,
attention_layer_norm_rms_epsilon=1e-6,
attention_head_count_kv=-1,
rope_freq_base=modelargs.rope_theta,
expert_count=modelargs.n_routed_experts,
expert_used_count=modelargs.n_activated_experts,
expert_score_func=modelargs.score_func,
rope_dimension_count=modelargs.qk_rope_head_dim,
route_scale=modelargs.route_scale,
)

x = torch.randint(0, modelargs.vocab_size, (2, 16))

st_path = args.safetensors
json_path = args.json_path
irpa_path = args.irpa_path

st = safe_open(st_path, framework="pt")

baseMapping = {
"token_embd.weight": "embed.weight",
"output_norm.weight": "norm.weight",
"output.weight": "head.weight",
}

attnMapping = {
"attn.kv_norm.weight": "kv_norm.weight",
"attn.wkv_a.weight": "wkv_a.weight",
"attn.wkv_b.weight": "wkv_b.weight",
"attn.wo.weight": "wo.weight",
"attn.wq.weight": "wq.weight",
"attn_norm.weight": "attn_norm.weight",
"ffn.w1.weight": "ffn.ffn_gate.weight",
"ffn.w2.weight": "ffn.ffn_down.weight",
"ffn.w3.weight": "ffn.ffn_up.weight",
"ffn_norm.weight": "ffn_norm.weight",
"ffn.gate.weight": "moe.ffn_gate_inp.weight",
"ffn.shared_experts.w1.weight": "moe.shared_experts.ffn_gate.weight",
"ffn.shared_experts.w2.weight": "moe.shared_experts.ffn_down.weight",
"ffn.shared_experts.w3.weight": "moe.shared_experts.ffn_up.weight",
}

expertMapping = {
"w1.weight": "ffn_gate_exps.weight",
"w2.weight": "ffn_down_exps.weight",
"w3.weight": "ffn_up_exps.weight",
}

tensors = {}
for key in baseMapping:
tensors[key] = st.get_tensor(baseMapping[key])

layers = {}
for key in st.keys():
parts = key.split(".", 2)
if parts[0] != "layers":
continue
layer = int(parts[1])
if layer not in layers:
layers[layer] = {}
layers[layer][parts[2]] = st.get_tensor(key)

for layer in layers:
weights = layers[layer]
experts = {}
for name in weights:
weight = weights[name]
if name in attnMapping:
tensors[f"blk.{layer}.{attnMapping[name]}"] = weight
continue

if name.startswith("ffn.experts."):
split = name.split(".", 3)
id = int(split[2])
if id not in experts:
experts[id] = {}
experts[id][split[3]] = weight
continue
assert False and "unhandled tensor found"

expert_keys = experts[0].keys() if experts else []
for key in expert_keys:
exs = [experts[expert][key] for expert in experts]
tensor = torch.stack(exs, dim=0)
for t in exs:
del t
newKey = expertMapping[key]
tensors[f"blk.{layer}.moe.{newKey}"] = tensor

config_json = dataclasses.asdict(hp)
meta_params = {k: v for k, v in config_json.items() if k.startswith("_")}
hparams = {k: v for k, v in config_json.items() if not k.startswith("_")}
props = {
"meta": meta_params,
"hparams": hparams,
}

tensors = [
DefaultPrimitiveTensor(name=name, data=tensors[name]) for name in tensors.keys()
]
theta = Theta(tensors)

dataset = Dataset(props, theta)
dataset.save(irpa_path, io_report_callback=logger.info)

0 comments on commit 59b0b47

Please sign in to comment.