-
Notifications
You must be signed in to change notification settings - Fork 277
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Long T5 #179
base: master
Are you sure you want to change the base?
Long T5 #179
Changes from all commits
572da07
2b44bf8
2c3860a
7f870e8
5e63b05
4a4997e
e4f8e49
4c8864e
7b635a2
ff939af
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,5 @@ | ||
from longformer.longformer import Longformer, LongformerForMaskedLM, LongformerConfig | ||
from longformer.longformer_encoder_decoder import LongformerEncoderDecoderConfig | ||
from longformer.longformer_encoder_decoder import LongformerEncoderDecoderForConditionalGeneration | ||
from longformer.longformer_encoder_decoder import LongformerEncoderDecoderForConditionalGeneration | ||
from longformer.longformer_encoder_decoder import LongformerT5ForConditionalGeneration | ||
from longformer.longformer_encoder_decoder import LongformerT5Config |
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,23 +1,31 @@ | ||
import copy | ||
from typing import List, Optional, Tuple, Dict | ||
from torch import nn, Tensor | ||
from longformer.longformer import LongformerSelfAttention | ||
from transformers.modeling_bart import BartConfig, BartForConditionalGeneration | ||
from transformers.modeling_t5 import T5Config, T5ForConditionalGeneration, T5Stack | ||
|
||
|
||
class LongformerEncoderDecoderForConditionalGeneration(BartForConditionalGeneration): | ||
def __init__(self, config): | ||
super().__init__(config) | ||
if config.attention_mode == 'n2': | ||
if config.attention_mode == "n2": | ||
pass # do nothing, use BertSelfAttention instead | ||
else: | ||
for i, layer in enumerate(self.model.encoder.layers): | ||
layer.self_attn = LongformerSelfAttentionForBart(config, layer_id=i) | ||
|
||
|
||
class LongformerEncoderDecoderConfig(BartConfig): | ||
def __init__(self, attention_window: List[int] = None, attention_dilation: List[int] = None, | ||
autoregressive: bool = False, attention_mode: str = 'sliding_chunks', | ||
gradient_checkpointing: bool = False, **kwargs): | ||
def __init__( | ||
self, | ||
attention_window: List[int] = None, | ||
attention_dilation: List[int] = None, | ||
autoregressive: bool = False, | ||
attention_mode: str = "sliding_chunks", | ||
gradient_checkpointing: bool = False, | ||
**kwargs | ||
): | ||
""" | ||
Args: | ||
attention_window: list of attention window sizes of length = number of layers. | ||
|
@@ -36,7 +44,7 @@ def __init__(self, attention_window: List[int] = None, attention_dilation: List[ | |
self.autoregressive = autoregressive | ||
self.attention_mode = attention_mode | ||
self.gradient_checkpointing = gradient_checkpointing | ||
assert self.attention_mode in ['tvm', 'sliding_chunks', 'n2'] | ||
assert self.attention_mode in ["tvm", "sliding_chunks", "n2"] | ||
|
||
|
||
class LongformerSelfAttentionForBart(nn.Module): | ||
|
@@ -74,3 +82,114 @@ def forward( | |
attn_output = self.output(outputs[0].transpose(0, 1)) | ||
|
||
return (attn_output,) + outputs[1:] if len(outputs) == 2 else (attn_output, None) | ||
|
||
|
||
class LongformerT5ForConditionalGeneration(T5ForConditionalGeneration): | ||
def __init__(self, config): | ||
super().__init__(config) | ||
if config.attention_mode == "n2": | ||
pass # do nothing, use BertSelfAttention instead | ||
else: | ||
for i, layer in enumerate(self.encoder.block): | ||
layer.layer[0].SelfAttention = LongformerSelfAttentionForT5(config, layer_id=i) | ||
|
||
class LongformerT5DecoderStack(T5Stack): | ||
def forward( | ||
self, | ||
input_ids=None, | ||
attention_mask=None, | ||
encoder_hidden_states=None, | ||
encoder_attention_mask=None, | ||
inputs_embeds=None, | ||
head_mask=None, | ||
past_key_value_states=None, | ||
use_cache=None, | ||
output_attentions=None, | ||
output_hidden_states=None, | ||
return_dict=None, | ||
): | ||
if encoder_attention_mask is not None: | ||
encoder_attention_mask = encoder_attention_mask.clamp(max=1) | ||
return T5Stack.forward( | ||
self, | ||
input_ids=input_ids, | ||
attention_mask=attention_mask, | ||
encoder_hidden_states=encoder_hidden_states, | ||
encoder_attention_mask=encoder_attention_mask, | ||
inputs_embeds=inputs_embeds, | ||
head_mask=head_mask, | ||
past_key_value_states=past_key_value_states, | ||
use_cache=use_cache, | ||
output_attentions=output_attentions, | ||
output_hidden_states=output_hidden_states, | ||
return_dict=return_dict, | ||
) | ||
|
||
self.decoder.__class__ = LongformerT5DecoderStack | ||
|
||
|
||
class LongformerT5Config(T5Config): | ||
def __init__( | ||
self, | ||
attention_window: List[int] = None, | ||
attention_dilation: List[int] = None, | ||
autoregressive: bool = False, | ||
attention_mode: str = "sliding_chunks", | ||
gradient_checkpointing: bool = False, | ||
long_relative_attention_num_buckets: int = 40, | ||
**kwargs | ||
): | ||
""" | ||
Args: | ||
attention_window: list of attention window sizes of length = number of layers. | ||
window size = number of attention locations on each side. | ||
For an affective window size of 512, use `attention_window=[256]*num_layers` | ||
which is 256 on each side. | ||
attention_dilation: list of attention dilation of length = number of layers. | ||
attention dilation of `1` means no dilation. | ||
autoregressive: do autoregressive attention or have attention of both sides | ||
attention_mode: 'n2' for regular n^2 self-attention, 'tvm' for TVM implemenation of Longformer | ||
selfattention, 'sliding_chunks' for another implementation of Longformer selfattention | ||
""" | ||
super().__init__(**kwargs) | ||
self.attention_window = attention_window | ||
self.attention_dilation = attention_dilation | ||
self.autoregressive = autoregressive | ||
self.attention_mode = attention_mode | ||
self.gradient_checkpointing = gradient_checkpointing | ||
self.long_relative_attention_num_buckets = long_relative_attention_num_buckets | ||
assert self.attention_mode in ["tvm", "sliding_chunks", "n2"] | ||
|
||
|
||
class LongformerSelfAttentionForT5(nn.Module): | ||
""" | ||
Replacement for T5Attention, but only for the encoder stack | ||
""" | ||
|
||
def __init__(self, config, layer_id): | ||
super().__init__() | ||
self.embed_dim = config.d_model | ||
self.has_relative_attention_bias = layer_id == 0 | ||
self.longformer_self_attn = LongformerSelfAttention( | ||
config, layer_id=layer_id, bias=False, attention_dim_scale=False | ||
) | ||
self.output = nn.Linear(self.embed_dim, self.embed_dim, bias=False) | ||
|
||
def forward( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. An alternative I considered was to let this class inherit LongformerSelfAttention. But eventually, I decided not to do so. The interfaces of the two classes are quite different. What we have here, i.e., making LongformerSelfAttention a member of the LongformerSelfAttentionForT5, is probably less confusing than the althernative. |
||
self, | ||
input, | ||
mask=None, | ||
kv=None, | ||
position_bias=None, | ||
past_key_value_state=None, | ||
head_mask=None, | ||
query_length=None, | ||
use_cache=False, | ||
output_attentions=False, | ||
): | ||
outputs = self.longformer_self_attn( | ||
input, attention_mask=mask, position_bias=position_bias, output_attentions=output_attentions, | ||
) | ||
outputs = (self.output(outputs[0]), None) + outputs[1:] | ||
|
||
return outputs |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
import argparse | ||
import logging | ||
import os | ||
import copy | ||
import torch | ||
from transformers import T5Tokenizer, T5ForConditionalGeneration | ||
from longformer.longformer_encoder_decoder import ( | ||
LongformerSelfAttentionForT5, | ||
LongformerT5Config, | ||
LongformerT5ForConditionalGeneration, | ||
) | ||
|
||
logger = logging.getLogger(__name__) | ||
logging.basicConfig(level=logging.INFO) | ||
|
||
|
||
def create_long_model(save_model_to, base_model, attention_window, max_pos, relative_attention_num_buckets): | ||
# load base model & tokenizer | ||
tokenizer = T5Tokenizer.from_pretrained(base_model, model_max_length=max_pos) | ||
model = T5ForConditionalGeneration.from_pretrained(base_model) | ||
print("Base model architecture") | ||
print(model) | ||
|
||
# setup config | ||
config = LongformerT5Config.from_pretrained(base_model) | ||
config.architectures = [ | ||
"LongformerT5ForConditionalGeneration", | ||
] | ||
# in T5 attention_probs_dropout_prob is dropout_rate | ||
config.attention_probs_dropout_prob = config.dropout_rate | ||
config.attention_window = [attention_window] * config.num_hidden_layers | ||
config.attention_dilation = [1] * config.num_hidden_layers | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. when increasing the model length we probably want to increase number of relative position buckets as well |
||
config.long_relative_attention_num_buckets = relative_attention_num_buckets | ||
|
||
# modify config in model | ||
# HF T5 includes multiple pointers to the config object | ||
model.config = copy.deepcopy(config) | ||
model.encoder.config = copy.deepcopy(config) | ||
model.encoder.config.use_cache = False | ||
model.encoder.config.is_encoder_decoder = False | ||
model.decoder.config = copy.deepcopy(config) | ||
model.decoder.config.is_decoder = True | ||
model.decoder.config.is_encoder_decoder = False | ||
|
||
# modify tokenizer | ||
tokenizer.model_max_length = max_pos | ||
tokenizer.init_kwargs["model_max_length"] = max_pos | ||
|
||
# modify model architecture | ||
for i, layer in enumerate(model.encoder.block): | ||
self_attn = layer.layer[0].SelfAttention | ||
|
||
longformer_self_attn_for_t5 = LongformerSelfAttentionForT5(config, layer_id=i) | ||
|
||
longformer_self_attn_for_t5.longformer_self_attn.query = self_attn.q | ||
longformer_self_attn_for_t5.longformer_self_attn.key = self_attn.k | ||
longformer_self_attn_for_t5.longformer_self_attn.value = self_attn.v | ||
|
||
longformer_self_attn_for_t5.longformer_self_attn.query_global = copy.deepcopy(self_attn.q) | ||
longformer_self_attn_for_t5.longformer_self_attn.key_global = copy.deepcopy(self_attn.k) | ||
longformer_self_attn_for_t5.longformer_self_attn.value_global = copy.deepcopy(self_attn.v) | ||
|
||
longformer_self_attn_for_t5.output = self_attn.o | ||
|
||
if i == 0: | ||
half_num_buckets = config.long_relative_attention_num_buckets // 2 | ||
half_t5_buckets = 16 | ||
with torch.no_grad(): | ||
longformer_self_attn_for_t5.longformer_self_attn.relative_attention_bias.weight[ | ||
:half_num_buckets | ||
] = self_attn.relative_attention_bias.weight[half_t5_buckets - 1] | ||
longformer_self_attn_for_t5.longformer_self_attn.relative_attention_bias.weight[ | ||
half_num_buckets: | ||
] = self_attn.relative_attention_bias.weight[-1] | ||
longformer_self_attn_for_t5.longformer_self_attn.relative_attention_bias.weight[ | ||
:half_t5_buckets | ||
] = self_attn.relative_attention_bias.weight[:half_t5_buckets] | ||
longformer_self_attn_for_t5.longformer_self_attn.relative_attention_bias.weight[ | ||
half_num_buckets + 1 : half_num_buckets + half_t5_buckets | ||
] = self_attn.relative_attention_bias.weight[half_t5_buckets + 1 :] | ||
|
||
layer.layer[0].SelfAttention = longformer_self_attn_for_t5 | ||
|
||
# save modified model | ||
logger.info(f"saving model to {save_model_to}") | ||
model.save_pretrained(save_model_to) | ||
config.save_pretrained(save_model_to) | ||
tokenizer.save_pretrained(save_model_to) | ||
return | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser( | ||
description="Convert T5 to LongT5. Replaces T5 encoder's T5Attention with LongformerSelfAttention" | ||
) | ||
parser.add_argument( | ||
"--base_model", type=str, default="t5-small", help="The name or path of the base model you want to convert", | ||
) | ||
parser.add_argument("--save_model_to", type=str, required=True, help="The path to save the converted model") | ||
parser.add_argument( | ||
"--attention_window", | ||
type=int, | ||
default=512, | ||
help="attention window size for longformer self attention (one sided)", | ||
) | ||
parser.add_argument("--max_pos", type=int, default=4096 * 4, help="maximum encoder positions") | ||
parser.add_argument("--num_pos_buckets", type=int, default=40, help="number of relative position buckets") | ||
|
||
args = parser.parse_args() | ||
|
||
if not os.path.exists(args.save_model_to): | ||
os.makedirs(args.save_model_to) | ||
|
||
create_long_model( | ||
save_model_to=args.save_model_to, | ||
base_model=args.base_model, | ||
attention_window=args.attention_window, | ||
max_pos=args.max_pos, | ||
relative_attention_num_buckets=args.num_pos_buckets, | ||
) | ||
|
||
tokenizer = T5Tokenizer.from_pretrained(args.save_model_to) | ||
# tokenizer = T5Tokenizer.from_pretrained(args.base_model) | ||
model = LongformerT5ForConditionalGeneration.from_pretrained(args.save_model_to) | ||
# model = T5ForConditionalGeneration.from_pretrained(args.base_model) | ||
|
||
model.eval() | ||
model.config.gradient_checkpointing = True | ||
model.encoder.config.gradient_checkpointing = True | ||
model.decoder.config.gradient_checkpointing = True | ||
print("Converted model architecture") | ||
print(model) | ||
|
||
TXT = "A rose is a rose is a" | ||
data = tokenizer([TXT], return_tensors="pt", padding="max_length", max_length=args.max_pos) | ||
input_ids = data["input_ids"] | ||
attention_mask = data["attention_mask"] | ||
attention_mask[0, 0:4:2] = 2 | ||
decoder_input_ids = model._shift_right(input_ids[:, :5]) | ||
|
||
logits = model(input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, use_cache=False,)[0] | ||
probs = logits[0, -1].softmax(dim=0) | ||
_, predictions = probs.topk(5) | ||
print(tokenizer.convert_ids_to_tokens(predictions)) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can see, we are getting many highly-similar config classes as we extending to other transformer models. If you like, we can simplify this by using Mixin. It will be like having another Mixin class containing all the longformer specific settings, and the LongformerT5Config class will inherit both the Mixin class and T5Config.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't have strong feelings about this. You decide (as long as we don't change the interface of the released code)