From 08fbd8276241f4795d969400f37c6065f894f959 Mon Sep 17 00:00:00 2001 From: Tuomas Rossi Date: Tue, 24 Sep 2024 15:03:35 +0300 Subject: [PATCH 1/8] Simplify code --- hmsc/test/test_rl_init.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/hmsc/test/test_rl_init.py b/hmsc/test/test_rl_init.py index 36bd7fc..f40079d 100644 --- a/hmsc/test/test_rl_init.py +++ b/hmsc/test/test_rl_init.py @@ -59,9 +59,9 @@ def test_calculate_GPP(): tf.keras.utils.set_random_seed(SEED) tf.config.experimental.enable_op_determinism() - d12, d22, alpha = input_values(rng) + inputs = input_values(rng) - values = calculate_GPP(d12, d22, alpha) + values = calculate_GPP(*inputs) values = list(map(lambda a: a.numpy(), values)) names = ['idD', 'iDW12', 'F', 'iF', 'detD'] assert len(names) == len(values) From 79fe282cdbb19729649e3def7fd01137ad745c8c Mon Sep 17 00:00:00 2001 From: Tuomas Rossi Date: Wed, 25 Sep 2024 10:56:59 +0300 Subject: [PATCH 2/8] Include zero alpha value in test --- hmsc/test/test_rl_init.py | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/hmsc/test/test_rl_init.py b/hmsc/test/test_rl_init.py index f40079d..56ed018 100644 --- a/hmsc/test/test_rl_init.py +++ b/hmsc/test/test_rl_init.py @@ -15,6 +15,7 @@ def input_values(rng): n2 = 3 alpha = rng.random(na) + alpha[-1] = 0.0 d12 = rng.random(n1 * n2).reshape(n1, n2) d22 = rng.random(n2 * n2).reshape(n2, n2) d22 = 0.5 * (d22 + d22.T) @@ -25,32 +26,35 @@ def input_values(rng): def reference_values(): idD = \ [[ 4.74732221, 1.21841991, 16.81039263, 1.27917015], - [ 2.89727739, 1.04729426, 3.83387891, 1.06836454]] + [ 1. , 1. , 1. , 1. ]] iDW12 = \ [[[ 1.56552007, 1.92810424, 4.20341649], [ 0.3454143 , 0.45571594, 0.44127379], [14.2458611 , 9.3939915 , 10.41141587], [ 0.38626669, 0.55671448, 0.44182194]], - [[ 0.40960001, 0.59143341, 2.33774095], - [ 0.11340611, 0.18487531, 0.17466852], - [ 2.86327373, 1.37390469, 1.64707451], - [ 0.12931064, 0.24636518, 0.16388959]]] + + [[ 0. , 0. , 0. ], + [ 0. , 0. , 0. ], + [ 0. , 0. , 0. ], + [ 0. , 0. , 0. ]]] F = \ [[[13.80338756, 9.72259845, 10.89601792], [ 9.72259845, 7.44538439, 8.4114639 ], [10.89601792, 8.4114639 , 11.48249436]], - [[ 3.22423038, 1.87731239, 1.82347872], - [ 1.87731239, 1.70253007, 1.46121186], - [ 1.82347872, 1.46121186, 3.64813777]]] + + [[ 1. , 0. , 0. ], + [ 0. , 1. , 0. ], + [ 0. , 0. , 1. ]]] iF = \ [[[ 0.90649386, -1.22934882, 0.04036143], [-1.22934882, 2.44625374, -0.62543626], [ 0.04036143, -0.62543626, 0.50695045]], - [[ 0.88076752, -0.90416607, -0.07808988], - [-0.90416607, 1.82323176, -0.27833386], - [-0.07808988, -0.27833386, 0.4246276 ]]] + + [[ 1. , 0. , 0. ], + [ 0. , 1. , 0. ], + [ 0. , 0. , 1. ]]] detD = \ -[-0.54608872, -0.15196945] +[-0.54608872, 0. ] return idD, iDW12, F, iF, detD From e0c73b85fb241885f63d22283367fe099d6e2022 Mon Sep 17 00:00:00 2001 From: Tuomas Rossi Date: Wed, 25 Sep 2024 10:19:22 +0300 Subject: [PATCH 3/8] Move tf conversion utility function --- hmsc/test/__init__.py | 19 +++++++++++++++++++ hmsc/test/test_update_z.py | 17 +---------------- 2 files changed, 20 insertions(+), 16 deletions(-) create mode 100644 hmsc/test/__init__.py diff --git a/hmsc/test/__init__.py b/hmsc/test/__init__.py new file mode 100644 index 0000000..efa2bc8 --- /dev/null +++ b/hmsc/test/__init__.py @@ -0,0 +1,19 @@ +import numpy as np +import tensorflow as tf + + +def convert_to_tf(data): + if isinstance(data, np.ndarray): + new = tf.convert_to_tensor(data, dtype=data.dtype) + elif isinstance(data, list): + new = [] + for value in data: + new.append(convert_to_tf(value)) + elif isinstance(data, dict): + new = {} + for key, value in data.items(): + new[key] = convert_to_tf(value) + else: + new = data + + return new diff --git a/hmsc/test/test_update_z.py b/hmsc/test/test_update_z.py index cfb44b2..8740718 100644 --- a/hmsc/test/test_update_z.py +++ b/hmsc/test/test_update_z.py @@ -6,26 +6,11 @@ from pytest import approx from hmsc.updaters.updateZ import updateZ +from hmsc.test import convert_to_tf SEED = 42 -def convert_to_tf(data): - if isinstance(data, np.ndarray): - new = tf.convert_to_tensor(data, dtype=data.dtype) - elif isinstance(data, list): - new = [] - for value in data: - new.append(convert_to_tf(value)) - elif isinstance(data, dict): - new = {} - for key, value in data.items(): - new[key] = convert_to_tf(value) - else: - new = data - - return new - def run_test(input_values, ref_values, *, tnlib='tf', From e9d71a5de1a3a4c77cfa015b0b1488a5fedf6dcd Mon Sep 17 00:00:00 2001 From: Tuomas Rossi Date: Wed, 25 Sep 2024 10:42:46 +0300 Subject: [PATCH 4/8] Reduce memory consumption by looping over batches --- hmsc/utils/import_utils.py | 82 +++++++++++++++++++++++--------------- 1 file changed, 49 insertions(+), 33 deletions(-) diff --git a/hmsc/utils/import_utils.py b/hmsc/utils/import_utils.py index 1fa8126..f94e1c8 100644 --- a/hmsc/utils/import_utils.py +++ b/hmsc/utils/import_utils.py @@ -108,39 +108,54 @@ def eye_like(W): def calculate_W(dist, alpha): assert dist.ndim == 2 - assert alpha.ndim == 1 - W = dist / alpha[:, None, None] - W[np.isnan(W)] = 0 - W = tf.exp(-W) - return W + assert tf.size(alpha) == 1 + assert alpha != 0.0 + return tf.exp(-dist / alpha) def calculate_GPP(d12, d22, alpha): - W12 = calculate_W(d12, alpha) - W22 = calculate_W(d22, alpha) - - LW22 = tfla.cholesky(W22) - detD = -2*tf.reduce_sum(tfm.log(tfla.diag_part(LW22)), -1) - iW22 = tfla.cholesky_solve(LW22, eye_like(LW22)) - del LW22 - W12iW22 = tf.matmul(W12, iW22) - del iW22 - - dD = 1 - tf.einsum("gih,gih->gi", W12iW22, W12) - del W12iW22 - detD += tf.reduce_sum(tfm.log(dD), -1) - idD = dD**-1 - del dD - - iDW12 = tf.einsum("gi,gik->gik", idD, W12) - F = W22 + tf.einsum("gik,gih->gkh", iDW12, W12) - del W12 - del W22 - - LF = tfla.cholesky(F) - detD += 2*tf.reduce_sum(tfm.log(tfla.diag_part(LF)), -1) - iF = tfla.cholesky_solve(LF, eye_like(LF)) - del LF + assert alpha.ndim == 1 + idD_i = [] + iDW12_i = [] + F_i = [] + iF_i = [] + detD_i = [] + for a in alpha: + W12 = calculate_W(d12, a) + W22 = calculate_W(d22, a) + + LW22 = tfla.cholesky(W22) + detD = -2*tf.reduce_sum(tfm.log(tfla.diag_part(LW22)), -1) + iW22 = tfla.cholesky_solve(LW22, eye_like(LW22)) + del LW22 + W12iW22 = tf.matmul(W12, iW22) + del iW22 + + dD = 1 - tf.einsum("ih,ih->i", W12iW22, W12) + del W12iW22 + detD += tf.reduce_sum(tfm.log(dD), -1) + idD = dD**-1 + del dD + idD_i.append(idD) + + iDW12 = tf.einsum("i,ik->ik", idD, W12) + iDW12_i.append(iDW12) + F = W22 + tf.einsum("ik,ih->kh", iDW12, W12) + del W12 + del W22 + F_i.append(F) + + LF = tfla.cholesky(F) + detD += 2*tf.reduce_sum(tfm.log(tfla.diag_part(LF)), -1) + detD_i.append(detD) + iF = tfla.cholesky_solve(LF, eye_like(LF)) + iF_i.append(iF) + del LF + idD = tf.stack(idD_i) + iDW12 = tf.stack(iDW12_i) + F = tf.stack(F_i) + iF = tf.stack(iF_i) + detD = tf.stack(detD_i) return idD, iDW12, F, iF, detD @@ -178,9 +193,10 @@ def load_random_level_hyperparams(hmscModel, dataParList, dtype=np.float64): elif rLPar["spatialMethod"] == "GPP": nK = int(dataParList["rLPar"][r]["nKnots"][0]) - alpha = rLPar["alphapw"][:, 0].astype(dtype) - d12 = dataParList["rLPar"][r]["distMat12"].astype(dtype) - d22 = dataParList["rLPar"][r]["distMat22"].astype(dtype) + alpha = tf.convert_to_tensor(rLPar["alphapw"][:, 0], dtype=dtype) + d12 = tf.convert_to_tensor(dataParList["rLPar"][r]["distMat12"], dtype=dtype) + d22 = tf.convert_to_tensor(dataParList["rLPar"][r]["distMat22"], dtype=dtype) + assert d12.shape == (npVec[r], nK) assert d22.shape == (nK, nK) From d23c8d7ba7c4d17134de6ebe2f094a5ac208f30b Mon Sep 17 00:00:00 2001 From: Tuomas Rossi Date: Wed, 25 Sep 2024 10:43:46 +0300 Subject: [PATCH 5/8] Use tf inputs --- hmsc/test/test_rl_init.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/hmsc/test/test_rl_init.py b/hmsc/test/test_rl_init.py index 56ed018..b046cc2 100644 --- a/hmsc/test/test_rl_init.py +++ b/hmsc/test/test_rl_init.py @@ -4,6 +4,7 @@ from pytest import approx from hmsc.utils.import_utils import calculate_GPP +from hmsc.test import convert_to_tf SEED = 42 @@ -64,6 +65,7 @@ def test_calculate_GPP(): tf.config.experimental.enable_op_determinism() inputs = input_values(rng) + inputs = map(convert_to_tf, inputs) values = calculate_GPP(*inputs) values = list(map(lambda a: a.numpy(), values)) From c1951640816423731176eec443df96a79048e756 Mon Sep 17 00:00:00 2001 From: Tuomas Rossi Date: Wed, 25 Sep 2024 14:21:54 +0300 Subject: [PATCH 6/8] Fix zero alpha case --- hmsc/utils/import_utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/hmsc/utils/import_utils.py b/hmsc/utils/import_utils.py index f94e1c8..512847b 100644 --- a/hmsc/utils/import_utils.py +++ b/hmsc/utils/import_utils.py @@ -109,7 +109,10 @@ def eye_like(W): def calculate_W(dist, alpha): assert dist.ndim == 2 assert tf.size(alpha) == 1 - assert alpha != 0.0 + if alpha == 0.0: + one = tf.constant(1, dtype=dist.dtype) + zero = tf.constant(0, dtype=dist.dtype) + return tf.where(tfm.logical_or(dist == 0, tfm.is_nan(dist)), one, zero) return tf.exp(-dist / alpha) From a21bd2360c22efb0874bc1feb8ff58b00bb8a866 Mon Sep 17 00:00:00 2001 From: Tuomas Rossi Date: Thu, 26 Sep 2024 09:17:46 +0300 Subject: [PATCH 7/8] Reduce memory consumption from stacking tensors --- hmsc/utils/import_utils.py | 52 ++++++++++++++++++++++++-------------- 1 file changed, 33 insertions(+), 19 deletions(-) diff --git a/hmsc/utils/import_utils.py b/hmsc/utils/import_utils.py index 512847b..f4153ba 100644 --- a/hmsc/utils/import_utils.py +++ b/hmsc/utils/import_utils.py @@ -116,21 +116,29 @@ def calculate_W(dist, alpha): return tf.exp(-dist / alpha) +def set_slice(variable, i, tensor): + variable.scatter_update(tf.IndexedSlices(tensor[tf.newaxis], tf.constant([i], dtype=tf.int64))) + + def calculate_GPP(d12, d22, alpha): + assert d12.ndim == 2 + assert d22.ndim == 2 assert alpha.ndim == 1 - idD_i = [] - iDW12_i = [] - F_i = [] - iF_i = [] - detD_i = [] - for a in alpha: - W12 = calculate_W(d12, a) + assert d12.dtype == d22.dtype + dtype = d12.dtype + idD_g = tf.Variable(tf.zeros(shape=[alpha.shape[0], d12.shape[0]], dtype=dtype)) + iDW12_g = tf.Variable(tf.zeros(shape=[alpha.shape[0], *d12.shape], dtype=dtype)) + F_g = tf.Variable(tf.zeros(shape=[alpha.shape[0], *d22.shape], dtype=dtype)) + iF_g = tf.Variable(tf.zeros(shape=[alpha.shape[0], *d22.shape], dtype=dtype)) + detD_g = tf.Variable(tf.zeros(shape=[alpha.shape[0]], dtype=dtype)) + for i, a in enumerate(alpha): W22 = calculate_W(d22, a) - LW22 = tfla.cholesky(W22) detD = -2*tf.reduce_sum(tfm.log(tfla.diag_part(LW22)), -1) iW22 = tfla.cholesky_solve(LW22, eye_like(LW22)) del LW22 + + W12 = calculate_W(d12, a) W12iW22 = tf.matmul(W12, iW22) del iW22 @@ -139,27 +147,33 @@ def calculate_GPP(d12, d22, alpha): detD += tf.reduce_sum(tfm.log(dD), -1) idD = dD**-1 del dD - idD_i.append(idD) + set_slice(idD_g, i, idD) iDW12 = tf.einsum("i,ik->ik", idD, W12) - iDW12_i.append(iDW12) + set_slice(iDW12_g, i, iDW12) + F = W22 + tf.einsum("ik,ih->kh", iDW12, W12) del W12 del W22 - F_i.append(F) + set_slice(F_g, i, F) LF = tfla.cholesky(F) detD += 2*tf.reduce_sum(tfm.log(tfla.diag_part(LF)), -1) - detD_i.append(detD) + set_slice(detD_g, i, detD) + del detD + iF = tfla.cholesky_solve(LF, eye_like(LF)) - iF_i.append(iF) del LF - idD = tf.stack(idD_i) - iDW12 = tf.stack(iDW12_i) - F = tf.stack(F_i) - iF = tf.stack(iF_i) - detD = tf.stack(detD_i) - return idD, iDW12, F, iF, detD + set_slice(iF_g, i, iF) + del iF + + iDW12_g = iDW12_g.read_value_no_copy() + idD_g = idD_g.read_value_no_copy() + F_g = F_g.read_value_no_copy() + iF_g = iF_g.read_value_no_copy() + detD_g = detD_g.read_value_no_copy() + + return idD_g, iDW12_g, F_g, iF_g, detD_g def load_random_level_hyperparams(hmscModel, dataParList, dtype=np.float64): From 1157226e9036561cc913b80a83882cabc6d05c64 Mon Sep 17 00:00:00 2001 From: Tuomas Rossi Date: Thu, 26 Sep 2024 09:19:18 +0300 Subject: [PATCH 8/8] Clarify variable name --- hmsc/utils/import_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/hmsc/utils/import_utils.py b/hmsc/utils/import_utils.py index f4153ba..e90a9b7 100644 --- a/hmsc/utils/import_utils.py +++ b/hmsc/utils/import_utils.py @@ -102,8 +102,8 @@ def load_model_hyperparams(hmscModel, dataParList, dtype=np.float64): return dataParams -def eye_like(W): - return tf.eye(*W.shape[-2:], batch_shape=W.shape[:-2], dtype=W.dtype) +def eye_like(tensor): + return tf.eye(*tensor.shape[-2:], batch_shape=tensor.shape[:-2], dtype=tensor.dtype) def calculate_W(dist, alpha):