diff --git a/src/altk/effcomm/util.py b/src/altk/effcomm/util.py index 3be12103..96a790c4 100644 --- a/src/altk/effcomm/util.py +++ b/src/altk/effcomm/util.py @@ -148,3 +148,25 @@ def DKL(p, q, axis=None): return (xlogx(p) - np.where(p > PRECISION, p * np.log2(q + PRECISION), 0)).sum( axis=axis ) + + +def gNID(pW_X, pV_X, pX): + """Compute Generalized Normalized Informational Distance between two encoders. + + Args: + pW_X: first encoder of shape `(|meanings|, |words|)` + + pV_X: second encoder of shape `(|meanings|, |words|)` + + pX: prior over source variables of shape `(|meanings|,)` + """ + if len(pX.shape) == 1: + pX = pX[:, None] + elif pX.shape[0] == 1 and pX.shape[1] > 1: + pX = pX.T + pXW = pW_X * pX + pWV = pXW.T @ (pV_X) + pWW = pXW.T @ (pW_X) + pVV = (pV_X * pX).T @ (pV_X) + score = 1 - MI(pWV) / (np.max([MI(pWW), MI(pVV)])) + return score