Skip to content

Commit

Permalink
• emcomposition.py
Browse files Browse the repository at this point in the history
  - ALL -> WEIGHTED
  - MAX_INDICATOR -> ARG_MAX
  - more docstring edits
  - add check that ARG_MAX is not used with enable_learning
• test_emcomposition.py
  - add test_softmax_choice
  • Loading branch information
jdcpni committed Oct 7, 2024
1 parent dcf9211 commit 932e140
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions tests/composition/test_emcomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,14 +235,21 @@ def test_memory_fill(start, memory_fill):
def test_softmax_choice(self):
for softmax_choice in [pnl.WEIGHTED, pnl.ARG_MAX]:
em = EMComposition(memory_template=[[[1,.1,.1]], [[.1,1,.1]], [[.1,.1,1]]],
softmax_choice=softmax_choice
)
softmax_choice=softmax_choice,
enable_learning=False)
result = em.run(inputs={em.query_input_nodes[0]:[[0,1,0]]})
if softmax_choice == pnl.WEIGHTED:
np.testing.assert_allclose(result, [[0.21330295, 0.77339411, 0.21330295]])
if softmax_choice == pnl.ARG_MAX:
np.testing.assert_allclose(result, [[.1, 1, .1]])

with pytest.raises(pnl.ComponentError) as error_text:
em = EMComposition(memory_template=[[[1,.1,.1]], [[.1,1,.1]], [[.1,.1,1]]],
softmax_choice=pnl.ARG_MAX)
assert ("The ARG_MAX option for the 'softmax_choice' arg of 'EM_Composition-2' can not be used "
"when 'enable_learning' is set to True; use WEIGHTED or set 'enable_learning' to False."
in str(error_text.value))

@pytest.mark.pytorch
class TestExecution:

Expand Down

0 comments on commit 932e140

Please sign in to comment.