diff --git a/demo/multiprocs.py b/demo/multiprocs.py new file mode 100644 index 00000000..cef1572e --- /dev/null +++ b/demo/multiprocs.py @@ -0,0 +1,69 @@ +import os +import json +import fire + +from loguru import logger + +from magic_pdf.pipe.UNIPipe import UNIPipe +from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter +import concurrent.futures +import time + +import magic_pdf.model as model_config +model_config.__use_inside_model__ = True + + + +def process(pdf_list): + current_script_dir = os.path.dirname(os.path.abspath(__file__)) + pid = os.getpid() + for index, pdf_path in enumerate(pdf_list): + if index == 1: + # We start to record time after the pipe processes the first pdf, + # since it creates model instances when processing the first one, + # and creating model instances is very time consuming + start = time.perf_counter() + try: + pdf_bytes = open(pdf_path, "rb").read() + model_json = [] # model_json传空list使用内置模型解析 + jso_useful_key = {"_pdf_type": "", "model_list": model_json} + local_image_dir = os.path.join(current_script_dir, str(pid), 'images') + image_dir = str(os.path.basename(local_image_dir)) + image_writer = DiskReaderWriter(local_image_dir) + pipe = UNIPipe(pdf_bytes, jso_useful_key, image_writer) + pipe.pipe_classify() + """如果没有传入有效的模型数据,则使用内置model解析""" + if len(model_json) == 0: + if model_config.__use_inside_model__: + pipe.pipe_analyze() + else: + logger.error("need model list input") + exit(1) + pipe.pipe_parse() + md_content = pipe.pipe_mk_markdown(image_dir, drop_mode="none") + # TODO: postprocessing md_content + except Exception as e: + logger.exception(e) + end = time.perf_counter() + logger.warning(f'>>>>> {pid} spent {(end-start):2f} s on processing {len(pdf_list) - 1} PDFs') + + +def file_ext(file_path): + return os.path.splitext(file_path)[-1] + + +def main(file_path: str, max_proc_num: int=4): + with open(file_path) as f: + file_paths = f.readlines() + pdf_file_list = [_path.strip() for _path in file_paths if file_ext(_path.strip()) == '.pdf'] + proc_file_list = [] + files_per_proc = (len(pdf_file_list) + max_proc_num - 1) // max_proc_num + for i in range(0, len(pdf_file_list), files_per_proc): + file_list_per_proc = pdf_file_list[i : i + files_per_proc] + proc_file_list.append(file_list_per_proc) + with concurrent.futures.ProcessPoolExecutor() as executor: + executor.map(process, proc_file_list) + + +if __name__ == '__main__': + fire.Fire(main) \ No newline at end of file diff --git a/magic_pdf/model/mfr_cudagraph.py b/magic_pdf/model/mfr_cudagraph.py new file mode 100644 index 00000000..65134c97 --- /dev/null +++ b/magic_pdf/model/mfr_cudagraph.py @@ -0,0 +1,899 @@ +from typing import Optional, Tuple, Union +import torch +from torch import nn +import os +from unimernet.common.config import Config +import unimernet.tasks as tasks +import argparse +from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions +from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask + +class PatchedMBartLearnedPositionalEmbedding(nn.Module): + + def __init__(self, origin: nn.Module): + super().__init__() + self.offset = origin.offset + self.embedding = nn.Embedding(origin.num_embeddings, origin.embedding_dim) + self.embedding.weight.data = origin.weight.data + + def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0): + """`input_ids' shape is expected to be [bsz x seqlen].""" + + bsz, seq_len = input_ids.shape[:2] + positions = torch.arange(0, seq_len, dtype=torch.long, device=self.embedding.weight.device + ) + positions += past_key_values_length + positions = positions.expand(bsz, -1) + + return self.embedding(positions + self.offset) + + +class PatchedMBartDecoder(nn.Module): + def __init__(self, origin: nn.Module, kvlen: torch.LongTensor): + super().__init__() + self.origin = origin + self.kvlen = kvlen + + self.config = origin.config + self.embed_tokens = origin.embed_tokens + self.embed_scale = origin.embed_scale + self._use_flash_attention_2 = origin._use_flash_attention_2 + self.embed_positions = origin.embed_positions + self.counting_context_weight = getattr(origin, 'counting_context_weight', None) + self.layernorm_embedding = origin.layernorm_embedding + self.layers = origin.layers + self.layer_norm = origin.layer_norm + + self.patched_embed_positions = PatchedMBartLearnedPositionalEmbedding(self.embed_positions) + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + count_pred: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + run_origin = False + if past_key_values is None: + run_origin = True + elif past_key_values[0][0].size(-2) < attention_mask.size(-1): + run_origin = True + + if run_origin: + return self.origin( + input_ids=input_ids, + attention_mask=attention_mask, + count_pred=count_pred, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + head_mask=head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + input = input_ids + input_shape = input.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + input = inputs_embeds[:, :, -1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + + if self._use_flash_attention_2: + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + if self._use_flash_attention_2: + encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + # embed positions + positions = self.patched_embed_positions(input, self.kvlen) + + hidden_states = inputs_embeds + positions.to(inputs_embeds.device) + + # TODO: add counting context weight to hidden_states + if count_pred is not None: + count_context_weight = self.counting_context_weight(count_pred) + hidden_states = hidden_states + 0.5 * count_context_weight.unsqueeze(1) + hidden_states = self.layernorm_embedding(hidden_states) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + next_decoder_cache = () if use_cache else None + + # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): + if attn_mask is not None: + if attn_mask.size()[0] != len(self.layers): + raise ValueError( + f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" + f" {attn_mask.size()[0]}." + ) + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=( + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None + ), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) + + hidden_states = self.layer_norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + + +class PatchedMBartAttention(nn.Module): + + def __init__(self, origin: nn.Module, kvlen: torch.LongTensor): + super().__init__() + self.embed_dim = origin.embed_dim + self.num_heads = origin.num_heads + self.dropout = origin.dropout + self.head_dim = origin.head_dim + self.config = origin.config + + self.scaling = origin.scaling + self.is_decoder = origin.is_decoder + self.is_causal = origin.is_causal + + self.k_proj = origin.k_proj + self.v_proj = origin.v_proj + self.q_proj = origin.q_proj + self.out_proj = origin.out_proj + self.kvlen = kvlen + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if past_key_value[0].size(-2) < attention_mask.size(-1): + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + past_key_value[0][:, :, self.kvlen[None]] = key_states + past_key_value[1][:, :, self.kvlen[None]] = value_states + key_states = past_key_value[0] + value_states = past_key_value[1] + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.reshape(*proj_shape) + value_states = value_states.reshape(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to be reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = attn_weights + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + # attn_output = self.out_proj(attn_output) + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped, past_key_value + + +class PatchedMBartSqueezeAttention(nn.Module): + + def __init__(self, origin: nn.Module, kvlen: torch.LongTensor): + super().__init__() + self.embed_dim = origin.embed_dim + self.num_heads = origin.num_heads + self.dropout = origin.dropout + self.head_dim = origin.head_dim + self.squeeze_head_dim=origin.squeeze_head_dim + self.config = origin.config + + self.scaling = origin.scaling + self.is_decoder = origin.is_decoder + self.scaling = origin.scaling + + self.q_proj = origin.q_proj + self.k_proj = origin.k_proj + self.v_proj = origin.v_proj + self.out_proj = origin.out_proj + self.kvlen = kvlen + + def _shape_qk(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.squeeze_head_dim).transpose(1, 2).contiguous() + + def _shape_v(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape_qk(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape_v(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape_qk(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape_v(self.v_proj(hidden_states), -1, bsz) + + if past_key_value[0].size(-2) < attention_mask.size(-1): + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + past_key_value[0][:, :, self.kvlen[None]] = key_states + past_key_value[1][:, :, self.kvlen[None]] = value_states + key_states = past_key_value[0] + value_states = past_key_value[1] + else: + # self_attention + key_states = self._shape_qk(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape_v(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.squeeze_head_dim) + value_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape_qk(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.reshape(*proj_shape) + value_states = value_states.reshape(*value_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to be reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped, past_key_value + +def patch_model(model: nn.Module, kvlen: torch.LongTensor): + for name, child in model.named_children(): + cls_name = type(child).__name__ + if cls_name == 'MBartAttention': + patched_child = PatchedMBartAttention(child, kvlen) + model.register_module(name, patched_child) + elif cls_name == 'MBartSqueezeAttention': + patched_child = PatchedMBartSqueezeAttention(child, kvlen) + model.register_module(name, patched_child) + else: + patch_model(child, kvlen) + + cls_name = type(model).__name__ + if cls_name == 'CustomMBartDecoder': + model = PatchedMBartDecoder(model, kvlen) + return model + + +def next_power_of_2(n: int): + """Return the smallest power of 2 greater than or equal to n.""" + n -= 1 + n |= n >> 1 + n |= n >> 2 + n |= n >> 4 + n |= n >> 8 + n |= n >> 16 + n |= n >> 32 + n += 1 + return n + + +def get_graph_key(batch_size: int, kvlens: int): + batch_size = next_power_of_2(batch_size) + kvlens = next_power_of_2(kvlens) + + batch_size = max(8, batch_size) + kvlens = max(32, kvlens) + + return batch_size, kvlens + + +class GraphRunnerImpl: + + def __init__(self, model: nn.Module, graph: torch.cuda.CUDAGraph, input_buffers: dict, output_buffers: dict): + self.model = model + self.graph = graph + self.input_buffers = input_buffers + self.output_buffers = output_buffers + + @staticmethod + def extract_input_buffers(input_buffers: dict, batch_size: int, kvlens: int): + input_ids = input_buffers['input_ids'][:batch_size] + attention_mask = input_buffers['attention_mask'][:batch_size, :kvlens] + encoder_hidden_states = input_buffers['encoder_hidden_states'][:batch_size] + kvlen=input_buffers['kvlen'] + + past_key_values = [] + for past_key_value in input_buffers['past_key_values']: + k0 = past_key_value[0][:batch_size, :, :kvlens] + v0 = past_key_value[1][:batch_size, :, :kvlens] + k1 = past_key_value[2][:batch_size] + v1 = past_key_value[3][:batch_size] + past_key_values.append((k0, v0, k1, v1)) + + input_buffers = dict( + input_ids=input_ids, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_values=past_key_values, + kvlen=kvlen, + ) + return input_buffers + + @staticmethod + def fill_input_buffers( + input_buffer: dict, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + ): + batch_size = input_ids.size(0) + kvlens = attention_mask.size(1) + + input_buffer['input_ids'][:batch_size] = input_ids + + if input_buffer['attention_mask'].data_ptr() != attention_mask.data_ptr(): + input_buffer['attention_mask'].fill_(0) + input_buffer['attention_mask'][:batch_size, :kvlens] = attention_mask + input_buffer['encoder_hidden_states'][:batch_size] = encoder_hidden_states + + if past_key_values is not None: + for buf_kv, kv in zip(input_buffer['past_key_values'], past_key_values): + idx = 0 + if buf_kv[idx].data_ptr() != kv[idx].data_ptr(): + buf_kv[idx].fill_(0) + buf_kv[idx][:batch_size, :, :kvlens-1] = kv[idx] + idx = 1 + if buf_kv[idx].data_ptr() != kv[idx].data_ptr(): + buf_kv[idx].fill_(0) + buf_kv[idx][:batch_size, :, :kvlens-1] = kv[idx] + + idx = 2 + if buf_kv[idx].data_ptr() != kv[idx].data_ptr(): + buf_kv[idx].fill_(0) + buf_kv[idx][:batch_size] = kv[idx] + idx = 3 + if buf_kv[idx].data_ptr() != kv[idx].data_ptr(): + buf_kv[idx].fill_(0) + buf_kv[idx][:batch_size] = kv[idx] + + input_buffer['kvlen'].fill_(kvlens - 1) + + @classmethod + @torch.inference_mode() + def capture(cls, + model: nn.Module, + input_buffers: dict, + pool, + warmup: bool = False, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + count_pred: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None,): + batch_size = input_ids.size(0) + kvlens = attention_mask.size(1) + + graph_key = get_graph_key(batch_size, kvlens) + batch_size = graph_key[0] + kvlens = graph_key[1] + + input_buffers = cls.extract_input_buffers(input_buffers, + batch_size=batch_size, + kvlens=kvlens) + cls.fill_input_buffers(input_buffers, + input_ids, + attention_mask, + encoder_hidden_states, + past_key_values) + + input_ids = input_buffers['input_ids'] + attention_mask = input_buffers['attention_mask'] + encoder_hidden_states = input_buffers['encoder_hidden_states'] + past_key_values = input_buffers['past_key_values'] + + if warmup: + # warmup + model( + input_ids=input_ids, + attention_mask=attention_mask, + count_pred=count_pred, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + head_mask=head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict) + + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph, + pool=pool): + outputs = model( + input_ids=input_ids, + attention_mask=attention_mask, + count_pred=count_pred, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + head_mask=head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict) + + output_buffers = dict( + last_hidden_state=outputs['last_hidden_state'], + past_key_values=outputs['past_key_values'], + ) + + return GraphRunnerImpl(model, graph, input_buffers, output_buffers) + + def __call__(self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + count_pred: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + batch_size = input_ids.size(0) + kvlens = attention_mask.size(1) + self.fill_input_buffers(self.input_buffers, + input_ids, + attention_mask, + encoder_hidden_states, + past_key_values) + + self.graph.replay() + + last_hidden_state = self.output_buffers['last_hidden_state'][:batch_size] + + past_key_values = [] + for past_key_value in self.output_buffers['past_key_values']: + k0 = past_key_value[0][:batch_size, :, :kvlens] + v0 = past_key_value[1][:batch_size, :, :kvlens] + k1 = past_key_value[2][:batch_size] + v1 = past_key_value[3][:batch_size] + past_key_values.append((k0, v0, k1, v1)) + + outputs = BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=last_hidden_state, + past_key_values=past_key_values, + ) + return outputs + +class GraphRunner(nn.Module): + + def __init__(self, model: nn.Module, max_batchs: int, max_kvlens: int, dtype:torch.dtype = torch.float16, device: torch.device = 'cuda'): + super().__init__() + + self.kvlen = torch.tensor(0, dtype=torch.long, device=device) + model = patch_model(model.to(dtype), self.kvlen) + self.model = model + self.max_batchs = max_batchs + self.max_kvlens = max_kvlens + self.device = device + + self.input_buffers = None + + self.impl_map = dict() + self.graph_pool_handle = torch.cuda.graph_pool_handle() + self.warmuped = False + + def create_buffers(self, encoder_kvlens: int, dtype: torch.dtype): + max_batchs = self.max_batchs + max_kvlens = self.max_kvlens + device = self.device + config = self.model.config + + d_model = config.d_model + decoder_layers = config.decoder_layers + num_heads = config.decoder_attention_heads + + head_dim = d_model // num_heads + self_attn = self.model.layers[0].self_attn + qk_head_dim = getattr(self_attn, 'squeeze_head_dim', head_dim) + + input_ids = torch.ones((max_batchs, 1), dtype=torch.int64, device=device) + attention_mask = torch.zeros((max_batchs, max_kvlens), dtype=torch.int64, device=device) + encoder_hidden_states = torch.zeros((max_batchs, encoder_kvlens, d_model), dtype=dtype, device=device) + + past_key_values = [] + for _ in range(decoder_layers): + k0 = torch.zeros((max_batchs, num_heads, max_kvlens, qk_head_dim), dtype=dtype, device=device) + v0 = torch.zeros((max_batchs, num_heads, max_kvlens, head_dim), dtype=dtype, device=device) + k1 = torch.zeros((max_batchs, num_heads, encoder_kvlens, qk_head_dim), dtype=dtype, device=device) + v1 = torch.zeros((max_batchs, num_heads, encoder_kvlens, head_dim), dtype=dtype, device=device) + + past_key_values.append((k0, v0, k1, v1)) + + self.input_buffers = dict( + input_ids=input_ids, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_values=past_key_values, + kvlen=self.kvlen + ) + + @torch.inference_mode() + def forward(self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + count_pred: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + batch_size, qlens = input_ids.size() + kvlens = attention_mask.size(1) + + eager_mode = False + + if qlens != 1: + eager_mode = True + + if past_key_values is None: + eager_mode = True + else: + for past_key_value in past_key_values: + if past_key_value is None: + eager_mode = True + break + + if batch_size >= self.max_batchs or kvlens >= self.max_kvlens: + eager_mode = True + + if eager_mode: + return self.model( + input_ids=input_ids, + attention_mask=attention_mask, + count_pred=count_pred, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + head_mask=head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict,) + + # create buffer if not exists. + if self.input_buffers is None: + encoder_kvlens = encoder_hidden_states.size(1) + self.create_buffers(encoder_kvlens=encoder_kvlens, dtype=encoder_hidden_states.dtype) + + graph_key = get_graph_key(batch_size, kvlens) + if graph_key not in self.impl_map: + warmup = False + if not self.warmuped: + warmup = True + self.warmuped = True + impl = GraphRunnerImpl.capture( + self.model, + self.input_buffers, + self.graph_pool_handle, + warmup=warmup, + input_ids=input_ids, + attention_mask=attention_mask, + count_pred=count_pred, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + head_mask=head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + self.impl_map[graph_key] = impl + impl = self.impl_map[graph_key] + + ret = impl( + input_ids=input_ids, + attention_mask=attention_mask, + count_pred=count_pred, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + head_mask=head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + return ret diff --git a/magic_pdf/model/pdf_extract_kit.py b/magic_pdf/model/pdf_extract_kit.py index 6c5b9d18..d0f5a621 100644 --- a/magic_pdf/model/pdf_extract_kit.py +++ b/magic_pdf/model/pdf_extract_kit.py @@ -3,6 +3,7 @@ import time from magic_pdf.libs.Constants import * +from .mfr_cudagraph import GraphRunner from magic_pdf.model.model_list import AtomicModel os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新 @@ -63,8 +64,10 @@ def mfr_model_init(weight_dir, cfg_path, _device_='cpu'): cfg.config.model.tokenizer_config.path = weight_dir task = tasks.setup_task(cfg) model = task.build_model(cfg) - model.to(_device_) - model.eval() + model = model.to(_device_) + if 'cuda' in _device_: + decoder_runner = GraphRunner(model.model.model.decoder.model.decoder, max_batchs=128, max_kvlens=256, device=_device_) + model.model.model.decoder.model.decoder = decoder_runner vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval) mfr_transform = transforms.Compose([vis_processor, ]) return [model, mfr_transform]