Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow support to reuse sessions/models #140

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion gpt_2_simple/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .gpt_2 import *
from .gpt_2 import *
21 changes: 11 additions & 10 deletions gpt_2_simple/gpt_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@ def download_file_with_progress(url_base, sub_dir, model_name, file_name):
file_name : str
name of file to get e.g. "hparams.json"
sub_dir: str
subdirectory inside which to get and copy locally eg. "models/124M"
subdirectory inside which to get and copy locally eg. "models/124M"
no trailing slash
url_base : str
Start of URL location specifying server and any base directories no
Start of URL location specifying server and any base directories no
trailing slash
e.g. "https://storage.googleapis.com/gpt-2"
"""
Expand All @@ -55,7 +55,7 @@ def download_file_with_progress(url_base, sub_dir, model_name, file_name):
for chunk in r.iter_content(chunk_size=DOWNLOAD_CHUNK_SIZE):
f.write(chunk)
pbar.update(DOWNLOAD_CHUNK_SIZE)


def download_gpt2(model_dir='models', model_name='124M'):
"""Downloads the GPT-2 model into the current directory
Expand All @@ -67,8 +67,8 @@ def download_gpt2(model_dir='models', model_name='124M'):
parent directory of model to download

model_name : str
name of the GPT-2 model to download.
As of 22 May 2019 one of "124M" or "355M" but may later include other
name of the GPT-2 model to download.
As of 22 May 2019 one of "124M" or "355M" but may later include other
model sizes

Adapted from https://github.com/openai/gpt-2/blob/master/download_model.py
Expand Down Expand Up @@ -102,7 +102,7 @@ def start_tf_sess(threads=-1, server=None):

if server is not None:
return tf.compat.v1.Session(target=server.target, config=config)

return tf.compat.v1.Session(config=config)


Expand Down Expand Up @@ -315,7 +315,7 @@ def sample_batch():

if steps:
steps = int(steps)

try:
while True:
if steps > 0 and counter == (counter_base + steps):
Expand Down Expand Up @@ -362,7 +362,8 @@ def load_gpt2(sess,
checkpoint_dir="checkpoint",
model_name=None,
model_dir='models',
multi_gpu=False):
multi_gpu=False,
reuse=False):
"""Loads the model checkpoint or existing model into a TensorFlow session
for repeated predictions.
"""
Expand All @@ -382,7 +383,7 @@ def load_gpt2(sess,
if multi_gpu:
gpus = get_available_gpus()

output = model.model(hparams=hparams, X=context, gpus=gpus)
output = model.model(hparams=hparams, X=context, gpus=gpus, reuse=reuse)

ckpt = tf.train.latest_checkpoint(checkpoint_path)
saver = tf.compat.v1.train.Saver(allow_empty=True)
Expand Down Expand Up @@ -661,7 +662,7 @@ def cmd():
)

# Explicit arguments

parser.add_argument(
'--mode', help='Mode for using the CLI (either "finetune" or "generate") [Required]', nargs='?')
parser.add_argument(
Expand Down
4 changes: 2 additions & 2 deletions gpt_2_simple/src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@ def __init__(self, n_vocab, n_ctx, n_embd, n_head, n_layer):
self.n_embd = n_embd
self.n_head = n_head
self.n_layer = n_layer

def override_from_dict(self, param_dict):
try:
self.n_vocab = param_dict['n_vocab']
except:
except:
pass
try:
self.n_ctx = param_dict['n_ctx']
Expand Down