Skip to content

Commit f7a5a6f

Browse files
Reinforce PR (#1288)
* - Add reinforce - Add reinforce leave one out - Add model weight sharing via pointers - Add online dataset * - Finish testing reinforce - Add example files - TODO: Add the whole online loop instructions once vllm fork is pushed * - add readme on how to install vllm for online inference during training * unify filename convention * improve docs * add reinforce link to post-training readme --------- Co-authored-by: Quentin Anthony <[email protected]>
1 parent 8900d05 commit f7a5a6f

10 files changed

+759
-11
lines changed

megatron/data/data_utils.py

+51-1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from megatron.data.blendable_dataset import BlendableDataset
2525
from megatron.data.gpt2_dataset import GPT2Dataset
2626
from megatron.data.pairwise_dataset import PairwiseDataset
27+
from megatron.data.online_dataset import OnlineDataset
2728
from megatron.data.samplers import DistributedBatchSampler
2829

2930

@@ -532,7 +533,56 @@ def build_train_valid_test_data_loaders(neox_args):
532533
pipe_load = True
533534

534535
# Data loader only on rank 0 of each model parallel group.
535-
if mpu.get_model_parallel_rank() == 0 and pipe_load:
536+
if (
537+
pipe_load
538+
and (neox_args.dataset_impl == "online")
539+
and (mpu.get_model_parallel_rank() == 0)
540+
):
541+
# Can skip most of the work...
542+
train_iters = neox_args.train_iters
543+
eval_iters = (train_iters // neox_args.eval_interval + 1) * neox_args.eval_iters
544+
test_iters = neox_args.eval_iters
545+
# Build datasets...
546+
print(
547+
f"train_iters: {train_iters}, eval_iters: {eval_iters}, test_iters: {test_iters}"
548+
)
549+
train_datasets = OnlineDataset(
550+
leave_one_out=neox_args.reinforce_leave_one_out,
551+
data_split="train",
552+
num_samples=train_iters * neox_args.train_batch_size,
553+
seq_length=neox_args.seq_length,
554+
dataserver_ips=neox_args.online_dataserver_ips,
555+
dataserver_ports=neox_args.online_dataserver_ports,
556+
)
557+
valid_datasets = OnlineDataset(
558+
leave_one_out=neox_args.reinforce_leave_one_out,
559+
data_split="valid",
560+
num_samples=eval_iters * neox_args.train_batch_size,
561+
seq_length=neox_args.seq_length,
562+
dataserver_ips=neox_args.online_dataserver_ips,
563+
dataserver_ports=neox_args.online_dataserver_ports,
564+
)
565+
test_datasets = OnlineDataset(
566+
leave_one_out=neox_args.reinforce_leave_one_out,
567+
data_split="test",
568+
num_samples=test_iters * neox_args.train_batch_size,
569+
seq_length=neox_args.seq_length,
570+
dataserver_ips=neox_args.online_dataserver_ips,
571+
dataserver_ports=neox_args.online_dataserver_ports,
572+
)
573+
# print length of datasets
574+
# Build dataloders.
575+
train_dataloader = make_data_loader(train_datasets, neox_args=neox_args)
576+
valid_dataloader = make_data_loader(valid_datasets, neox_args=neox_args)
577+
test_dataloader = make_data_loader(test_datasets, neox_args=neox_args)
578+
579+
# Flags to know if we need to do training/validation/testing.
580+
do_train = train_dataloader is not None and neox_args.train_iters > 0
581+
do_valid = valid_dataloader is not None and neox_args.eval_iters > 0
582+
do_test = test_dataloader is not None and neox_args.eval_iters > 0
583+
# Need to broadcast num_tokens and num_type_tokens.
584+
flags = torch.cuda.LongTensor([int(do_train), int(do_valid), int(do_test)])
585+
elif mpu.get_model_parallel_rank() == 0 and pipe_load:
536586
# Number of train/valid/test samples.
537587
if neox_args.train_iters is not None:
538588
train_iters = neox_args.train_iters

megatron/data/online_dataset.py

+128
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
# Copyright (c) 2024, EleutherAI
2+
# This file is based on code by the authors denoted below and has been modified from its original version.
3+
#
4+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
18+
"""Online dataset."""
19+
from typing import Union, List
20+
21+
import numpy as np
22+
import torch
23+
import torch.utils.data
24+
import socket
25+
import pickle
26+
from megatron.mpu.initialize import get_data_parallel_rank
27+
28+
29+
class OnlineDataset(torch.utils.data.Dataset):
30+
def __init__(
31+
self,
32+
num_samples,
33+
seq_length,
34+
leave_one_out=False,
35+
data_split="train",
36+
dataserver_ips: Union[str, List[str]] = "localhost",
37+
dataserver_ports: Union[int, List[int]] = 10000,
38+
):
39+
self.num_samples = num_samples
40+
self.global_rank = get_data_parallel_rank()
41+
self.leave_one_out = leave_one_out
42+
self.reward_buffer = []
43+
self.online_batching_data = []
44+
self.data_split = data_split
45+
self.seq_length = seq_length
46+
self.dataserver_ips = dataserver_ips
47+
self.dataserver_ports = dataserver_ports
48+
49+
def __len__(self):
50+
# dummy value since it's decided by the Online Trainer
51+
return self.num_samples
52+
53+
def update_online_batches(self):
54+
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
55+
if isinstance(self.dataserver_ips, str):
56+
ipaddr = self.dataserver_ips
57+
else:
58+
ipaddr = self.dataserver_ips[self.global_rank]
59+
if isinstance(self.dataserver_ports, int):
60+
# simply add over the global rank
61+
port = self.dataserver_ports
62+
else:
63+
# in case we want to use different ports for different ranks, e.g. per machine sampling
64+
port = self.dataserver_ports[self.global_rank]
65+
print(f"Connecting to {ipaddr}:{port}")
66+
s.connect((ipaddr, port))
67+
s.send(self.data_split.encode())
68+
data = b""
69+
while True:
70+
chunk = s.recv(4096)
71+
if not chunk:
72+
break
73+
data += chunk
74+
batch_data = pickle.loads(data)
75+
s.close()
76+
print(f"Received {len(batch_data)} samples from the server.")
77+
for data in batch_data:
78+
if self.leave_one_out:
79+
rewards = list()
80+
for i in range(len(data["rewards"])):
81+
rewards.append(
82+
data["rewards"][i]
83+
- np.mean(
84+
[
85+
data["rewards"][j]
86+
for j in range(len(data["rewards"]))
87+
if j != i
88+
]
89+
)
90+
)
91+
data["raw_rewards"] = data["rewards"]
92+
data["rewards"] = rewards
93+
else:
94+
moving_average = 0
95+
if len(self.reward_buffer) > 0:
96+
moving_average = np.mean(self.reward_buffer)
97+
self.reward_buffer.append(np.mean(data["rewards"]))
98+
if len(self.reward_buffer) > 100:
99+
self.reward_buffer.pop(0)
100+
# For metrics...
101+
data["raw_rewards"] = data["rewards"]
102+
data["rewards"] = [r - moving_average for r in data["rewards"]]
103+
for i in range(len(data["completions"])):
104+
self.online_batching_data.append(
105+
[
106+
data["prefix"],
107+
data["completions"][i],
108+
data["rewards"][i],
109+
data["raw_rewards"][i],
110+
]
111+
)
112+
113+
def __getitem__(self, idx):
114+
if len(self.online_batching_data) == 0:
115+
self.update_online_batches()
116+
batch = self.online_batching_data.pop(0)
117+
text = batch[0] + batch[1]
118+
label = [-100 for _ in batch[0]] + batch[1]
119+
# +1 because of causal masking
120+
if len(text) <= self.seq_length:
121+
text = text + [0] * ((self.seq_length + 1) - len(text))
122+
label = label + [-100] * ((self.seq_length + 1) - len(label))
123+
return {
124+
"text": np.array(text, dtype=np.int64),
125+
"label": np.array(label, dtype=np.int64),
126+
"reward": np.array([batch[2]], dtype=np.float32),
127+
"raw_reward": np.array([batch[3]], dtype=np.float32),
128+
}

megatron/model/weight_server.py

+64
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
from typing import Union, List
2+
3+
import torch
4+
import socket
5+
import pickle
6+
7+
8+
def send_tensor(state_dict_key, data, sock, end: bool):
9+
storage = data.storage()
10+
(
11+
storage_device,
12+
storage_handle,
13+
storage_size_bytes,
14+
storage_offset_bytes,
15+
ref_counter_handle,
16+
ref_counter_offset,
17+
event_handle,
18+
event_sync_required,
19+
) = storage._share_cuda_()
20+
sock.send(
21+
pickle.dumps(
22+
{
23+
"state_dict_key": state_dict_key,
24+
"dtype": data.dtype,
25+
"tensor_size": data.shape,
26+
"tensor_stride": data.stride(),
27+
"tensor_offset": data.storage_offset(), # !Not sure about this one.
28+
"storage_cls": type(storage),
29+
"storage_device": storage_device,
30+
"storage_handle": storage_handle,
31+
"storage_size_bytes": storage_size_bytes,
32+
"storage_offset_bytes": storage_offset_bytes,
33+
"requires_grad": False,
34+
"ref_counter_handle": ref_counter_handle,
35+
"ref_counter_offset": ref_counter_offset,
36+
"event_handle": event_handle,
37+
"event_sync_required": event_sync_required,
38+
"end": end,
39+
}
40+
)
41+
)
42+
43+
44+
def send_state_dict(state_dict, sock):
45+
for i, key in enumerate(state_dict.keys()):
46+
print(key)
47+
end = i == len(state_dict.keys()) - 1
48+
send_tensor(key, state_dict[key], sock, end)
49+
sock.recv(4096)
50+
51+
52+
def start_server(model, ports: Union[int, List[int]] = 6000):
53+
global_rank = torch.distributed.get_rank()
54+
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
55+
if type(ports) == int:
56+
port = ports + global_rank
57+
else:
58+
port = ports[global_rank]
59+
s.bind(("localhost", port))
60+
s.listen(1)
61+
conn, addr = s.accept()
62+
state_dict = model.state_dict()
63+
send_state_dict(state_dict, conn)
64+
conn.close()

megatron/neox_arguments/neox_args.py

+47-4
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,28 @@ class NeoXArgsModel(NeoXArgsTemplate):
502502
Parameter controlling whether the output layer is parallelized over the hidden dim (row) or the vocab dim (column)
503503
"""
504504

505+
serve_model_weights: bool = False
506+
"""
507+
If true, serve model weight pointers over a socket connection
508+
"""
509+
510+
weight_server_port: Union[int, List[int]] = 6000
511+
"""
512+
Port(s) to serve model weights over
513+
If an integer is provided, the port for each GPU will be 6000 + global rank
514+
If a list is provided, the ports will be used in order, e.g. rank0 will be weight_server_port[0]
515+
"""
516+
517+
online_dataserver_ips: Union[str, List[str]] = "localhost"
518+
"""
519+
ip addresses to connect to for online data serving, defaults to localhost
520+
"""
521+
522+
online_dataserver_ports: Union[int, List[int]] = 10000
523+
"""
524+
Port(s) to connect to for online data serving, defaults to 10000
525+
"""
526+
505527
te_columnparallel: bool = False
506528
"""
507529
Use TransformerEngine for RowParallelLinear layer.
@@ -1132,14 +1154,14 @@ class NeoXArgsTraining(NeoXArgsTemplate):
11321154
warning: pack_until_overflow is very naive and will likely have issues with pretraining scale datasets
11331155
"""
11341156

1135-
dataset_impl: Literal["gpt2", "pairwise"] = "gpt2"
1157+
dataset_impl: Literal["gpt2", "pairwise", "online"] = "gpt2"
11361158
"""
1137-
Dataset implementation, can be one of "gpt2" or "pairwise"
1159+
Dataset implementation, can be one of "gpt2", "pairwise", or "online"
11381160
"""
11391161

1140-
train_impl: Literal["normal", "dpo", "rm", "kto"] = "normal"
1162+
train_impl: Literal["normal", "dpo", "rm", "kto", "reinforce"] = "normal"
11411163
"""
1142-
Training implementation, can be one of "normal", "dpo", "kto", or "rm"
1164+
Training implementation, can be one of "normal", "dpo", "kto", "reinforce", or "rm"
11431165
"""
11441166

11451167
dpo_fp32: bool = True
@@ -1184,6 +1206,27 @@ class NeoXArgsTraining(NeoXArgsTemplate):
11841206
Beta value for KTO
11851207
"""
11861208

1209+
fp32_reinforce: bool = True
1210+
"""
1211+
Whether to cast logits to fp32 for Reinforce loss calculation.
1212+
"""
1213+
1214+
kl_impl: Literal["abs", "mse", "kl", "full"] = "mse"
1215+
"""
1216+
KL divergence implementation, can be one of "abs", "mse", "kl", or "full"
1217+
"""
1218+
1219+
kl_div_beta: float = 0.1
1220+
"""
1221+
Beta value for KL divergence in Reinforce loss calculation.
1222+
"""
1223+
1224+
reinforce_leave_one_out: bool = False
1225+
"""
1226+
Whether to use reinforce leave one out for training
1227+
(from https://arxiv.org/abs/2402.14740 and https://api.semanticscholar.org/CorpusID:198489118)
1228+
"""
1229+
11871230
allow_chopped: bool = True
11881231
"""
11891232
WARNING: if your packing impl is packed, this is ignored.

0 commit comments

Comments
 (0)