From d9b673eb394c79648d9d7f060bbbea0e3123a6f9 Mon Sep 17 00:00:00 2001 From: Max Woolf Date: Sun, 16 Jun 2019 10:18:09 -0700 Subject: [PATCH] Cap gen length if prefix to prevent OOB (#38) --- gpt_2_simple/gpt_2.py | 13 +++++++------ setup.py | 2 +- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/gpt_2_simple/gpt_2.py b/gpt_2_simple/gpt_2.py index b97370c..b0e44a8 100644 --- a/gpt_2_simple/gpt_2.py +++ b/gpt_2_simple/gpt_2.py @@ -378,9 +378,6 @@ def generate(sess, if prefix == '': prefix = None - if prefix: - context = tf.placeholder(tf.int32, [batch_size, None]) - CHECKPOINT_DIR = 'checkpoint' SAMPLE_DIR = 'samples' @@ -391,11 +388,17 @@ def generate(sess, with open(os.path.join(checkpoint_path, 'hparams.json')) as f: hparams.override_from_dict(json.load(f)) + if prefix: + context = tf.placeholder(tf.int32, [batch_size, None]) + context_tokens = enc.encode(prefix) + assert len(context_tokens) < length + np.random.seed(seed) tf.set_random_seed(seed) output = sample.sample_sequence( - hparams=hparams, length=length, + hparams=hparams, + length=min(length, 1023 - (len(context_tokens) if prefix else 0)), start_token=enc.encoder['<|endoftext|>'] if not prefix else None, context=context if prefix else None, batch_size=batch_size, @@ -404,8 +407,6 @@ def generate(sess, if destination_path: f = open(destination_path, 'w') - if prefix: - context_tokens = enc.encode(prefix) generated = 0 gen_texts = [] while generated < nsamples: diff --git a/setup.py b/setup.py index d5dc8df..1357311 100644 --- a/setup.py +++ b/setup.py @@ -47,7 +47,7 @@ setup( name='gpt_2_simple', packages=['gpt_2_simple'], # this must be the same as the name above - version='0.5.1', + version='0.5.2', description="Python package to easily retrain OpenAI's GPT-2 " \ "text-generating model on new texts.", long_description=long_description,