From 572da0780230dd9e51277a702cfab082dd2b05f6 Mon Sep 17 00:00:00 2001 From: Akshita Bhagia Date: Mon, 7 Dec 2020 23:58:07 -0800 Subject: [PATCH 01/10] adding t5 options to longformer --- longformer/longformer.py | 91 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 91 insertions(+) diff --git a/longformer/longformer.py b/longformer/longformer.py index 14da60f..bdf3691 100644 --- a/longformer/longformer.py +++ b/longformer/longformer.py @@ -67,6 +67,17 @@ def __init__(self, config, layer_id): self.key = nn.Linear(config.hidden_size, self.embed_dim) self.value = nn.Linear(config.hidden_size, self.embed_dim) + # this is for the T5 setting + if "has_relative_attention_bias" in config.to_dict(): + self.is_decoder = config.is_decoder + self.relative_attention_num_buckets = config.relative_attention_num_buckets + self.has_relative_attention_bias = has_relative_attention_bias + if self.has_relative_attention_bias: + self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.num_heads) + self.is_t5 = True + else: + self.is_t5 = False + self.query_global = nn.Linear(config.hidden_size, self.embed_dim) self.key_global = nn.Linear(config.hidden_size, self.embed_dim) self.value_global = nn.Linear(config.hidden_size, self.embed_dim) @@ -85,10 +96,72 @@ def __init__(self, config, layer_id): assert not self.autoregressive # not supported assert self.attention_dilation == 1 # dilation is not supported + @staticmethod + def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + Returns: + a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) + """ + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0).to(torch.long) * num_buckets + relative_position = torch.abs(relative_position) + else: + relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) + # now relative_position is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + relative_postion_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_postion_if_large = torch.min( + relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1) + ) + + relative_buckets += torch.where(is_small, relative_position, relative_postion_if_large) + return relative_buckets + + def compute_bias(self, qlen, klen): + """ Compute binned relative position bias """ + relative_position = torch.tensor([[i-self.attention_window for i in range(2*self.attention_window+1)]]) + rp_bucket = self._relative_position_bucket( + relative_position, # shape (qlen, klen) + bidirectional=not self.is_decoder, + num_buckets=self.relative_attention_num_buckets, + ) + rp_bucket = rp_bucket.to(self.relative_attention_bias.weight.device) + values = self.relative_attention_bias(rp_bucket) # shape (qlen, klen, num_heads) +# values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, qlen, klen) + # Changing the shape to below because that's what LongformerSelfAttention's attn_weights need. + values = values.permute([0, 2, 1]).unsqueeze(0) # shape (1, qlen, num_heads, klen) + return values + def forward( self, hidden_states, attention_mask=None, + position_bias=None, + past_key_value_state=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, @@ -185,6 +258,24 @@ def forward( # concat to attn_weights # (bsz, seq_len, num_heads, extra attention count + 2*window+1) attn_weights = torch.cat((selected_attn_weights, attn_weights), dim=-1) + + if self.is_t5: + if position_bias is None: + if not self.has_relative_attention_bias: + raise ValueError("No position_bias provided and no weights to compute position_bias") + + position_bias = self.compute_bias(seq_len, seq_len) + # if key and values are already calculated + # we want only the last query position bias + if past_key_value_state is not None: + position_bias = position_bias[:, :, -1:, :] + # TODO: attention_mask should also be the same shape as position_bias. + # Sliding attention window?? + # if attention_mask is not None: + # position_bias = position_bias + attention_mask # (1, num_heads, seq_len, 2*window+1) + attn_weights += position_bias + + attn_weights_float = F.softmax(attn_weights, dim=-1, dtype=torch.float32) # use fp32 for numerical stability if key_padding_mask is not None: # softmax sometimes inserts NaN if all positions are masked, replace them with 0 From 2b44bf8e4784116977c456af25cd1771eb9205ef Mon Sep 17 00:00:00 2001 From: Akshita Bhagia Date: Tue, 8 Dec 2020 00:00:12 -0800 Subject: [PATCH 02/10] t5 encoder decoder options --- longformer/longformer_encoder_decoder.py | 76 +++++++++++++++++++++++- 1 file changed, 75 insertions(+), 1 deletion(-) diff --git a/longformer/longformer_encoder_decoder.py b/longformer/longformer_encoder_decoder.py index df38224..b21b2a4 100644 --- a/longformer/longformer_encoder_decoder.py +++ b/longformer/longformer_encoder_decoder.py @@ -2,7 +2,7 @@ from torch import nn, Tensor from longformer.longformer import LongformerSelfAttention from transformers.modeling_bart import BartConfig, BartForConditionalGeneration - +from transformers.modeling_t5 import T5Config, T5ForConditionalGeneration class LongformerEncoderDecoderForConditionalGeneration(BartForConditionalGeneration): def __init__(self, config): @@ -74,3 +74,77 @@ def forward( attn_output = self.output(outputs[0].transpose(0, 1)) return (attn_output,) + outputs[1:] if len(outputs) == 2 else (attn_output, None) + + +class LongformerEncoderDecoderForConditionalGenerationT5(T5ForConditionalGeneration): + def __init__(self, config): + super().__init__(config) + if config.attention_mode == 'n2': + pass # do nothing, use BertSelfAttention instead + else: + for i, layer in enumerate(self.encoder.block): + layer.layer[0].SelfAttention = LongformerSelfAttentionForT5(config, layer_id=i) + + +class LongformerEncoderDecoderConfigT5(T5Config): + def __init__(self, attention_window: List[int] = None, attention_dilation: List[int] = None, + autoregressive: bool = False, attention_mode: str = 'sliding_chunks', + has_relative_attention_bias: bool = False, gradient_checkpointing: bool = False, + **kwargs): + """ + Args: + attention_window: list of attention window sizes of length = number of layers. + window size = number of attention locations on each side. + For an affective window size of 512, use `attention_window=[256]*num_layers` + which is 256 on each side. + attention_dilation: list of attention dilation of length = number of layers. + attention dilation of `1` means no dilation. + autoregressive: do autoregressive attention or have attention of both sides + attention_mode: 'n2' for regular n^2 self-attention, 'tvm' for TVM implemenation of Longformer + selfattention, 'sliding_chunks' for another implementation of Longformer selfattention + """ + super().__init__(**kwargs) + self.attention_window = attention_window + self.attention_dilation = attention_dilation + self.autoregressive = autoregressive + self.attention_mode = attention_mode + self.has_relative_attention_bias = has_relative_attention_bias + self.gradient_checkpointing = gradient_checkpointing + self.attention_probs_dropout_prob = self.dropout_rate + assert self.attention_mode in ['tvm', 'sliding_chunks', 'n2'] + +class LongformerSelfAttentionForT5(nn.Module): + def __init__(self, config, layer_id): + super().__init__() + self.embed_dim = config.d_model + self.longformer_self_attn = LongformerSelfAttention(config, layer_id=layer_id) + self.output = nn.Linear(self.embed_dim, self.embed_dim) + + def forward( + self, + query, + mask=None, + kv=None, + position_bias=None, + past_key_value_state=None, + head_mask=None, + query_length=None, + use_cache=False, + output_attentions=False, + ): + + tgt_len, bsz, embed_dim = query.size() + assert embed_dim == self.embed_dim + assert list(query.size()) == [tgt_len, bsz, embed_dim] + + outputs = self.longformer_self_attn( + query, #.transpose(0, 1), # LongformerSelfAttention expects (bsz, seqlen, embd_dim) + #attention_mask=key_padding_mask.unsqueeze(dim=1).unsqueeze(dim=1) * -1, + attention_mask=mask, #.unsqueeze(dim=1).unsqueeze(dim=1)*-1, + output_attentions=output_attentions, + ) + + attn_output = self.output(outputs[0].transpose(0, 1)) + + return (attn_output,) + outputs[1:] if len(outputs) == 2 else (attn_output, None) + From 2c3860af8f978f1741865503b812b0d8594b477c Mon Sep 17 00:00:00 2001 From: Akshita Bhagia Date: Tue, 8 Dec 2020 00:14:32 -0800 Subject: [PATCH 03/10] adding convert script --- longformer/longformer.py | 2 +- longformer/longformer_encoder_decoder.py | 2 +- .../convert_t5_to_longformerencoderdecoder.py | 148 ++++++++++++++++++ 3 files changed, 150 insertions(+), 2 deletions(-) create mode 100644 scripts/convert_t5_to_longformerencoderdecoder.py diff --git a/longformer/longformer.py b/longformer/longformer.py index bdf3691..ce93f8f 100644 --- a/longformer/longformer.py +++ b/longformer/longformer.py @@ -71,7 +71,7 @@ def __init__(self, config, layer_id): if "has_relative_attention_bias" in config.to_dict(): self.is_decoder = config.is_decoder self.relative_attention_num_buckets = config.relative_attention_num_buckets - self.has_relative_attention_bias = has_relative_attention_bias + self.has_relative_attention_bias = config.has_relative_attention_bias if self.has_relative_attention_bias: self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.num_heads) self.is_t5 = True diff --git a/longformer/longformer_encoder_decoder.py b/longformer/longformer_encoder_decoder.py index b21b2a4..145bb6b 100644 --- a/longformer/longformer_encoder_decoder.py +++ b/longformer/longformer_encoder_decoder.py @@ -89,7 +89,7 @@ def __init__(self, config): class LongformerEncoderDecoderConfigT5(T5Config): def __init__(self, attention_window: List[int] = None, attention_dilation: List[int] = None, autoregressive: bool = False, attention_mode: str = 'sliding_chunks', - has_relative_attention_bias: bool = False, gradient_checkpointing: bool = False, + has_relative_attention_bias: bool = True, gradient_checkpointing: bool = False, **kwargs): """ Args: diff --git a/scripts/convert_t5_to_longformerencoderdecoder.py b/scripts/convert_t5_to_longformerencoderdecoder.py new file mode 100644 index 0000000..6219b0f --- /dev/null +++ b/scripts/convert_t5_to_longformerencoderdecoder.py @@ -0,0 +1,148 @@ +import argparse +import logging +import os + +from transformers import T5Tokenizer + +from transformers import T5ForConditionalGeneration +from transformers.modeling_bart import shift_tokens_right +from longformer.longformer_encoder_decoder import LongformerSelfAttentionForT5, LongformerEncoderDecoderConfigT5 +from longformer.longformer_encoder_decoder import LongformerEncoderDecoderForConditionalGenerationT5 + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + + +def create_long_model( + save_model_to, + base_model, + tokenizer_name_or_path, + attention_window, + max_pos +): + model = T5ForConditionalGeneration.from_pretrained(base_model) + tokenizer = T5Tokenizer.from_pretrained(tokenizer_name_or_path, model_max_length=max_pos) + config = LongformerEncoderDecoderConfigT5.from_pretrained(base_model) + model.config = config + + # in T5 attention_probs_dropout_prob is dropout_rate, but LongformerSelfAttention + # expects attention_probs_dropout_prob, so set it here + config.attention_probs_dropout_prob = config.dropout_rate + config.architectures = ['LongformerEncoderDecoderForConditionalGenerationT5', ] + + # extend position embeddings + tokenizer.model_max_length = max_pos + tokenizer.init_kwargs['model_max_length'] = max_pos + + # current_max_pos, embed_size = model.model.embed_positions.weight.shape + # assert current_max_pos == config.max_position_embeddings + 2 + + # config.max_encoder_position_embeddings = max_pos + # config.max_decoder_position_embeddings = config.max_position_embeddings + # del config.max_position_embeddings + # # TODO: check what's the deal with T5 here. + # max_pos += 2 # NOTE: BART has positions 0,1 reserved, so embedding size is max position + 2 + # assert max_pos >= current_max_pos + + # # allocate a larger position embedding matrix for the encoder + # new_encoder_pos_embed = model.model.encoder.embed_positions.weight.new_empty(max_pos, embed_size) + # # copy position embeddings over and over to initialize the new position embeddings + # k = 2 + # step = current_max_pos - 2 + # while k < max_pos - 1: + # new_encoder_pos_embed[k:(k + step)] = model.model.encoder.embed_positions.weight[2:] + # k += step + # model.model.encoder.embed_positions.weight.data = new_encoder_pos_embed + + # replace the `modeling_t5.T5Attention` object with `LongformerSelfAttention` + config.attention_window = [attention_window] * config.num_hidden_layers + config.attention_dilation = [1] * config.num_hidden_layers + model.encoder.block = model.encoder.block[:1] + + for i, layer in enumerate(model.encoder.block): + self_attn = layer.layer[0].SelfAttention + + longformer_self_attn_for_t5 = LongformerSelfAttentionForT5(config, layer_id=i) + + longformer_self_attn_for_t5.longformer_self_attn.query = self_attn.q + longformer_self_attn_for_t5.longformer_self_attn.key = self_attn.k + longformer_self_attn_for_t5.longformer_self_attn.value = self_attn.v + + longformer_self_attn_for_t5.longformer_self_attn.query_global = self_attn.q + longformer_self_attn_for_t5.longformer_self_attn.key_global = self_attn.k + longformer_self_attn_for_t5.longformer_self_attn.value_global = self_attn.v + + longformer_self_attn_for_t5.output = self_attn.o + + layer.layer[0].SelfAttention = longformer_self_attn_for_t5 + + logger.info(f'saving model to {save_model_to}') + model.save_pretrained(save_model_to) + tokenizer.save_pretrained(save_model_to) + return model, tokenizer + + +def main(): + parser = argparse.ArgumentParser(description="Convert T5 to LongT5. Replaces T5 encoder's T5Attention with LongformerSelfAttention") + parser.add_argument( + '--base_model', + type=str, + default='t5-large', + help='The name or path of the base model you want to convert' + ) + parser.add_argument( + '--tokenizer_name_or_path', + type=str, + default='t5-large', + help='The name or path of the tokenizer' + ) + parser.add_argument( + '--save_model_to', + type=str, + required=True, + help='The path to save the converted model' + ) + parser.add_argument( + '--attention_window', + type=int, + default=512, + help='attention window size for longformer self attention (one sided)' + ) + parser.add_argument( + '--max_pos', + type=int, + default=4096 * 4, + help='maximum encoder positions' + ) + + args = parser.parse_args() + + if not os.path.exists(args.save_model_to): + os.mkdir(args.save_model_to) + + create_long_model( + save_model_to=args.save_model_to, + base_model=args.base_model, + tokenizer_name_or_path=args.tokenizer_name_or_path, + attention_window=args.attention_window, + max_pos=args.max_pos + ) + + tokenizer = T5Tokenizer.from_pretrained(args.save_model_to) + TXT = "My friends are but they eat too many carbs." + model = LongformerEncoderDecoderForConditionalGenerationT5.from_pretrained(args.save_model_to) + model.encoder.config.gradient_checkpointing = True + model.decoder.config.gradient_checkpointing = True + data = tokenizer([TXT], return_tensors='pt', padding='max_length', max_length=2048) + input_ids = data['input_ids'] + attention_mask = data['attention_mask'] + decoder_input_ids = shift_tokens_right(input_ids[:, :5], tokenizer.pad_token_id) + logits = model(input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, use_cache=False)[0] + masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item() + probs = logits[0, masked_index].softmax(dim=0) + values, predictions = probs.topk(5) + print(tokenizer.convert_ids_to_tokens(predictions)) + + +if __name__ == "__main__": + main() From 7f870e817a323fe88a61b62a4f29c1ecf4656d94 Mon Sep 17 00:00:00 2001 From: Haokun Liu Date: Wed, 24 Mar 2021 10:25:13 -0400 Subject: [PATCH 04/10] add t5 --- longformer/longformer.py | 446 ++++++++++++------ longformer/longformer_encoder_decoder.py | 68 +-- .../convert_t5_to_longformerencoderdecoder.py | 144 +++--- tests/test_t5_short_sequence.py | 47 ++ 4 files changed, 434 insertions(+), 271 deletions(-) create mode 100644 tests/test_t5_short_sequence.py diff --git a/longformer/longformer.py b/longformer/longformer.py index ce93f8f..bf16c9d 100644 --- a/longformer/longformer.py +++ b/longformer/longformer.py @@ -5,14 +5,17 @@ import torch.nn.functional as F from longformer.diagonaled_mm_tvm import diagonaled_mm as diagonaled_mm_tvm, mask_invalid_locations from longformer.sliding_chunks import sliding_chunks_matmul_qk, sliding_chunks_matmul_pv -from longformer.sliding_chunks import sliding_chunks_no_overlap_matmul_qk, sliding_chunks_no_overlap_matmul_pv +from longformer.sliding_chunks import ( + sliding_chunks_no_overlap_matmul_qk, + sliding_chunks_no_overlap_matmul_pv, +) from transformers.modeling_roberta import RobertaConfig, RobertaModel, RobertaForMaskedLM class Longformer(RobertaModel): def __init__(self, config): super(Longformer, self).__init__(config) - if config.attention_mode == 'n2': + if config.attention_mode == "n2": pass # do nothing, use BertSelfAttention instead else: for i, layer in enumerate(self.encoder.layer): @@ -22,7 +25,7 @@ def __init__(self, config): class LongformerForMaskedLM(RobertaForMaskedLM): def __init__(self, config): super(LongformerForMaskedLM, self).__init__(config) - if config.attention_mode == 'n2': + if config.attention_mode == "n2": pass # do nothing, use BertSelfAttention instead else: for i, layer in enumerate(self.roberta.encoder.layer): @@ -30,8 +33,14 @@ def __init__(self, config): class LongformerConfig(RobertaConfig): - def __init__(self, attention_window: List[int] = None, attention_dilation: List[int] = None, - autoregressive: bool = False, attention_mode: str = 'sliding_chunks', **kwargs): + def __init__( + self, + attention_window: List[int] = None, + attention_dilation: List[int] = None, + autoregressive: bool = False, + attention_mode: str = "sliding_chunks", + **kwargs + ): """ Args: attention_window: list of attention window sizes of length = number of layers. @@ -49,38 +58,29 @@ def __init__(self, attention_window: List[int] = None, attention_dilation: List[ self.attention_dilation = attention_dilation self.autoregressive = autoregressive self.attention_mode = attention_mode - assert self.attention_mode in ['tvm', 'sliding_chunks', 'n2', 'sliding_chunks_no_overlap'] + assert self.attention_mode in ["tvm", "sliding_chunks", "n2", "sliding_chunks_no_overlap"] class LongformerSelfAttention(nn.Module): - def __init__(self, config, layer_id): + def __init__(self, config, layer_id, bias=True, attention_dim_scale=True): super(LongformerSelfAttention, self).__init__() if config.hidden_size % config.num_attention_heads != 0: raise ValueError( "The hidden size (%d) is not a multiple of the number of attention " - "heads (%d)" % (config.hidden_size, config.num_attention_heads)) + "heads (%d)" % (config.hidden_size, config.num_attention_heads) + ) self.num_heads = config.num_attention_heads self.head_dim = int(config.hidden_size / config.num_attention_heads) self.embed_dim = config.hidden_size + self.attention_dim_scale = attention_dim_scale - self.query = nn.Linear(config.hidden_size, self.embed_dim) - self.key = nn.Linear(config.hidden_size, self.embed_dim) - self.value = nn.Linear(config.hidden_size, self.embed_dim) - - # this is for the T5 setting - if "has_relative_attention_bias" in config.to_dict(): - self.is_decoder = config.is_decoder - self.relative_attention_num_buckets = config.relative_attention_num_buckets - self.has_relative_attention_bias = config.has_relative_attention_bias - if self.has_relative_attention_bias: - self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.num_heads) - self.is_t5 = True - else: - self.is_t5 = False + self.query = nn.Linear(config.hidden_size, self.embed_dim, bias=bias) + self.key = nn.Linear(config.hidden_size, self.embed_dim, bias=bias) + self.value = nn.Linear(config.hidden_size, self.embed_dim, bias=bias) - self.query_global = nn.Linear(config.hidden_size, self.embed_dim) - self.key_global = nn.Linear(config.hidden_size, self.embed_dim) - self.value_global = nn.Linear(config.hidden_size, self.embed_dim) + self.query_global = nn.Linear(config.hidden_size, self.embed_dim, bias=bias) + self.key_global = nn.Linear(config.hidden_size, self.embed_dim, bias=bias) + self.value_global = nn.Linear(config.hidden_size, self.embed_dim, bias=bias) self.dropout = config.attention_probs_dropout_prob @@ -89,73 +89,23 @@ def __init__(self, config, layer_id): self.attention_dilation = config.attention_dilation[self.layer_id] self.attention_mode = config.attention_mode self.autoregressive = config.autoregressive + + if hasattr(config, "relative_attention_num_buckets") and layer_id == 0: + self.has_relative_attention_bias = True + self.relative_attention_num_buckets = config.relative_attention_num_buckets + self.relative_attention_bias = nn.Embedding( + self.relative_attention_num_buckets, self.num_heads + ) + else: + self.has_relative_attention_bias = False + assert self.attention_window > 0 assert self.attention_dilation > 0 - assert self.attention_mode in ['tvm', 'sliding_chunks', 'sliding_chunks_no_overlap'] - if self.attention_mode in ['sliding_chunks', 'sliding_chunks_no_overlap']: + assert self.attention_mode in ["tvm", "sliding_chunks", "sliding_chunks_no_overlap"] + if self.attention_mode in ["sliding_chunks", "sliding_chunks_no_overlap"]: assert not self.autoregressive # not supported assert self.attention_dilation == 1 # dilation is not supported - @staticmethod - def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): - """ - Adapted from Mesh Tensorflow: - https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 - Translate relative position to a bucket number for relative attention. The relative position is defined as - memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to - position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for - small absolute relative_position and larger buckets for larger absolute relative_positions. All relative - positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. - This should allow for more graceful generalization to longer sequences than the model has been trained on - Args: - relative_position: an int32 Tensor - bidirectional: a boolean - whether the attention is bidirectional - num_buckets: an integer - max_distance: an integer - Returns: - a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) - """ - relative_buckets = 0 - if bidirectional: - num_buckets //= 2 - relative_buckets += (relative_position > 0).to(torch.long) * num_buckets - relative_position = torch.abs(relative_position) - else: - relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) - # now relative_position is in the range [0, inf) - - # half of the buckets are for exact increments in positions - max_exact = num_buckets // 2 - is_small = relative_position < max_exact - - # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance - relative_postion_if_large = max_exact + ( - torch.log(relative_position.float() / max_exact) - / math.log(max_distance / max_exact) - * (num_buckets - max_exact) - ).to(torch.long) - relative_postion_if_large = torch.min( - relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1) - ) - - relative_buckets += torch.where(is_small, relative_position, relative_postion_if_large) - return relative_buckets - - def compute_bias(self, qlen, klen): - """ Compute binned relative position bias """ - relative_position = torch.tensor([[i-self.attention_window for i in range(2*self.attention_window+1)]]) - rp_bucket = self._relative_position_bucket( - relative_position, # shape (qlen, klen) - bidirectional=not self.is_decoder, - num_buckets=self.relative_attention_num_buckets, - ) - rp_bucket = rp_bucket.to(self.relative_attention_bias.weight.device) - values = self.relative_attention_bias(rp_bucket) # shape (qlen, klen, num_heads) -# values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, qlen, klen) - # Changing the shape to below because that's what LongformerSelfAttention's attn_weights need. - values = values.permute([0, 2, 1]).unsqueeze(0) # shape (1, qlen, num_heads, klen) - return values - def forward( self, hidden_states, @@ -167,14 +117,18 @@ def forward( encoder_attention_mask=None, output_attentions=False, ): - ''' + """ The `attention_mask` is changed in `BertModel.forward` from 0, 1, 2 to -ve: no attention 0: local attention +ve: global attention - ''' - assert encoder_hidden_states is None, "`encoder_hidden_states` is not supported and should be None" - assert encoder_attention_mask is None, "`encoder_attention_mask` is not supported and shiould be None" + """ + assert ( + encoder_hidden_states is None + ), "`encoder_hidden_states` is not supported and should be None" + assert ( + encoder_attention_mask is None + ), "`encoder_attention_mask` is not supported and shiould be None" if attention_mask is not None: attention_mask = attention_mask.squeeze(dim=2).squeeze(dim=1) @@ -192,10 +146,13 @@ def forward( # in a 3d tensor and pad it to `max_num_extra_indices_per_batch` # 1) selecting embeddings that correspond to global attention extra_attention_mask_nonzeros = extra_attention_mask.nonzero(as_tuple=True) - zero_to_max_range = torch.arange(0, max_num_extra_indices_per_batch, - device=num_extra_indices_per_batch.device) + zero_to_max_range = torch.arange( + 0, max_num_extra_indices_per_batch, device=num_extra_indices_per_batch.device + ) # mask indicating which values are actually going to be padding - selection_padding_mask = zero_to_max_range < num_extra_indices_per_batch.unsqueeze(dim=-1) + selection_padding_mask = zero_to_max_range < num_extra_indices_per_batch.unsqueeze( + dim=-1 + ) # 2) location of the non-padding values in the selected global attention selection_padding_mask_nonzeros = selection_padding_mask.nonzero(as_tuple=True) # 3) location of the padding values in the selected global attention @@ -211,91 +168,175 @@ def forward( q = self.query(hidden_states) k = self.key(hidden_states) v = self.value(hidden_states) - q /= math.sqrt(self.head_dim) + if self.attention_dim_scale: + q /= math.sqrt(self.head_dim) q = q.view(seq_len, bsz, self.num_heads, self.head_dim).transpose(0, 1) k = k.view(seq_len, bsz, self.num_heads, self.head_dim).transpose(0, 1) # attn_weights = (bsz, seq_len, num_heads, window*2+1) - if self.attention_mode == 'tvm': + if self.attention_mode == "tvm": q = q.float().contiguous() k = k.float().contiguous() - attn_weights = diagonaled_mm_tvm(q, k, self.attention_window, self.attention_dilation, False, 0, False) + attn_weights = diagonaled_mm_tvm( + q, k, self.attention_window, self.attention_dilation, False, 0, False + ) elif self.attention_mode == "sliding_chunks": attn_weights = sliding_chunks_matmul_qk(q, k, self.attention_window, padding_value=0) elif self.attention_mode == "sliding_chunks_no_overlap": - attn_weights = sliding_chunks_no_overlap_matmul_qk(q, k, self.attention_window, padding_value=0) + attn_weights = sliding_chunks_no_overlap_matmul_qk( + q, k, self.attention_window, padding_value=0 + ) else: raise False mask_invalid_locations(attn_weights, self.attention_window, self.attention_dilation, False) if remove_from_windowed_attention_mask is not None: # This implementation is fast and takes very little memory because num_heads x hidden_size = 1 # from (bsz x seq_len) to (bsz x seq_len x num_heads x hidden_size) - remove_from_windowed_attention_mask = remove_from_windowed_attention_mask.unsqueeze(dim=-1).unsqueeze(dim=-1) + remove_from_windowed_attention_mask = remove_from_windowed_attention_mask.unsqueeze( + dim=-1 + ).unsqueeze(dim=-1) # cast to float/half then replace 1's with -inf - float_mask = remove_from_windowed_attention_mask.type_as(q).masked_fill(remove_from_windowed_attention_mask, -10000.0) - repeat_size = 1 if isinstance(self.attention_dilation, int) else len(self.attention_dilation) + float_mask = remove_from_windowed_attention_mask.type_as(q).masked_fill( + remove_from_windowed_attention_mask, -10000.0 + ) + repeat_size = ( + 1 if isinstance(self.attention_dilation, int) else len(self.attention_dilation) + ) float_mask = float_mask.repeat(1, 1, repeat_size, 1) ones = float_mask.new_ones(size=float_mask.size()) # tensor of ones # diagonal mask with zeros everywhere and -inf inplace of padding - if self.attention_mode == 'tvm': - d_mask = diagonaled_mm_tvm(ones, float_mask, self.attention_window, self.attention_dilation, False, 0, False) + if self.attention_mode == "tvm": + d_mask = diagonaled_mm_tvm( + ones, + float_mask, + self.attention_window, + self.attention_dilation, + False, + 0, + False, + ) elif self.attention_mode == "sliding_chunks": - d_mask = sliding_chunks_matmul_qk(ones, float_mask, self.attention_window, padding_value=0) + d_mask = sliding_chunks_matmul_qk( + ones, float_mask, self.attention_window, padding_value=0 + ) elif self.attention_mode == "sliding_chunks_no_overlap": - d_mask = sliding_chunks_no_overlap_matmul_qk(ones, float_mask, self.attention_window, padding_value=0) + d_mask = sliding_chunks_no_overlap_matmul_qk( + ones, float_mask, self.attention_window, padding_value=0 + ) attn_weights += d_mask assert list(attn_weights.size())[:3] == [bsz, seq_len, self.num_heads] - assert attn_weights.size(dim=3) in [self.attention_window * 2 + 1, self.attention_window * 3] + assert attn_weights.size(dim=3) in [ + self.attention_window * 2 + 1, + self.attention_window * 3, + ] # the extra attention if extra_attention_mask is not None: - selected_k = k.new_zeros(bsz, max_num_extra_indices_per_batch, self.num_heads, self.head_dim) + selected_k = k.new_zeros( + bsz, max_num_extra_indices_per_batch, self.num_heads, self.head_dim + ) selected_k[selection_padding_mask_nonzeros] = k[extra_attention_mask_nonzeros] # (bsz, seq_len, num_heads, max_num_extra_indices_per_batch) - selected_attn_weights = torch.einsum('blhd,bshd->blhs', (q, selected_k)) - selected_attn_weights[selection_padding_mask_zeros[0], :, :, selection_padding_mask_zeros[1]] = -10000 + selected_attn_weights = torch.einsum("blhd,bshd->blhs", (q, selected_k)) + selected_attn_weights[ + selection_padding_mask_zeros[0], :, :, selection_padding_mask_zeros[1] + ] = -10000 # concat to attn_weights - # (bsz, seq_len, num_heads, extra attention count + 2*window+1) + # (bsz, seq_len, num_heads, max_num_extra_indices_per_batch + 2*window+1) attn_weights = torch.cat((selected_attn_weights, attn_weights), dim=-1) - if self.is_t5: - if position_bias is None: - if not self.has_relative_attention_bias: - raise ValueError("No position_bias provided and no weights to compute position_bias") - - position_bias = self.compute_bias(seq_len, seq_len) - # if key and values are already calculated - # we want only the last query position bias - if past_key_value_state is not None: - position_bias = position_bias[:, :, -1:, :] - # TODO: attention_mask should also be the same shape as position_bias. - # Sliding attention window?? - # if attention_mask is not None: - # position_bias = position_bias + attention_mask # (1, num_heads, seq_len, 2*window+1) - attn_weights += position_bias + if position_bias is None and self.has_relative_attention_bias: + window_relative_position = torch.arange( + -self.attention_window, + self.attention_window + 1, + 1, + dtype=torch.long, + device=attn_weights.device, + ) + window_position_bias = ( + self.relative_attention_bias( + relative_position_bucket( + window_relative_position, num_buckets=self.relative_attention_num_buckets + ) + ) + .permute(1, 0)[None, None, :, :] + .repeat(bsz, seq_len, 1, 1) + ) # (bsz, seq_len, num_heads, 2*window+1) + perm_global_position_bias = attn_weights.new_zeros( + bsz, max_num_extra_indices_per_batch, seq_len, self.num_heads + ) # (bsz, max_num_extra_indices_per_batch, seq_len, num_heads) + if extra_attention_mask is not None: + selected_global_memory_position = extra_attention_mask_nonzeros[1][ + :, None + ] # (sum num_extra_indices_per_batch, 1) + selected_global_query_position = torch.arange( + seq_len, dtype=torch.long, device=attn_weights.device + )[ + None, : + ] # (1, seq_len) + selected_global_relative_position = ( + selected_global_memory_position - selected_global_query_position + ) # (sum num_extra_indices_per_batch, seq_len) + selected_global_position_bias = self.relative_attention_bias( + relative_position_bucket(selected_global_relative_position) + ) # (sum num_extra_indices_per_batch, seq_len, num_heads) + perm_global_position_bias[ + selection_padding_mask_nonzeros + ] = selected_global_position_bias # (bsz, max_num_extra_indices_per_batch, seq_len) + global_position_bias = perm_global_position_bias.permute(0, 2, 3, 1) + # (bsz, seq_len, num_heads, max_num_extra_indices_per_batch) + position_bias = torch.cat( + ( + global_position_bias, + window_position_bias, + ), + dim=-1, + ) + # (bsz, seq_len, num_heads, max_num_extra_indices_per_batch + 2*window+1) + else: + position_bias = window_position_bias + if position_bias is not None: + attn_weights += position_bias - attn_weights_float = F.softmax(attn_weights, dim=-1, dtype=torch.float32) # use fp32 for numerical stability + attn_weights_float = F.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ) # use fp32 for numerical stability if key_padding_mask is not None: # softmax sometimes inserts NaN if all positions are masked, replace them with 0 - attn_weights_float = torch.masked_fill(attn_weights_float, key_padding_mask.unsqueeze(-1).unsqueeze(-1), 0.0) + attn_weights_float = torch.masked_fill( + attn_weights_float, key_padding_mask.unsqueeze(-1).unsqueeze(-1), 0.0 + ) attn_weights = attn_weights_float.type_as(attn_weights) - attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training) + attn_probs = F.dropout( + attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training + ) v = v.view(seq_len, bsz, self.num_heads, self.head_dim).transpose(0, 1) attn = 0 if extra_attention_mask is not None: selected_attn_probs = attn_probs.narrow(-1, 0, max_num_extra_indices_per_batch) - selected_v = v.new_zeros(bsz, max_num_extra_indices_per_batch, self.num_heads, self.head_dim) + selected_v = v.new_zeros( + bsz, max_num_extra_indices_per_batch, self.num_heads, self.head_dim + ) selected_v[selection_padding_mask_nonzeros] = v[extra_attention_mask_nonzeros] # use `matmul` because `einsum` crashes sometimes with fp16 # attn = torch.einsum('blhs,bshd->blhd', (selected_attn_probs, selected_v)) - attn = torch.matmul(selected_attn_probs.transpose(1, 2), selected_v.transpose(1, 2).type_as(selected_attn_probs)).transpose(1, 2) - attn_probs = attn_probs.narrow(-1, max_num_extra_indices_per_batch, attn_probs.size(-1) - max_num_extra_indices_per_batch).contiguous() - - if self.attention_mode == 'tvm': + attn = torch.matmul( + selected_attn_probs.transpose(1, 2), + selected_v.transpose(1, 2).type_as(selected_attn_probs), + ).transpose(1, 2) + attn_probs = attn_probs.narrow( + -1, + max_num_extra_indices_per_batch, + attn_probs.size(-1) - max_num_extra_indices_per_batch, + ).contiguous() + + if self.attention_mode == "tvm": v = v.float().contiguous() - attn += diagonaled_mm_tvm(attn_probs, v, self.attention_window, self.attention_dilation, True, 0, False) + attn += diagonaled_mm_tvm( + attn_probs, v, self.attention_window, self.attention_dilation, True, 0, False + ) elif self.attention_mode == "sliding_chunks": attn += sliding_chunks_matmul_pv(attn_probs, v, self.attention_window) elif self.attention_mode == "sliding_chunks_no_overlap": @@ -310,36 +351,73 @@ def forward( # For this case, we'll just recompute the attention for these indices # and overwrite the attn tensor. TODO: remove the redundant computation if extra_attention_mask is not None: - selected_hidden_states = hidden_states.new_zeros(max_num_extra_indices_per_batch, bsz, embed_dim) - selected_hidden_states[selection_padding_mask_nonzeros[::-1]] = hidden_states[extra_attention_mask_nonzeros[::-1]] + selected_hidden_states = hidden_states.new_zeros( + max_num_extra_indices_per_batch, bsz, embed_dim + ) + selected_hidden_states[selection_padding_mask_nonzeros[::-1]] = hidden_states[ + extra_attention_mask_nonzeros[::-1] + ] q = self.query_global(selected_hidden_states) k = self.key_global(hidden_states) v = self.value_global(hidden_states) - q /= math.sqrt(self.head_dim) - - q = q.contiguous().view(max_num_extra_indices_per_batch, bsz * self.num_heads, self.head_dim).transpose(0, 1) # (bsz*self.num_heads, max_num_extra_indices_per_batch, head_dim) - k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) # bsz * self.num_heads, seq_len, head_dim) - v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) # bsz * self.num_heads, seq_len, head_dim) + if self.attention_dim_scale: + q /= math.sqrt(self.head_dim) + + q = ( + q.contiguous() + .view(max_num_extra_indices_per_batch, bsz * self.num_heads, self.head_dim) + .transpose(0, 1) + ) # (bsz * self.num_heads, max_num_extra_indices_per_batch, head_dim) + k = ( + k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) + ) # (bsz * self.num_heads, seq_len, head_dim) + v = ( + v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) + ) # (bsz * self.num_heads, seq_len, head_dim) attn_weights = torch.bmm(q, k.transpose(1, 2)) - assert list(attn_weights.size()) == [bsz * self.num_heads, max_num_extra_indices_per_batch, seq_len] - - attn_weights = attn_weights.view(bsz, self.num_heads, max_num_extra_indices_per_batch, seq_len) - attn_weights[selection_padding_mask_zeros[0], :, selection_padding_mask_zeros[1], :] = -10000.0 + assert list(attn_weights.size()) == [ + bsz * self.num_heads, + max_num_extra_indices_per_batch, + seq_len, + ] + + attn_weights = attn_weights.view( + bsz, self.num_heads, max_num_extra_indices_per_batch, seq_len + ) + attn_weights[ + selection_padding_mask_zeros[0], :, selection_padding_mask_zeros[1], : + ] = -10000.0 if key_padding_mask is not None: attn_weights = attn_weights.masked_fill( key_padding_mask.unsqueeze(1).unsqueeze(2), -10000.0, ) - attn_weights = attn_weights.view(bsz * self.num_heads, max_num_extra_indices_per_batch, seq_len) - attn_weights_float = F.softmax(attn_weights, dim=-1, dtype=torch.float32) # use fp32 for numerical stability - attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training) + attn_weights = attn_weights.view( + bsz * self.num_heads, max_num_extra_indices_per_batch, seq_len + ) + attn_weights_float = F.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ) # use fp32 for numerical stability + attn_probs = F.dropout( + attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training + ) selected_attn = torch.bmm(attn_probs, v) - assert list(selected_attn.size()) == [bsz * self.num_heads, max_num_extra_indices_per_batch, self.head_dim] - - selected_attn_4d = selected_attn.view(bsz, self.num_heads, max_num_extra_indices_per_batch, self.head_dim) - nonzero_selected_attn = selected_attn_4d[selection_padding_mask_nonzeros[0], :, selection_padding_mask_nonzeros[1]] - attn[extra_attention_mask_nonzeros[::-1]] = nonzero_selected_attn.view(len(selection_padding_mask_nonzeros[0]), -1).type_as(hidden_states) + assert list(selected_attn.size()) == [ + bsz * self.num_heads, + max_num_extra_indices_per_batch, + self.head_dim, + ] + + selected_attn_4d = selected_attn.view( + bsz, self.num_heads, max_num_extra_indices_per_batch, self.head_dim + ) + nonzero_selected_attn = selected_attn_4d[ + selection_padding_mask_nonzeros[0], :, selection_padding_mask_nonzeros[1] + ] + attn[extra_attention_mask_nonzeros[::-1]] = nonzero_selected_attn.view( + len(selection_padding_mask_nonzeros[0]), -1 + ).type_as(hidden_states) context_layer = attn.transpose(0, 1) if output_attentions: @@ -350,11 +428,69 @@ def forward( # It doesn't not return local attention # In case of variable number of global attantion in the rows of a batch, # attn_weights are padded with -10000.0 attention scores - attn_weights = attn_weights.view(bsz, self.num_heads, max_num_extra_indices_per_batch, seq_len) + attn_weights = attn_weights.view( + bsz, self.num_heads, max_num_extra_indices_per_batch, seq_len + ) else: # without global attention, return local attention probabilities # batch_size x num_heads x sequence_length x window_size # which is the attention weights of every token attending to its neighbours attn_weights = attn_weights.permute(0, 2, 1, 3) - outputs = (context_layer, attn_weights) if output_attentions else (context_layer,) + + outputs = (context_layer,) + if output_attentions: + outputs = outputs + (attn_weights,) + if self.has_relative_attention_bias: + outputs = outputs + (position_bias,) + return outputs + + +def relative_position_bucket( + relative_position, bidirectional=True, num_buckets=32, max_distance=128 +): + """ + Imported from Huggingface transformers + https://github.com/huggingface/transformers/blob/a0a027c2ed53b324cf4d0179ceec88d4ff414d47/src/transformers/models/t5/modeling_t5.py#L344 + Original description below: + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + Returns: + a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) + """ + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0).to(torch.long) * num_buckets + relative_position = torch.abs(relative_position) + else: + relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) + # now relative_position is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + relative_postion_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_postion_if_large = torch.min( + relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1) + ) + + relative_buckets += torch.where(is_small, relative_position, relative_postion_if_large) + return relative_buckets diff --git a/longformer/longformer_encoder_decoder.py b/longformer/longformer_encoder_decoder.py index 145bb6b..272aa6e 100644 --- a/longformer/longformer_encoder_decoder.py +++ b/longformer/longformer_encoder_decoder.py @@ -4,10 +4,11 @@ from transformers.modeling_bart import BartConfig, BartForConditionalGeneration from transformers.modeling_t5 import T5Config, T5ForConditionalGeneration + class LongformerEncoderDecoderForConditionalGeneration(BartForConditionalGeneration): def __init__(self, config): super().__init__(config) - if config.attention_mode == 'n2': + if config.attention_mode == "n2": pass # do nothing, use BertSelfAttention instead else: for i, layer in enumerate(self.model.encoder.layers): @@ -15,9 +16,15 @@ def __init__(self, config): class LongformerEncoderDecoderConfig(BartConfig): - def __init__(self, attention_window: List[int] = None, attention_dilation: List[int] = None, - autoregressive: bool = False, attention_mode: str = 'sliding_chunks', - gradient_checkpointing: bool = False, **kwargs): + def __init__( + self, + attention_window: List[int] = None, + attention_dilation: List[int] = None, + autoregressive: bool = False, + attention_mode: str = "sliding_chunks", + gradient_checkpointing: bool = False, + **kwargs + ): """ Args: attention_window: list of attention window sizes of length = number of layers. @@ -36,7 +43,7 @@ def __init__(self, attention_window: List[int] = None, attention_dilation: List[ self.autoregressive = autoregressive self.attention_mode = attention_mode self.gradient_checkpointing = gradient_checkpointing - assert self.attention_mode in ['tvm', 'sliding_chunks', 'n2'] + assert self.attention_mode in ["tvm", "sliding_chunks", "n2"] class LongformerSelfAttentionForBart(nn.Module): @@ -76,21 +83,26 @@ def forward( return (attn_output,) + outputs[1:] if len(outputs) == 2 else (attn_output, None) -class LongformerEncoderDecoderForConditionalGenerationT5(T5ForConditionalGeneration): +class LongformerT5ForConditionalGeneration(T5ForConditionalGeneration): def __init__(self, config): super().__init__(config) - if config.attention_mode == 'n2': + if config.attention_mode == "n2": pass # do nothing, use BertSelfAttention instead else: for i, layer in enumerate(self.encoder.block): layer.layer[0].SelfAttention = LongformerSelfAttentionForT5(config, layer_id=i) -class LongformerEncoderDecoderConfigT5(T5Config): - def __init__(self, attention_window: List[int] = None, attention_dilation: List[int] = None, - autoregressive: bool = False, attention_mode: str = 'sliding_chunks', - has_relative_attention_bias: bool = True, gradient_checkpointing: bool = False, - **kwargs): +class LongformerT5Config(T5Config): + def __init__( + self, + attention_window: List[int] = None, + attention_dilation: List[int] = None, + autoregressive: bool = False, + attention_mode: str = "sliding_chunks", + gradient_checkpointing: bool = False, + **kwargs + ): """ Args: attention_window: list of attention window sizes of length = number of layers. @@ -108,21 +120,27 @@ def __init__(self, attention_window: List[int] = None, attention_dilation: List[ self.attention_dilation = attention_dilation self.autoregressive = autoregressive self.attention_mode = attention_mode - self.has_relative_attention_bias = has_relative_attention_bias self.gradient_checkpointing = gradient_checkpointing - self.attention_probs_dropout_prob = self.dropout_rate - assert self.attention_mode in ['tvm', 'sliding_chunks', 'n2'] + assert self.attention_mode in ["tvm", "sliding_chunks", "n2"] + class LongformerSelfAttentionForT5(nn.Module): + """ + Replacement for T5Attention, but only for the encoder stack + """ + def __init__(self, config, layer_id): super().__init__() self.embed_dim = config.d_model - self.longformer_self_attn = LongformerSelfAttention(config, layer_id=layer_id) - self.output = nn.Linear(self.embed_dim, self.embed_dim) + self.has_relative_attention_bias = layer_id == 0 + self.longformer_self_attn = LongformerSelfAttention( + config, layer_id=layer_id, bias=False, attention_dim_scale=False + ) + self.output = nn.Linear(self.embed_dim, self.embed_dim, bias=False) def forward( self, - query, + input, mask=None, kv=None, position_bias=None, @@ -133,18 +151,10 @@ def forward( output_attentions=False, ): - tgt_len, bsz, embed_dim = query.size() - assert embed_dim == self.embed_dim - assert list(query.size()) == [tgt_len, bsz, embed_dim] - outputs = self.longformer_self_attn( - query, #.transpose(0, 1), # LongformerSelfAttention expects (bsz, seqlen, embd_dim) - #attention_mask=key_padding_mask.unsqueeze(dim=1).unsqueeze(dim=1) * -1, - attention_mask=mask, #.unsqueeze(dim=1).unsqueeze(dim=1)*-1, - output_attentions=output_attentions, + input, attention_mask=mask, position_bias=position_bias, output_attentions=output_attentions, ) - attn_output = self.output(outputs[0].transpose(0, 1)) - - return (attn_output,) + outputs[1:] if len(outputs) == 2 else (attn_output, None) + outputs = (self.output(outputs[0]), None) + outputs[1:] + return outputs diff --git a/scripts/convert_t5_to_longformerencoderdecoder.py b/scripts/convert_t5_to_longformerencoderdecoder.py index 6219b0f..ac7e414 100644 --- a/scripts/convert_t5_to_longformerencoderdecoder.py +++ b/scripts/convert_t5_to_longformerencoderdecoder.py @@ -1,64 +1,45 @@ import argparse import logging import os +import copy from transformers import T5Tokenizer from transformers import T5ForConditionalGeneration -from transformers.modeling_bart import shift_tokens_right -from longformer.longformer_encoder_decoder import LongformerSelfAttentionForT5, LongformerEncoderDecoderConfigT5 -from longformer.longformer_encoder_decoder import LongformerEncoderDecoderForConditionalGenerationT5 +from longformer.longformer_encoder_decoder import ( + LongformerSelfAttentionForT5, + LongformerT5Config, +) +from longformer.longformer_encoder_decoder import LongformerT5ForConditionalGeneration logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) -def create_long_model( - save_model_to, - base_model, - tokenizer_name_or_path, - attention_window, - max_pos -): +def create_long_model(save_model_to, base_model, attention_window, max_pos): + # load base model & tokenizer model = T5ForConditionalGeneration.from_pretrained(base_model) - tokenizer = T5Tokenizer.from_pretrained(tokenizer_name_or_path, model_max_length=max_pos) - config = LongformerEncoderDecoderConfigT5.from_pretrained(base_model) - model.config = config - - # in T5 attention_probs_dropout_prob is dropout_rate, but LongformerSelfAttention - # expects attention_probs_dropout_prob, so set it here + tokenizer = T5Tokenizer.from_pretrained(base_model, model_max_length=max_pos) + + # setup config + config = LongformerT5Config.from_pretrained(base_model) + config.architectures = [ + "LongformerT5ForConditionalGeneration", + ] + # in T5 attention_probs_dropout_prob is dropout_rate config.attention_probs_dropout_prob = config.dropout_rate - config.architectures = ['LongformerEncoderDecoderForConditionalGenerationT5', ] - - # extend position embeddings - tokenizer.model_max_length = max_pos - tokenizer.init_kwargs['model_max_length'] = max_pos - - # current_max_pos, embed_size = model.model.embed_positions.weight.shape - # assert current_max_pos == config.max_position_embeddings + 2 - - # config.max_encoder_position_embeddings = max_pos - # config.max_decoder_position_embeddings = config.max_position_embeddings - # del config.max_position_embeddings - # # TODO: check what's the deal with T5 here. - # max_pos += 2 # NOTE: BART has positions 0,1 reserved, so embedding size is max position + 2 - # assert max_pos >= current_max_pos - - # # allocate a larger position embedding matrix for the encoder - # new_encoder_pos_embed = model.model.encoder.embed_positions.weight.new_empty(max_pos, embed_size) - # # copy position embeddings over and over to initialize the new position embeddings - # k = 2 - # step = current_max_pos - 2 - # while k < max_pos - 1: - # new_encoder_pos_embed[k:(k + step)] = model.model.encoder.embed_positions.weight[2:] - # k += step - # model.model.encoder.embed_positions.weight.data = new_encoder_pos_embed - - # replace the `modeling_t5.T5Attention` object with `LongformerSelfAttention` config.attention_window = [attention_window] * config.num_hidden_layers config.attention_dilation = [1] * config.num_hidden_layers - model.encoder.block = model.encoder.block[:1] + # modify config in model + # HF T5 includes multiple pointers to the config object + model.config = model.encoder.config = model.decoder.config = config + + # modify tokenizer + tokenizer.model_max_length = max_pos + tokenizer.init_kwargs["model_max_length"] = max_pos + + # modify model architecture for i, layer in enumerate(model.encoder.block): self_attn = layer.layer[0].SelfAttention @@ -68,52 +49,40 @@ def create_long_model( longformer_self_attn_for_t5.longformer_self_attn.key = self_attn.k longformer_self_attn_for_t5.longformer_self_attn.value = self_attn.v - longformer_self_attn_for_t5.longformer_self_attn.query_global = self_attn.q - longformer_self_attn_for_t5.longformer_self_attn.key_global = self_attn.k - longformer_self_attn_for_t5.longformer_self_attn.value_global = self_attn.v + longformer_self_attn_for_t5.longformer_self_attn.query_global = copy.deepcopy(self_attn.q) + longformer_self_attn_for_t5.longformer_self_attn.key_global = copy.deepcopy(self_attn.k) + longformer_self_attn_for_t5.longformer_self_attn.value_global = copy.deepcopy(self_attn.v) longformer_self_attn_for_t5.output = self_attn.o + if i == 0: + longformer_self_attn_for_t5.longformer_self_attn.relative_attention_bias = self_attn.relative_attention_bias + layer.layer[0].SelfAttention = longformer_self_attn_for_t5 - logger.info(f'saving model to {save_model_to}') + # save modified model + logger.info(f"saving model to {save_model_to}") model.save_pretrained(save_model_to) tokenizer.save_pretrained(save_model_to) - return model, tokenizer + config.save_pretrained(save_model_to) + return def main(): - parser = argparse.ArgumentParser(description="Convert T5 to LongT5. Replaces T5 encoder's T5Attention with LongformerSelfAttention") - parser.add_argument( - '--base_model', - type=str, - default='t5-large', - help='The name or path of the base model you want to convert' - ) - parser.add_argument( - '--tokenizer_name_or_path', - type=str, - default='t5-large', - help='The name or path of the tokenizer' + parser = argparse.ArgumentParser( + description="Convert T5 to LongT5. Replaces T5 encoder's T5Attention with LongformerSelfAttention" ) parser.add_argument( - '--save_model_to', - type=str, - required=True, - help='The path to save the converted model' + "--base_model", type=str, default="t5-small", help="The name or path of the base model you want to convert", ) + parser.add_argument("--save_model_to", type=str, required=True, help="The path to save the converted model") parser.add_argument( - '--attention_window', + "--attention_window", type=int, default=512, - help='attention window size for longformer self attention (one sided)' - ) - parser.add_argument( - '--max_pos', - type=int, - default=4096 * 4, - help='maximum encoder positions' + help="attention window size for longformer self attention (one sided)", ) + parser.add_argument("--max_pos", type=int, default=4096 * 4, help="maximum encoder positions") args = parser.parse_args() @@ -123,24 +92,25 @@ def main(): create_long_model( save_model_to=args.save_model_to, base_model=args.base_model, - tokenizer_name_or_path=args.tokenizer_name_or_path, attention_window=args.attention_window, - max_pos=args.max_pos + max_pos=args.max_pos, ) tokenizer = T5Tokenizer.from_pretrained(args.save_model_to) - TXT = "My friends are but they eat too many carbs." - model = LongformerEncoderDecoderForConditionalGenerationT5.from_pretrained(args.save_model_to) - model.encoder.config.gradient_checkpointing = True - model.decoder.config.gradient_checkpointing = True - data = tokenizer([TXT], return_tensors='pt', padding='max_length', max_length=2048) - input_ids = data['input_ids'] - attention_mask = data['attention_mask'] - decoder_input_ids = shift_tokens_right(input_ids[:, :5], tokenizer.pad_token_id) - logits = model(input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, use_cache=False)[0] - masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item() - probs = logits[0, masked_index].softmax(dim=0) - values, predictions = probs.topk(5) + model = LongformerT5ForConditionalGeneration.from_pretrained(args.save_model_to) + model.eval() + model.config.gradient_checkpointing = True + + TXT = "A rose is a rose is a" + data = tokenizer([TXT], return_tensors="pt", padding="max_length", max_length=2048) + input_ids = data["input_ids"] + attention_mask = data["attention_mask"] + attention_mask[0, 0:4:2] = 2 + decoder_input_ids = model._shift_right(input_ids[:, :5]) + + logits = model(input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, use_cache=False,)[0] + probs = logits[0, -1].softmax(dim=0) + _, predictions = probs.topk(5) print(tokenizer.convert_ids_to_tokens(predictions)) diff --git a/tests/test_t5_short_sequence.py b/tests/test_t5_short_sequence.py new file mode 100644 index 0000000..de4627d --- /dev/null +++ b/tests/test_t5_short_sequence.py @@ -0,0 +1,47 @@ +import torch +import unittest +from longformer.longformer_encoder_decoder import LongformerT5ForConditionalGeneration +from longformer.sliding_chunks import pad_to_window_size +from transformers import T5Tokenizer, T5ForConditionalGeneration + + +class TestT5ShortSeq(unittest.TestCase): + def _run_test(self, INPUT_TEXT, long_model_name_or_path, base_model_name_or_path): + + tokenizer = T5Tokenizer.from_pretrained(long_model_name_or_path) + model = LongformerEncoderDecoderForConditionalGenerationT5.from_pretrained(long_model_name_or_path) + model.eval() + model.config.gradient_checkpointing = True + base_model = T5ForConditionalGeneration.from_pretrained(base_model_name_or_path) + base_model.eval() + + data = tokenizer([INPUT_TEXT], return_tensors="pt", padding="max_length", max_length=2048) + input_ids = data["input_ids"] + attention_mask = data["attention_mask"] + decoder_input_ids = model._shift_right(input_ids[:, :5]) + + output = model(input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, use_cache=False,)[ + 0 + ].float() + expected_output = base_model( + input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, use_cache=False, + )[0].float() + + atol = 1e-4 + self.assertTrue(torch.allclose(output, expected_output, atol=atol)) + + def test_outout(self): + self._run_test( + INPUT_TEXT="Hello world!", + long_model_name_or_path="/net/nfs2.s2-research/haokunl/exp_files/model_artifacts/t5/longt5-small-4096", + base_model_name_or_path="t5-small", + ) + self._run_test( + INPUT_TEXT="It begins with the Great Hungerer. It ends in utter darkeness.", + long_model_name_or_path="/net/nfs2.s2-research/haokunl/exp_files/model_artifacts/t5/longt5-small-4096", + base_model_name_or_path="t5-small", + ) + + +if __name__ == "__main__": + unittest.main() From 5e63b0563ed0b3a590665d2ab9315bf218e7a7d1 Mon Sep 17 00:00:00 2001 From: Haokun Liu Date: Wed, 24 Mar 2021 10:29:47 -0400 Subject: [PATCH 05/10] tidy imports in convert_t5_to_longformer_encoderdecoder --- scripts/convert_t5_to_longformerencoderdecoder.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/scripts/convert_t5_to_longformerencoderdecoder.py b/scripts/convert_t5_to_longformerencoderdecoder.py index ac7e414..9680c7e 100644 --- a/scripts/convert_t5_to_longformerencoderdecoder.py +++ b/scripts/convert_t5_to_longformerencoderdecoder.py @@ -3,14 +3,12 @@ import os import copy -from transformers import T5Tokenizer - -from transformers import T5ForConditionalGeneration +from transformers import T5Tokenizer, T5ForConditionalGeneration from longformer.longformer_encoder_decoder import ( LongformerSelfAttentionForT5, LongformerT5Config, + LongformerT5ForConditionalGeneration, ) -from longformer.longformer_encoder_decoder import LongformerT5ForConditionalGeneration logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) From 4a4997e1088b546ada72cb0c09367e14de201a99 Mon Sep 17 00:00:00 2001 From: Haokun Liu Date: Wed, 24 Mar 2021 10:39:35 -0400 Subject: [PATCH 06/10] add some prints for easy inspection of model architecture --- scripts/convert_t5_to_longformerencoderdecoder.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/scripts/convert_t5_to_longformerencoderdecoder.py b/scripts/convert_t5_to_longformerencoderdecoder.py index 9680c7e..530b44c 100644 --- a/scripts/convert_t5_to_longformerencoderdecoder.py +++ b/scripts/convert_t5_to_longformerencoderdecoder.py @@ -16,8 +16,10 @@ def create_long_model(save_model_to, base_model, attention_window, max_pos): # load base model & tokenizer - model = T5ForConditionalGeneration.from_pretrained(base_model) tokenizer = T5Tokenizer.from_pretrained(base_model, model_max_length=max_pos) + model = T5ForConditionalGeneration.from_pretrained(base_model) + print("Base model architecture") + print(model) # setup config config = LongformerT5Config.from_pretrained(base_model) @@ -98,6 +100,8 @@ def main(): model = LongformerT5ForConditionalGeneration.from_pretrained(args.save_model_to) model.eval() model.config.gradient_checkpointing = True + print("Converted model architecture") + print(model) TXT = "A rose is a rose is a" data = tokenizer([TXT], return_tensors="pt", padding="max_length", max_length=2048) From e4f8e49894dc1f86b968f7925fa3fc1d3349c591 Mon Sep 17 00:00:00 2001 From: Haokun Liu Date: Wed, 24 Mar 2021 10:52:58 -0400 Subject: [PATCH 07/10] change line length --- longformer/longformer.py | 138 ++++++++++----------------------------- 1 file changed, 33 insertions(+), 105 deletions(-) diff --git a/longformer/longformer.py b/longformer/longformer.py index bf16c9d..5711416 100644 --- a/longformer/longformer.py +++ b/longformer/longformer.py @@ -93,9 +93,7 @@ def __init__(self, config, layer_id, bias=True, attention_dim_scale=True): if hasattr(config, "relative_attention_num_buckets") and layer_id == 0: self.has_relative_attention_bias = True self.relative_attention_num_buckets = config.relative_attention_num_buckets - self.relative_attention_bias = nn.Embedding( - self.relative_attention_num_buckets, self.num_heads - ) + self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.num_heads) else: self.has_relative_attention_bias = False @@ -123,12 +121,8 @@ def forward( 0: local attention +ve: global attention """ - assert ( - encoder_hidden_states is None - ), "`encoder_hidden_states` is not supported and should be None" - assert ( - encoder_attention_mask is None - ), "`encoder_attention_mask` is not supported and shiould be None" + assert encoder_hidden_states is None, "`encoder_hidden_states` is not supported and should be None" + assert encoder_attention_mask is None, "`encoder_attention_mask` is not supported and shiould be None" if attention_mask is not None: attention_mask = attention_mask.squeeze(dim=2).squeeze(dim=1) @@ -150,9 +144,7 @@ def forward( 0, max_num_extra_indices_per_batch, device=num_extra_indices_per_batch.device ) # mask indicating which values are actually going to be padding - selection_padding_mask = zero_to_max_range < num_extra_indices_per_batch.unsqueeze( - dim=-1 - ) + selection_padding_mask = zero_to_max_range < num_extra_indices_per_batch.unsqueeze(dim=-1) # 2) location of the non-padding values in the selected global attention selection_padding_mask_nonzeros = selection_padding_mask.nonzero(as_tuple=True) # 3) location of the padding values in the selected global attention @@ -177,52 +169,36 @@ def forward( if self.attention_mode == "tvm": q = q.float().contiguous() k = k.float().contiguous() - attn_weights = diagonaled_mm_tvm( - q, k, self.attention_window, self.attention_dilation, False, 0, False - ) + attn_weights = diagonaled_mm_tvm(q, k, self.attention_window, self.attention_dilation, False, 0, False) elif self.attention_mode == "sliding_chunks": attn_weights = sliding_chunks_matmul_qk(q, k, self.attention_window, padding_value=0) elif self.attention_mode == "sliding_chunks_no_overlap": - attn_weights = sliding_chunks_no_overlap_matmul_qk( - q, k, self.attention_window, padding_value=0 - ) + attn_weights = sliding_chunks_no_overlap_matmul_qk(q, k, self.attention_window, padding_value=0) else: raise False mask_invalid_locations(attn_weights, self.attention_window, self.attention_dilation, False) if remove_from_windowed_attention_mask is not None: # This implementation is fast and takes very little memory because num_heads x hidden_size = 1 # from (bsz x seq_len) to (bsz x seq_len x num_heads x hidden_size) - remove_from_windowed_attention_mask = remove_from_windowed_attention_mask.unsqueeze( + remove_from_windowed_attention_mask = remove_from_windowed_attention_mask.unsqueeze(dim=-1).unsqueeze( dim=-1 - ).unsqueeze(dim=-1) + ) # cast to float/half then replace 1's with -inf float_mask = remove_from_windowed_attention_mask.type_as(q).masked_fill( remove_from_windowed_attention_mask, -10000.0 ) - repeat_size = ( - 1 if isinstance(self.attention_dilation, int) else len(self.attention_dilation) - ) + repeat_size = 1 if isinstance(self.attention_dilation, int) else len(self.attention_dilation) float_mask = float_mask.repeat(1, 1, repeat_size, 1) ones = float_mask.new_ones(size=float_mask.size()) # tensor of ones # diagonal mask with zeros everywhere and -inf inplace of padding if self.attention_mode == "tvm": d_mask = diagonaled_mm_tvm( - ones, - float_mask, - self.attention_window, - self.attention_dilation, - False, - 0, - False, + ones, float_mask, self.attention_window, self.attention_dilation, False, 0, False, ) elif self.attention_mode == "sliding_chunks": - d_mask = sliding_chunks_matmul_qk( - ones, float_mask, self.attention_window, padding_value=0 - ) + d_mask = sliding_chunks_matmul_qk(ones, float_mask, self.attention_window, padding_value=0) elif self.attention_mode == "sliding_chunks_no_overlap": - d_mask = sliding_chunks_no_overlap_matmul_qk( - ones, float_mask, self.attention_window, padding_value=0 - ) + d_mask = sliding_chunks_no_overlap_matmul_qk(ones, float_mask, self.attention_window, padding_value=0) attn_weights += d_mask assert list(attn_weights.size())[:3] == [bsz, seq_len, self.num_heads] @@ -233,32 +209,22 @@ def forward( # the extra attention if extra_attention_mask is not None: - selected_k = k.new_zeros( - bsz, max_num_extra_indices_per_batch, self.num_heads, self.head_dim - ) + selected_k = k.new_zeros(bsz, max_num_extra_indices_per_batch, self.num_heads, self.head_dim) selected_k[selection_padding_mask_nonzeros] = k[extra_attention_mask_nonzeros] # (bsz, seq_len, num_heads, max_num_extra_indices_per_batch) selected_attn_weights = torch.einsum("blhd,bshd->blhs", (q, selected_k)) - selected_attn_weights[ - selection_padding_mask_zeros[0], :, :, selection_padding_mask_zeros[1] - ] = -10000 + selected_attn_weights[selection_padding_mask_zeros[0], :, :, selection_padding_mask_zeros[1]] = -10000 # concat to attn_weights # (bsz, seq_len, num_heads, max_num_extra_indices_per_batch + 2*window+1) attn_weights = torch.cat((selected_attn_weights, attn_weights), dim=-1) if position_bias is None and self.has_relative_attention_bias: window_relative_position = torch.arange( - -self.attention_window, - self.attention_window + 1, - 1, - dtype=torch.long, - device=attn_weights.device, + -self.attention_window, self.attention_window + 1, 1, dtype=torch.long, device=attn_weights.device, ) window_position_bias = ( self.relative_attention_bias( - relative_position_bucket( - window_relative_position, num_buckets=self.relative_attention_num_buckets - ) + relative_position_bucket(window_relative_position, num_buckets=self.relative_attention_num_buckets) ) .permute(1, 0)[None, None, :, :] .repeat(bsz, seq_len, 1, 1) @@ -270,9 +236,7 @@ def forward( selected_global_memory_position = extra_attention_mask_nonzeros[1][ :, None ] # (sum num_extra_indices_per_batch, 1) - selected_global_query_position = torch.arange( - seq_len, dtype=torch.long, device=attn_weights.device - )[ + selected_global_query_position = torch.arange(seq_len, dtype=torch.long, device=attn_weights.device)[ None, : ] # (1, seq_len) selected_global_relative_position = ( @@ -286,13 +250,7 @@ def forward( ] = selected_global_position_bias # (bsz, max_num_extra_indices_per_batch, seq_len) global_position_bias = perm_global_position_bias.permute(0, 2, 3, 1) # (bsz, seq_len, num_heads, max_num_extra_indices_per_batch) - position_bias = torch.cat( - ( - global_position_bias, - window_position_bias, - ), - dim=-1, - ) + position_bias = torch.cat((global_position_bias, window_position_bias,), dim=-1,) # (bsz, seq_len, num_heads, max_num_extra_indices_per_batch + 2*window+1) else: position_bias = window_position_bias @@ -300,43 +258,32 @@ def forward( if position_bias is not None: attn_weights += position_bias - attn_weights_float = F.softmax( - attn_weights, dim=-1, dtype=torch.float32 - ) # use fp32 for numerical stability + attn_weights_float = F.softmax(attn_weights, dim=-1, dtype=torch.float32) # use fp32 for numerical stability if key_padding_mask is not None: # softmax sometimes inserts NaN if all positions are masked, replace them with 0 attn_weights_float = torch.masked_fill( attn_weights_float, key_padding_mask.unsqueeze(-1).unsqueeze(-1), 0.0 ) attn_weights = attn_weights_float.type_as(attn_weights) - attn_probs = F.dropout( - attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training - ) + attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training) v = v.view(seq_len, bsz, self.num_heads, self.head_dim).transpose(0, 1) attn = 0 if extra_attention_mask is not None: selected_attn_probs = attn_probs.narrow(-1, 0, max_num_extra_indices_per_batch) - selected_v = v.new_zeros( - bsz, max_num_extra_indices_per_batch, self.num_heads, self.head_dim - ) + selected_v = v.new_zeros(bsz, max_num_extra_indices_per_batch, self.num_heads, self.head_dim) selected_v[selection_padding_mask_nonzeros] = v[extra_attention_mask_nonzeros] # use `matmul` because `einsum` crashes sometimes with fp16 # attn = torch.einsum('blhs,bshd->blhd', (selected_attn_probs, selected_v)) attn = torch.matmul( - selected_attn_probs.transpose(1, 2), - selected_v.transpose(1, 2).type_as(selected_attn_probs), + selected_attn_probs.transpose(1, 2), selected_v.transpose(1, 2).type_as(selected_attn_probs), ).transpose(1, 2) attn_probs = attn_probs.narrow( - -1, - max_num_extra_indices_per_batch, - attn_probs.size(-1) - max_num_extra_indices_per_batch, + -1, max_num_extra_indices_per_batch, attn_probs.size(-1) - max_num_extra_indices_per_batch, ).contiguous() if self.attention_mode == "tvm": v = v.float().contiguous() - attn += diagonaled_mm_tvm( - attn_probs, v, self.attention_window, self.attention_dilation, True, 0, False - ) + attn += diagonaled_mm_tvm(attn_probs, v, self.attention_window, self.attention_dilation, True, 0, False) elif self.attention_mode == "sliding_chunks": attn += sliding_chunks_matmul_pv(attn_probs, v, self.attention_window) elif self.attention_mode == "sliding_chunks_no_overlap": @@ -351,9 +298,7 @@ def forward( # For this case, we'll just recompute the attention for these indices # and overwrite the attn tensor. TODO: remove the redundant computation if extra_attention_mask is not None: - selected_hidden_states = hidden_states.new_zeros( - max_num_extra_indices_per_batch, bsz, embed_dim - ) + selected_hidden_states = hidden_states.new_zeros(max_num_extra_indices_per_batch, bsz, embed_dim) selected_hidden_states[selection_padding_mask_nonzeros[::-1]] = hidden_states[ extra_attention_mask_nonzeros[::-1] ] @@ -382,26 +327,15 @@ def forward( seq_len, ] - attn_weights = attn_weights.view( - bsz, self.num_heads, max_num_extra_indices_per_batch, seq_len - ) - attn_weights[ - selection_padding_mask_zeros[0], :, selection_padding_mask_zeros[1], : - ] = -10000.0 + attn_weights = attn_weights.view(bsz, self.num_heads, max_num_extra_indices_per_batch, seq_len) + attn_weights[selection_padding_mask_zeros[0], :, selection_padding_mask_zeros[1], :] = -10000.0 if key_padding_mask is not None: - attn_weights = attn_weights.masked_fill( - key_padding_mask.unsqueeze(1).unsqueeze(2), - -10000.0, - ) - attn_weights = attn_weights.view( - bsz * self.num_heads, max_num_extra_indices_per_batch, seq_len - ) + attn_weights = attn_weights.masked_fill(key_padding_mask.unsqueeze(1).unsqueeze(2), -10000.0,) + attn_weights = attn_weights.view(bsz * self.num_heads, max_num_extra_indices_per_batch, seq_len) attn_weights_float = F.softmax( attn_weights, dim=-1, dtype=torch.float32 ) # use fp32 for numerical stability - attn_probs = F.dropout( - attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training - ) + attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training) selected_attn = torch.bmm(attn_probs, v) assert list(selected_attn.size()) == [ bsz * self.num_heads, @@ -409,9 +343,7 @@ def forward( self.head_dim, ] - selected_attn_4d = selected_attn.view( - bsz, self.num_heads, max_num_extra_indices_per_batch, self.head_dim - ) + selected_attn_4d = selected_attn.view(bsz, self.num_heads, max_num_extra_indices_per_batch, self.head_dim) nonzero_selected_attn = selected_attn_4d[ selection_padding_mask_nonzeros[0], :, selection_padding_mask_nonzeros[1] ] @@ -428,9 +360,7 @@ def forward( # It doesn't not return local attention # In case of variable number of global attantion in the rows of a batch, # attn_weights are padded with -10000.0 attention scores - attn_weights = attn_weights.view( - bsz, self.num_heads, max_num_extra_indices_per_batch, seq_len - ) + attn_weights = attn_weights.view(bsz, self.num_heads, max_num_extra_indices_per_batch, seq_len) else: # without global attention, return local attention probabilities # batch_size x num_heads x sequence_length x window_size @@ -446,9 +376,7 @@ def forward( return outputs -def relative_position_bucket( - relative_position, bidirectional=True, num_buckets=32, max_distance=128 -): +def relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): """ Imported from Huggingface transformers https://github.com/huggingface/transformers/blob/a0a027c2ed53b324cf4d0179ceec88d4ff414d47/src/transformers/models/t5/modeling_t5.py#L344 From 4c8864e3b65005f34cffb9c682a265e7ac057d1b Mon Sep 17 00:00:00 2001 From: Haokun Liu Date: Wed, 24 Mar 2021 10:56:26 -0400 Subject: [PATCH 08/10] clean up --- longformer/longformer_encoder_decoder.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/longformer/longformer_encoder_decoder.py b/longformer/longformer_encoder_decoder.py index 272aa6e..b36e9af 100644 --- a/longformer/longformer_encoder_decoder.py +++ b/longformer/longformer_encoder_decoder.py @@ -150,11 +150,9 @@ def forward( use_cache=False, output_attentions=False, ): - outputs = self.longformer_self_attn( input, attention_mask=mask, position_bias=position_bias, output_attentions=output_attentions, ) - outputs = (self.output(outputs[0]), None) + outputs[1:] return outputs From 7b635a2c859fa2cc41c9191e8a682001f7dc7fcd Mon Sep 17 00:00:00 2001 From: Haokun Liu Date: Wed, 24 Mar 2021 12:05:43 -0400 Subject: [PATCH 09/10] add missing annotation --- longformer/longformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/longformer/longformer.py b/longformer/longformer.py index 5711416..fc58cbd 100644 --- a/longformer/longformer.py +++ b/longformer/longformer.py @@ -221,7 +221,7 @@ def forward( if position_bias is None and self.has_relative_attention_bias: window_relative_position = torch.arange( -self.attention_window, self.attention_window + 1, 1, dtype=torch.long, device=attn_weights.device, - ) + ) # (2*window+1,) window_position_bias = ( self.relative_attention_bias( relative_position_bucket(window_relative_position, num_buckets=self.relative_attention_num_buckets) From ff939af96808cb37f731fe5bd45dca96e7f77cdc Mon Sep 17 00:00:00 2001 From: Haokun Liu Date: Mon, 24 May 2021 21:03:32 -0400 Subject: [PATCH 10/10] update --- longformer/__init__.py | 4 +- longformer/longformer.py | 62 +++-- longformer/longformer_encoder_decoder.py | 39 ++- scripts/cheatsheet.txt | 14 + .../convert_t5_to_longformerencoderdecoder.py | 42 ++- scripts/summarization.py | 257 ++++++++++-------- scripts/temp.py | 53 ++++ tests/test_t5_short_sequence.py | 11 +- 8 files changed, 349 insertions(+), 133 deletions(-) create mode 100644 scripts/temp.py diff --git a/longformer/__init__.py b/longformer/__init__.py index d3e343c..47d5319 100644 --- a/longformer/__init__.py +++ b/longformer/__init__.py @@ -1,3 +1,5 @@ from longformer.longformer import Longformer, LongformerForMaskedLM, LongformerConfig from longformer.longformer_encoder_decoder import LongformerEncoderDecoderConfig -from longformer.longformer_encoder_decoder import LongformerEncoderDecoderForConditionalGeneration \ No newline at end of file +from longformer.longformer_encoder_decoder import LongformerEncoderDecoderForConditionalGeneration +from longformer.longformer_encoder_decoder import LongformerT5ForConditionalGeneration +from longformer.longformer_encoder_decoder import LongformerT5Config diff --git a/longformer/longformer.py b/longformer/longformer.py index fc58cbd..09ca674 100644 --- a/longformer/longformer.py +++ b/longformer/longformer.py @@ -92,7 +92,7 @@ def __init__(self, config, layer_id, bias=True, attention_dim_scale=True): if hasattr(config, "relative_attention_num_buckets") and layer_id == 0: self.has_relative_attention_bias = True - self.relative_attention_num_buckets = config.relative_attention_num_buckets + self.relative_attention_num_buckets = config.long_relative_attention_num_buckets self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.num_heads) else: self.has_relative_attention_bias = False @@ -224,12 +224,19 @@ def forward( ) # (2*window+1,) window_position_bias = ( self.relative_attention_bias( - relative_position_bucket(window_relative_position, num_buckets=self.relative_attention_num_buckets) + relative_position_bucket( + window_relative_position, + num_buckets=self.relative_attention_num_buckets, + max_distance=self.attention_window, + ) ) .permute(1, 0)[None, None, :, :] .repeat(bsz, seq_len, 1, 1) ) # (bsz, seq_len, num_heads, 2*window+1) - perm_global_position_bias = attn_weights.new_zeros( + perm_global_position_bias_from_g = attn_weights.new_zeros( + bsz, max_num_extra_indices_per_batch, seq_len, self.num_heads + ) # (bsz, max_num_extra_indices_per_batch, seq_len, num_heads) + perm_global_position_bias_to_g = attn_weights.new_zeros( bsz, max_num_extra_indices_per_batch, seq_len, self.num_heads ) # (bsz, max_num_extra_indices_per_batch, seq_len, num_heads) if extra_attention_mask is not None: @@ -242,21 +249,43 @@ def forward( selected_global_relative_position = ( selected_global_memory_position - selected_global_query_position ) # (sum num_extra_indices_per_batch, seq_len) - selected_global_position_bias = self.relative_attention_bias( - relative_position_bucket(selected_global_relative_position) + selected_global_position_bias_from_g = self.relative_attention_bias( + relative_position_bucket( + selected_global_relative_position, + num_buckets=self.relative_attention_num_buckets, + max_distance=self.attention_window, + ) + ) # (sum num_extra_indices_per_batch, seq_len, num_heads) + perm_global_position_bias_from_g[ + selection_padding_mask_nonzeros + ] = selected_global_position_bias_from_g # (bsz, max_num_extra_indices_per_batch, seq_len, num_heads) + selected_global_position_bias_to_g = self.relative_attention_bias( + relative_position_bucket( + -selected_global_relative_position, + num_buckets=self.relative_attention_num_buckets, + max_distance=self.attention_window, + ) ) # (sum num_extra_indices_per_batch, seq_len, num_heads) - perm_global_position_bias[ + perm_global_position_bias_to_g[ selection_padding_mask_nonzeros - ] = selected_global_position_bias # (bsz, max_num_extra_indices_per_batch, seq_len) - global_position_bias = perm_global_position_bias.permute(0, 2, 3, 1) + ] = selected_global_position_bias_to_g # (bsz, max_num_extra_indices_per_batch, seq_len, num_heads) + + global_position_bias_from_g = perm_global_position_bias_from_g.permute(0, 2, 3, 1) # (bsz, seq_len, num_heads, max_num_extra_indices_per_batch) - position_bias = torch.cat((global_position_bias, window_position_bias,), dim=-1,) - # (bsz, seq_len, num_heads, max_num_extra_indices_per_batch + 2*window+1) + global_position_bias_to_g = perm_global_position_bias_to_g.permute(0, 3, 1, 2) + # (bsz, num_heads, max_num_extra_indices_per_batch, seq_len) + + position_bias = { + "window": torch.cat((global_position_bias_from_g, window_position_bias,), dim=-1,), + "global": global_position_bias_to_g, + } + # window: (bsz, seq_len, num_heads, max_num_extra_indices_per_batch + 2*window+1), + # global: (bsz, num_heads, max_num_extra_indices_per_batch, seq_len) else: - position_bias = window_position_bias + position_bias = {"window": window_position_bias} if position_bias is not None: - attn_weights += position_bias + attn_weights += position_bias["window"] attn_weights_float = F.softmax(attn_weights, dim=-1, dtype=torch.float32) # use fp32 for numerical stability if key_padding_mask is not None: @@ -326,8 +355,9 @@ def forward( max_num_extra_indices_per_batch, seq_len, ] - attn_weights = attn_weights.view(bsz, self.num_heads, max_num_extra_indices_per_batch, seq_len) + if position_bias is not None: + attn_weights += position_bias["global"] attn_weights[selection_padding_mask_zeros[0], :, selection_padding_mask_zeros[1], :] = -10000.0 if key_padding_mask is not None: attn_weights = attn_weights.masked_fill(key_padding_mask.unsqueeze(1).unsqueeze(2), -10000.0,) @@ -376,9 +406,9 @@ def forward( return outputs -def relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): +def relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_exact=16, max_distance=128): """ - Imported from Huggingface transformers + Imported from Huggingface transformers, with some modification https://github.com/huggingface/transformers/blob/a0a027c2ed53b324cf4d0179ceec88d4ff414d47/src/transformers/models/t5/modeling_t5.py#L344 Original description below: Adapted from Mesh Tensorflow: @@ -400,6 +430,7 @@ def relative_position_bucket(relative_position, bidirectional=True, num_buckets= relative_buckets = 0 if bidirectional: num_buckets //= 2 + max_exact //= 2 relative_buckets += (relative_position > 0).to(torch.long) * num_buckets relative_position = torch.abs(relative_position) else: @@ -407,7 +438,6 @@ def relative_position_bucket(relative_position, bidirectional=True, num_buckets= # now relative_position is in the range [0, inf) # half of the buckets are for exact increments in positions - max_exact = num_buckets // 2 is_small = relative_position < max_exact # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance diff --git a/longformer/longformer_encoder_decoder.py b/longformer/longformer_encoder_decoder.py index b36e9af..cdb68c9 100644 --- a/longformer/longformer_encoder_decoder.py +++ b/longformer/longformer_encoder_decoder.py @@ -1,8 +1,9 @@ +import copy from typing import List, Optional, Tuple, Dict from torch import nn, Tensor from longformer.longformer import LongformerSelfAttention from transformers.modeling_bart import BartConfig, BartForConditionalGeneration -from transformers.modeling_t5 import T5Config, T5ForConditionalGeneration +from transformers.modeling_t5 import T5Config, T5ForConditionalGeneration, T5Stack class LongformerEncoderDecoderForConditionalGeneration(BartForConditionalGeneration): @@ -92,6 +93,40 @@ def __init__(self, config): for i, layer in enumerate(self.encoder.block): layer.layer[0].SelfAttention = LongformerSelfAttentionForT5(config, layer_id=i) + class LongformerT5DecoderStack(T5Stack): + def forward( + self, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + inputs_embeds=None, + head_mask=None, + past_key_value_states=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + if encoder_attention_mask is not None: + encoder_attention_mask = encoder_attention_mask.clamp(max=1) + return T5Stack.forward( + self, + input_ids=input_ids, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + past_key_value_states=past_key_value_states, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + self.decoder.__class__ = LongformerT5DecoderStack + class LongformerT5Config(T5Config): def __init__( @@ -101,6 +136,7 @@ def __init__( autoregressive: bool = False, attention_mode: str = "sliding_chunks", gradient_checkpointing: bool = False, + long_relative_attention_num_buckets: int = 40, **kwargs ): """ @@ -121,6 +157,7 @@ def __init__( self.autoregressive = autoregressive self.attention_mode = attention_mode self.gradient_checkpointing = gradient_checkpointing + self.long_relative_attention_num_buckets = long_relative_attention_num_buckets assert self.attention_mode in ["tvm", "sliding_chunks", "n2"] diff --git a/scripts/cheatsheet.txt b/scripts/cheatsheet.txt index 34ba1ec..d6be311 100644 --- a/scripts/cheatsheet.txt +++ b/scripts/cheatsheet.txt @@ -88,3 +88,17 @@ source /anaconda3/bin/activate torch-xla-nightly # Resume training python scripts/summarization.py --num_workers 12 --save_prefix eval_long16k_nooverlap_large --model_path bart-large-long-16384/ --max_input_len 16368 --batch_size 2 --grad_accum 4 --grad_ckpt --attention_mode sliding_chunks_no_overlap --attention_window 340 --val_every 0.333333333 --debug --resume summarization/run_long16k_nooverlap_large/_ckpt_epoch_3_v1.ckpt --val_percent_check 1.0 --disable_checkpointing + +# Convert model +python scripts/convert_t5_to_longformerencoderdecoder.py --save_model_to /net/nfs2.s2-research/haokunl/exp_files/model_artifacts/t5/longt5-base-16384 --base_model t5-base +python scripts/convert_t5_to_longformerencoderdecoder.py --save_model_to /net/nfs2.s2-research/haokunl/exp_files/model_artifacts/t5/longt5-small-16384 --base_model t5-small + + +srun --gpus=4 --nodes=1 python scripts/summarization.py --num_workers 12 \ + --save_dir /net/nfs2.s2-research/haokunl/exp_files/summarization \ + --save_prefix longt5-base-16k \ + --model_path /net/nfs2.s2-research/haokunl/exp_files/model_artifacts/t5/longt5-base-16384 \ + --adafactor --tokenizer t5-base \ + --max_input_len 16384 --batch_size 2 --grad_accum 4 --grad_ckpt \ + --attention_mode sliding_chunks --attention_window 512 \ + --val_every 0.333333333 --debug --val_percent_check 1.0 --disable_checkpointing \ No newline at end of file diff --git a/scripts/convert_t5_to_longformerencoderdecoder.py b/scripts/convert_t5_to_longformerencoderdecoder.py index 530b44c..9f46366 100644 --- a/scripts/convert_t5_to_longformerencoderdecoder.py +++ b/scripts/convert_t5_to_longformerencoderdecoder.py @@ -2,7 +2,7 @@ import logging import os import copy - +import torch from transformers import T5Tokenizer, T5ForConditionalGeneration from longformer.longformer_encoder_decoder import ( LongformerSelfAttentionForT5, @@ -14,7 +14,7 @@ logging.basicConfig(level=logging.INFO) -def create_long_model(save_model_to, base_model, attention_window, max_pos): +def create_long_model(save_model_to, base_model, attention_window, max_pos, relative_attention_num_buckets): # load base model & tokenizer tokenizer = T5Tokenizer.from_pretrained(base_model, model_max_length=max_pos) model = T5ForConditionalGeneration.from_pretrained(base_model) @@ -30,10 +30,17 @@ def create_long_model(save_model_to, base_model, attention_window, max_pos): config.attention_probs_dropout_prob = config.dropout_rate config.attention_window = [attention_window] * config.num_hidden_layers config.attention_dilation = [1] * config.num_hidden_layers + config.long_relative_attention_num_buckets = relative_attention_num_buckets # modify config in model # HF T5 includes multiple pointers to the config object - model.config = model.encoder.config = model.decoder.config = config + model.config = copy.deepcopy(config) + model.encoder.config = copy.deepcopy(config) + model.encoder.config.use_cache = False + model.encoder.config.is_encoder_decoder = False + model.decoder.config = copy.deepcopy(config) + model.decoder.config.is_decoder = True + model.decoder.config.is_encoder_decoder = False # modify tokenizer tokenizer.model_max_length = max_pos @@ -56,15 +63,29 @@ def create_long_model(save_model_to, base_model, attention_window, max_pos): longformer_self_attn_for_t5.output = self_attn.o if i == 0: - longformer_self_attn_for_t5.longformer_self_attn.relative_attention_bias = self_attn.relative_attention_bias + half_num_buckets = config.long_relative_attention_num_buckets // 2 + half_t5_buckets = 16 + with torch.no_grad(): + longformer_self_attn_for_t5.longformer_self_attn.relative_attention_bias.weight[ + :half_num_buckets + ] = self_attn.relative_attention_bias.weight[half_t5_buckets - 1] + longformer_self_attn_for_t5.longformer_self_attn.relative_attention_bias.weight[ + half_num_buckets: + ] = self_attn.relative_attention_bias.weight[-1] + longformer_self_attn_for_t5.longformer_self_attn.relative_attention_bias.weight[ + :half_t5_buckets + ] = self_attn.relative_attention_bias.weight[:half_t5_buckets] + longformer_self_attn_for_t5.longformer_self_attn.relative_attention_bias.weight[ + half_num_buckets + 1 : half_num_buckets + half_t5_buckets + ] = self_attn.relative_attention_bias.weight[half_t5_buckets + 1 :] layer.layer[0].SelfAttention = longformer_self_attn_for_t5 # save modified model logger.info(f"saving model to {save_model_to}") model.save_pretrained(save_model_to) - tokenizer.save_pretrained(save_model_to) config.save_pretrained(save_model_to) + tokenizer.save_pretrained(save_model_to) return @@ -83,28 +104,35 @@ def main(): help="attention window size for longformer self attention (one sided)", ) parser.add_argument("--max_pos", type=int, default=4096 * 4, help="maximum encoder positions") + parser.add_argument("--num_pos_buckets", type=int, default=40, help="number of relative position buckets") args = parser.parse_args() if not os.path.exists(args.save_model_to): - os.mkdir(args.save_model_to) + os.makedirs(args.save_model_to) create_long_model( save_model_to=args.save_model_to, base_model=args.base_model, attention_window=args.attention_window, max_pos=args.max_pos, + relative_attention_num_buckets=args.num_pos_buckets, ) tokenizer = T5Tokenizer.from_pretrained(args.save_model_to) + # tokenizer = T5Tokenizer.from_pretrained(args.base_model) model = LongformerT5ForConditionalGeneration.from_pretrained(args.save_model_to) + # model = T5ForConditionalGeneration.from_pretrained(args.base_model) + model.eval() model.config.gradient_checkpointing = True + model.encoder.config.gradient_checkpointing = True + model.decoder.config.gradient_checkpointing = True print("Converted model architecture") print(model) TXT = "A rose is a rose is a" - data = tokenizer([TXT], return_tensors="pt", padding="max_length", max_length=2048) + data = tokenizer([TXT], return_tensors="pt", padding="max_length", max_length=args.max_pos) input_ids = data["input_ids"] attention_mask = data["attention_mask"] attention_mask[0, 0:4:2] = 2 diff --git a/scripts/summarization.py b/scripts/summarization.py index 9bf6161..7de8eb6 100644 --- a/scripts/summarization.py +++ b/scripts/summarization.py @@ -17,6 +17,7 @@ from longformer import LongformerEncoderDecoderForConditionalGeneration, LongformerEncoderDecoderConfig +from longformer import LongformerT5ForConditionalGeneration, LongformerT5Config from longformer.sliding_chunks import pad_to_window_size @@ -44,7 +45,8 @@ def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100): class SummarizationDataset(Dataset): - def __init__(self, hf_dataset, tokenizer, max_input_len, max_output_len): + def __init__(self, hf_dataset, tokenizer, max_input_len, max_output_len, task_prefix=""): + self.task_prefix = task_prefix self.hf_dataset = hf_dataset self.tokenizer = tokenizer self.max_input_len = max_input_len @@ -55,8 +57,10 @@ def __len__(self): def __getitem__(self, idx): entry = self.hf_dataset[idx] - input_ids = self.tokenizer.encode(entry['article'], truncation=True, max_length=self.max_input_len) - output_ids = self.tokenizer.encode(entry['abstract'], truncation=True, max_length=self.max_output_len) + input_ids = self.tokenizer.encode( + self.task_prefix + entry["article"], truncation=True, max_length=self.max_input_len + ) + output_ids = self.tokenizer.encode(entry["abstract"], truncation=True, max_length=self.max_output_len) if self.tokenizer.bos_token_id is None: # pegasus output_ids = [self.tokenizer.pad_token_id] + output_ids return torch.tensor(input_ids), torch.tensor(output_ids) @@ -78,54 +82,64 @@ def collate_fn(batch): class Summarizer(pl.LightningModule): - def __init__(self, params): super().__init__() self.args = params self.hparams = params self.tokenizer = AutoTokenizer.from_pretrained(self.args.tokenizer, use_fast=True) - if 'long' in self.args.model_path: - config = LongformerEncoderDecoderConfig.from_pretrained(self.args.model_path) - config.attention_dropout = self.args.attention_dropout + if "long" in self.args.model_path: + if "bart" in self.args.model_path: + model_class, config_class = ( + LongformerEncoderDecoderForConditionalGeneration, + LongformerEncoderDecoderConfig, + ) + elif "t5" in self.args.model_path: + model_class, config_class = LongformerT5ForConditionalGeneration, LongformerT5Config + config = config_class.from_pretrained(self.args.model_path) + # config.attention_dropout = self.args.attention_dropout config.gradient_checkpointing = self.args.grad_ckpt config.attention_mode = self.args.attention_mode - config.attention_window = [self.args.attention_window] * config.encoder_layers - self.model = LongformerEncoderDecoderForConditionalGeneration.from_pretrained( - self.args.model_path, config=config) + # config.attention_window = [self.args.attention_window] * config.encoder_layers + self.model = model_class.from_pretrained(self.args.model_path, config=config) else: config = AutoConfig.from_pretrained(self.args.model_path) config.attention_dropout = self.args.attention_dropout - self.model = AutoModelForSeq2SeqLM.from_pretrained( - self.args.model_path, config=config) + self.model = AutoModelForSeq2SeqLM.from_pretrained(self.args.model_path, config=config) self.train_dataloader_object = self.val_dataloader_object = self.test_dataloader_object = None def _prepare_input(self, input_ids): attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=input_ids.device) attention_mask[input_ids == self.tokenizer.pad_token_id] = 0 - if isinstance(self.model, LongformerEncoderDecoderForConditionalGeneration): - attention_mask[:, 0] = 2 # global attention on one token for all model params to be used, which is important for gradient checkpointing to work - if self.args.attention_mode == 'sliding_chunks': + if isinstance( + self.model, (LongformerEncoderDecoderForConditionalGeneration, LongformerT5ForConditionalGeneration) + ): + attention_mask[ + :, 0 + ] = 2 # global attention on one token for all model params to be used, which is important for gradient checkpointing to work + if self.args.attention_mode == "sliding_chunks": half_padding_mod = self.model.config.attention_window[0] - elif self.args.attention_mode == 'sliding_chunks_no_overlap': + elif self.args.attention_mode == "sliding_chunks_no_overlap": half_padding_mod = self.model.config.attention_window[0] / 2 else: raise NotImplementedError input_ids, attention_mask = pad_to_window_size( # ideally, should be moved inside the LongformerModel - input_ids, attention_mask, half_padding_mod, self.tokenizer.pad_token_id) + input_ids, attention_mask, half_padding_mod, self.tokenizer.pad_token_id + ) return input_ids, attention_mask def forward(self, input_ids, output_ids): input_ids, attention_mask = self._prepare_input(input_ids) decoder_input_ids = output_ids[:, :-1] - decoder_attention_mask = (decoder_input_ids != self.tokenizer.pad_token_id) + decoder_attention_mask = decoder_input_ids != self.tokenizer.pad_token_id labels = output_ids[:, 1:].clone() outputs = self.model( - input_ids, - attention_mask=attention_mask, - decoder_input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - use_cache=False,) + input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + use_cache=False, + ) lm_logits = outputs[0] if self.args.label_smoothing == 0: # Same behavior as modeling_bart.py, besides ignoring pad_token_id @@ -142,12 +156,15 @@ def forward(self, input_ids, output_ids): def training_step(self, batch, batch_nb): output = self.forward(*batch) loss = output[0] - lr = loss.new_zeros(1) + self.trainer.optimizers[0].param_groups[0]['lr'] - tensorboard_logs = {'train_loss': loss, 'lr': lr, - 'input_size': batch[0].numel(), - 'output_size': batch[1].numel(), - 'mem': torch.cuda.memory_allocated(loss.device) / 1024 ** 3 if torch.cuda.is_available() else 0} - return {'loss': loss, 'log': tensorboard_logs} + lr = loss.new_zeros(1) + self.trainer.optimizers[0].param_groups[0]["lr"] + tensorboard_logs = { + "train_loss": loss, + "lr": lr, + "input_size": batch[0].numel(), + "output_size": batch[1].numel(), + "mem": torch.cuda.memory_allocated(loss.device) / 1024 ** 3 if torch.cuda.is_available() else 0, + } + return {"loss": loss, "log": tensorboard_logs} def validation_step(self, batch, batch_nb): for p in self.model.parameters(): @@ -157,35 +174,41 @@ def validation_step(self, batch, batch_nb): vloss = outputs[0] input_ids, output_ids = batch input_ids, attention_mask = self._prepare_input(input_ids) - generated_ids = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, - use_cache=True, max_length=self.args.max_output_len, - num_beams=1) + generated_ids = self.model.generate( + input_ids=input_ids, + attention_mask=attention_mask, + use_cache=True, + max_length=self.args.max_output_len, + num_beams=1, + ) generated_str = self.tokenizer.batch_decode(generated_ids.tolist(), skip_special_tokens=True) gold_str = self.tokenizer.batch_decode(output_ids.tolist(), skip_special_tokens=True) - scorer = rouge_scorer.RougeScorer(rouge_types=['rouge1', 'rouge2', 'rougeL', 'rougeLsum'], use_stemmer=False) + scorer = rouge_scorer.RougeScorer(rouge_types=["rouge1", "rouge2", "rougeL", "rougeLsum"], use_stemmer=False) rouge1 = rouge2 = rougel = rougelsum = 0.0 for ref, pred in zip(gold_str, generated_str): score = scorer.score(ref, pred) - rouge1 += score['rouge1'].fmeasure - rouge2 += score['rouge2'].fmeasure - rougel += score['rougeL'].fmeasure - rougelsum += score['rougeLsum'].fmeasure + rouge1 += score["rouge1"].fmeasure + rouge2 += score["rouge2"].fmeasure + rougel += score["rougeL"].fmeasure + rougelsum += score["rougeLsum"].fmeasure rouge1 /= len(generated_str) rouge2 /= len(generated_str) rougel /= len(generated_str) rougelsum /= len(generated_str) - return {'vloss': vloss, - 'rouge1': vloss.new_zeros(1) + rouge1, - 'rouge2': vloss.new_zeros(1) + rouge2, - 'rougeL': vloss.new_zeros(1) + rougel, - 'rougeLsum': vloss.new_zeros(1) + rougelsum, } + return { + "vloss": vloss, + "rouge1": vloss.new_zeros(1) + rouge1, + "rouge2": vloss.new_zeros(1) + rouge2, + "rougeL": vloss.new_zeros(1) + rougel, + "rougeLsum": vloss.new_zeros(1) + rougelsum, + } def validation_epoch_end(self, outputs): for p in self.model.parameters(): p.requires_grad = True - names = ['vloss', 'rouge1', 'rouge2', 'rougeL', 'rougeLsum'] + names = ["vloss", "rouge1", "rouge2", "rougeL", "rougeLsum"] metrics = [] for name in names: metric = torch.stack([x[name] for x in outputs]).mean() @@ -195,7 +218,7 @@ def validation_epoch_end(self, outputs): metrics.append(metric) logs = dict(zip(*[names, metrics])) print(logs) - return {'avg_val_loss': logs['vloss'], 'log': logs, 'progress_bar': logs} + return {"avg_val_loss": logs["vloss"], "log": logs, "progress_bar": logs} def test_step(self, batch, batch_nb): return self.validation_step(batch, batch_nb) @@ -221,72 +244,95 @@ def configure_optimizers(self): def _get_dataloader(self, current_dataloader, split_name, is_train): if current_dataloader is not None: return current_dataloader - dataset = SummarizationDataset(hf_dataset=self.hf_datasets[split_name], tokenizer=self.tokenizer, - max_input_len=self.args.max_input_len, max_output_len=self.args.max_output_len) - sampler = torch.utils.data.distributed.DistributedSampler(dataset, shuffle=is_train) if self.trainer.use_ddp else None - return DataLoader(dataset, batch_size=self.args.batch_size, shuffle=(sampler is None), - num_workers=self.args.num_workers, sampler=sampler, - collate_fn=SummarizationDataset.collate_fn) + dataset = SummarizationDataset( + task_prefix=("Summarize: " if "t5" in self.args.model_path else ""), + hf_dataset=self.hf_datasets[split_name], + tokenizer=self.tokenizer, + max_input_len=self.args.max_input_len, + max_output_len=self.args.max_output_len, + ) + sampler = ( + torch.utils.data.distributed.DistributedSampler(dataset, shuffle=is_train) if self.trainer.use_ddp else None + ) + return DataLoader( + dataset, + batch_size=self.args.batch_size, + shuffle=(sampler is None), + num_workers=self.args.num_workers, + sampler=sampler, + collate_fn=SummarizationDataset.collate_fn, + ) @pl.data_loader def train_dataloader(self): - self.train_dataloader_object = self._get_dataloader(self.train_dataloader_object, 'train', is_train=True) + self.train_dataloader_object = self._get_dataloader(self.train_dataloader_object, "train", is_train=True) return self.train_dataloader_object @pl.data_loader def val_dataloader(self): - self.val_dataloader_object = self._get_dataloader(self.val_dataloader_object, 'validation', is_train=False) + self.val_dataloader_object = self._get_dataloader(self.val_dataloader_object, "validation", is_train=False) return self.val_dataloader_object @pl.data_loader def test_dataloader(self): - self.test_dataloader_object = self._get_dataloader(self.test_dataloader_object, 'test', is_train=False) + self.test_dataloader_object = self._get_dataloader(self.test_dataloader_object, "test", is_train=False) return self.test_dataloader_object def configure_ddp(self, model, device_ids): - model = LightningDistributedDataParallel( - model, - device_ids=device_ids, - find_unused_parameters=False - ) + model = LightningDistributedDataParallel(model, device_ids=device_ids, find_unused_parameters=False) return model @staticmethod def add_model_specific_args(parser, root_dir): - parser.add_argument("--save_dir", type=str, default='summarization') - parser.add_argument("--save_prefix", type=str, default='test') + parser.add_argument("--save_dir", type=str, default="summarization") + parser.add_argument("--save_prefix", type=str, default="test") parser.add_argument("--batch_size", type=int, default=16, help="Batch size") parser.add_argument("--grad_accum", type=int, default=1, help="number of gradient accumulation steps") - parser.add_argument("--gpus", type=int, default=-1, - help="Number of gpus. 0 for CPU") + parser.add_argument("--gpus", type=int, default=-1, help="Number of gpus. 0 for CPU") parser.add_argument("--warmup", type=int, default=1000, help="Number of warmup steps") parser.add_argument("--lr", type=float, default=0.00003, help="Maximum learning rate") parser.add_argument("--val_every", type=float, default=1.0, help="Number of training steps between validations") - parser.add_argument("--val_percent_check", default=1.00, type=float, help='Percent of validation data used') + parser.add_argument("--val_percent_check", default=1.00, type=float, help="Percent of validation data used") parser.add_argument("--num_workers", type=int, default=0, help="Number of data loader workers") parser.add_argument("--seed", type=int, default=1234, help="Seed") parser.add_argument("--epochs", type=int, default=5, help="Number of epochs") - parser.add_argument("--disable_checkpointing", action='store_true', help="No logging or checkpointing") - parser.add_argument("--max_output_len", type=int, default=256, - help="maximum num of wordpieces/summary. Used for training and testing") - parser.add_argument("--max_input_len", type=int, default=512, - help="maximum num of wordpieces/summary. Used for training and testing") - parser.add_argument("--test", action='store_true', help="Test only, no training") - parser.add_argument("--model_path", type=str, default='facebook/bart-base', - help="Path to the checkpoint directory or model name") - parser.add_argument("--tokenizer", type=str, default='facebook/bart-base') - parser.add_argument("--no_progress_bar", action='store_true', help="no progress bar. Good for printing") - parser.add_argument("--fp32", action='store_true', help="default is fp16. Use --fp32 to switch to fp32") - parser.add_argument("--debug", action='store_true', help="debug run") + parser.add_argument("--disable_checkpointing", action="store_true", help="No logging or checkpointing") + parser.add_argument( + "--max_output_len", + type=int, + default=256, + help="maximum num of wordpieces/summary. Used for training and testing", + ) + parser.add_argument( + "--max_input_len", + type=int, + default=512, + help="maximum num of wordpieces/summary. Used for training and testing", + ) + parser.add_argument("--test", action="store_true", help="Test only, no training") + parser.add_argument( + "--model_path", + type=str, + default="facebook/bart-base", + help="Path to the checkpoint directory or model name", + ) + parser.add_argument("--tokenizer", type=str, default="facebook/bart-base") + parser.add_argument("--no_progress_bar", action="store_true", help="no progress bar. Good for printing") + parser.add_argument("--fp32", action="store_true", help="default is fp16. Use --fp32 to switch to fp32") + parser.add_argument("--debug", action="store_true", help="debug run") parser.add_argument("--resume_ckpt", type=str, help="Path of a checkpoint to resume from") - parser.add_argument("--from_pretrained", type=str, default=None, - help="Path to a checkpoint to load model weights but not training state") - parser.add_argument('--grad_ckpt', action='store_true', help='Enable gradient checkpointing to save memory') + parser.add_argument( + "--from_pretrained", + type=str, + default=None, + help="Path to a checkpoint to load model weights but not training state", + ) + parser.add_argument("--grad_ckpt", action="store_true", help="Enable gradient checkpointing to save memory") parser.add_argument("--attention_dropout", type=float, default=0.1, help="attention dropout") - parser.add_argument("--attention_mode", type=str, default='sliding_chunks', help="Longformer attention mode") + parser.add_argument("--attention_mode", type=str, default="sliding_chunks", help="Longformer attention mode") parser.add_argument("--attention_window", type=int, default=512, help="Attention window") parser.add_argument("--label_smoothing", type=float, default=0.0, required=False) - parser.add_argument("--adafactor", action='store_true', help="Use adafactor optimizer") + parser.add_argument("--adafactor", action="store_true", help="Use adafactor optimizer") return parser @@ -303,45 +349,44 @@ def main(args): else: model = Summarizer(args) - model.hf_datasets = nlp.load_dataset('scientific_papers', 'arxiv') + model.hf_datasets = nlp.load_dataset("scientific_papers", "arxiv") - logger = TestTubeLogger( - save_dir=args.save_dir, - name=args.save_prefix, - version=0 # always use version=0 - ) + logger = TestTubeLogger(save_dir=args.save_dir, name=args.save_prefix, version=0) # always use version=0 checkpoint_callback = ModelCheckpoint( filepath=os.path.join(args.save_dir, args.save_prefix, "checkpoints"), save_top_k=5, verbose=True, - monitor='avg_val_loss', - mode='min', + monitor="avg_val_loss", + mode="min", period=-1, - prefix='' + prefix="", ) print(args) args.dataset_size = 203037 # hardcode dataset size. Needed to compute number of steps for the lr scheduler - trainer = pl.Trainer(gpus=args.gpus, distributed_backend='ddp' if torch.cuda.is_available() else None, - track_grad_norm=-1, - max_epochs=args.epochs if not args.debug else 100, - max_steps=None if not args.debug else 1, - replace_sampler_ddp=False, - accumulate_grad_batches=args.grad_accum, - val_check_interval=args.val_every if not args.debug else 1, - num_sanity_val_steps=2 if not args.debug else 0, - check_val_every_n_epoch=1 if not args.debug else 1, - val_percent_check=args.val_percent_check, - test_percent_check=args.val_percent_check, - logger=logger, - checkpoint_callback=checkpoint_callback if not args.disable_checkpointing else False, - show_progress_bar=not args.no_progress_bar, - use_amp=not args.fp32, amp_level='O2', - resume_from_checkpoint=args.resume_ckpt, - ) + trainer = pl.Trainer( + gpus=args.gpus, + distributed_backend="ddp" if torch.cuda.is_available() else None, + track_grad_norm=-1, + max_epochs=args.epochs if not args.debug else 100, + max_steps=None if not args.debug else 1, + replace_sampler_ddp=False, + accumulate_grad_batches=args.grad_accum, + val_check_interval=args.val_every if not args.debug else 1, + num_sanity_val_steps=2 if not args.debug else 0, + check_val_every_n_epoch=1 if not args.debug else 1, + val_percent_check=args.val_percent_check, + test_percent_check=args.val_percent_check, + logger=logger, + checkpoint_callback=checkpoint_callback if not args.disable_checkpointing else False, + show_progress_bar=not args.no_progress_bar, + use_amp=not args.fp32, + amp_level="O2", + resume_from_checkpoint=args.resume_ckpt, + ) if not args.test: trainer.fit(model) trainer.test(model) diff --git a/scripts/temp.py b/scripts/temp.py new file mode 100644 index 0000000..4c7cd3d --- /dev/null +++ b/scripts/temp.py @@ -0,0 +1,53 @@ +import torch +from transformers import T5Tokenizer, T5ForConditionalGeneration +from longformer.longformer_encoder_decoder import ( + LongformerSelfAttentionForT5, + LongformerT5Config, + LongformerT5ForConditionalGeneration, +) + + +tokenizer = T5Tokenizer.from_pretrained("t5-base") +# model = LongformerT5ForConditionalGeneration.from_pretrained( +# "/net/nfs2.s2-research/haokunl/exp_files/model_artifacts/t5/longt5-base-16384" +# ) +model = T5ForConditionalGeneration.from_pretrained("t5-base") +model.eval() +model.config.gradient_checkpointing = True +model.encoder.config.gradient_checkpointing = True +model.decoder.config.gradient_checkpointing = True + +TXT = """ +Mayor Bill de Blasio will change a rule that has, for months, created a paradox in New York City’s school reopening plan: Classrooms that had been reopened to students often closed again because school buildings had to shut temporarily whenever two unrelated virus cases were detected. +The mayor announced Monday that he would alter the rule, but he did not explain how. He said the will be outlined in the coming days, but did not commit to making changes this week. +The closure rule has been extremely frustrating for many parents, who have said that every day brings uncertainty about whether their children will be able to attend school the following morning. Many schools have closed multiple times and sometimes have been open for just a few days before the next closure. The rule has also been intensely disruptive for educators, who have been forced to toggle between in-person and online learning with only a few hours’ notice. +The controversy over the closure rule has highlighted the enormous difficulties and trade-offs inherent in reopening schools during the pandemic. Mayors and education leaders across the country have scrambled to find ways to return students to classrooms while experimenting with safety protocols in real time. +Closures have accelerated in recent weeks and months, as middle and high school students have returned to their buildings after months of all-remote learning. The vast majority of New York City students — roughly 700,000 out of 1 million — have chosen to learn remotely full time, which means the closure rule did not affect most families. +But the city is giving all families an opportunity to switch from remote learning to classroom instruction for the rest of the school year, so that number may shift. Some students will get full-time instruction, while others will go in a few days a week and learn from home the rest of the time, based on individual school capacity. Families have until the end of the day on Friday to switch. +In recent weeks, some epidemiologists and medical experts have told ProPublica and the education news site Chalkbeat that New York’s two-case rule was arbitrary and had led to unnecessary closures, and called on the mayor to adjust it. +“The way to beat Covid is not by closing schools excessively, but by suppressing transmission both inside and outside of schools,” Dr. Dave A. Chokshi, the city’s health commissioner, said during a news conference on Monday. +The city’s schools have had very low virus transmission in classrooms since they began to reopen last fall. Michael Mulgrew, president of the United Federation of Teachers, has strenuously opposed any changes to the rule for months, arguing that the city’s schools were safe only because of the strict , including the two-case threshold. +“We can’t just say because they’re an inconvenience we don’t want them,” Mr. Mulgrew said of the guidelines during a radio interview last month. +The closure rule was settled last summer during a period of intense turmoil between City Hall and the union, at a moment when it was unclear whether Mr. de Blasio would be able to reopen schools at all. The city and union eventually agreed on a host of safety rules that cleared a path for New York to become the first large school district in America to reopen schools for all grades. +Several of those rules have changed over the last eight months. The mayor said over the summer, when the average citywide test positivity rate was hovering under 1 percent, that the entire school system would shut if the positivity rate hit 3 percent, which it did in November. He closed the school system for several weeks but came under significant pressure from parents and experts to set a different threshold. +When Mr. de Blasio reopened schools for young children and some students with disabilities in December, he said there would no longer be a citywide positivity threshold to shut the school system. +The city is also poised to partially change a rule it set over the summer that mandated six feet of distance between students in classrooms. Last month, the Centers for Disease Control and Prevention said districts should consider to three feet, a standard that Mr. de Blasio said the city would adopt in elementary school classrooms later this month. +That shift rankled the teachers’ union, which has had significant influence over the school reopening process in . Though relations between City Hall and the union have been frosty for months, the mayor has tried to maintain some peace with Mr. Mulgrew. +For example, when the city reopened elementary schools late last year despite rising virus cases across the city, Mr. de Blasio announced increased random testing in school buildings, a consistent union priority that experts have supported. +But the city and union have struggled to find a compromise on the two-case rule. For weeks, Mr. de Blasio said a revision to the rule was imminent, but behind the scenes, negotiations between the two sides were stalling. The city and union still do not have an agreement on what the new closure threshold should be. +While the mayor has the power to unilaterally change the rule, City Hall has tried to avoid alienating the union with just a few months left in the school year. The U.F.T. raised many issues with the city over the reopening plan last summer, but it has been more willing to reopen schools than some other teachers’ unions in big cities, including Chicago and Los Angeles. +The fact that all grades of the school system are open means that the union has less leverage now than at any point in the school reopening process. But the union still has enormous influence over how the next school year will unfold. +Mr. de Blasio has said he expects full-time, in-person instruction come September, though it is likely that there will be a remote option for some families into the fall. That goal will rest in part on the union’s cooperation and support, and teachers will no doubt play a crucial role in reaching out to reluctant families and encouraging them to return to classrooms. +""" +LABELS_1 = " new rules safety measures reducing the distance New York " +LABELS_0 = " full-time in-person instruction closing schools City Hall " +data = tokenizer([TXT], return_tensors="pt", padding="max_length", max_length=4096) +input_ids = data["input_ids"] +attention_mask = data["attention_mask"] +labels_1 = tokenizer(LABELS_1, return_tensors="pt").input_ids +labels_0 = tokenizer(LABELS_0, return_tensors="pt").input_ids +with torch.no_grad(): + loss_1 = model(input_ids=input_ids, attention_mask=attention_mask, use_cache=True, labels=labels_1)[0] + loss_0 = model(input_ids=input_ids, attention_mask=attention_mask, use_cache=True, labels=labels_0)[0] + +print("real", loss_1, "false", loss_0) diff --git a/tests/test_t5_short_sequence.py b/tests/test_t5_short_sequence.py index de4627d..c5cb917 100644 --- a/tests/test_t5_short_sequence.py +++ b/tests/test_t5_short_sequence.py @@ -9,7 +9,7 @@ class TestT5ShortSeq(unittest.TestCase): def _run_test(self, INPUT_TEXT, long_model_name_or_path, base_model_name_or_path): tokenizer = T5Tokenizer.from_pretrained(long_model_name_or_path) - model = LongformerEncoderDecoderForConditionalGenerationT5.from_pretrained(long_model_name_or_path) + model = LongformerT5ForConditionalGeneration.from_pretrained(long_model_name_or_path) model.eval() model.config.gradient_checkpointing = True base_model = T5ForConditionalGeneration.from_pretrained(base_model_name_or_path) @@ -20,15 +20,22 @@ def _run_test(self, INPUT_TEXT, long_model_name_or_path, base_model_name_or_path attention_mask = data["attention_mask"] decoder_input_ids = model._shift_right(input_ids[:, :5]) + attention_mask_mixed = data["attention_mask"] * torch.randint(1, 3, data["attention_mask"].size()) + # randomly set some tokens to global, this should not change the output of a short sequence + output = model(input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, use_cache=False,)[ 0 ].float() + output_mixed = model( + input_ids, attention_mask=attention_mask_mixed, decoder_input_ids=decoder_input_ids, use_cache=False, + )[0].float() expected_output = base_model( input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, use_cache=False, )[0].float() atol = 1e-4 self.assertTrue(torch.allclose(output, expected_output, atol=atol)) + self.assertTrue(torch.allclose(output_mixed, expected_output, atol=atol)) def test_outout(self): self._run_test( @@ -37,7 +44,7 @@ def test_outout(self): base_model_name_or_path="t5-small", ) self._run_test( - INPUT_TEXT="It begins with the Great Hungerer. It ends in utter darkeness.", + INPUT_TEXT="It begins with the Great Hungerer. It ends in utter darkness.", long_model_name_or_path="/net/nfs2.s2-research/haokunl/exp_files/model_artifacts/t5/longt5-small-4096", base_model_name_or_path="t5-small", )