Skip to content

Commit

Permalink
Merge branch 'hpcaitech:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
duanjunwen authored Jun 13, 2024
2 parents 9c7276e + 3bcbba9 commit b37ad1f
Show file tree
Hide file tree
Showing 12 changed files with 194 additions and 20 deletions.
6 changes: 4 additions & 2 deletions colossalai/inference/modeling/backends/attention_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from dataclasses import dataclass

import torch
from flash_attn import flash_attn_varlen_func

from colossalai.inference.config import ModelShardInferenceConfig
from colossalai.kernel.kernel_loader import InferenceOpsLoader
Expand Down Expand Up @@ -44,14 +43,17 @@ class CudaAttentionBackend(AttentionBackend):
it uses Triton op `context_attention_unpadded` for prefilling and our cuda op `flash_decoding_attention` for decoding.
"""

def __init__(self, use_flash_attn: bool):
def __init__(self, use_flash_attn: bool = False):
super().__init__()
self.inference_ops = InferenceOpsLoader().load()
self.use_flash_attn = use_flash_attn

def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
if self.use_flash_attn:
token_nums = kwargs.get("token_nums", -1)

from flash_attn import flash_attn_varlen_func

attn_output = flash_attn_varlen_func(
attn_metadata.query_states,
attn_metadata.key_states,
Expand Down
2 changes: 0 additions & 2 deletions colossalai/inference/modeling/models/nopadding_baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,6 @@ def forward(

self.pre_attention_backend.decode(
attn_metadata,
cos=cos_sin[0],
sin=cos_sin[1],
q_len=q_len,
)
attn_output = self.attention_backend.decode(
Expand Down
4 changes: 2 additions & 2 deletions colossalai/inference/modeling/models/nopadding_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def llama_model_forward(

elif use_cuda_kernel:
if can_use_flash_attn2(inputmetadata.dtype):
cu_seqlens = F.pad(torch.cumsum(sequence_lengths, dim=0, dtype=torch.torch.int32), (1, 0))
cu_seqlens = F.pad(torch.cumsum(sequence_lengths, dim=0, dtype=torch.int32), (1, 0))

hidden_dim = self._cos_cached.size(-1)
total_length = hidden_states.size(0)
Expand Down Expand Up @@ -265,7 +265,7 @@ def __init__(
mlp_dproj: ParallelModule = None,
process_group: ProcessGroup = None,
):
"""A Unified Layer for
"""Replacement of LlamaMLP layer.
Args:
config (LlamaConfig): Holding the Llama model config.
Expand Down
2 changes: 2 additions & 0 deletions colossalai/inference/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@ def can_use_flash_attn2(dtype: torch.dtype) -> bool:
return False

try:
from flash_attn import flash_attn_varlen_func # noqa

return True
except ImportError:
logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.")
Expand Down
2 changes: 1 addition & 1 deletion colossalai/shardformer/layer/attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def get_pad_info(padding_mask: torch.Tensor) -> Tuple[int, torch.Tensor, torch.T
seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return max_seqlen_in_batch, cu_seqlens, indices


Expand Down
2 changes: 2 additions & 0 deletions colossalai/zero/gemini/chunk/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def __init__(
chunk_configuration,
init_device: Optional[torch.device] = None,
reuse_fp16_chunk: bool = True,
max_prefetch: int = 0,
) -> None:
self.device = init_device or get_accelerator().get_current_device()
self.dp_degree_chunk_size_dict: Dict[int, int] = dict()
Expand All @@ -42,6 +43,7 @@ def __init__(
# Whether model is accumulating gradients,
self.accumulating_grads = False
self.overflow_counter = torch.tensor([0], dtype=torch.int, device=get_accelerator().get_current_device())
self._prefetch_stream = get_accelerator().Stream() if max_prefetch else None

def register_tensor(
self,
Expand Down
7 changes: 2 additions & 5 deletions colossalai/zero/gemini/chunk/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def init_chunk_manager(
hidden_dim: Optional[int] = None,
reuse_fp16_chunk: bool = True,
verbose: bool = False,
max_prefetch: int = 0,
**kwargs,
) -> ChunkManager:
if hidden_dim:
Expand Down Expand Up @@ -51,9 +52,5 @@ def init_chunk_manager(
)
dist.barrier()

chunk_manager = ChunkManager(
config_dict,
init_device,
reuse_fp16_chunk=reuse_fp16_chunk,
)
chunk_manager = ChunkManager(config_dict, init_device, reuse_fp16_chunk=reuse_fp16_chunk, max_prefetch=max_prefetch)
return chunk_manager
5 changes: 2 additions & 3 deletions colossalai/zero/gemini/gemini_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,7 @@ def __init__(
self.enable_gradient_accumulation = enable_gradient_accumulation
if chunk_config_dict is not None:
self.chunk_manager = ChunkManager(
chunk_config_dict,
chunk_init_device,
reuse_fp16_chunk=reuse_fp16_chunk,
chunk_config_dict, chunk_init_device, reuse_fp16_chunk=reuse_fp16_chunk, max_prefetch=max_prefetch
)
else:
# some ugly hotfix for the compatibility with Lightning
Expand All @@ -122,6 +120,7 @@ def __init__(
process_group=zero_group,
reuse_fp16_chunk=reuse_fp16_chunk,
verbose=verbose,
max_prefetch=max_prefetch,
)
self.gemini_manager = GeminiManager(
placement_policy,
Expand Down
19 changes: 15 additions & 4 deletions colossalai/zero/gemini/gemini_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import torch

from colossalai.accelerator import get_accelerator
from colossalai.tensor.param_op_hook import ColoParamOpHook
from colossalai.utils import is_ddp_ignored
from colossalai.zero.gemini import TensorState
Expand Down Expand Up @@ -54,10 +55,20 @@ def pre_op(self, params):
)

# prefetch
for chunk in chunks_fetch_async:
maybe_work = self._chunk_manager.access_chunk(chunk, async_access=True)
if maybe_work is not None:
self._gemini_manager.add_work(chunk, maybe_work)
if self._gemini_manager.chunk_manager._prefetch_stream is not None:
# This is when prefetch happens the first time and there is no dist.Work to sync,
# there is possibility that the optimizer haven't finish computation on default stream,
# thus we might prefetch outdated chunks there.
#
# Other than that, self._gemini_manager.wait_chunks will have synced with default stream
# by calling dist.Work.wait() and this line makes no diff.
self._gemini_manager.chunk_manager._prefetch_stream.wait_stream(torch.cuda.current_stream())

with get_accelerator().stream(self._gemini_manager.chunk_manager._prefetch_stream):
for chunk in chunks_fetch_async:
maybe_work = self._chunk_manager.access_chunk(chunk, async_access=True)
if maybe_work is not None:
self._gemini_manager.add_work(chunk, maybe_work)

# record cuda model data of the current OP, including memory for prefetched chunks
self._gemini_manager.record_model_data_volume()
Expand Down
2 changes: 1 addition & 1 deletion tests/test_infer/test_kernels/cuda/test_kv_cache_memcpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def prepare_data(
num_tokens = torch.sum(context_lengths).item()

max_seq_len_in_batch = context_lengths.max()
cu_seqlens = F.pad(torch.cumsum(context_lengths, dim=0, dtype=torch.torch.int32), (1, 0))
cu_seqlens = F.pad(torch.cumsum(context_lengths, dim=0, dtype=torch.int32), (1, 0))

kv_size = (num_tokens, num_kv_heads, HEAD_DIM)
key = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
Expand Down
161 changes: 161 additions & 0 deletions tests/test_infer/test_models/test_custom_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
import os
import random

import numpy as np
import pytest
import torch
import torch.distributed as dist
from torch.multiprocessing import Manager
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, LlamaForCausalLM, LlamaTokenizer

import colossalai
import colossalai.inference.modeling.policy as policy
from colossalai.inference.config import _DEFAULT_PROMPT_TEMPLATES, InferenceConfig
from colossalai.inference.core.engine import InferenceEngine
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn

# NOTE: To test a model with the inference engine, you need to provide the path to your
# local pretrained model weights in the MODEL_MAP dictionary
MODEL_MAP = {
"baichuan": {
"model": AutoModelForCausalLM,
"tokenizer": AutoTokenizer,
"policy": policy.NoPaddingBaichuanModelInferPolicy,
"model_name_or_path": "baichuan-inc/Baichuan2-13B-Base", # provide the path to local model weights
},
"llama": {
"model": LlamaForCausalLM,
"tokenizer": LlamaTokenizer,
"policy": policy.NoPaddingLlamaModelInferPolicy,
"model_name_or_path": "meta-llama/Llama-2-70b-hf",
},
}

MODELS_TO_TEST = ["llama", "baichuan"] # Specify the models to test


@parameterize("model", MODELS_TO_TEST)
@parameterize("prompt_template", [None, "model_specific"])
@parameterize("do_sample", [False])
@parameterize("use_cuda_kernel", [True])
@pytest.mark.largedist
@rerun_if_address_is_in_use()
def test_model(model, prompt_template, do_sample, use_cuda_kernel):
model_path = MODEL_MAP[model]["model_name_or_path"]
if not os.path.exists(model_path):
pytest.skip(
f"There is no local model address included for {model}, please replace this address with a valid one."
)

if prompt_template == "model_specific":
prompt_template = model

model_config = MODEL_MAP[model]

kwargs1 = {
"model": model,
"use_engine": True,
"prompt_template": prompt_template,
"do_sample": do_sample,
"policy": model_config["policy"](),
"use_cuda_kernel": use_cuda_kernel,
}

kwargs2 = {
"model": model,
"use_engine": False,
"prompt_template": prompt_template,
"do_sample": do_sample,
"policy": None,
"use_cuda_kernel": use_cuda_kernel,
}

colossal_tp_1_output = run_engine(1, **kwargs1)
colossal_tp_2_output = run_engine(2, **kwargs1)
transformer_tp_1_output = run_engine(1, **kwargs2)

for s1, s2, s3 in zip(colossal_tp_1_output, colossal_tp_2_output, transformer_tp_1_output):
assert s1 == s3, f"\nColossalAI TP=1 Output: {s1}\nTransformers Output: {s3}"
assert s1 == s2, f"\nColossalAI TP=1 Output: {s1}\nColossalAI TP=2 Output: {s2}"


def run_engine(world_size, **kwargs):
manager = Manager()
result_list = manager.list([-1] * world_size) # Create a shared list
spawn(run_dist, world_size, func_to_run=_run_engine, ret=result_list, **kwargs)
return result_list[0]


def run_dist(rank, world_size, port, func_to_run, ret=None, **kwargs):
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")

if ret:
ret[rank] = func_to_run(**kwargs)
else:
func_to_run(**kwargs)


def _run_engine(model, use_engine=False, do_sample=False, use_cuda_kernel=False, prompt_template=None, policy=None):
setup_seed(20)
model_config = MODEL_MAP[model]
model_name_or_path = model_config["model_name_or_path"]
tokenizer = model_config["tokenizer"].from_pretrained(model_name_or_path, use_fast=False, trust_remote_code=True)
model = model_config["model"].from_pretrained(model_name_or_path, trust_remote_code=True).half().cuda()
model = model.eval()

inputs = [
"Introduce some landmarks in Paris:",
]

output_len = 38

if do_sample:
top_p = 0.5
top_k = 50
else:
top_p = None
top_k = None

if use_engine:
inference_config = InferenceConfig(
max_output_len=output_len,
prompt_template=prompt_template,
use_cuda_kernel=use_cuda_kernel,
tp_size=dist.get_world_size(),
)
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True, model_policy=policy)
assert inference_engine.generation_config.max_new_tokens == output_len
inference_engine.add_request(prompts=inputs)
assert inference_engine.request_handler._has_waiting()
generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k, max_new_tokens=output_len)
outputs = inference_engine.generate(generation_config=generation_config)
else:
if prompt_template:
# apply prompt template
inputs = [_DEFAULT_PROMPT_TEMPLATES[prompt_template].format(input_text=input_text) for input_text in inputs]
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
inputs = tokenizer.batch_encode_plus(inputs, padding=True, return_tensors="pt")["input_ids"]
inputs = inputs.cuda()
generation_config = GenerationConfig(
do_sample=do_sample,
top_p=top_p,
top_k=top_k,
pad_token_id=tokenizer.pad_token_id,
max_new_tokens=output_len,
)
outputs = model.generate(inputs, generation_config=generation_config)
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
return outputs


def setup_seed(seed):
torch.manual_seed(seed)
torch.random.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)


if __name__ == "__main__":
test_model()
2 changes: 2 additions & 0 deletions tests/test_shardformer/test_model/test_shard_qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,13 +217,15 @@ def check_qwen2_3d(rank, world_size, port):


@pytest.mark.skipif(transformers.__version__ < "4.39.1", reason="Requires transformers version 4.39.1 or later")
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_qwen2():
spawn(check_qwen2, 4)


@pytest.mark.skipif(transformers.__version__ < "4.39.1", reason="Requires transformers version 4.39.1 or later")
@pytest.mark.largedist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_qwen2_3d():
Expand Down

0 comments on commit b37ad1f

Please sign in to comment.