Skip to content

Commit

Permalink
Support Ascend NPU
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangsibo1129 committed Sep 16, 2023
1 parent aa153d5 commit 5419ac1
Show file tree
Hide file tree
Showing 8 changed files with 61 additions and 2 deletions.
41 changes: 41 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,18 @@ python3 -m fastchat.serve.cli --model-path lmsys/vicuna-7b-v1.3 --device xpu
```
Vicuna-7B can run on an Intel Arc A770 16GB.

#### Ascend NPU (Huawei AI Processor)
Install the [Ascend PyTorch Adapter](https://github.com/Ascend/pytorch). Set the CANN environment variables:
```
source /usr/local/Ascend/ascend-toolkit/set_env.sh
```

Use `--device npu` to enable NPU acceleration.
```
python3 -m fastchat.serve.cli --model-path lmsys/vicuna-7b-v1.3 --device npu
```
Vicuna-7B/13B can run on an Ascend 910B NPU 60GB.

#### Not Enough Memory
If you do not have enough memory, you can enable 8-bit compression by adding `--load-8bit` to commands above.
This can reduce memory usage by around half with slightly degraded model quality.
Expand Down Expand Up @@ -301,6 +313,35 @@ Tips:
- If you meet out-of-memory due to "FSDP Warning: When using FSDP, it is efficient and recommended... ", see solutions [here](https://github.com/huggingface/transformers/issues/24724#issuecomment-1645189539).
- If you meet out-of-memory during model saving, see solutions [here](https://github.com/pytorch/pytorch/issues/98823).

### Fine-tuning Vicuna-7B with Local NPUs

You can use the following command to train Vicuna-7B with 8 x 910B (60GB). Use `--nproc_per_node` to specify the number of NPUs.
```bash
torchrun --nproc_per_node=8 --master_port=20001 fastchat/train/train.py \
--model_name_or_path ~/vicuna-7b-v1.5-16k \
--data_path data/dummy_conversation.json \
--fp16 True \
--output_dir output_vicuna \
--num_train_epochs 3 \
--per_device_train_batch_size 8 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 1 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 1200 \
--save_total_limit 10 \
--learning_rate 2e-5 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--fsdp "full_shard auto_wrap" \
--fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \
--model_max_length 2048 \
--gradient_checkpointing True \
--lazy_preprocess True
```

### Other models and LoRA support
More instructions to train other models (e.g., FastChat-T5) and use LoRA are in [docs/training.md](docs/training.md).

Expand Down
2 changes: 2 additions & 0 deletions fastchat/model/compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,8 @@ def load_compress_model(model_path, device, torch_dtype, use_fast, revision="mai
torch.cuda.empty_cache()
if device == "xpu":
torch.xpu.empty_cache()
if device == "npu":
torch.npu.empty_cache()

for name in model.state_dict():
if name not in linear_weights:
Expand Down
10 changes: 9 additions & 1 deletion fastchat/model/model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,13 @@ def load_model(
warnings.warn(
"Intel Extension for PyTorch is not installed, but is required for xpu inference."
)
elif device == "npu":
kwargs = {"torch_dtype": torch.float16}
# Try to load ipex, while it looks unused, it links into torch for xpu support
try:
import torch_npu
except ImportError:
warnings.warn("Ascend Extension for PyTorch is not installed.")
else:
raise ValueError(f"Invalid device: {device}")

Expand Down Expand Up @@ -288,6 +295,7 @@ def load_model(
if (device == "cuda" and num_gpus == 1 and not cpu_offloading) or device in (
"mps",
"xpu",
"npu",
):
model.to(device)

Expand Down Expand Up @@ -369,7 +377,7 @@ def add_model_args(parser):
parser.add_argument(
"--device",
type=str,
choices=["cpu", "cuda", "mps", "xpu"],
choices=["cpu", "cuda", "mps", "xpu", "npu"],
default="cuda",
help="The device type",
)
Expand Down
2 changes: 2 additions & 0 deletions fastchat/model/model_codet5p.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,5 @@ def __call__(
torch.cuda.empty_cache()
if device == "xpu":
torch.xpu.empty_cache()
if device == "npu":
torch.npu.empty_cache()
2 changes: 2 additions & 0 deletions fastchat/model/model_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,5 @@ def generate_stream_falcon(
torch.cuda.empty_cache()
if device == "xpu":
torch.xpu.empty_cache()
if device == "npu":
torch.npu.empty_cache()
2 changes: 2 additions & 0 deletions fastchat/serve/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,8 @@ def generate_stream(
torch.cuda.empty_cache()
if device == "xpu":
torch.xpu.empty_cache()
if device == "npu":
torch.npu.empty_cache()


class ChatIO(abc.ABC):
Expand Down
2 changes: 1 addition & 1 deletion fastchat/serve/launch_all_serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
parser.add_argument(
"--device",
type=str,
choices=["cpu", "cuda", "mps", "xpu"],
choices=["cpu", "cuda", "mps", "xpu", "npu"],
default="cuda",
help="The device type",
)
Expand Down
2 changes: 2 additions & 0 deletions fastchat/serve/model_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,8 @@ def get_embeddings(self, params):
torch.cuda.empty_cache()
if self.device == "xpu":
torch.xpu.empty_cache()
if self.device == "npu":
torch.npu.empty_cache()
except torch.cuda.OutOfMemoryError as e:
ret = {
"text": f"{SERVER_ERROR_MSG}\n\n({e})",
Expand Down

0 comments on commit 5419ac1

Please sign in to comment.