Skip to content

Commit

Permalink
Merge pull request #442 from allenai/shanea/add-input-embedding-arg
Browse files Browse the repository at this point in the history
[HF] Add input embedding argument to HF model
  • Loading branch information
2015aroras authored Feb 9, 2024
2 parents 3be4c1e + faccd94 commit 97296e6
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 8 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed

- Fixed default value of `--tokenizer` argument to `scripts/prepare_tulu_data.py` to be an absolute path, not relative path, the script can be run from other directories.
- Added the option to directly pass input embeddings to `OLMo` and `OLMoForCausalLM`.

## [v0.2.4](https://github.com/allenai/OLMo/releases/tag/v0.2.4) - 2024-02-02

Expand Down
2 changes: 2 additions & 0 deletions hf_olmo/modeling_olmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(self, config: OLMoConfig, model: Optional[Olmo] = None, init_params
def forward(
self,
input_ids: torch.LongTensor = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
labels: Optional[torch.LongTensor] = None,
Expand All @@ -64,6 +65,7 @@ def forward(
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model.forward(
input_ids=input_ids,
input_embeddings=inputs_embeds,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
Expand Down
17 changes: 9 additions & 8 deletions olmo/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1137,6 +1137,7 @@ def get_alibi_attention_bias(self, seq_len: int, device: torch.device) -> torch.
def forward(
self,
input_ids: torch.LongTensor,
input_embeddings: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
attention_bias: Optional[torch.Tensor] = None,
past_key_values: Optional[Sequence[Tuple[torch.Tensor, torch.Tensor]]] = None,
Expand All @@ -1145,6 +1146,8 @@ def forward(
) -> OlmoOutput:
"""
:param input_ids: A tensor of shape `(batch_size, seq_len)`.
:param input_embeddings: A tensor of shape `(batch_size, seq_len, d_model)` with input
embeddings. When provided, it is treated as the output of the input embedding layer.
:param attention_mask: A tensor of shape `(batch_size, seq_len)` that indicates
which input IDs are masked. A `1` value in the mask means that
the corresponding input ID should *not* be ignored. A `0` means
Expand Down Expand Up @@ -1174,22 +1177,20 @@ def forward(
if past_key_values:
assert len(past_key_values) == self.config.n_layers

batch_size, seq_len = input_ids.size()
batch_size, seq_len = input_ids.size() if input_embeddings is None else input_embeddings.size()[:2]
if past_key_values is None:
past_length = 0
else:
past_length = past_key_values[0][0].size(-2)

# Get embeddings of input.
# shape: (batch_size, seq_len, d_model)
x = self.transformer.wte(input_ids) # type: ignore
x = self.transformer.wte(input_ids) if input_embeddings is None else input_embeddings # type: ignore

if not (self.config.alibi or self.config.rope):
# Get positional embeddings.
# shape: (1, seq_len)
pos = torch.arange(
past_length, past_length + seq_len, dtype=torch.long, device=input_ids.device
).unsqueeze(0)
pos = torch.arange(past_length, past_length + seq_len, dtype=torch.long, device=x.device).unsqueeze(0)
# shape: (1, seq_len, d_model)
pos_emb = self.transformer.wpe(pos) # type: ignore
x = pos_emb + x
Expand Down Expand Up @@ -1229,7 +1230,7 @@ def forward(
if attention_mask is not None:
mask_len = attention_mask.shape[-1]
elif past_key_values is not None:
mask_len = past_key_values[0][0].shape[-2] + input_ids.shape[-1]
mask_len = past_key_values[0][0].shape[-2] + seq_len
attention_bias = attention_bias[:, :, :mask_len, :mask_len].to(dtype=torch.float)

# Add in the masking bias.
Expand Down Expand Up @@ -1470,7 +1471,7 @@ def generate(
tokens_generated = 0

def flatten_past_key_values(
past_key_values: List[Tuple[torch.Tensor, torch.Tensor]]
past_key_values: List[Tuple[torch.Tensor, torch.Tensor]],
) -> Dict[str, torch.Tensor]:
out = {}
for i, (key, value) in enumerate(past_key_values):
Expand All @@ -1479,7 +1480,7 @@ def flatten_past_key_values(
return out

def unflatten_past_key_values(
past_key_values: Dict[str, torch.Tensor]
past_key_values: Dict[str, torch.Tensor],
) -> List[Tuple[torch.Tensor, torch.Tensor]]:
out = []
for i in range(self.config.n_layers):
Expand Down

0 comments on commit 97296e6

Please sign in to comment.