Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Long T5 #179

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion longformer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from longformer.longformer import Longformer, LongformerForMaskedLM, LongformerConfig
from longformer.longformer_encoder_decoder import LongformerEncoderDecoderConfig
from longformer.longformer_encoder_decoder import LongformerEncoderDecoderForConditionalGeneration
from longformer.longformer_encoder_decoder import LongformerEncoderDecoderForConditionalGeneration
from longformer.longformer_encoder_decoder import LongformerT5ForConditionalGeneration
from longformer.longformer_encoder_decoder import LongformerT5Config
283 changes: 234 additions & 49 deletions longformer/longformer.py

Large diffs are not rendered by default.

129 changes: 124 additions & 5 deletions longformer/longformer_encoder_decoder.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,31 @@
import copy
from typing import List, Optional, Tuple, Dict
from torch import nn, Tensor
from longformer.longformer import LongformerSelfAttention
from transformers.modeling_bart import BartConfig, BartForConditionalGeneration
from transformers.modeling_t5 import T5Config, T5ForConditionalGeneration, T5Stack


class LongformerEncoderDecoderForConditionalGeneration(BartForConditionalGeneration):
def __init__(self, config):
super().__init__(config)
if config.attention_mode == 'n2':
if config.attention_mode == "n2":
pass # do nothing, use BertSelfAttention instead
else:
for i, layer in enumerate(self.model.encoder.layers):
layer.self_attn = LongformerSelfAttentionForBart(config, layer_id=i)


class LongformerEncoderDecoderConfig(BartConfig):
def __init__(self, attention_window: List[int] = None, attention_dilation: List[int] = None,
autoregressive: bool = False, attention_mode: str = 'sliding_chunks',
gradient_checkpointing: bool = False, **kwargs):
def __init__(
self,
attention_window: List[int] = None,
attention_dilation: List[int] = None,
autoregressive: bool = False,
attention_mode: str = "sliding_chunks",
gradient_checkpointing: bool = False,
**kwargs
):
"""
Args:
attention_window: list of attention window sizes of length = number of layers.
Expand All @@ -36,7 +44,7 @@ def __init__(self, attention_window: List[int] = None, attention_dilation: List[
self.autoregressive = autoregressive
self.attention_mode = attention_mode
self.gradient_checkpointing = gradient_checkpointing
assert self.attention_mode in ['tvm', 'sliding_chunks', 'n2']
assert self.attention_mode in ["tvm", "sliding_chunks", "n2"]


class LongformerSelfAttentionForBart(nn.Module):
Expand Down Expand Up @@ -74,3 +82,114 @@ def forward(
attn_output = self.output(outputs[0].transpose(0, 1))

return (attn_output,) + outputs[1:] if len(outputs) == 2 else (attn_output, None)


class LongformerT5ForConditionalGeneration(T5ForConditionalGeneration):
def __init__(self, config):
super().__init__(config)
if config.attention_mode == "n2":
pass # do nothing, use BertSelfAttention instead
else:
for i, layer in enumerate(self.encoder.block):
layer.layer[0].SelfAttention = LongformerSelfAttentionForT5(config, layer_id=i)

class LongformerT5DecoderStack(T5Stack):
def forward(
self,
input_ids=None,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
inputs_embeds=None,
head_mask=None,
past_key_value_states=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
if encoder_attention_mask is not None:
encoder_attention_mask = encoder_attention_mask.clamp(max=1)
return T5Stack.forward(
self,
input_ids=input_ids,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
inputs_embeds=inputs_embeds,
head_mask=head_mask,
past_key_value_states=past_key_value_states,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)

self.decoder.__class__ = LongformerT5DecoderStack


class LongformerT5Config(T5Config):
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can see, we are getting many highly-similar config classes as we extending to other transformer models. If you like, we can simplify this by using Mixin. It will be like having another Mixin class containing all the longformer specific settings, and the LongformerT5Config class will inherit both the Mixin class and T5Config.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have strong feelings about this. You decide (as long as we don't change the interface of the released code)

def __init__(
self,
attention_window: List[int] = None,
attention_dilation: List[int] = None,
autoregressive: bool = False,
attention_mode: str = "sliding_chunks",
gradient_checkpointing: bool = False,
long_relative_attention_num_buckets: int = 40,
**kwargs
):
"""
Args:
attention_window: list of attention window sizes of length = number of layers.
window size = number of attention locations on each side.
For an affective window size of 512, use `attention_window=[256]*num_layers`
which is 256 on each side.
attention_dilation: list of attention dilation of length = number of layers.
attention dilation of `1` means no dilation.
autoregressive: do autoregressive attention or have attention of both sides
attention_mode: 'n2' for regular n^2 self-attention, 'tvm' for TVM implemenation of Longformer
selfattention, 'sliding_chunks' for another implementation of Longformer selfattention
"""
super().__init__(**kwargs)
self.attention_window = attention_window
self.attention_dilation = attention_dilation
self.autoregressive = autoregressive
self.attention_mode = attention_mode
self.gradient_checkpointing = gradient_checkpointing
self.long_relative_attention_num_buckets = long_relative_attention_num_buckets
assert self.attention_mode in ["tvm", "sliding_chunks", "n2"]


class LongformerSelfAttentionForT5(nn.Module):
"""
Replacement for T5Attention, but only for the encoder stack
"""

def __init__(self, config, layer_id):
super().__init__()
self.embed_dim = config.d_model
self.has_relative_attention_bias = layer_id == 0
self.longformer_self_attn = LongformerSelfAttention(
config, layer_id=layer_id, bias=False, attention_dim_scale=False
)
self.output = nn.Linear(self.embed_dim, self.embed_dim, bias=False)

def forward(
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An alternative I considered was to let this class inherit LongformerSelfAttention. But eventually, I decided not to do so. The interfaces of the two classes are quite different. What we have here, i.e., making LongformerSelfAttention a member of the LongformerSelfAttentionForT5, is probably less confusing than the althernative.

self,
input,
mask=None,
kv=None,
position_bias=None,
past_key_value_state=None,
head_mask=None,
query_length=None,
use_cache=False,
output_attentions=False,
):
outputs = self.longformer_self_attn(
input, attention_mask=mask, position_bias=position_bias, output_attentions=output_attentions,
)
outputs = (self.output(outputs[0]), None) + outputs[1:]

return outputs
14 changes: 14 additions & 0 deletions scripts/cheatsheet.txt
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,17 @@ source /anaconda3/bin/activate torch-xla-nightly

# Resume training
python scripts/summarization.py --num_workers 12 --save_prefix eval_long16k_nooverlap_large --model_path bart-large-long-16384/ --max_input_len 16368 --batch_size 2 --grad_accum 4 --grad_ckpt --attention_mode sliding_chunks_no_overlap --attention_window 340 --val_every 0.333333333 --debug --resume summarization/run_long16k_nooverlap_large/_ckpt_epoch_3_v1.ckpt --val_percent_check 1.0 --disable_checkpointing

# Convert model
python scripts/convert_t5_to_longformerencoderdecoder.py --save_model_to /net/nfs2.s2-research/haokunl/exp_files/model_artifacts/t5/longt5-base-16384 --base_model t5-base
python scripts/convert_t5_to_longformerencoderdecoder.py --save_model_to /net/nfs2.s2-research/haokunl/exp_files/model_artifacts/t5/longt5-small-16384 --base_model t5-small


srun --gpus=4 --nodes=1 python scripts/summarization.py --num_workers 12 \
--save_dir /net/nfs2.s2-research/haokunl/exp_files/summarization \
--save_prefix longt5-base-16k \
--model_path /net/nfs2.s2-research/haokunl/exp_files/model_artifacts/t5/longt5-base-16384 \
--adafactor --tokenizer t5-base \
--max_input_len 16384 --batch_size 2 --grad_accum 4 --grad_ckpt \
--attention_mode sliding_chunks --attention_window 512 \
--val_every 0.333333333 --debug --val_percent_check 1.0 --disable_checkpointing
148 changes: 148 additions & 0 deletions scripts/convert_t5_to_longformerencoderdecoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import argparse
import logging
import os
import copy
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration
from longformer.longformer_encoder_decoder import (
LongformerSelfAttentionForT5,
LongformerT5Config,
LongformerT5ForConditionalGeneration,
)

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)


def create_long_model(save_model_to, base_model, attention_window, max_pos, relative_attention_num_buckets):
# load base model & tokenizer
tokenizer = T5Tokenizer.from_pretrained(base_model, model_max_length=max_pos)
model = T5ForConditionalGeneration.from_pretrained(base_model)
print("Base model architecture")
print(model)

# setup config
config = LongformerT5Config.from_pretrained(base_model)
config.architectures = [
"LongformerT5ForConditionalGeneration",
]
# in T5 attention_probs_dropout_prob is dropout_rate
config.attention_probs_dropout_prob = config.dropout_rate
config.attention_window = [attention_window] * config.num_hidden_layers
config.attention_dilation = [1] * config.num_hidden_layers
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when increasing the model length we probably want to increase number of relative position buckets as well config.relative_attention_num_buckets

config.long_relative_attention_num_buckets = relative_attention_num_buckets

# modify config in model
# HF T5 includes multiple pointers to the config object
model.config = copy.deepcopy(config)
model.encoder.config = copy.deepcopy(config)
model.encoder.config.use_cache = False
model.encoder.config.is_encoder_decoder = False
model.decoder.config = copy.deepcopy(config)
model.decoder.config.is_decoder = True
model.decoder.config.is_encoder_decoder = False

# modify tokenizer
tokenizer.model_max_length = max_pos
tokenizer.init_kwargs["model_max_length"] = max_pos

# modify model architecture
for i, layer in enumerate(model.encoder.block):
self_attn = layer.layer[0].SelfAttention

longformer_self_attn_for_t5 = LongformerSelfAttentionForT5(config, layer_id=i)

longformer_self_attn_for_t5.longformer_self_attn.query = self_attn.q
longformer_self_attn_for_t5.longformer_self_attn.key = self_attn.k
longformer_self_attn_for_t5.longformer_self_attn.value = self_attn.v

longformer_self_attn_for_t5.longformer_self_attn.query_global = copy.deepcopy(self_attn.q)
longformer_self_attn_for_t5.longformer_self_attn.key_global = copy.deepcopy(self_attn.k)
longformer_self_attn_for_t5.longformer_self_attn.value_global = copy.deepcopy(self_attn.v)

longformer_self_attn_for_t5.output = self_attn.o

if i == 0:
half_num_buckets = config.long_relative_attention_num_buckets // 2
half_t5_buckets = 16
with torch.no_grad():
longformer_self_attn_for_t5.longformer_self_attn.relative_attention_bias.weight[
:half_num_buckets
] = self_attn.relative_attention_bias.weight[half_t5_buckets - 1]
longformer_self_attn_for_t5.longformer_self_attn.relative_attention_bias.weight[
half_num_buckets:
] = self_attn.relative_attention_bias.weight[-1]
longformer_self_attn_for_t5.longformer_self_attn.relative_attention_bias.weight[
:half_t5_buckets
] = self_attn.relative_attention_bias.weight[:half_t5_buckets]
longformer_self_attn_for_t5.longformer_self_attn.relative_attention_bias.weight[
half_num_buckets + 1 : half_num_buckets + half_t5_buckets
] = self_attn.relative_attention_bias.weight[half_t5_buckets + 1 :]

layer.layer[0].SelfAttention = longformer_self_attn_for_t5

# save modified model
logger.info(f"saving model to {save_model_to}")
model.save_pretrained(save_model_to)
config.save_pretrained(save_model_to)
tokenizer.save_pretrained(save_model_to)
return


def main():
parser = argparse.ArgumentParser(
description="Convert T5 to LongT5. Replaces T5 encoder's T5Attention with LongformerSelfAttention"
)
parser.add_argument(
"--base_model", type=str, default="t5-small", help="The name or path of the base model you want to convert",
)
parser.add_argument("--save_model_to", type=str, required=True, help="The path to save the converted model")
parser.add_argument(
"--attention_window",
type=int,
default=512,
help="attention window size for longformer self attention (one sided)",
)
parser.add_argument("--max_pos", type=int, default=4096 * 4, help="maximum encoder positions")
parser.add_argument("--num_pos_buckets", type=int, default=40, help="number of relative position buckets")

args = parser.parse_args()

if not os.path.exists(args.save_model_to):
os.makedirs(args.save_model_to)

create_long_model(
save_model_to=args.save_model_to,
base_model=args.base_model,
attention_window=args.attention_window,
max_pos=args.max_pos,
relative_attention_num_buckets=args.num_pos_buckets,
)

tokenizer = T5Tokenizer.from_pretrained(args.save_model_to)
# tokenizer = T5Tokenizer.from_pretrained(args.base_model)
model = LongformerT5ForConditionalGeneration.from_pretrained(args.save_model_to)
# model = T5ForConditionalGeneration.from_pretrained(args.base_model)

model.eval()
model.config.gradient_checkpointing = True
model.encoder.config.gradient_checkpointing = True
model.decoder.config.gradient_checkpointing = True
print("Converted model architecture")
print(model)

TXT = "A rose is a rose is a"
data = tokenizer([TXT], return_tensors="pt", padding="max_length", max_length=args.max_pos)
input_ids = data["input_ids"]
attention_mask = data["attention_mask"]
attention_mask[0, 0:4:2] = 2
decoder_input_ids = model._shift_right(input_ids[:, :5])

logits = model(input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, use_cache=False,)[0]
probs = logits[0, -1].softmax(dim=0)
_, predictions = probs.topk(5)
print(tokenizer.convert_ids_to_tokens(predictions))


if __name__ == "__main__":
main()
Loading