Skip to content

Commit

Permalink
Merge branch 'reduce-memory-consumption' into simplify-io-w-reduce-me…
Browse files Browse the repository at this point in the history
…mory-consumption
  • Loading branch information
trossi committed Sep 26, 2024
2 parents 63df28b + 1157226 commit 855b847
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 66 deletions.
19 changes: 19 additions & 0 deletions hmsc/test/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import numpy as np
import tensorflow as tf


def convert_to_tf(data):
if isinstance(data, np.ndarray):
new = tf.convert_to_tensor(data, dtype=data.dtype)
elif isinstance(data, list):
new = []
for value in data:
new.append(convert_to_tf(value))
elif isinstance(data, dict):
new = {}
for key, value in data.items():
new[key] = convert_to_tf(value)
else:
new = data

return new
34 changes: 20 additions & 14 deletions hmsc/test/test_rl_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pytest import approx

from hmsc.utils.import_utils import calculate_GPP
from hmsc.test import convert_to_tf


SEED = 42
Expand All @@ -15,6 +16,7 @@ def input_values(rng):
n2 = 3

alpha = rng.random(na)
alpha[-1] = 0.0
d12 = rng.random(n1 * n2).reshape(n1, n2)
d22 = rng.random(n2 * n2).reshape(n2, n2)
d22 = 0.5 * (d22 + d22.T)
Expand All @@ -25,32 +27,35 @@ def input_values(rng):
def reference_values():
idD = \
[[ 4.74732221, 1.21841991, 16.81039263, 1.27917015],
[ 2.89727739, 1.04729426, 3.83387891, 1.06836454]]
[ 1. , 1. , 1. , 1. ]]
iDW12 = \
[[[ 1.56552007, 1.92810424, 4.20341649],
[ 0.3454143 , 0.45571594, 0.44127379],
[14.2458611 , 9.3939915 , 10.41141587],
[ 0.38626669, 0.55671448, 0.44182194]],
[[ 0.40960001, 0.59143341, 2.33774095],
[ 0.11340611, 0.18487531, 0.17466852],
[ 2.86327373, 1.37390469, 1.64707451],
[ 0.12931064, 0.24636518, 0.16388959]]]

[[ 0. , 0. , 0. ],
[ 0. , 0. , 0. ],
[ 0. , 0. , 0. ],
[ 0. , 0. , 0. ]]]
F = \
[[[13.80338756, 9.72259845, 10.89601792],
[ 9.72259845, 7.44538439, 8.4114639 ],
[10.89601792, 8.4114639 , 11.48249436]],
[[ 3.22423038, 1.87731239, 1.82347872],
[ 1.87731239, 1.70253007, 1.46121186],
[ 1.82347872, 1.46121186, 3.64813777]]]

[[ 1. , 0. , 0. ],
[ 0. , 1. , 0. ],
[ 0. , 0. , 1. ]]]
iF = \
[[[ 0.90649386, -1.22934882, 0.04036143],
[-1.22934882, 2.44625374, -0.62543626],
[ 0.04036143, -0.62543626, 0.50695045]],
[[ 0.88076752, -0.90416607, -0.07808988],
[-0.90416607, 1.82323176, -0.27833386],
[-0.07808988, -0.27833386, 0.4246276 ]]]

[[ 1. , 0. , 0. ],
[ 0. , 1. , 0. ],
[ 0. , 0. , 1. ]]]
detD = \
[-0.54608872, -0.15196945]
[-0.54608872, 0. ]
return idD, iDW12, F, iF, detD


Expand All @@ -59,9 +64,10 @@ def test_calculate_GPP():
tf.keras.utils.set_random_seed(SEED)
tf.config.experimental.enable_op_determinism()

d12, d22, alpha = input_values(rng)
inputs = input_values(rng)
inputs = map(convert_to_tf, inputs)

values = calculate_GPP(d12, d22, alpha)
values = calculate_GPP(*inputs)
values = list(map(lambda a: a.numpy(), values))
names = ['idD', 'iDW12', 'F', 'iF', 'detD']
assert len(names) == len(values)
Expand Down
17 changes: 1 addition & 16 deletions hmsc/test/test_update_z.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,11 @@
from pytest import approx

from hmsc.updaters.updateZ import updateZ
from hmsc.test import convert_to_tf


SEED = 42

def convert_to_tf(data):
if isinstance(data, np.ndarray):
new = tf.convert_to_tensor(data, dtype=data.dtype)
elif isinstance(data, list):
new = []
for value in data:
new.append(convert_to_tf(value))
elif isinstance(data, dict):
new = {}
for key, value in data.items():
new[key] = convert_to_tf(value)
else:
new = data

return new


def run_test(input_values, ref_values, *,
tnlib='tf',
Expand Down
105 changes: 69 additions & 36 deletions hmsc/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,46 +102,78 @@ def load_model_hyperparams(hmscModel, dataParList, dtype=np.float64):
return dataParams


def eye_like(W):
return tf.eye(*W.shape[-2:], batch_shape=W.shape[:-2], dtype=W.dtype)
def eye_like(tensor):
return tf.eye(*tensor.shape[-2:], batch_shape=tensor.shape[:-2], dtype=tensor.dtype)


def calculate_W(dist, alpha):
assert dist.ndim == 2
assert alpha.ndim == 1
W = dist / alpha[:, None, None]
W[np.isnan(W)] = 0
W = tf.exp(-W)
return W
assert tf.size(alpha) == 1
if alpha == 0.0:
one = tf.constant(1, dtype=dist.dtype)
zero = tf.constant(0, dtype=dist.dtype)
return tf.where(tfm.logical_or(dist == 0, tfm.is_nan(dist)), one, zero)
return tf.exp(-dist / alpha)


def set_slice(variable, i, tensor):
variable.scatter_update(tf.IndexedSlices(tensor[tf.newaxis], tf.constant([i], dtype=tf.int64)))


def calculate_GPP(d12, d22, alpha):
W12 = calculate_W(d12, alpha)
W22 = calculate_W(d22, alpha)

LW22 = tfla.cholesky(W22)
detD = -2*tf.reduce_sum(tfm.log(tfla.diag_part(LW22)), -1)
iW22 = tfla.cholesky_solve(LW22, eye_like(LW22))
del LW22
W12iW22 = tf.matmul(W12, iW22)
del iW22

dD = 1 - tf.einsum("gih,gih->gi", W12iW22, W12)
del W12iW22
detD += tf.reduce_sum(tfm.log(dD), -1)
idD = dD**-1
del dD

iDW12 = tf.einsum("gi,gik->gik", idD, W12)
F = W22 + tf.einsum("gik,gih->gkh", iDW12, W12)
del W12
del W22

LF = tfla.cholesky(F)
detD += 2*tf.reduce_sum(tfm.log(tfla.diag_part(LF)), -1)
iF = tfla.cholesky_solve(LF, eye_like(LF))
del LF
return idD, iDW12, F, iF, detD
assert d12.ndim == 2
assert d22.ndim == 2
assert alpha.ndim == 1
assert d12.dtype == d22.dtype
dtype = d12.dtype
idD_g = tf.Variable(tf.zeros(shape=[alpha.shape[0], d12.shape[0]], dtype=dtype))
iDW12_g = tf.Variable(tf.zeros(shape=[alpha.shape[0], *d12.shape], dtype=dtype))
F_g = tf.Variable(tf.zeros(shape=[alpha.shape[0], *d22.shape], dtype=dtype))
iF_g = tf.Variable(tf.zeros(shape=[alpha.shape[0], *d22.shape], dtype=dtype))
detD_g = tf.Variable(tf.zeros(shape=[alpha.shape[0]], dtype=dtype))
for i, a in enumerate(alpha):
W22 = calculate_W(d22, a)
LW22 = tfla.cholesky(W22)
detD = -2*tf.reduce_sum(tfm.log(tfla.diag_part(LW22)), -1)
iW22 = tfla.cholesky_solve(LW22, eye_like(LW22))
del LW22

W12 = calculate_W(d12, a)
W12iW22 = tf.matmul(W12, iW22)
del iW22

dD = 1 - tf.einsum("ih,ih->i", W12iW22, W12)
del W12iW22
detD += tf.reduce_sum(tfm.log(dD), -1)
idD = dD**-1
del dD
set_slice(idD_g, i, idD)

iDW12 = tf.einsum("i,ik->ik", idD, W12)
set_slice(iDW12_g, i, iDW12)

F = W22 + tf.einsum("ik,ih->kh", iDW12, W12)
del W12
del W22
set_slice(F_g, i, F)

LF = tfla.cholesky(F)
detD += 2*tf.reduce_sum(tfm.log(tfla.diag_part(LF)), -1)
set_slice(detD_g, i, detD)
del detD

iF = tfla.cholesky_solve(LF, eye_like(LF))
del LF
set_slice(iF_g, i, iF)
del iF

iDW12_g = iDW12_g.read_value_no_copy()
idD_g = idD_g.read_value_no_copy()
F_g = F_g.read_value_no_copy()
iF_g = iF_g.read_value_no_copy()
detD_g = detD_g.read_value_no_copy()

return idD_g, iDW12_g, F_g, iF_g, detD_g


def load_random_level_hyperparams(hmscModel, dataParList, dtype=np.float64):
Expand Down Expand Up @@ -178,9 +210,10 @@ def load_random_level_hyperparams(hmscModel, dataParList, dtype=np.float64):

elif rLPar["spatialMethod"] == "GPP":
nK = int(dataParList["rLPar"][r]["nKnots"][0])
alpha = rLPar["alphapw"][:, 0].astype(dtype)
d12 = dataParList["rLPar"][r]["distMat12"].astype(dtype)
d22 = dataParList["rLPar"][r]["distMat22"].astype(dtype)
alpha = tf.convert_to_tensor(rLPar["alphapw"][:, 0], dtype=dtype)
d12 = tf.convert_to_tensor(dataParList["rLPar"][r]["distMat12"], dtype=dtype)
d22 = tf.convert_to_tensor(dataParList["rLPar"][r]["distMat22"], dtype=dtype)

assert d12.shape == (npVec[r], nK)
assert d22.shape == (nK, nK)

Expand Down

0 comments on commit 855b847

Please sign in to comment.