diff --git a/tf_metrics/__init__.py b/tf_metrics/__init__.py index 65e12bc..16ab3e0 100644 --- a/tf_metrics/__init__.py +++ b/tf_metrics/__init__.py @@ -137,7 +137,7 @@ def fbeta(labels, predictions, num_classes, pos_indices=None, weights=None, def safe_div(numerator, denominator): """Safe division, return 0 if denominator is 0""" - numerator, denominator = tf.to_float(numerator), tf.to_float(denominator) + numerator, denominator = tf.cast(numerator, dtype=tf.float32), tf.cast(denominator, dtype=tf.float32) zeros = tf.zeros_like(numerator, dtype=numerator.dtype) denominator_is_zero = tf.equal(denominator, zeros) return tf.where(denominator_is_zero, zeros, numerator / denominator) @@ -149,7 +149,7 @@ def pr_re_fbeta(cm, pos_indices, beta=1): neg_indices = [i for i in range(num_classes) if i not in pos_indices] cm_mask = np.ones([num_classes, num_classes]) cm_mask[neg_indices, neg_indices] = 0 - diag_sum = tf.reduce_sum(tf.diag_part(cm * cm_mask)) + diag_sum = tf.reduce_sum(tf.linalg.diag_part(cm * cm_mask)) cm_mask = np.ones([num_classes, num_classes]) cm_mask[:, neg_indices] = 0 @@ -196,7 +196,7 @@ def metrics_from_confusion_matrix(cm, pos_indices=None, average='micro', fbetas.append(fbeta) cm_mask = np.zeros([num_classes, num_classes]) cm_mask[idx, :] = 1 - n_golds.append(tf.to_float(tf.reduce_sum(cm * cm_mask))) + n_golds.append(tf.cast(tf.reduce_sum(cm * cm_mask)), dtype=tf.float32) if average == 'macro': pr = tf.reduce_mean(precisions)