diff --git a/tensorflow_v2/notebooks/2_BasicModels/logistic_regression.ipynb b/tensorflow_v2/notebooks/2_BasicModels/logistic_regression.ipynb index c29c42b9..b9b1ccc4 100644 --- a/tensorflow_v2/notebooks/2_BasicModels/logistic_regression.ipynb +++ b/tensorflow_v2/notebooks/2_BasicModels/logistic_regression.ipynb @@ -109,7 +109,7 @@ " # Clip prediction values to avoid log(0) error.\n", " y_pred = tf.clip_by_value(y_pred, 1e-9, 1.)\n", " # Compute cross-entropy.\n", - " return tf.reduce_mean(-tf.reduce_sum(y_true * tf.math.log(y_pred)))\n", + " return tf.reduce_mean(-tf.reduce_sum(y_true * tf.math.log(y_pred),1))\n", "\n", "# Accuracy metric.\n", "def accuracy(y_pred, y_true):\n",