From 52eac324895173d31f822acdaf1bbdb49a76c811 Mon Sep 17 00:00:00 2001 From: Maykeye Date: Sat, 3 Feb 2024 18:38:19 +0600 Subject: [PATCH] Added inputs_embeds --- mamba_ssm/models/mixer_seq_simple.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/mamba_ssm/models/mixer_seq_simple.py b/mamba_ssm/models/mixer_seq_simple.py index 5b3ddfcf..94fac8be 100644 --- a/mamba_ssm/models/mixer_seq_simple.py +++ b/mamba_ssm/models/mixer_seq_simple.py @@ -148,8 +148,10 @@ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs) for i, layer in enumerate(self.layers) } - def forward(self, input_ids, inference_params=None): - hidden_states = self.embedding(input_ids) + def forward(self, input_ids, inference_params=None, inputs_embeds=None): + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + hidden_states = self.embedding(input_ids) if inputs_embeds is None else inputs_embeds residual = None for layer in self.layers: hidden_states, residual = layer( @@ -225,12 +227,12 @@ def tie_weights(self): def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): return self.backbone.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) - def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0): + def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0, inputs_embeds=None): """ "position_ids" is just to be compatible with Transformer generation. We don't use it. num_last_tokens: if > 0, only return the logits for the last n tokens """ - hidden_states = self.backbone(input_ids, inference_params=inference_params) + hidden_states = self.backbone(input_ids, inference_params=inference_params, inputs_embeds=inputs_embeds) if num_last_tokens > 0: hidden_states = hidden_states[:, -num_last_tokens:] lm_logits = self.lm_head(hidden_states)