Skip to content

Commit

Permalink
Rework import quark dataset
Browse files Browse the repository at this point in the history
import directly to gguf format
  • Loading branch information
dan-garvey committed Aug 28, 2024
1 parent b56451d commit 7acf9d6
Show file tree
Hide file tree
Showing 3 changed files with 433 additions and 41 deletions.
89 changes: 48 additions & 41 deletions sharktank/sharktank/models/llama/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig):
self.activation_dtype = config.activation_dtype
self.use_hf = config.use_hf


key = "token_embd"
if key not in list(theta.keys):
key = "model.embed_tokens"
self.add_module(
"token_embedding",
TokenEmbeddingLayer(theta(key), dtype=config.activation_dtype),
Expand Down Expand Up @@ -173,16 +170,12 @@ def prefill(
seq_block_ids: torch.Tensor,
cache_state: list[torch.Tensor],
):
print("tokens.device: ")
print(tokens.device)
self._assert_device(tokens)
self._assert_device(attention_mask, dtype=self.activation_dtype)
self._assert_device(seq_block_ids)
self._assert_device(*cache_state, dtype=self.activation_dtype)
h = self.token_embedding(tokens)
self.trace_tensor("llama.token_embedding", h)
#with safe_open("/home/nod/cuda_dumps/model_layers_0_input_layernorm.safetensors", "pt") as f:
#h =f.get_tensor("input").to(self.device, dtype=torch.float16)

# Iterate over attention blocks.
for block_idx, block in enumerate(self.attn_blocks):
Expand Down Expand Up @@ -298,35 +291,68 @@ def __init__(
):
super().__init__(theta)
if "input_layernorm" in list(theta.keys):
# tensor = theta("self_attn.qkv.weight").tensor
# tensor = tensor.reshape(head_count_kv, head_count // head_count_kv + 2, head_dim, head_dim * head_count)
# print(tensor)
self.add_module(
"attn_norm", RMSNormLayer(theta("input_layernorm"), epsilon=rms_epsilon, debug_save_file=f"input_layernorm_{block_index}.safetensors")
"attn_norm",
RMSNormLayer(
theta("input_layernorm"),
epsilon=rms_epsilon,
debug_save_file=f"input_layernorm_{block_index}.safetensors",
),
)
self.add_module(
"attn_q",
LinearLayer(
theta("self_attn.q_proj"),
debug_save_file=f"attn_q_{block_index}.safetensors",
),
)
self.add_module(
"attn_k",
LinearLayer(
theta("self_attn.k_proj"),
debug_save_file=f"attn_k_{block_index}.safetensors",
),
)
self.add_module(
"attn_v",
LinearLayer(
theta("self_attn.v_proj"),
debug_save_file=f"attn_v_{block_index}.safetensors",
),
)
self.add_module(
"attn_output",
LinearLayer(
theta("self_attn.o_proj"),
debug_save_file=f"attn_output_{block_index}.safetensors",
),
)
# self.add_module("attn_qkv", LinearLayer(theta("self_attn.qkv")))
self.add_module("attn_q", LinearLayer(theta("self_attn.q_proj"), debug_save_file=f"attn_q_{block_index}.safetensors"))
print(f"self.attn_q.weightshape: {self.attn_q.weight.shape}")
self.add_module("attn_k", LinearLayer(theta("self_attn.k_proj"), debug_save_file=f"attn_k_{block_index}.safetensors"))
self.add_module("attn_v", LinearLayer(theta("self_attn.v_proj"), debug_save_file=f"attn_v_{block_index}.safetensors"))
self.add_module("attn_output", LinearLayer(theta("self_attn.o_proj"), debug_save_file=f"attn_output_{block_index}.safetensors"))
self.add_module(
"ffn_norm",
RMSNormLayer(theta("post_attention_layernorm"), epsilon=rms_epsilon),
)
self.add_module("ffn_gate", LinearLayer(theta("mlp.gate_proj")))
self.add_module("ffn_up", LinearLayer(theta("mlp.up_proj"), debug_save_file=f"ffn_up_{block_index}.safetensors"))
self.add_module("ffn_down", LinearLayer(theta("mlp.down_proj"), debug_save_file=f"ffn_down_{block_index}.safetensors"))
self.add_module(
"ffn_up",
LinearLayer(
theta("mlp.up_proj"),
debug_save_file=f"ffn_up_{block_index}.safetensors",
),
)
self.add_module(
"ffn_down",
LinearLayer(
theta("mlp.down_proj"),
debug_save_file=f"ffn_down_{block_index}.safetensors",
),
)
else:
self.add_module(
"attn_norm", RMSNormLayer(theta("attn_norm"), epsilon=rms_epsilon)
)
self.add_module("attn_q", LinearLayer(theta("attn_q")))
print(f"self.attn_q.weightshape: {self.attn_q.weight.shape}")
self.add_module("attn_k", LinearLayer(theta("attn_k")))
print(f"self.attn_k.weightshape: {self.attn_k.weight.shape}")
self.add_module("attn_v", LinearLayer(theta("attn_v")))
print(f"self.attn_v.weightshape: {self.attn_v.weight.shape}")
self.add_module("attn_output", LinearLayer(theta("attn_output")))
self.add_module(
"ffn_norm", RMSNormLayer(theta("ffn_norm"), epsilon=rms_epsilon)
Expand Down Expand Up @@ -360,8 +386,6 @@ def forward(
):
assert bool(start_index is not None) ^ bool(embedding_batch_mask is not None)
x = self.attn_norm(h)
#with safe_open("/home/nod/quant_linear.safetensors", "pt") as f:
# x = f.get_tensor("input").to(torch.float32)
bs, batch_seq_len, feature_dim = x.shape
assert feature_dim == self.head_count * self.head_dim
xq = self.attn_q(x)
Expand All @@ -371,19 +395,11 @@ def forward(
xq = xq.view(bs, batch_seq_len, self.head_count, self.head_dim)
xk = xk.view(bs, batch_seq_len, self.head_count_kv, self.head_dim)
xv = xv.view(bs, batch_seq_len, self.head_count_kv, self.head_dim)
#save_dict = {"xq": xq.detach().clone().contiguous(), "xk": xk.detach().clone().contiguous(), "xv": xv.detach().clone().contiguous()}
#with safe_open(f"/home/nod/cuda_dumps/hf_attn_block{self.block_index}.safetensors", "pt") as f:
#xq = f.get_tensor("xq").to("cuda:0").transpose(1,2)
#xk = f.get_tensor("xk").to("cuda:0").transpose(1,2)
#xv = f.get_tensor("xv").to("cuda:0").transpose(1,2)

# Fast path to start_index based embedding lookup if available.
# Falls back to a slower position based index lookup.
if start_index is not None:
print(f"using start index: {start_index}")
xq, xk = embedding.forward(xq=xq, xk=xk, start_index=start_index)
#save_dict["embed_xq"]= xq.detach().clone().contiguous()
#save_dict["embed_xk"]= xk.detach().clone().contiguous()
else:
xq, xk = embedding.apply_batched_mask(
xq=xq, xk=xk, mask=embedding_batch_mask
Expand Down Expand Up @@ -411,8 +427,6 @@ def forward(
kv_seq_len=kv_seq_len,
cache_state=cache_state,
)
#save_dict["cache_xk"] = xk.detach().clone().contiguous()
#save_dict["cache_xv"] = xv.detach().clone().contiguous()
else:
raise NotImplementedError(f"Unsupported KV cache type: {type(self.cache)}")

Expand All @@ -431,8 +445,6 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor:

xk = repeat_kv(xk)
xv = repeat_kv(xv)
#save_dict["repeat_xk"] = xk.detach().clone().contiguous()
#save_dict["repeat_xv"] = xv.detach().clone().contiguous()

# Transpose into [bs, heads, sl, dim]
xq = xq.transpose(1, 2)
Expand All @@ -441,21 +453,16 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor:

# Flash attention.
attn_weights = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
#save_dict["attn_weights"] = attn_weights.detach().clone().contiguous()
self.assert_not_nan(attn_weights)

# Apply attention mask.
self.trace_tensor("attn_weights", attn_weights, values=False)
if attention_mask is not None:
# self.trace_tensor("attn_mask", attention_mask)
attn_weights = attn_weights + attention_mask
#save_dict["masked_attn_weights"] = attn_weights.detach().clone().contiguous()
attn_weights = F.softmax(attn_weights.float(), dim=-1).type_as(xq)
#save_dict["softmax_attn_weights"] = attn_weights.detach().clone().contiguous()
attn_output = torch.matmul(attn_weights, values) # (bs, heads, slen, head_dim)
attn_output = attn_output.transpose(1, 2).reshape(bs, batch_seq_len, -1)
#save_dict["pre_projection"] = attn_output.detach().clone().contiguous()
#save_file(save_dict, f"attn_block_{self.block_index}.safetensors")
# Project.
attn_output = self.attn_output(attn_output)

Expand Down
Loading

0 comments on commit 7acf9d6

Please sign in to comment.