Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add 'as e' to selective_scan_interface.py #2

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Binary file added .DS_Store
Binary file not shown.
191 changes: 176 additions & 15 deletions attention.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,171 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Union
import transformer_engine as te
from typing import Union, Optional
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from rotary import *
from enums import AttnMaskType

class CustomDotProductAttention(nn.Module):
"""
Memory-efficient dot product attention implementation.
Optimized for both training and inference, supporting causal and non-causal attention.
"""

def __init__(
self,
num_attention_heads: int,
kv_channels: int,
attention_dropout: float = 0.0,
causal: bool = False
):
super().__init__()
if not isinstance(num_attention_heads, int) or num_attention_heads <= 0:
raise ValueError(f"num_attention_heads must be positive integer, got {num_attention_heads}")
if not isinstance(kv_channels, int) or kv_channels <= 0:
raise ValueError(f"kv_channels must be positive integer, got {kv_channels}")
if not 0.0 <= attention_dropout < 1.0:
raise ValueError(f"attention_dropout must be in [0.0, 1.0), got {attention_dropout}")

self.num_attention_heads = num_attention_heads
self.kv_channels = kv_channels
self.dropout_p = attention_dropout
self.causal = causal

# Register scaling factor - use fp32 for numerical stability
self.register_buffer(
'scale_factor',
torch.tensor(1.0 / math.sqrt(kv_channels), dtype=torch.float32),
persistent=False
)

def _check_inputs(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor):
"""Validate input tensors shapes and types."""
if query.dim() != 4 or key.dim() != 4 or value.dim() != 4:
raise ValueError(
f"Expected 4D tensors, got query: {query.dim()}D, key: {key.dim()}D, value: {value.dim()}D"
)

seq_len_q, batch_size, num_heads_q, head_dim = query.shape
seq_len_k, batch_size_k, num_heads_k, _ = key.shape
seq_len_v, batch_size_v, num_heads_v, _ = value.shape

if not (batch_size == batch_size_k == batch_size_v):
raise ValueError(f"Batch sizes must match: {batch_size}, {batch_size_k}, {batch_size_v}")

if not num_heads_q == self.num_attention_heads:
raise ValueError(f"Query heads {num_heads_q} != expected heads {self.num_attention_heads}")

if not head_dim == self.kv_channels:
raise ValueError(f"Head dimension {head_dim} != expected {self.kv_channels}")

if not (seq_len_k == seq_len_v):
raise ValueError(f"Key/Value sequence lengths must match: {seq_len_k}, {seq_len_v}")

return seq_len_q, seq_len_k

def _create_attention_bias(
self,
L: int, # query sequence length
S: int, # key sequence length
dtype: torch.dtype,
device: torch.device,
attention_mask: Optional[torch.Tensor] = None,
inference_params: Optional[dict] = None,
) -> torch.Tensor:
"""Create attention bias incorporating causal and attention masks."""
# Initialize bias with zeros of proper shape
attn_bias = torch.zeros(L, S, dtype=dtype, device=device)

# Apply causal mask if needed and not in generation mode
if self.causal and not (inference_params and inference_params.sequence_len_offset > 0):
# Create and apply causal mask efficiently
causal_mask = torch.triu(
torch.ones(L, S, dtype=torch.bool, device=device),
diagonal=1
)
attn_bias.masked_fill_(causal_mask, float("-inf"))

# Apply attention mask if provided
if attention_mask is not None:
if attention_mask.dtype == torch.bool:
attn_bias.masked_fill_(attention_mask.logical_not(), float("-inf"))
else:
# Handle additive attention mask
attn_bias = attn_bias + attention_mask.to(dtype=dtype)

return attn_bias

def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
inference_params: Optional[dict] = None,
) -> torch.Tensor:
"""
Compute scaled dot-product attention.

Args:
query: shape [seq_len_q, batch_size, num_heads, head_dim]
key: shape [seq_len_kv, batch_size, num_heads, head_dim]
value: shape [seq_len_kv, batch_size, num_heads, head_dim]
attention_mask: Optional mask tensor
inference_params: Optional inference parameters

Returns:
Output tensor of shape [seq_len_q, batch_size, num_heads, head_dim]
"""
# Input validation
L, S = self._check_inputs(query, key, value)

# Compute attention scores with automatic mixed precision handling
scale = self.scale_factor.to(query.dtype)

# Efficient attention computation
attn_weight = torch.matmul(query, key.transpose(-2, -1))
attn_weight = attn_weight * scale

# Create attention bias
attn_bias = self._create_attention_bias(
L, S,
dtype=query.dtype,
device=query.device,
attention_mask=attention_mask,
inference_params=inference_params
)

# Add bias more efficiently using a single view operation
attn_weight = attn_weight + attn_bias.view(L, 1, 1, S)

# Compute attention probabilities
attn_weight = F.softmax(attn_weight, dim=-1, dtype=torch.float32)

# Cast back to input dtype if necessary
if attn_weight.dtype != query.dtype:
attn_weight = attn_weight.to(query.dtype)

# Apply dropout during training only
if self.training and self.dropout_p > 0:
if not (inference_params and getattr(inference_params, 'no_dropout', False)):
attn_weight = F.dropout(attn_weight, p=self.dropout_p, training=True)

# Compute output efficiently
output = torch.matmul(attn_weight, value)

return output

def extra_repr(self) -> str:
"""Returns a string containing extra information about the module."""
return (f'num_attention_heads={self.num_attention_heads}, '
f'kv_channels={self.kv_channels}, '
f'attention_dropout={self.dropout_p}, '
f'causal={self.causal}')

class CausalSelfAttention(nn.Module):

def __init__(self, config, layer_number, attn_mask_type=AttnMaskType.padding, **kwargs):
Expand All @@ -25,16 +184,18 @@ def __init__(self, config, layer_number, attn_mask_type=AttnMaskType.padding, **
self.hidden_size_per_attention_head = self.query_projection_size // self.config.num_attention_heads
self.num_attention_heads_per_partition = self.config.num_attention_heads
self.num_query_groups_per_partition = self.config.num_query_groups
self.dpa = te.pytorch.DotProductAttention(num_attention_heads=self.config.num_attention_heads,
kv_channels=self.config.kv_channels,
attention_dropout=0.0,
layer_number=layer_number,
attn_mask_type="causal"
)
self.dpa_generation = te.pytorch.DotProductAttention(num_attention_heads=self.config.num_attention_heads,
kv_channels =self.config.kv_channels,
attention_dropout=0.0, layer_number=layer_number,
attn_mask_type="no_mask")
self.dpa = CustomDotProductAttention(
num_attention_heads=self.config.num_attention_heads,
kv_channels=self.config.kv_channels,
attention_dropout=0.0,
causal=True
)
self.dpa_generation = CustomDotProductAttention(
num_attention_heads=self.config.num_attention_heads,
kv_channels=self.config.kv_channels,
attention_dropout=0.0,
causal=False
)

if self.config.use_shared_attention_lora:
self.linear_q_lora_A_list = nn.ParameterList([])
Expand Down Expand Up @@ -188,10 +349,10 @@ def forward(self, hidden_states, attention_mask, key_value_states=None, inferenc


if inference_params is None or inference_params.sequence_len_offset == 0:
y = self.dpa(query, key, value)
y = self.dpa(query, key, value, attention_mask=attention_mask, inference_params=inference_params)
else:
y = self.dpa_generation(query, key, value)
y = self.dpa_generation(query, key, value, attention_mask=attention_mask, inference_params=inference_params)

y = self.linear_proj(y)

return y
return y
23 changes: 18 additions & 5 deletions hf_utils.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,30 @@
import json

import os
import torch

from transformers.utils import CONFIG_NAME
from transformers.utils.hub import cached_file

HF_TOKEN = os.getenv("HF_TOKEN")

def load_config_hf(model_name):
resolved_archive_file = cached_file(model_name, CONFIG_NAME, _raise_exceptions_for_missing_entries=True, force_download=False)
resolved_archive_file = cached_file(
model_name,
CONFIG_NAME,
token=HF_TOKEN,
_raise_exceptions_for_missing_entries=True,
force_download=False
)
return json.load(open(resolved_archive_file))

def load_state_dict_hf(model_name, device=None, dtype=None):
mapped_device = "cpu" if dtype not in [torch.float32, None] else device
WEIGHTS_NAME = "Zamba2_2p7b_direct_from_pytorch.pt"
resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=True, force_download=False)
return torch.load(resolved_archive_file, map_location=mapped_device)
WEIGHTS_NAME = "pytorch_model.bin"
resolved_archive_file = cached_file(
model_name,
WEIGHTS_NAME,
token=HF_TOKEN,
_raise_exceptions_for_missing_entries=True,
force_download=False
)
return torch.load(resolved_archive_file, map_location=mapped_device)
17 changes: 17 additions & 0 deletions hf_utils_old.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import json

import torch

from transformers.utils import CONFIG_NAME
from transformers.utils.hub import cached_file


def load_config_hf(model_name):
resolved_archive_file = cached_file(model_name, CONFIG_NAME, _raise_exceptions_for_missing_entries=True, force_download=False)
return json.load(open(resolved_archive_file))

def load_state_dict_hf(model_name, device=None, dtype=None):
mapped_device = "cpu" if dtype not in [torch.float32, None] else device
WEIGHTS_NAME = "/workspace/Zamba2-1.2B/pytorch_model.bin"
resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=True, force_download=False)
return torch.load(resolved_archive_file, map_location=mapped_device)
Loading