Skip to content

Commit

Permalink
spelling
Browse files Browse the repository at this point in the history
  • Loading branch information
tlrmchlsmth committed Aug 27, 2024
1 parent c2cd071 commit bc9b5cf
Showing 1 changed file with 22 additions and 21 deletions.
43 changes: 22 additions & 21 deletions vllm/model_executor/models/mamba2.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,28 +65,30 @@ def __init__(self, hidden_size, eps=1e-6):
def forward(self, hidden_states, gate=None):
input_dtype = hidden_states.dtype
input_shape = hidden_states.shape
hidden_states = hidden_states.to(torch.float32).view(-1, hidden_states.shape[-1])
hidden_states = hidden_states.to(torch.float32).view(
-1, hidden_states.shape[-1])

if gate is not None:
hidden_states = hidden_states * nn.functional.silu(
gate.to(torch.float32))

# Use Welford's online algorithm for caculating the variance in the
# Use Welford's online algorithm for caculating the variance in the
# tensor parallel setting, as the hidden_states are sharded along the
# same axis as we are calculating the variance along.
# same axis as we are calculating the variance along.
if self.tp_size > 1:
# Calculate local sum and squared_sum
local_sums = torch.zeros((hidden_states[0], 3), hidden_state.dtype, hidden_state.device)
local_sums[:,0] = hidden_states.sum(-1, keep_dim=False)
local_sums[:,1] = hidden_states.pow(2).sum(-1, keep_dim=False)

local_sums = torch.zeros((hidden_states[0], 3), hidden_state.dtype,
hidden_state.device)
local_sums[:, 0] = hidden_states.sum(-1, keep_dim=False)
local_sums[:, 1] = hidden_states.pow(2).sum(-1, keep_dim=False)

# Get global sum and squared sum
global_sums = tensor_model_parallel_all_reduce(sum_and_squared_sum)

# Calculate the variance
count = hidden_size.shape(-1)
global_mean = global_sums[:,0] / count
variance = (global_sq_sum[:,1] / count) - global_mean.pow(2)
global_mean = global_sums[:, 0] / count
variance = (global_sq_sum[:, 1] / count) - global_mean.pow(2)

else:
variance = hidden_states.pow(2).mean(-1, keepdim=True)
Expand Down Expand Up @@ -135,13 +137,13 @@ def __init__(self, config: MambaConfig, layer_idx):
self.use_bias = config.use_bias

groups_time_state_size = self.n_groups * self.ssm_state_size
self.conv_dim = (self.intermediate_size +
2 * groups_time_state_size)
self.conv_dim = (self.intermediate_size + 2 * groups_time_state_size)

self.conv1d = MergedColumnParallelLinear(
self.conv_kernel_size,
[self.intermediate_size, groups_time_state_size, groups_time_state_size],
bias=self.use_conv_bias)
self.conv1d = MergedColumnParallelLinear(self.conv_kernel_size, [
self.intermediate_size, groups_time_state_size,
groups_time_state_size
],
bias=self.use_conv_bias)

# unsqueeze to fit conv1d weights shape into the linear weights shape.
# Can't do this in `weight_loader` since it already exists in
Expand All @@ -150,12 +152,11 @@ def __init__(self, config: MambaConfig, layer_idx):
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)

# The sharded outputs are gate, hidden_states, B, C, and dt
self.in_proj = MergedColumnParallelLinear(
self.hidden_size,
[self.intermediate_size, self.intermediate_size,
groups_time_state_size, groups_time_state_size,
self.num_heads],
bias=self.use_bias)
self.in_proj = MergedColumnParallelLinear(self.hidden_size, [
self.intermediate_size, self.intermediate_size,
groups_time_state_size, groups_time_state_size, self.num_heads
],
bias=self.use_bias)

# time step projection (discretization)
# instantiate once and copy inv_dt in init_weights of PretrainedModel
Expand Down

0 comments on commit bc9b5cf

Please sign in to comment.