From 715021750a52b3ee5f756eb508df894c00c02483 Mon Sep 17 00:00:00 2001 From: jsadler2 Date: Mon, 17 Aug 2020 16:20:48 -0500 Subject: [PATCH] [#31] just updating cell state, not hidden state --- river_dl/RGCN.py | 39 ++++++++------------------------------- 1 file changed, 8 insertions(+), 31 deletions(-) diff --git a/river_dl/RGCN.py b/river_dl/RGCN.py index 774c5a7..cfcbadc 100644 --- a/river_dl/RGCN.py +++ b/river_dl/RGCN.py @@ -35,13 +35,6 @@ def __init__(self, hidden_size, A, flow_in_temp=False, rand_seed=None): w_initializer = tf.random_normal_initializer(stddev=0.02, seed=rand_seed) - # was Wg1 - self.W_graph_h = self.add_weight(shape=[hidden_size, hidden_size], - initializer=w_initializer, - name='W_graph_h') - # was bg1 - self.b_graph_h = self.add_weight(shape=[hidden_size], - initializer='zeros', name='b_graph_h') # was Wg2 self.W_graph_c = self.add_weight(shape=[hidden_size, hidden_size], initializer=w_initializer, @@ -50,17 +43,6 @@ def __init__(self, hidden_size, A, flow_in_temp=False, rand_seed=None): self.b_graph_c = self.add_weight(shape=[hidden_size], initializer='zeros', name='b_graph_c') - # was Wa1 - self.W_h_cur = self.add_weight(shape=[hidden_size, hidden_size], - initializer=w_initializer, - name='W_h_cur') - # was Wa2 - self.W_h_prev = self.add_weight(shape=[hidden_size, hidden_size], - initializer=w_initializer, - name='W_h_prev') - # was ba - self.b_h = self.add_weight(shape=[hidden_size], initializer='zeros', - name='b_h') # was Wc1 self.W_c_cur = self.add_weight(shape=[hidden_size, hidden_size], @@ -101,42 +83,37 @@ def __init__(self, hidden_size, A, flow_in_temp=False, rand_seed=None): @tf.function def call(self, inputs, **kwargs): graph_size = self.A.shape[0] - hidden_state_prev, cell_state_prev = (tf.zeros([graph_size, + hidden_state, cell_state_prev = (tf.zeros([graph_size, self.hidden_size]), tf.zeros([graph_size, self.hidden_size])) out = [] n_steps = inputs.shape[1] for t in range(n_steps): - h_graph = tf.nn.tanh(tf.matmul(self.A, tf.matmul(hidden_state_prev, - self.W_graph_h) - + self.b_graph_h)) c_graph = tf.nn.tanh(tf.matmul(self.A, tf.matmul(cell_state_prev, self.W_graph_c) + self.b_graph_c)) - seq, state = self.lstm(inputs[:, t, :], states=[hidden_state_prev, + seq, state = self.lstm(inputs[:, t, :], states=[hidden_state, cell_state_prev]) - hidden_state_cur, cell_state_cur = state + hidden_state, cell_state_cur = state - h_update = tf.nn.sigmoid(tf.matmul(hidden_state_cur, self.W_h_cur) - + tf.matmul(h_graph, self.W_h_prev) - + self.b_h) c_update = tf.nn.sigmoid(tf.matmul(cell_state_cur, self.W_c_cur) + tf.matmul(c_graph, self.W_c_prev) + self.b_c) if self.flow_in_temp: - out_pred_q = tf.matmul(h_update, self.W_out_flow) + self.b_out_flow - out_pred_t = tf.matmul(tf.concat([h_update, out_pred_q], axis=1), + out_pred_q = tf.matmul(hidden_state, self.W_out_flow) +\ + self.b_out_flow + out_pred_t = tf.matmul(tf.concat([hidden_state, out_pred_q], + axis=1), self.W_out_temp) + self.b_out_temp out_pred = tf.concat([out_pred_t, out_pred_q], axis=1) else: - out_pred = tf.matmul(h_update, self.W_out) + self.b_out + out_pred = tf.matmul(hidden_state, self.W_out) + self.b_out out.append(out_pred) - hidden_state_prev = h_update cell_state_prev = c_update out = tf.stack(out) out = tf.transpose(out, [1, 0, 2])