Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stickbreaking fix #136

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions dolomite_engine/hf_models/modeling_utils/linear.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from typing import Union

import torch
import torch.nn as nn
from torch.nn.common_types import _size_1_t, _size_2_t, _size_3_t


class ParameterizedLinear(nn.Linear):
Expand All @@ -23,3 +26,34 @@ def reset_parameters(self) -> None:
nn.init.normal_(self.weight, mean=0, std=self.std)
if hasattr(self, "bias") and self.bias is not None:
self.bias.zero_()


class ParameterizedConv1d(nn.Conv1d):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: _size_1_t,
stride: _size_1_t = 1,
padding: Union[str, _size_1_t] = 0,
dilation: _size_1_t = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = "zeros", # TODO: refine this type
device=None,
dtype=None,
std: float | None = None,
):
self.std = std
super().__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, device, dtype
)

@torch.no_grad()
def reset_parameters(self) -> None:
if self.std is None:
super().reset_parameters()
else:
nn.init.normal_(self.weight, mean=0, std=self.std)
if hasattr(self, "bias") and self.bias is not None:
self.bias.zero_()
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,12 @@ def __init__(
self.intermediate_size, self.hidden_size, bias=add_bias, std=std / math.sqrt(2 * num_layers)
)

@torch.no_grad()
def reset_parameters(self) -> None:
A = torch.arange(1, self.num_heads + 1)
self.A_log.data = torch.log(A)
nn.init.ones_(self.D)

def forward(
self,
input_states: torch.Tensor,
Expand Down Expand Up @@ -177,14 +183,16 @@ def forward(
# for batched generation
dt = dt[:, 0, :][:, None, ...]
dt = dt.transpose(1, 2).expand(batch_size, dt.shape[-1], self.head_dim)
# [num_heads] -> [num_heads, head_dim]
# [batch_size, num_heads] -> [batch_size, num_heads, head_dim]
dt_bias = self.dt_bias[..., None].expand(self.dt_bias.shape[0], self.head_dim)


dt = torch.nn.functional.softplus(dt + dt_bias.to(dt.dtype))
dt = torch.clamp(dt, self.time_step_limit[0], self.time_step_limit[1])
A = A[..., None, None].expand(self.num_heads, self.head_dim, self.ssm_state_size).to(dtype=torch.float32)
# [bsz, num_heads, head_dim, state_size]
# [num_heads, head_dim, state_size]
dA = (torch.exp(dt[..., None] * A)).to(device=cache_device)
# [bsz, num_heads, head_dim, state_size]

# Discretize B
# [bsz, n_groups * state_size] -> [bsz, n_groups, 1, state_size] ->
Expand Down Expand Up @@ -231,19 +239,23 @@ def forward(
else:
# begin ssd naive implementation without einsums
dt = nn.functional.softplus(dt + self.dt_bias)
# dt: [batch_size, seq_len, num_heads]
dt = torch.clamp(dt, self.time_step_limit[0], self.time_step_limit[1])
hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float()
# hidden_states: [batch_size, seq_len, num_heads, head_dim]
B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
# B, C: [batch_size, seq_len, n_groups, ssm_state_size]
B = B.repeat(1, 1, self.num_heads // self.n_groups, 1)
C = C.repeat(1, 1, self.num_heads // self.n_groups, 1)
# B, C: [batch_size, seq_len, num_heads, ssm_state_size]
pad_size = (self.chunk_size - seq_len % self.chunk_size) % self.chunk_size

D_residual = self.D[..., None] * _pad_tensor_by_size(hidden_states, pad_size)

# Discretize x and A
hidden_states = hidden_states * dt[..., None]
A = A.to(hidden_states.dtype) * dt
# A: [batch_size, seq_len, num_heads]

# Rearrange into blocks/chunks
hidden_states, A, B, C = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from stickbreaking_attention import sb_attn, sb_attn_varlen


@torch.compile
def decoding_stickbreaking(q, k, v, scale=None):
"""
Stick-breaking attention weights.
Expand All @@ -27,20 +26,14 @@ def decoding_stickbreaking(q, k, v, scale=None):
original_dtype = q.dtype
q = q.float()
k = k.float()
# logits = torch.einsum('bhid,bhjd->bhij', q, k[..., :-1, :]) * scale
logits = q @ k[..., :-1, :].transpose(-1, -2) * scale
# logits = logits.float()
log_z = F.logsigmoid(logits).to(original_dtype)
log_beta = F.logsigmoid(-logits).to(original_dtype)
# re_cum_log_beta = log_beta.sum(dim=-1, keepdim=True) - log_beta.cumsum(dim=-1)
re_cum_log_beta = log_beta.flip(-1).cumsum(dim=-1).flip(-1) - log_beta
# re_cum_log_beta = log_beta.sum(dim=-1, keepdim=True) - log_beta.cumsum(dim=-1)
log_att = log_z + re_cum_log_beta
# print("log_att", log_att[0, 0, 0, -20:])
att = log_att.exp()
# print("att ", att[0, 0, 0, -20:])
out = torch.einsum("bhij,bhjd->bhid", att, v[..., :-1, :])
# out = att @ v[..., :-1, :]
att: torch.Tensor = log_att.exp()
v = v[..., :-1, :]
out = torch.einsum("bhij,bhjd->bhid", att, v)
return out, 1 - att.sum(dim=-1)


Expand Down Expand Up @@ -93,7 +86,7 @@ def forward(
max_seqlen: torch.Tensor | None = None,
sb_metadata=None,
) -> torch.Tensor:
assert past_key_values is None
# assert past_key_values is None

query, key, value = self._prepare_qkv_for_forward(hidden_states)
softmax_scale = self._get_softmax_scale()
Expand Down Expand Up @@ -207,9 +200,11 @@ def _prepare_qkv_for_forward_gqa(

# this needs to be a reshape instead of view sadly
query = query.reshape(total_q, -1, self.head_dim)
key = key.repeat(1, self.num_heads // self.num_key_value_heads, 1)
value = value.repeat(1, self.num_heads // self.num_key_value_heads, 1)

# key = key.repeat(1, self.num_heads // self.num_key_value_heads, 1)
# value = value.repeat(1, self.num_heads // self.num_key_value_heads, 1)
group_size = self.num_heads // self.num_key_value_heads
key = key.repeat_interleave(repeats=group_size, dim=1)
value = value.repeat_interleave(repeats=group_size, dim=1)
return query, key, value

def _prepare_qkv_for_forward_mqa(
Expand Down
24 changes: 15 additions & 9 deletions dolomite_engine/optimization/params_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,9 @@ def get_mup_group_with_names(model: ModelWrapper, optimizer_class_args: dict) ->
model = model.model

normal_params = {}
no_weight_decay_params = {}
normal_no_weight_decay_params = {}
mup_params = {}
mup_no_weight_decay_params = {}

# collect parameters with mup learning rate
for module_name, module in model.named_modules():
Expand All @@ -106,19 +107,19 @@ def get_mup_group_with_names(model: ModelWrapper, optimizer_class_args: dict) ->
mup_params[f"{module_name}.{param_name}"] = param
elif isinstance(module, (nn.LayerNorm, nn.RMSNorm)) or module.__class__.__name__.lower().endswith("norm"):
for param_name, param in module.named_parameters():
no_weight_decay_params[f"{module_name}.{param_name}"] = param
normal_no_weight_decay_params[f"{module_name}.{param_name}"] = param

# remove biases from weight decay
for param_name, param in model.named_parameters():
if param_name not in no_weight_decay_params and param_name.endswith("bias"):
no_weight_decay_params[param_name] = param
if param_name not in normal_no_weight_decay_params and param_name.endswith("bias"):
normal_no_weight_decay_params[param_name] = param

# collect parameters without mup learning rate
for param_name, param in model.named_parameters():
if param_name not in mup_params and param_name not in no_weight_decay_params:
if param_name not in mup_params and param_name not in normal_no_weight_decay_params and param_name not in mup_no_weight_decay_params:
normal_params[param_name] = param

assert len(normal_params) + len(no_weight_decay_params) + len(mup_params) == len(
assert len(normal_params) + len(normal_no_weight_decay_params) + len(mup_params) + len(mup_no_weight_decay_params) == len(
list(model.parameters())
), "params in groups don't sum up to total parameters"

Expand All @@ -128,16 +129,21 @@ def get_mup_group_with_names(model: ModelWrapper, optimizer_class_args: dict) ->
if len(normal_params) > 0:
trainable_parameters_or_param_groups.append({"params": list(normal_params.values())})
names["normal"] = list(normal_params.keys())
if len(no_weight_decay_params) > 0:
if len(normal_no_weight_decay_params) > 0:
trainable_parameters_or_param_groups.append(
{"params": list(no_weight_decay_params.values()), "weight_decay": 0}
{"params": list(normal_no_weight_decay_params.values()), "weight_decay": 0}
)
names["no_weight_decay"] = list(no_weight_decay_params.keys())
names["normal_no_weight_decay"] = list(normal_no_weight_decay_params.keys())
if len(mup_params) > 0:
trainable_parameters_or_param_groups.append(
{"params": list(mup_params.values()), "lr": optimizer_class_args["lr"] / model.config.m_width}
)
names["mup"] = list(mup_params.keys())
if len(mup_no_weight_decay_params) > 0:
trainable_parameters_or_param_groups.append(
{"params": list(mup_no_weight_decay_params.values()), "lr": optimizer_class_args["lr"] / model.config.m_width, "weight_decay": 0}
)
names["mup_no_weight_decay"] = list(mup_no_weight_decay_params.keys())

return trainable_parameters_or_param_groups, names

Expand Down
4 changes: 3 additions & 1 deletion dolomite_engine/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,10 @@ def get_model_tflops(
block = config.sequence_mixer_blocks[layer_idx]
sequence_mixer_type = block.sequence_mixer_type

if sequence_mixer_type == "softmax_attention":
if sequence_mixer_type in ["softmax_attention", "stickbreaking_attention"]:
attention_flops = 4 * b * s * h * (h * (1 + block.num_key_value_heads / n) + s)
elif sequence_mixer_type in "mamba2":
attention_flops = 4 * b * s * h * h * block.num_heads / n
else:
raise NotImplementedError(f"unexpected sequence_mixer_type ({sequence_mixer_type})")

Expand Down
Loading