Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

upgrade to TensorFlow 2 #262

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions src/generate_unconditional_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,16 @@ def sample_model(
enc = encoder.get_encoder(model_name, models_dir)
hparams = model.default_hparams()
with open(os.path.join(models_dir, model_name, 'hparams.json')) as f:
hparams.override_from_dict(json.load(f))
hparams.update(json.load(f))

if length is None:
length = hparams.n_ctx
elif length > hparams.n_ctx:
raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx)
length = hparams['n_ctx']
elif length > hparams['n_ctx']:
raise ValueError("Can't get samples longer than window size: %s" % hparams['n_ctx'])

with tf.Session(graph=tf.Graph()) as sess:
with tf.compat.v1.Session(graph=tf.Graph()) as sess:
np.random.seed(seed)
tf.set_random_seed(seed)
tf.compat.v1.set_random_seed(seed)

output = sample.sample_sequence(
hparams=hparams, length=length,
Expand All @@ -62,7 +62,7 @@ def sample_model(
temperature=temperature, top_k=top_k, top_p=top_p
)[:, 1:]

saver = tf.train.Saver()
saver = tf.compat.v1.train.Saver()
ckpt = tf.train.latest_checkpoint(os.path.join(models_dir, model_name))
saver.restore(sess, ckpt)

Expand Down
16 changes: 8 additions & 8 deletions src/interactive_conditional_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,25 +47,25 @@ def interact_model(
enc = encoder.get_encoder(model_name, models_dir)
hparams = model.default_hparams()
with open(os.path.join(models_dir, model_name, 'hparams.json')) as f:
hparams.override_from_dict(json.load(f))
hparams.update(json.load(f))

if length is None:
length = hparams.n_ctx // 2
elif length > hparams.n_ctx:
raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx)
length = hparams['n_ctx'] // 2
elif length > hparams['n_ctx']:
raise ValueError("Can't get samples longer than window size: %s" % hparams['n_ctx'])

with tf.Session(graph=tf.Graph()) as sess:
context = tf.placeholder(tf.int32, [batch_size, None])
with tf.compat.v1.Session(graph=tf.Graph()) as sess:
context = tf.compat.v1.placeholder(tf.int32, [batch_size, None])
np.random.seed(seed)
tf.set_random_seed(seed)
tf.compat.v1.set_random_seed(seed)
output = sample.sample_sequence(
hparams=hparams, length=length,
context=context,
batch_size=batch_size,
temperature=temperature, top_k=top_k, top_p=top_p
)

saver = tf.train.Saver()
saver = tf.compat.v1.train.Saver()
ckpt = tf.train.latest_checkpoint(os.path.join(models_dir, model_name))
saver.restore(sess, ckpt)

Expand Down
87 changes: 43 additions & 44 deletions src/model.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,38 @@
import numpy as np
import tensorflow as tf
from tensorflow.contrib.training import HParams

def default_hparams():
return HParams(
n_vocab=0,
n_ctx=1024,
n_embd=768,
n_head=12,
n_layer=12,
)
return {
'n_vocab' : 0,
'n_ctx' : 1024,
'n_embd' : 768,
'n_head' : 12,
'n_layer' : 12,
}

def shape_list(x):
"""Deal with dynamic shape in tensorflow cleanly."""
static = x.shape.as_list()
dynamic = tf.shape(x)
dynamic = tf.shape(input=x)
return [dynamic[i] if s is None else s for i, s in enumerate(static)]

def softmax(x, axis=-1):
x = x - tf.reduce_max(x, axis=axis, keepdims=True)
x = x - tf.reduce_max(input_tensor=x, axis=axis, keepdims=True)
ex = tf.exp(x)
return ex / tf.reduce_sum(ex, axis=axis, keepdims=True)
return ex / tf.reduce_sum(input_tensor=ex, axis=axis, keepdims=True)

def gelu(x):
return 0.5*x*(1+tf.tanh(np.sqrt(2/np.pi)*(x+0.044715*tf.pow(x, 3))))

def norm(x, scope, *, axis=-1, epsilon=1e-5):
"""Normalize to mean = 0, std = 1, then do a diagonal affine transform."""
with tf.variable_scope(scope):
n_state = x.shape[-1].value
g = tf.get_variable('g', [n_state], initializer=tf.constant_initializer(1))
b = tf.get_variable('b', [n_state], initializer=tf.constant_initializer(0))
u = tf.reduce_mean(x, axis=axis, keepdims=True)
s = tf.reduce_mean(tf.square(x-u), axis=axis, keepdims=True)
x = (x - u) * tf.rsqrt(s + epsilon)
with tf.compat.v1.variable_scope(scope):
n_state = x.shape[-1]
g = tf.compat.v1.get_variable('g', [n_state], initializer=tf.compat.v1.constant_initializer(1))
b = tf.compat.v1.get_variable('b', [n_state], initializer=tf.compat.v1.constant_initializer(0))
u = tf.reduce_mean(input_tensor=x, axis=axis, keepdims=True)
s = tf.reduce_mean(input_tensor=tf.square(x-u), axis=axis, keepdims=True)
x = (x - u) * tf.math.rsqrt(s + epsilon)
x = x*g + b
return x

Expand All @@ -48,10 +47,10 @@ def merge_states(x):
return tf.reshape(x, start + [a*b])

def conv1d(x, scope, nf, *, w_init_stdev=0.02):
with tf.variable_scope(scope):
with tf.compat.v1.variable_scope(scope):
*start, nx = shape_list(x)
w = tf.get_variable('w', [1, nx, nf], initializer=tf.random_normal_initializer(stddev=w_init_stdev))
b = tf.get_variable('b', [nf], initializer=tf.constant_initializer(0))
w = tf.compat.v1.get_variable('w', [1, nx, nf], initializer=tf.compat.v1.random_normal_initializer(stddev=w_init_stdev))
b = tf.compat.v1.get_variable('b', [nf], initializer=tf.compat.v1.constant_initializer(0))
c = tf.reshape(tf.matmul(tf.reshape(x, [-1, nx]), tf.reshape(w, [-1, nf]))+b, start+[nf])
return c

Expand All @@ -68,17 +67,17 @@ def attention_mask(nd, ns, *, dtype):

def attn(x, scope, n_state, *, past, hparams):
assert x.shape.ndims == 3 # Should be [batch, sequence, features]
assert n_state % hparams.n_head == 0
assert n_state % hparams['n_head'] == 0
if past is not None:
assert past.shape.ndims == 5 # Should be [batch, 2, heads, sequence, features], where 2 is [k, v]

def split_heads(x):
# From [batch, sequence, features] to [batch, heads, sequence, features]
return tf.transpose(split_states(x, hparams.n_head), [0, 2, 1, 3])
return tf.transpose(a=split_states(x, hparams['n_head']), perm=[0, 2, 1, 3])

def merge_heads(x):
# Reverse of split_heads
return merge_states(tf.transpose(x, [0, 2, 1, 3]))
return merge_states(tf.transpose(a=x, perm=[0, 2, 1, 3]))

def mask_attn_weights(w):
# w has shape [batch, heads, dst_sequence, src_sequence], where information flows from src to dst.
Expand All @@ -91,14 +90,14 @@ def mask_attn_weights(w):
def multihead_attn(q, k, v):
# q, k, v have shape [batch, heads, sequence, features]
w = tf.matmul(q, k, transpose_b=True)
w = w * tf.rsqrt(tf.cast(v.shape[-1].value, w.dtype))
w = w * tf.math.rsqrt(tf.cast(v.shape[-1], w.dtype))

w = mask_attn_weights(w)
w = softmax(w)
a = tf.matmul(w, v)
return a

with tf.variable_scope(scope):
with tf.compat.v1.variable_scope(scope):
c = conv1d(x, 'c_attn', n_state*3)
q, k, v = map(split_heads, tf.split(c, 3, axis=2))
present = tf.stack([k, v], axis=1)
Expand All @@ -113,62 +112,62 @@ def multihead_attn(q, k, v):


def mlp(x, scope, n_state, *, hparams):
with tf.variable_scope(scope):
nx = x.shape[-1].value
with tf.compat.v1.variable_scope(scope):
nx = x.shape[-1]
h = gelu(conv1d(x, 'c_fc', n_state))
h2 = conv1d(h, 'c_proj', nx)
return h2


def block(x, scope, *, past, hparams):
with tf.variable_scope(scope):
nx = x.shape[-1].value
with tf.compat.v1.variable_scope(scope):
nx = x.shape[-1]
a, present = attn(norm(x, 'ln_1'), 'attn', nx, past=past, hparams=hparams)
x = x + a
m = mlp(norm(x, 'ln_2'), 'mlp', nx*4, hparams=hparams)
x = x + m
return x, present

def past_shape(*, hparams, batch_size=None, sequence=None):
return [batch_size, hparams.n_layer, 2, hparams.n_head, sequence, hparams.n_embd // hparams.n_head]
return [batch_size, hparams['n_layer'], 2, hparams['n_head'], sequence, hparams['n_embd'] // hparams['n_head']]

def expand_tile(value, size):
"""Add a new axis of given size."""
value = tf.convert_to_tensor(value, name='value')
value = tf.convert_to_tensor(value=value, name='value')
ndims = value.shape.ndims
return tf.tile(tf.expand_dims(value, axis=0), [size] + [1]*ndims)

def positions_for(tokens, past_length):
batch_size = tf.shape(tokens)[0]
nsteps = tf.shape(tokens)[1]
batch_size = tf.shape(input=tokens)[0]
nsteps = tf.shape(input=tokens)[1]
return expand_tile(past_length + tf.range(nsteps), batch_size)


def model(hparams, X, past=None, scope='model', reuse=False):
with tf.variable_scope(scope, reuse=reuse):
with tf.compat.v1.variable_scope(scope, reuse=reuse):
results = {}
batch, sequence = shape_list(X)

wpe = tf.get_variable('wpe', [hparams.n_ctx, hparams.n_embd],
initializer=tf.random_normal_initializer(stddev=0.01))
wte = tf.get_variable('wte', [hparams.n_vocab, hparams.n_embd],
initializer=tf.random_normal_initializer(stddev=0.02))
past_length = 0 if past is None else tf.shape(past)[-2]
wpe = tf.compat.v1.get_variable('wpe', [hparams['n_ctx'], hparams['n_embd']],
initializer=tf.compat.v1.random_normal_initializer(stddev=0.01))
wte = tf.compat.v1.get_variable('wte', [hparams['n_vocab'], hparams['n_embd']],
initializer=tf.compat.v1.random_normal_initializer(stddev=0.02))
past_length = 0 if past is None else tf.shape(input=past)[-2]
h = tf.gather(wte, X) + tf.gather(wpe, positions_for(X, past_length))

# Transformer
presents = []
pasts = tf.unstack(past, axis=1) if past is not None else [None] * hparams.n_layer
assert len(pasts) == hparams.n_layer
pasts = tf.unstack(past, axis=1) if past is not None else [None] * hparams['n_layer']
assert len(pasts) == hparams['n_layer']
for layer, past in enumerate(pasts):
h, present = block(h, 'h%d' % layer, past=past, hparams=hparams)
presents.append(present)
results['present'] = tf.stack(presents, axis=1)
h = norm(h, 'ln_f')

# Language model loss. Do tokens <n predict token n?
h_flat = tf.reshape(h, [batch*sequence, hparams.n_embd])
h_flat = tf.reshape(h, [batch*sequence, hparams['n_embd']])
logits = tf.matmul(h_flat, wte, transpose_b=True)
logits = tf.reshape(logits, [batch, sequence, hparams.n_vocab])
logits = tf.reshape(logits, [batch, sequence, hparams['n_vocab']])
results['logits'] = logits
return results
22 changes: 11 additions & 11 deletions src/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@ def top_k_logits(logits, k):
def _top_k():
values, _ = tf.nn.top_k(logits, k=k)
min_values = values[:, -1, tf.newaxis]
return tf.where(
return tf.compat.v1.where(
logits < min_values,
tf.ones_like(logits, dtype=logits.dtype) * -1e10,
logits,
)
return tf.cond(
tf.equal(k, 0),
lambda: logits,
lambda: _top_k(),
pred=tf.equal(k, 0),
true_fn=lambda: logits,
false_fn=lambda: _top_k(),
)


Expand All @@ -30,10 +30,10 @@ def top_p_logits(logits, p):
indices = tf.stack([
tf.range(0, batch),
# number of indices to include
tf.maximum(tf.reduce_sum(tf.cast(cumulative_probs <= p, tf.int32), axis=-1) - 1, 0),
tf.maximum(tf.reduce_sum(input_tensor=tf.cast(cumulative_probs <= p, tf.int32), axis=-1) - 1, 0),
], axis=-1)
min_values = tf.gather_nd(sorted_logits, indices)
return tf.where(
return tf.compat.v1.where(
logits < min_values,
tf.ones_like(logits) * -1e10,
logits,
Expand All @@ -48,23 +48,23 @@ def sample_sequence(*, hparams, length, start_token=None, batch_size=None, conte
context = tf.fill([batch_size, 1], start_token)

def step(hparams, tokens, past=None):
lm_output = model.model(hparams=hparams, X=tokens, past=past, reuse=tf.AUTO_REUSE)
lm_output = model.model(hparams=hparams, X=tokens, past=past, reuse=tf.compat.v1.AUTO_REUSE)

logits = lm_output['logits'][:, :, :hparams.n_vocab]
logits = lm_output['logits'][:, :, :hparams['n_vocab']]
presents = lm_output['present']
presents.set_shape(model.past_shape(hparams=hparams, batch_size=batch_size))
return {
'logits': logits,
'presents': presents,
}

with tf.name_scope('sample_sequence'):
with tf.compat.v1.name_scope('sample_sequence'):
def body(past, prev, output):
next_outputs = step(hparams, prev, past=past)
logits = next_outputs['logits'][:, -1, :] / tf.to_float(temperature)
logits = next_outputs['logits'][:, -1, :] / tf.cast(temperature, dtype=tf.float32)
logits = top_k_logits(logits, k=top_k)
logits = top_p_logits(logits, p=top_p)
samples = tf.multinomial(logits, num_samples=1, output_dtype=tf.int32)
samples = tf.random.categorical(logits=logits, num_samples=1, dtype=tf.int32)
return [
next_outputs['presents'] if past is None else tf.concat([past, next_outputs['presents']], axis=-2),
samples,
Expand Down