Skip to content
This repository has been archived by the owner on Oct 19, 2024. It is now read-only.

Commit

Permalink
including clovaai#268
Browse files Browse the repository at this point in the history
  • Loading branch information
matteocacciola committed Jul 11, 2024
1 parent 541fa5a commit 6827417
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -138,4 +138,5 @@ dmypy.json
# Pyre type checker
.pyre/

.idea
.idea
.DS_Store
11 changes: 9 additions & 2 deletions donut/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,11 +337,18 @@ def __init__(
new_bart_state_dict[x] = bart_state_dict[x]
self.model.load_state_dict(new_bart_state_dict)

def add_special_tokens(self, list_of_tokens: List[str]):
def add_special_tokens(self, list_of_tokens: List[str], replace_additional_special_tokens: bool | None = False):
"""
Add special tokens to tokenizer and resize the token embeddings
"""
newly_added_num = self.tokenizer.add_special_tokens({"additional_special_tokens": sorted(set(list_of_tokens))})
newly_added_num = 0
set_of_tokens = set(list_of_tokens)
set_special_tokens = set(self.tokenizer.all_special_tokens)
if len(set_of_tokens - set_special_tokens) > 0:
newly_added_num = self.tokenizer.add_special_tokens(
{"additional_special_tokens": sorted(set_of_tokens)},
replace_additional_special_tokens=replace_additional_special_tokens
)
if newly_added_num > 0:
self.model.resize_token_embeddings(len(self.tokenizer))

Expand Down

0 comments on commit 6827417

Please sign in to comment.