From eb0d08751ece2fd31f7d4968d736c66f00de4043 Mon Sep 17 00:00:00 2001 From: Bethany Lusch Date: Sat, 12 Oct 2019 17:16:11 -0500 Subject: [PATCH] avoid checking shape of TF objects: problem in tf1.14, fixes #5 --- networkarch.py | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/networkarch.py b/networkarch.py index 1a95db5..9fcb798 100644 --- a/networkarch.py +++ b/networkarch.py @@ -294,7 +294,6 @@ def varying_multiply(y, omegas, delta_t, num_real, num_complex_pairs): Side effects: None """ - k = y.shape[1] complex_list = [] # first, Jordan blocks for each pair of complex conjugate eigenvalues @@ -406,12 +405,15 @@ def omega_net_apply(phase, keep_prob, params, ycoords, weights, biases): for j in np.arange(params['num_complex_pairs']): temp_name = 'OC%d_' % (j + 1) ind = 2 * j + pair_of_columns = ycoords[:, ind:ind + 2] + radius_of_pair = tf.reduce_sum(tf.square(pair_of_columns), axis=1, keep_dims=True) omegas.append( - omega_net_apply_one(phase, keep_prob, params, ycoords[:, ind:ind + 2], weights, biases, temp_name)) + omega_net_apply_one(phase, keep_prob, params, radius_of_pair, weights, biases, temp_name)) for j in np.arange(params['num_real']): temp_name = 'OR%d_' % (j + 1) ind = 2 * params['num_complex_pairs'] + j - omegas.append(omega_net_apply_one(phase, keep_prob, params, ycoords[:, ind], weights, biases, temp_name)) + one_column = ycoords[:, ind] + omegas.append(omega_net_apply_one(phase, keep_prob, params, one_column[:, np.newaxis], weights, biases, temp_name)) return omegas @@ -434,17 +436,8 @@ def omega_net_apply_one(phase, keep_prob, params, ycoords, weights, biases, name Side effects: None """ - if len(ycoords.shape) == 1: - ycoords = ycoords[:, np.newaxis] - if ycoords.shape[1] == 2: - # complex conjugate pair - input = tf.reduce_sum(tf.square(ycoords), axis=1, keep_dims=True) - - else: - input = ycoords - - omegas = encoder_apply_one_shift(input, weights, biases, params['act_type'], params['batch_flag'], phase, + omegas = encoder_apply_one_shift(ycoords, weights, biases, params['act_type'], params['batch_flag'], phase, keep_prob=keep_prob, name=name, num_encoder_weights=params['num_omega_weights']) return omegas