Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Finetuning #258

Open
wants to merge 36 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
d1fc873
Add finetuning code.
Mar 3, 2019
1fba31f
chmod +x
Mar 3, 2019
dfca3cf
Add finetuning instructions
Mar 3, 2019
9423776
Fix sample generation with batch_size greater than 1.
Mar 3, 2019
3e18729
Add training script with Horovod support
tlkh Mar 18, 2019
ec16bad
Fix typo in train command in README
tlkh Mar 18, 2019
0bad9e4
Added instructions for training using Horovod
tlkh Mar 18, 2019
ef62678
Merge pull request #2 from tlkh/finetuning
nshepperd Mar 19, 2019
c465071
autoformat
Mar 4, 2019
1e32b10
Combine input text files with <|endoftext|> delimiter to ensure there…
Mar 19, 2019
3a3ce65
Write losses to summary file for tensorboard.
Mar 20, 2019
d5b387b
Add learning rate as command line flag.
Mar 20, 2019
b106d0a
Use argparse instead of fire in train.py.
Mar 20, 2019
2044d13
Fix encode.py
Mar 21, 2019
a359a34
Add gradient accumulation with default of 5 minibatches
Mar 21, 2019
8738950
Merge remote-tracking branch 'origin/master' into finetuning
Mar 25, 2019
eda8777
Turn off gradient accumulation by default, it shouldn't be needed.
May 2, 2019
47df6da
Add gradient checkpointing and another optimization necessary to allo…
May 4, 2019
c46ed99
Add "validation" loss calculation.
May 4, 2019
941a762
Add toposort to requirements
Tenoke May 5, 2019
13c5412
Merge pull request #3 from Tenoke/finetuning
May 6, 2019
3985cc7
Add option to use SGD for optimizer
May 14, 2019
7fc2a44
Record learning rate in tensorboard logs
May 14, 2019
a464925
Add text in README for --optimizer flag
May 14, 2019
ae535b6
Reduce default learning rate of train.py.
May 14, 2019
2d4fd0c
Merge remote-tracking branch 'origin/master' into finetuning
May 14, 2019
6a77a7b
New feature: add noise to network inputs to regularize against overre…
May 15, 2019
87fe3d7
Add top-p sampling
May 15, 2019
e99ee37
Add top_p to interactive_conditional_samples.py and generate_uncondit…
May 15, 2019
2b24145
fix typo in top_p
May 15, 2019
6c1f21d
Fix top_p sampling for batch_size>1
May 15, 2019
cca7144
Updated README.md
biranchi2018 Aug 15, 2019
a070f38
Merge pull request #22 from biranchi2018/biranchi2018-patch-1
Aug 27, 2019
50fa3b6
Add note to install cudnn, re https://github.com/nshepperd/gpt-2/issu…
Jun 16, 2019
b7cda3f
Add flag to set encoding for text reading and writing, defaulting to …
Jul 20, 2019
b8cd943
Replace Dockerfile with a fixed one
neo2478 Aug 13, 2020
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
__pycache__
.mypy_cache/
models/
checkpoint
samples
14 changes: 14 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
FROM tensorflow/tensorflow:1.15.2-py3-jupyter

# setup environment language
ENV LANG=C.UTF-8

# copy requirements.txt into image
COPY requirements.txt requirements.txt

# update and upgrade packages and pip and install python libraries
RUN apt-get update && apt-get upgrade -y \
&& apt-get install -y apt-utils \
&& pip3 install --upgrade pip \
&& pip3 install -r requirements.txt \
&& rm requirements.txt
9 changes: 0 additions & 9 deletions Dockerfile.cpu

This file was deleted.

18 changes: 0 additions & 18 deletions Dockerfile.gpu

This file was deleted.

55 changes: 55 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@

Reference: ["Beginner’s Guide to Retrain GPT-2 (117M) to Generate Custom Text Content"](https://medium.com/@ngwaifoong92/beginners-guide-to-retrain-gpt-2-117m-to-generate-custom-text-content-8bb5363d8b7f)

# gpt-2

Code from the paper ["Language Models are Unsupervised Multitask Learners"](https://d4mucfpksywv.cloudfront.net/better-language-models/language-models.pdf).
Expand Down Expand Up @@ -30,6 +33,58 @@ See [DEVELOPERS.md](./DEVELOPERS.md)

See [CONTRIBUTORS.md](./CONTRIBUTORS.md)

## Fine tuning on custom datasets

To retrain GPT-2 117M model on a custom text dataset:

```
PYTHONPATH=src ./train.py --dataset <file|directory|glob>
```

If you want to precompute the dataset's encoding for multiple runs, you can instead use:

```
PYTHONPATH=src ./encode.py <file|directory|glob> /path/to/encoded.npz
PYTHONPATH=src ./train.py --dataset /path/to/encoded.npz
```

Make sure `cudnn` is installed. [Some have reported](https://github.com/nshepperd/gpt-2/issues/8) that `train.py` runs without it but has worse memory usage and might OOM.

### Gradient Checkpointing

https://github.com/openai/gradient-checkpointing is included to reduce the memory requirements of the model, and can be enabled by `--memory_saving_gradients`. The checkpoints are currently chosen manually (poorly) by just adding layer 10 to the 'checkpoints' collection in model.py. `--memory_saving_gradients` is enabled by default for training the 345M model.

### Validation loss

Set `--val_every` to a number of steps `N > 0`, and "validation" loss against a fixed sample of the dataset will be calculated every N steps to get a better sense of training progress. N around 200 suggested. You can set `--val_dataset` to choose a separate validation dataset, otherwise it defaults to a sample from the train dataset (so not a real cross-validation loss!).

### Optimizer

You can use SGD instead of Adam with `--optimizer sgd`. This also helps conserve memory when training the 345M model. Note: the learning rate needs to be adjusted for SGD, due to not having Adam's gradient normalization (0.0006 seems to be a good number from some experiments).

### Multi gpu (out of date)

To do distributed on multiple GPUs or machines using Horovod:

```
mpirun -np 4 \
-H localhost:4 \
-bind-to none -map-by slot \
-x NCCL_DEBUG=INFO -x LD_LIBRARY_PATH -x PATH \
-x PYTHONPATH=src \
-mca pml ob1 -mca btl ^openib \
/home/jovyan/gpt-2/train-horovod.py --dataset encoded.npz
```

## GPT-2 samples

| WARNING: Samples are unfiltered and may contain offensive content. |
| --- |

While we have not yet released GPT-2 itself, you can see some samples from it in the `gpt-2-samples` folder.
We show unconditional samples with default settings (temperature 1 and no truncation), with temperature 0.7, and with truncation with top_k 40.
We show conditional samples, with contexts drawn from `WebText`'s test set, with default settings (temperature 1 and no truncation), with temperature 0.7, and with truncation with top_k 40.

## Citation

Please use the following bibtex entry:
Expand Down
31 changes: 31 additions & 0 deletions encode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#!/usr/bin/env python3
# Usage:
# PYTHONPATH=src ./encode.py <file|directory|glob> /path/to/output.npz
# PYTHONPATH=src ./train --dataset /path/to/output.npz

import argparse
import numpy as np

import encoder
from load_dataset import load_dataset

parser = argparse.ArgumentParser(
description='Pre-encode text files into tokenized training set.',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--model_name', metavar='MODEL', type=str, default='117M', help='Pretrained model name')
parser.add_argument('--combine', metavar='CHARS', type=int, default=50000, help='Concatenate files with <|endoftext|> separator into chunks of this minimum size')
parser.add_argument('--encoding', type=str, default='utf-8', help='Set the encoding for reading and writing files.')
parser.add_argument('in_text', metavar='PATH', type=str, help='Input file, directory, or glob pattern (utf-8 text).')
parser.add_argument('out_npz', metavar='OUT.npz', type=str, help='Output file path')

def main():
args = parser.parse_args()
enc = encoder.get_encoder(args.model_name)
print('Reading files')
chunks = load_dataset(enc, args.in_text, args.combine, encoding=args.encoding)
print('Writing', args.out_npz)
np.savez_compressed(args.out_npz, *chunks)


if __name__ == '__main__':
main()
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ fire>=0.1.3
regex==2017.4.5
requests==2.21.0
tqdm==4.31.1
toposort==1.5
36 changes: 36 additions & 0 deletions src/accumulate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import argparse
import json
import os
import numpy as np
import tensorflow as tf
import time


class AccumulatingOptimizer(object):
def __init__(self, opt, var_list):
self.opt = opt
self.var_list = var_list
self.accum_vars = {tv : tf.Variable(tf.zeros_like(tv.initialized_value()), trainable=False)
for tv in var_list}
self.total_loss = tf.Variable(tf.zeros(shape=[], dtype=tf.float32))
self.count_loss = tf.Variable(tf.zeros(shape=[], dtype=tf.float32))

def reset(self):
updates = [tv.assign(tf.zeros_like(tv)) for tv in self.accum_vars.values()]
updates.append(self.total_loss.assign(tf.zeros(shape=[], dtype=tf.float32)))
updates.append(self.count_loss.assign(tf.zeros(shape=[], dtype=tf.float32)))
with tf.control_dependencies(updates):
return tf.no_op()

def compute_gradients(self, loss):
grads = self.opt.compute_gradients(loss, self.var_list)
updates = [self.accum_vars[v].assign_add(g) for (g,v) in grads]
updates.append(self.total_loss.assign_add(loss))
updates.append(self.count_loss.assign_add(1.0))
with tf.control_dependencies(updates):
return tf.no_op()

def apply_gradients(self):
grads = [(g,v) for (v,g) in self.accum_vars.items()]
with tf.control_dependencies([self.opt.apply_gradients(grads)]):
return self.total_loss / self.count_loss
6 changes: 4 additions & 2 deletions src/generate_unconditional_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def sample_model(
length=None,
temperature=1,
top_k=0,
top_p=0.0
):
"""
Run the sample_model
Expand All @@ -35,6 +36,8 @@ def sample_model(
considered for each step (token), resulting in deterministic completions,
while 40 means 40 words are considered at each step. 0 (default) is a
special setting meaning no restrictions. 40 generally is a good value.
:top_p=0.0 : Float value controlling diversity. Implements nucleus sampling,
overriding top_k if set to a value > 0. A good setting is 0.9.
"""
enc = encoder.get_encoder(model_name)
hparams = model.default_hparams()
Expand All @@ -54,7 +57,7 @@ def sample_model(
hparams=hparams, length=length,
start_token=enc.encoder['<|endoftext|>'],
batch_size=batch_size,
temperature=temperature, top_k=top_k
temperature=temperature, top_k=top_k, top_p=top_p
)[:, 1:]

saver = tf.train.Saver()
Expand All @@ -72,4 +75,3 @@ def sample_model(

if __name__ == '__main__':
fire.Fire(sample_model)

6 changes: 4 additions & 2 deletions src/interactive_conditional_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def interact_model(
length=None,
temperature=1,
top_k=0,
top_p=0.0
):
"""
Interactively run the model
Expand All @@ -34,6 +35,8 @@ def interact_model(
considered for each step (token), resulting in deterministic completions,
while 40 means 40 words are considered at each step. 0 (default) is a
special setting meaning no restrictions. 40 generally is a good value.
:top_p=0.0 : Float value controlling diversity. Implements nucleus sampling,
overriding top_k if set to a value > 0. A good setting is 0.9.
"""
if batch_size is None:
batch_size = 1
Expand All @@ -57,7 +60,7 @@ def interact_model(
hparams=hparams, length=length,
context=context,
batch_size=batch_size,
temperature=temperature, top_k=top_k
temperature=temperature, top_k=top_k, top_p=top_p
)

saver = tf.train.Saver()
Expand All @@ -84,4 +87,3 @@ def interact_model(

if __name__ == '__main__':
fire.Fire(interact_model)

83 changes: 83 additions & 0 deletions src/load_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import glob
import numpy as np
import os
import tensorflow as tf
import tqdm


def load_dataset(enc, path, combine, encoding=None):
paths = []
if os.path.isfile(path):
# Simple file
paths.append(path)
elif os.path.isdir(path):
# Directory
for (dirpath, _, fnames) in os.walk(path):
for fname in fnames:
paths.append(os.path.join(dirpath, fname))
else:
# Assume glob
paths = glob.glob(path)

token_chunks = []
raw_text = ''
for path in tqdm.tqdm(paths):
if path.endswith('.npz'):
# Pre-encoded
with np.load(path) as npz:
for item in npz.files:
token_chunks.append(npz[item])
else:
# Plain text
with open(path, 'r', encoding=encoding) as fp:
raw_text += fp.read()
if len(raw_text) >= combine:
tokens = np.stack(enc.encode(raw_text))
token_chunks.append(tokens)
raw_text = ''
else:
raw_text += '<|endoftext|>'
if raw_text:
tokens = np.stack(enc.encode(raw_text))
token_chunks.append(tokens)
return token_chunks


def binary_search(f, lo, hi):
if f(lo) or not f(hi):
return None
while hi > lo + 1:
mid = (lo + hi) // 2
if f(mid):
hi = mid
else:
lo = mid
return hi


class Sampler(object):
"""Fairly samples a slice from a set of variable sized chunks.

'Fairly' means that the distribution is the same as sampling from one concatenated chunk,
but without crossing chunk boundaries."""

def __init__(self, chunks, seed=None):
self.chunks = chunks
self.total_size = sum(chunk.shape[0] for chunk in chunks)
self.boundaries = [0]
for i in range(len(chunks)):
self.boundaries.append(self.boundaries[-1] + chunks[i].shape[0])
self.rs = np.random.RandomState(seed=seed)

def sample(self, length):
assert length < self.total_size // len(
self.chunks
), "Dataset files are too small to sample {} tokens at a time".format(
length)
while True:
index = self.rs.randint(0, self.total_size - length - 1)
i = binary_search(lambda j: self.boundaries[j] > index, 0,
len(self.boundaries) - 1) - 1
if self.boundaries[i + 1] > index + length:
within_chunk = index - self.boundaries[i]
return self.chunks[i][within_chunk:within_chunk + length]
Loading