Skip to content

Commit cefa4cb

Browse files
committed
FIX #62; ADD GPT-J example #63
1 parent 949ffef commit cefa4cb

File tree

3 files changed

+257
-9
lines changed

3 files changed

+257
-9
lines changed

bminf/scheduler/__init__.py

+44-8
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,24 @@
11
import torch
2-
from typing import List, Optional, Tuple, Union
2+
from typing import List, Optional, Tuple, Union, Dict, Set
33
from cpm_kernels.library import cudart
4+
from typing_extensions import TypedDict
45

5-
def calc_fixed_layers(total_layers : int, max_fixed : int):
6+
class ParameterInfo(TypedDict):
7+
shape : torch.Size
8+
dtype : torch.dtype
9+
10+
11+
class SchedLayerInfo(TypedDict):
12+
parameters : Dict[str, torch.Tensor]
13+
evt : torch.cuda.Event
14+
unused : bool
15+
id : int
16+
17+
def calc_fixed_layers(total_layers : int, max_fixed : int) -> List[int]:
618
max_fixed = min(max_fixed, total_layers)
719
scheduled_layers = total_layers - max_fixed
820
vals = [(i + 1) * scheduled_layers // total_layers for i in range(total_layers)]
9-
ret = []
21+
ret : List[int] = []
1022
last_v = 0
1123
for i, v in enumerate(vals):
1224
if v == last_v:
@@ -19,16 +31,23 @@ def pin_layer(m : torch.nn.Module):
1931
for param in m.parameters():
2032
with torch.no_grad():
2133
param.data = param.data.pin_memory()
34+
for buf in m.buffers():
35+
with torch.no_grad():
36+
buf.data = buf.data.pin_memory()
2237
return m
2338

24-
def transfer_layers(m_src : torch.nn.Module, m_dst : dict):
39+
def transfer_layers(m_src : torch.nn.Module, m_dst : Dict[str, torch.Tensor]):
2540
with torch.no_grad():
2641
for name, param in m_src.named_parameters():
2742
assert name in m_dst
2843
# copy to device buffer
2944
m_dst[name].copy_(param, non_blocking=True)
45+
for name, buf in m_src.named_buffers():
46+
assert name in m_dst
47+
m_dst[name].copy_(buf, non_blocking=True)
48+
3049

31-
def swap_params(m_src : torch.nn.Module, m_dst : dict):
50+
def swap_params(m_src : torch.nn.Module, m_dst : Dict[str, torch.Tensor]):
3251
with torch.no_grad():
3352
for name, param in m_src.named_parameters():
3453
assert name in m_dst
@@ -37,6 +56,13 @@ def swap_params(m_src : torch.nn.Module, m_dst : dict):
3756
tmp = m_dst[name].data
3857
m_dst[name].data = param.data
3958
param.data = tmp
59+
for name, buf in m_src.named_buffers():
60+
assert name in m_dst
61+
62+
# swap memory info
63+
tmp = m_dst[name].data
64+
m_dst[name].data = buf.data
65+
buf.data = tmp
4066

4167
class OpDeviceLayer(torch.autograd.Function):
4268
@staticmethod
@@ -176,8 +202,8 @@ def __init__(self, layers : List[torch.nn.Module], device_id : int, memory_limit
176202
self._device = device_id
177203
self._num_layers = len(layers)
178204

179-
self._fixed_layers = set()
180-
self._sched_layers = []
205+
self._fixed_layers : Set[int] = set()
206+
self._sched_layers : List[SchedLayerInfo] = []
181207
self._layers = []
182208
self._active_layers = {}
183209

@@ -193,6 +219,8 @@ def __init__(self, layers : List[torch.nn.Module], device_id : int, memory_limit
193219
total_size = 0
194220
for param in layers[0].parameters():
195221
total_size += param.numel() * param.storage().element_size()
222+
for buf in layers[0].buffers():
223+
total_size += buf.numel() * buf.storage().element_size()
196224

197225
total_layers = free_mem // total_size
198226
if total_layers < 2:
@@ -217,7 +245,12 @@ def __init__(self, layers : List[torch.nn.Module], device_id : int, memory_limit
217245
if i not in self._fixed_layers:
218246
self._active_layers[i] = len(self._sched_layers)
219247
self._sched_layers.append({
220-
"parameters": { name: param.cuda() for name, param in layers[i].named_parameters()},
248+
"parameters": {
249+
name: param.cuda() for name, param in (
250+
list(layers[i].named_parameters())
251+
+ list(layers[i].named_buffers())
252+
)
253+
},
221254
"evt": torch.cuda.Event(),
222255
"id": i,
223256
"unused": True
@@ -375,6 +408,9 @@ def __iter__(self):
375408
for sched in self._scheds:
376409
for layer in sched:
377410
yield layer
411+
412+
def __len__(self):
413+
return len(self.layers)
378414

379415
def forward(self, x, *args, **kwargs):
380416
for sched in self._scheds:

example/huggingface/gpt-j.ipynb

+211
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
{
2+
"cells": [
3+
{
4+
"attachments": {},
5+
"cell_type": "markdown",
6+
"metadata": {},
7+
"source": [
8+
"# GPT-J 6B"
9+
]
10+
},
11+
{
12+
"attachments": {},
13+
"cell_type": "markdown",
14+
"metadata": {},
15+
"source": [
16+
"## 1. Load model and tokenizer from HuggingFace Hub\n",
17+
"\n",
18+
"GPT-J is loaded in fp32 mode by default which takes about 24GB CPU memory."
19+
]
20+
},
21+
{
22+
"cell_type": "code",
23+
"execution_count": 1,
24+
"metadata": {},
25+
"outputs": [],
26+
"source": [
27+
"from transformers import AutoTokenizer, AutoModelForCausalLM\n",
28+
"\n",
29+
"tokenizer = AutoTokenizer.from_pretrained(\"EleutherAI/gpt-j-6B\")\n",
30+
"\n",
31+
"model = AutoModelForCausalLM.from_pretrained(\"EleutherAI/gpt-j-6B\")"
32+
]
33+
},
34+
{
35+
"attachments": {},
36+
"cell_type": "markdown",
37+
"metadata": {},
38+
"source": [
39+
"## 2. Use BMInf wrapper for low-resource inference"
40+
]
41+
},
42+
{
43+
"cell_type": "code",
44+
"execution_count": 2,
45+
"metadata": {},
46+
"outputs": [],
47+
"source": [
48+
"import torch\n",
49+
"import bminf\n",
50+
"with torch.cuda.device(0):\n",
51+
" model = bminf.wrapper(model, quantization=False, memory_limit=8 << 30) # 8GB"
52+
]
53+
},
54+
{
55+
"attachments": {},
56+
"cell_type": "markdown",
57+
"metadata": {},
58+
"source": [
59+
"## 3. See the GPU usage"
60+
]
61+
},
62+
{
63+
"cell_type": "code",
64+
"execution_count": 3,
65+
"metadata": {},
66+
"outputs": [
67+
{
68+
"name": "stdout",
69+
"output_type": "stream",
70+
"text": [
71+
"|===========================================================================|\n",
72+
"| PyTorch CUDA memory summary, device ID 0 |\n",
73+
"|---------------------------------------------------------------------------|\n",
74+
"| CUDA OOMs: 0 | cudaMalloc retries: 0 |\n",
75+
"|===========================================================================|\n",
76+
"| Metric | Cur Usage | Peak Usage | Tot Alloc | Tot Freed |\n",
77+
"|---------------------------------------------------------------------------|\n",
78+
"| Allocated memory | 9297 MB | 9297 MB | 9297 MB | 0 B |\n",
79+
"| from large pool | 9296 MB | 9296 MB | 9296 MB | 0 B |\n",
80+
"| from small pool | 1 MB | 1 MB | 1 MB | 0 B |\n",
81+
"|---------------------------------------------------------------------------|\n",
82+
"| Active memory | 9297 MB | 9297 MB | 9297 MB | 0 B |\n",
83+
"| from large pool | 9296 MB | 9296 MB | 9296 MB | 0 B |\n",
84+
"| from small pool | 1 MB | 1 MB | 1 MB | 0 B |\n",
85+
"|---------------------------------------------------------------------------|\n",
86+
"| GPU reserved memory | 9298 MB | 9298 MB | 9298 MB | 0 B |\n",
87+
"| from large pool | 9296 MB | 9296 MB | 9296 MB | 0 B |\n",
88+
"| from small pool | 2 MB | 2 MB | 2 MB | 0 B |\n",
89+
"|---------------------------------------------------------------------------|\n",
90+
"| Non-releasable memory | 710656 B | 18400 KB | 34800 KB | 34106 KB |\n",
91+
"| from large pool | 0 B | 16384 KB | 32768 KB | 32768 KB |\n",
92+
"| from small pool | 710656 B | 2032 KB | 2032 KB | 1338 KB |\n",
93+
"|---------------------------------------------------------------------------|\n",
94+
"| Allocations | 125 | 125 | 125 | 0 |\n",
95+
"| from large pool | 72 | 72 | 72 | 0 |\n",
96+
"| from small pool | 53 | 53 | 53 | 0 |\n",
97+
"|---------------------------------------------------------------------------|\n",
98+
"| Active allocs | 125 | 125 | 125 | 0 |\n",
99+
"| from large pool | 72 | 72 | 72 | 0 |\n",
100+
"| from small pool | 53 | 53 | 53 | 0 |\n",
101+
"|---------------------------------------------------------------------------|\n",
102+
"| GPU reserved segments | 65 | 65 | 65 | 0 |\n",
103+
"| from large pool | 64 | 64 | 64 | 0 |\n",
104+
"| from small pool | 1 | 1 | 1 | 0 |\n",
105+
"|---------------------------------------------------------------------------|\n",
106+
"| Non-releasable allocs | 1 | 2 | 3 | 2 |\n",
107+
"| from large pool | 0 | 1 | 2 | 2 |\n",
108+
"| from small pool | 1 | 1 | 1 | 0 |\n",
109+
"|---------------------------------------------------------------------------|\n",
110+
"| Oversize allocations | 0 | 0 | 0 | 0 |\n",
111+
"|---------------------------------------------------------------------------|\n",
112+
"| Oversize GPU segments | 0 | 0 | 0 | 0 |\n",
113+
"|===========================================================================|\n",
114+
"\n"
115+
]
116+
}
117+
],
118+
"source": [
119+
"print(torch.cuda.memory_summary())"
120+
]
121+
},
122+
{
123+
"attachments": {},
124+
"cell_type": "markdown",
125+
"metadata": {},
126+
"source": [
127+
"## 4. Run generation"
128+
]
129+
},
130+
{
131+
"cell_type": "code",
132+
"execution_count": 9,
133+
"metadata": {},
134+
"outputs": [
135+
{
136+
"name": "stderr",
137+
"output_type": "stream",
138+
"text": [
139+
"The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
140+
"Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
141+
]
142+
}
143+
],
144+
"source": [
145+
"prompt = \"To be or not to be, that\"\n",
146+
"input_ids = tokenizer(prompt, return_tensors=\"pt\").input_ids\n",
147+
"gen_tokens = model.generate(\n",
148+
" input_ids.cuda(),\n",
149+
" do_sample=True,\n",
150+
" temperature=0.9,\n",
151+
" max_length=20\n",
152+
")"
153+
]
154+
},
155+
{
156+
"attachments": {},
157+
"cell_type": "markdown",
158+
"metadata": {},
159+
"source": [
160+
"## 5. Get the generated text"
161+
]
162+
},
163+
{
164+
"cell_type": "code",
165+
"execution_count": 10,
166+
"metadata": {},
167+
"outputs": [
168+
{
169+
"data": {
170+
"text/plain": [
171+
"['To be or not to be, that is the question — that has been the question, and still']"
172+
]
173+
},
174+
"execution_count": 10,
175+
"metadata": {},
176+
"output_type": "execute_result"
177+
}
178+
],
179+
"source": [
180+
"tokenizer.batch_decode(gen_tokens)"
181+
]
182+
}
183+
],
184+
"metadata": {
185+
"kernelspec": {
186+
"display_name": "venv",
187+
"language": "python",
188+
"name": "python3"
189+
},
190+
"language_info": {
191+
"codemirror_mode": {
192+
"name": "ipython",
193+
"version": 3
194+
},
195+
"file_extension": ".py",
196+
"mimetype": "text/x-python",
197+
"name": "python",
198+
"nbconvert_exporter": "python",
199+
"pygments_lexer": "ipython3",
200+
"version": "3.8.10"
201+
},
202+
"orig_nbformat": 4,
203+
"vscode": {
204+
"interpreter": {
205+
"hash": "29d71688ffbe7d005e79abd80e578fa5cab2d2c2e11d1955de002b95fcc7229b"
206+
}
207+
}
208+
},
209+
"nbformat": 4,
210+
"nbformat_minor": 2
211+
}

requirements.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
torch
2-
cpm_kernels>=1.0.9
2+
cpm_kernels>=1.0.9
3+
typing_extensions

0 commit comments

Comments
 (0)