Skip to content

Commit

Permalink
KLD fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
ryanjgallagher committed Nov 9, 2021
1 parent f0e25c6 commit 61f81ef
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 10 deletions.
1 change: 1 addition & 0 deletions shifterator/entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def get_entropy_type_scores(p_1, p_2, base, alpha):
score_1 = -1 * p_1 ** (alpha - 1) / (alpha - 1)
if p_2 > 0:
score_2 = -1 * p_2 ** (alpha - 1) / (alpha - 1)

return score_1, score_2


Expand Down
19 changes: 9 additions & 10 deletions shifterator/shifts.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -250,11 +250,10 @@ def __init__(
# Check that KLD is well defined
types_1 = set(type2freq_1.keys())
types_2 = set(type2freq_2.keys())
if len(types_1.symmetric_difference(types_2)) > 0:
if len(types_2.difference(types_1)) > 0:
err = (
"There are types that appear in either type2freq_1 or "
+ "type2freq_2 but not the other: the KL divergence is not "
+ "well defined"
"There are types that appear in type2freq_2 but not type2freq_1:"
+ "the KL divergence is not well-defined"
)
raise ValueError(err)

Expand All @@ -263,16 +262,16 @@ def __init__(
type2freq_2 = type2freq_2.copy()
type2p_1 = entropy.get_relative_freqs(type2freq_1)
type2p_2 = entropy.get_relative_freqs(type2freq_2)

# Get surprisal scores
type2s_1 = {t: p * -1 * entropy.log(p, base) for t, p in type2p_1.items()}
type2s_2 = {t: p * -1 * entropy.log(p, base) for t, p in type2p_2.items()}
type2s_1, type2s_2 = entropy.get_entropy_scores(type2p_1, type2p_2, base, alpha=1)

# Initialize shift
super().__init__(
type2freq_1=type2p_2,
type2freq_2=type2p_2,
type2score_1=type2s_1,
type2score_2=type2s_2,
type2freq_1=type2freq_2,
type2freq_2=type2freq_2,
type2score_1=type2s_2,
type2score_2=type2s_1,
handle_missing_scores="error",
stop_lens=None,
stop_words=None,
Expand Down

0 comments on commit 61f81ef

Please sign in to comment.