forked from deepspeedai/DeepSpeedExamples
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest-bert.py
37 lines (31 loc) · 1.45 KB
/
test-bert.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from transformers import pipeline
import transformers
import deepspeed
import torch
import os
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--model", "-m", type=str, help="hf model name")
parser.add_argument("--dtype", type=str, default="fp16", help="fp16 or fp32")
parser.add_argument("--local_rank", type=int, default=0, help="local rank")
parser.add_argument("--trials", type=int, default=8, help="number of trials")
parser.add_argument("--kernel-inject", action="store_true", help="inject kernels on")
parser.add_argument("--graphs", action="store_true", help="CUDA Graphs on")
parser.add_argument("--triton", action="store_true", help="triton kernels on")
parser.add_argument("--deepspeed", action="store_true", help="use deepspeed inference")
parser.add_argument("--task", type=str, default="fill-mask", help="fill-mask or token-classification")
args = parser.parse_args()
local_rank = int(os.getenv('LOCAL_RANK', '0'))
world_size = int(os.getenv('WORLD_SIZE', '4'))
pipe = pipeline('fill-mask', model='bert-large-cased', device=local_rank)
pipe.model = deepspeed.init_inference(
pipe.model,
mp_size=world_size,
dtype=torch.float16 if args.triton else torch.float,
replace_with_kernel_inject=True,
use_triton=args.triton,
)
pipe.device = torch.device(f'cuda:{local_rank}')
output = pipe("In Autumn the [MASK] fall from the trees.")
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
print(output)