diff --git a/packages/learning/src/sac/sac.ts b/packages/learning/src/sac/sac.ts index c9a218a1..c4b64336 100644 --- a/packages/learning/src/sac/sac.ts +++ b/packages/learning/src/sac/sac.ts @@ -8,7 +8,7 @@ function mlp( x: tf.Tensor2D, hiddenSizes: number[], activation: ActivationIdentifier, - outputActivation: ActivationIdentifier, + outputActivation: ActivationIdentifier | undefined, ) { for (const h of hiddenSizes) { x = tf.layers @@ -74,7 +74,11 @@ function mlpGaussianPolicy( const pi = tf.add(mu, tf.mul(tf.randomNormal(std.shape), std)) const logPi = gaussianLikelihood(a, mu, logstdClipped) - return { mu, pi, logPi } + return { + mu: mu as tf.Tensor2D, + pi: pi as tf.Tensor2D, + logPi: logPi as tf.Tensor2D, + } } /* @@ -142,13 +146,14 @@ function mlpActorCritic( } = applySquashingFunction(mu, pi, logPi) const actionScale = actionSpace - mu = tf.mul(mu, actionScale) - pi = tf.mul(pi, actionScale) + const muScaled = tf.mul(muSquashed, actionScale) + const piScaled = tf.mul(piSquashed, actionScale) - const vfMlp = (x: tf.Tensor2D) => tf.squeeze(mlp(x, [...hiddenSizes, 1], activation, null), 1) + const vfMlp = (x: tf.Tensor2D) => + tf.squeeze(mlp(x, [...hiddenSizes, 1], activation, undefined), [1]) const q1 = vfMlp(tf.concat([x, a], 1)) const q2 = vfMlp(tf.concat([x, a], 1)) - return { mu, pi, logPi, q1, q2 } + return { mu: muScaled, pi: piScaled, logPi: logPiSquashed, q1, q2 } }