Skip to content

[Feature] Batch Decoding #477

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 127 additions & 28 deletions mamba_ssm/modules/mamba_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def forward(self, hidden_states, inference_params=None):
conv_state, ssm_state = None, None
if inference_params is not None:
conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
if inference_params.seqlen_offset > 0:
if inference_params.seqlen_offset > 0 and seqlen == 1:
# The states are updated inplace
out, _, _ = self.step(hidden_states, conv_state, ssm_state)
return out
Expand Down Expand Up @@ -161,20 +161,27 @@ def forward(self, hidden_states, inference_params=None):
else:
x, z = xz.chunk(2, dim=1)
# Compute short convolution
if conv_state is not None:
# If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
# Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0))) # Update state (B D W)
if causal_conv1d_fn is None:
x = self.act(self.conv1d(x)[..., :seqlen])
else:
assert self.activation in ["silu", "swish"]
x = causal_conv1d_fn(
x=x,
weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
bias=self.conv1d.bias,
activation=self.activation,
)
# if conv_state is not None:
# # If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
# # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
# conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0))) # Update state (B D W)

# if conv_state is None:
# conv_state = torch.zeros(b, d * self.expand, self.d_conv, device = x.device)
x = torch.cat((conv_state, x), dim=2)
conv_state.copy_(x[:, :, -self.d_conv:])
x = self.act(self.conv1d(x)[:, :, self.d_conv:self.d_conv + seqlen])

# if causal_conv1d_fn is None:
# x = self.act(self.conv1d(x)[..., :seqlen])
# else:
# assert self.activation in ["silu", "swish"]
# x = causal_conv1d_fn(
# x=x,
# weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
# bias=self.conv1d.bias,
# activation=self.activation,
# )

# We're careful here about the layout, to avoid extra transposes.
# We want dt to have d as the slowest moving dimension
Expand All @@ -186,24 +193,116 @@ def forward(self, hidden_states, inference_params=None):
B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
assert self.activation in ["silu", "swish"]
y = selective_scan_fn(
x,
dt,
A,
B,
C,
self.D.float(),
z=z,
delta_bias=self.dt_proj.bias.float(),
delta_softplus=True,
return_last_state=ssm_state is not None,
)
# y = selective_scan_fn(
# x,
# dt,
# A,
# B,
# C,
# self.D.float(),
# z=z,
# delta_bias=self.dt_proj.bias.float(),
# delta_softplus=True,
# return_last_state=ssm_state is not None,
# )
y = self.selective_scan(x.float(), dt.float(), A.float(), B.float(), C.float(), self.D.float(), z=z.float(), delta_bias=self.dt_proj.bias.float(),
delta_softplus=True, ssm_state=ssm_state.float())
if ssm_state is not None:
y, last_state = y
ssm_state.copy_(last_state)
y = rearrange(y, "b d l -> b l d")
out = self.out_proj(y)
out = self.out_proj(y.to(torch.bfloat16))
return out

def selective_scan(self, x, delta, A, B, C, D, z=None, delta_bias=None, delta_softplus=False, ssm_state=None):
x = rearrange(x, 'b d l -> b l d')
B = rearrange(B, 'b d l -> b l d')
C = rearrange(C, 'b d l -> b l d')

(b, l, hidden_dim) = x.shape
n = A.shape[1]

if delta_bias is not None:
delta = delta + delta_bias[..., None]
if delta_softplus:
delta = F.softplus(delta)

# Discretize continuous parameters (A, B)
# - A is discretized using zero-order hold (ZOH) discretization (see Section 2 Equation 4 in the Mamba paper [1])
# - B is discretized using a simplified Euler discretization instead of ZOH. From a discussion with authors:
# "A is the more important term and the performance doesn't change much with the simplification on B"
# Perform selective scan (see scan_SSM() in The Annotated S4 [2])
# Note that the below is sequential, while the official implementation does a much faster parallel scan that
# is additionally hardware-aware (like FlashAttention).
deltaA = torch.exp(torch.einsum('bdl,dn->bldn', delta, A))
deltaBX = torch.einsum('bdl,bln,bld->bldn', delta, B, x)
y = torch.empty((b, l, hidden_dim), device=deltaA.device)
# -- Build h --
if ssm_state is not None:
h = ssm_state
else:
h = torch.zeros((b, hidden_dim, n), device=deltaA.device, dtype=deltaA.dtype)
# -- TODO, how to fast it in parallel? --
for i in range(l):
h = deltaA[:, i] * h + deltaBX[:, i]

y[:,i] = torch.einsum('bdn,bn->bd', h, C[:, i]) # BUG? the h is not h(t), it is already set to h(t+1) in prev line

# -- Save h --
# if state is not None:
# ssm_state.copy_(h)
out = y + x * D
out = rearrange(out, 'b l d -> b d l')
if z is not None:
out = out * F.silu(z)
return out, h

def step(self, hidden_states, conv_state, ssm_state):
dtype = hidden_states.dtype
assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now"
xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
x, z = xz.chunk(2, dim=-1) # (B D)

# Conv step
if causal_conv1d_update is None:
conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
conv_state[:, :, -1] = x
x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D)
if self.conv1d.bias is not None:
x = x + self.conv1d.bias
x = self.act(x).to(dtype=dtype)
else:
x = causal_conv1d_update(
x,
conv_state,
rearrange(self.conv1d.weight, "d 1 w -> d w"),
self.conv1d.bias,
self.activation,
)

x_db = self.x_proj(x) # (B dt_rank+2*d_state)
dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
# Don't add dt_bias here
dt = F.linear(dt, self.dt_proj.weight) # (B d_inner)
A = -torch.exp(self.A_log.float()) # (d_inner, d_state)

# SSM step
if selective_state_update is None:
# Discretize A and B
dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype))
dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A))
dB = torch.einsum("bd,bn->bdn", dt, B)
ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB)
y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C)
y = y + self.D.to(dtype) * x
y = y * self.act(z) # (B D)
else:
y = selective_state_update(
ssm_state, x, dt, A, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True
)

out = self.out_proj(y)
return out.unsqueeze(1), conv_state, ssm_state

def step(self, hidden_states, conv_state, ssm_state):
dtype = hidden_states.dtype
Expand Down
132 changes: 132 additions & 0 deletions test_mamba_ssm_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
from mamba_ssm.modules.mamba_simple import Mamba
from mamba_ssm.utils.generation import InferenceParams

import torch
import copy


def test_state_seq():
"""Check that Mamba([x1.x2.x3.x4]) == Mamba([x1,x2])|>step(x3)|>step(x4)"""
device = "cuda:0"
dim_model = 8

# Generate a model with random weights
m = Mamba(dim_model, layer_idx=7).to(device=device)
m.requires_grad_(False) # allows deepcopy of tensrors

# Generate the whole sequence
x_all = torch.rand(1, 5, dim_model, device=device)
y_all = m(x_all)

# Introducing empty inference parameters should not add new data
inference_all = InferenceParams(max_seqlen=16, max_batch_size=3)
y_with_inference = m(x_all, inference_params=inference_all)

assert len(inference_all.key_value_memory_dict)
assert torch.allclose(y_with_inference, y_all)

# Inference by parts
# X0,X1
inference_part = InferenceParams(
max_seqlen=inference_all.max_seqlen, max_batch_size=inference_all.max_batch_size)
y01 = m(x_all[:, 0:2], inference_params=inference_part)
assert torch.allclose(y_with_inference[:, :2], y01)

# (past state up to X1), X2, X3
inference_part.seqlen_offset = 2
inference_part_b = copy.deepcopy(inference_part)
y2 = m(x_all[:, 2:4], inference_params=inference_part)

# (past state up to X3), X4
inference_part.seqlen_offset = 4
y3 = m(x_all[:, 4:5], inference_params=inference_part)

# (past state up to X1), X2 again
inference_part_b.seqlen_offset = 2
y2_b = m(x_all[:, 2:3], inference_params=inference_part_b)
# (past state up to X2), X3 again
inference_part_b.seqlen_offset = 3
y3_b = m(x_all[:, 3:4], inference_params=inference_part_b)

# Values should match result we got from inferencin over the all sequence
assert torch.allclose(y_all[:, 0:2], y01)
assert torch.allclose(y_all[:, 2:4], y2) #Decode chunk - Finally works.
assert torch.allclose(y_all[:, 4:5], y3)
assert torch.allclose(y_all[:, 2:3], y2_b)
assert torch.allclose(y_all[:, 3:4], y3_b)

# Sanity check
assert not torch.allclose(y_all[:, 3:4], y2)


def test_state_batch_drop_empty_infer():
"""Check that you can drop a batch when inference parms are empty"""
device = "cuda"
dim_model = 8

# Generate a model with random weights
m = Mamba(dim_model, layer_idx=7).to(device=device)
m.requires_grad_(False) # allows deepcopy of tensrors

x_all = torch.rand(3, 4, dim_model, device=device)
y_all = m(x_all)

# Introducing empty inference parameters should not add new data
inference_all = InferenceParams(max_seqlen=16, max_batch_size=3)
y_all = m(x_all, inference_params=inference_all)
kv = inference_all.key_value_memory_dict[7]

# Drop batch in the middle
x_02 = x_all[(0, 2), ...]
kv = tuple(batched[(0, 2), ...] for batched in kv)
inference_all.key_value_memory_dict[7] = kv

inference_02 = InferenceParams(max_seqlen=16, max_batch_size=3)
y_02 = m(x_02, inference_params=inference_02)
y_02_a = y_all[(0, 2), ...]
assert torch.allclose(y_02, y_02_a)


def test_state_batch_drop_step():
"""Check that you can drop a batch when inference parms are filled"""

device = "cuda"
dim_model = 8

# Generate a model with random weights
m = Mamba(dim_model, layer_idx=7).to(device=device)
m.requires_grad_(False) # allows deepcopy of tensrors

x_prefix = torch.rand(3, 4, dim_model, device=device)

# Rewind model forward so inference parms has data
inference_parms = InferenceParams(max_seqlen=16, max_batch_size=3)
_ = m(x_prefix, inference_params=inference_parms)

x_next = torch.rand(3, 1, dim_model, device=device)
inference_parms.seqlen_offset = x_prefix.shape[1]
inference_parms_bak = copy.deepcopy(inference_parms)

# Y with all 3 batches
y_next = m(x_next, inference_params=inference_parms)

# Remove middle batch from cache
kv = inference_parms_bak.key_value_memory_dict[7]
kv = tuple(batched[(0, 2), ...] for batched in kv)
inference_parms_bak.key_value_memory_dict[7] = kv

# Calculate batches without middle batch
x_02 = x_next[(0, 2), ...]
y_next_parmed = m(x_02, inference_params=inference_parms_bak)

# Check that batch was removed
y_next_a = y_next[(0, 2), ...]
assert torch.allclose(y_next_a, y_next_parmed)

# Sanity check
assert not torch.allclose(y_next[(0, 1), ...], y_next_parmed)


test_state_seq()
# test_state_batch_drop_empty_infer()
# test_state_batch_drop_step()
64 changes: 64 additions & 0 deletions time_measure.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from mamba_ssm.utils.generation import InferenceParams
from mamba_ssm import Mamba
import time
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel


device = "cuda:5"
x = 'hello world'
x2 = 'hello world2'
chunk = 'hello world ' * 300
question = 'What is the meaning of life?'
dtype=torch.bfloat16
repeats = 5
x3 = torch.rand(2, 64, 16, device=device)
torch.random.manual_seed(0)
model_path = 'state-spaces/mamba-2.8b'
model = MambaLMHeadModel.from_pretrained(model_path, device=device, dtype=dtype)
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b", device=device, dtype=dtype)
model.requires_grad_(False)
model.eval()
input = tokenizer(chunk, truncation=False, return_tensors="pt").to(device)
input2 = tokenizer(question, truncation=False, return_tensors="pt").to(device)
input_ids2 = torch.randint(high=32000, size=(20, 20), dtype=torch.long, device=device)

times = 0
max_length = input2['input_ids'].shape[1] + 100
with torch.inference_mode():

for repeat in range(repeats):
#x = torch.randint(high=32000, size=(1, 10000), device=device)
inference_all = InferenceParams(max_seqlen=5000, max_batch_size=20)
inference_all.seqlen_offset = input['input_ids'].shape[1]
# y_all = model.generate(input_ids=input2['input_ids'], max_length=max_length, cg=True,
# return_dict_in_generate=True, output_scores=True)
y_all = model(input_ids=input['input_ids'])
t1 = time.time()
for i in range(input_ids2.shape[1]): #for i in range(input2['input_ids'].shape[1]):
inference_all.seqlen_offset = input_ids2.shape[1] + i + 1 #input['input_ids'].shape[1] + i + 1
y_with_inference = model(input_ids=input_ids2[:, i:i+1], inference_params=inference_all) #model(input_ids=input2['input_ids'][:, i:i+1], inference_params=inference_all)
t2 = time.time()
inference_time = t2 - t1
times += inference_time
print(f'{model_path}, Forward: inference time: {inference_time}')
times /= repeats
print(f"1: Average time is : {times}")

times = 0
for repeat in range(repeats):
#x = torch.randint(high=32000, size=(1, 10000), device=device)
inference_all = InferenceParams(max_seqlen=5000, max_batch_size=20)
inference_all.seqlen_offset = input['input_ids'].shape[1]
y_all = model(input_ids=input_ids2)
t1 = time.time()
y_with_inference = model(input_ids=input_ids2, inference_params=inference_all) #model(input_ids=input2['input_ids'], inference_params=inference_all)
t2 = time.time()
inference_time = t2 - t1
times += inference_time
print(f'{model_path}, Forward: inference time: {inference_time}')
times /= repeats
print(f"2: Average time is : {times}")