Skip to content

Commit

Permalink
allow for general padding_idx
Browse files Browse the repository at this point in the history
  • Loading branch information
iamlemec committed May 6, 2024
1 parent 153201c commit 889e14c
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions convert-hf-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2603,12 +2603,12 @@ class XLMRoberrtaModel(BertModel):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.pad_token_id = self.hparams["pad_token_id"]

def set_gguf_parameters(self):
# set this and pop it so super doesn't write it too
context_length_train = self.hparams.pop("max_position_embeddings")
pad_token_id = self.hparams["pad_token_id"]
context_length = context_length_train - pad_token_id - 1 # since padding_idx=1
context_length = context_length_train - self.pad_token_id - 1 # since padding_idx=1
self.gguf_writer.add_context_length(context_length)

super().set_gguf_parameters()
Expand Down Expand Up @@ -2700,7 +2700,8 @@ def write_tensors(self):

# chop off position embeddings by two to handle padding_idx offset (1 + padding_token_id)
if name == "embeddings.position_embeddings.weight":
data_torch = data_torch[2:]
context_chop = self.pad_token_id + 1
data_torch = data_torch[context_chop:]

# convert any unsupported data types to float32
if data_torch.dtype not in (torch.float16, torch.float32):
Expand Down

0 comments on commit 889e14c

Please sign in to comment.