Skip to content

Commit

Permalink
Example: DeepSpeed deferred init with opt-30b (#2419)
Browse files Browse the repository at this point in the history
* load from checkpoint

* Deepspeed deferred init

* DeepSpeed with opt-30b

* Update examples/large_models/deepspeed/opt/Readme.md

Co-authored-by: Hamid Shojanazeri <[email protected]>

* fix lint

* fix lint

* move files

* move files

---------

Co-authored-by: Ubuntu <[email protected]>
Co-authored-by: Hamid Shojanazeri <[email protected]>
  • Loading branch information
3 people authored Jun 24, 2023
1 parent a77a150 commit 603e89f
Show file tree
Hide file tree
Showing 9 changed files with 90 additions and 114 deletions.
46 changes: 46 additions & 0 deletions examples/large_models/deepspeed/Readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Loading large Huggingface models on Multiple GPUs

This document briefs on serving large HuggingFace (HF) models on multiple GPUs using deepspeed. We are using facebook/opt-30b in this example

To run this example we need to have deepspeed installed. This has been added to the requirement.txt which can be bundled during model packaging.


```bash
pip install deepspeed

```

### Step 1: Download model

```bash
python ../utils/Download_model.py --model_path model --model_name facebook/opt-30b --revision main
```

The script prints the path where the model is downloaded as below.

`opt/model/models--facebook--opt-30b/snapshots/ceea0a90ac0f6fae7c2c34bcb40477438c152546`

### Step 2: Generate mar or tgz file

```bash
torch-model-archiver --model-name opt --version 1.0 --handler custom_handler.py --extra-files ds-config.json -r requirements.txt --config-file opt/model-config.yaml --archive-format tgz
```

### Step 3: Add the tgz file to model store

```bash
mkdir model_store
mv opt.tar.gz model_store
```

### Step 4: Start torchserve

```bash
torchserve --start --ncs --model-store model_store --models opt.tar.gz
```

### Step 5: Run inference

```bash
curl "http://localhost:8080/predictions/opt" -T sample_text.txt
```
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import torch
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer

from ts.context import Context
from ts.handler_utils.distributed.deepspeed import get_ds_engine
Expand Down Expand Up @@ -36,15 +36,21 @@ def initialize(self, ctx: Context):
model_dir = ctx.system_properties.get("model_dir")
self.max_length = int(ctx.model_yaml_config["handler"]["max_length"])
self.max_new_tokens = int(ctx.model_yaml_config["handler"]["max_new_tokens"])
model_name = ctx.model_yaml_config["handler"]["model_name"]
model_path = ctx.model_yaml_config["handler"]["model_path"]
seed = int(ctx.model_yaml_config["handler"]["manual_seed"])
torch.manual_seed(seed)

self.tokenizer = AutoTokenizer.from_pretrained(model_dir, padding_side="left")
logger.info("Model %s loading tokenizer", ctx.model_name)

self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.tokenizer.pad_token = self.tokenizer.eos_token
self.model = AutoModelForCausalLM.from_pretrained(
model_dir, torch_dtype=torch.float16
)
self.model.eval()
config = AutoConfig.from_pretrained(model_name)
with torch.device("meta"):
self.model = AutoModelForCausalLM.from_config(
config, torch_dtype=torch.float16
)
self.model = self.model.eval()

ds_engine = get_ds_engine(self.model, ctx)
self.model = ds_engine.module
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
"dtype": "torch.float16",
"replace_with_kernel_inject": true,
"tensor_parallel": {
"tp_size": 2
"tp_size": 4
}
}
}
94 changes: 0 additions & 94 deletions examples/large_models/deepspeed/opt/Readme.md

This file was deleted.

7 changes: 5 additions & 2 deletions examples/large_models/deepspeed/opt/model-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,19 @@ responseTimeout: 1200
parallelType: "tp"
deviceType: "gpu"
# example of user specified GPU deviceIds
deviceIds: [2,3] # seting CUDA_VISIBLE_DEVICES
deviceIds: [0,1,2,3] # seting CUDA_VISIBLE_DEVICES

torchrun:
nproc-per-node: 2
nproc-per-node: 4

# TorchServe Backend parameters
deepspeed:
config: ds-config.json
checkpoint: checkpoints.json

handler:
model_name: "facebook/opt-30b"
model_path: "/home/ubuntu/serve/examples/large_models/deepspeed/opt/model/models--facebook--opt-30b/snapshots/ceea0a90ac0f6fae7c2c34bcb40477438c152546"
max_length: 50
max_new_tokens: 10
manual_seed: 40
2 changes: 0 additions & 2 deletions examples/large_models/deepspeed/opt/requirements.txt

This file was deleted.

2 changes: 2 additions & 0 deletions examples/large_models/deepspeed/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
transformers>=4.30.1
deepspeed>=0.9.4
31 changes: 23 additions & 8 deletions ts/handler_utils/distributed/deepspeed.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,29 @@
import json
import logging
import os
from pathlib import Path

import deepspeed

from ts.context import Context


def create_checkpoints_json(model_path, checkpoints_json):
checkpoint_files = file_list = [
str(entry)
for entry in Path(model_path).rglob("*.[bp][it][n]")
if entry.is_file()
]
data = {"type": "ds_model", "checkpoints": checkpoint_files, "version": 1.0}
print(f"Creating deepspeed checkpoint file {checkpoints_json}")
json.dump(data, open(checkpoints_json, "w"))


def get_ds_engine(model, ctx: Context):
model_dir = ctx.system_properties.get("model_dir")
ds_config = None
checkpoint = None
ds_config, checkpoint = None, None
model_path = ctx.model_yaml_config["handler"]["model_path"]

if "deepspeed" in ctx.model_yaml_config:
# config: the deepspeed config json file path.
# deepspeed config parameters:
Expand All @@ -23,17 +37,18 @@ def get_ds_engine(model, ctx: Context):
f"{ctx.model_name} has no deepspeed config file {ds_config}"
)

if "checkpoint" in ctx.model_yaml_config:
if "checkpoint" in ctx.model_yaml_config["deepspeed"]:
checkpoint = os.path.join(
model_dir, ctx.model_yaml_config["deepspeed"]["checkpoint"]
)
if not os.path.exists(checkpoint):
raise ValueError(
f"{ctx.model_name} has no deepspeed checkpoint file {checkpoint}"
)
create_checkpoints_json(model_path, checkpoint)

logging.debug("Creating DeepSpeed engine")
ds_engine = deepspeed.init_inference(
model, config=ds_config, checkpoint=checkpoint
model,
config=ds_config,
base_dir=model_path,
checkpoint=checkpoint,
)
return ds_engine
else:
Expand Down

0 comments on commit 603e89f

Please sign in to comment.