From 932e1408ea75b7d58e945b93bd0468ba34323d37 Mon Sep 17 00:00:00 2001 From: jdcpni Date: Sun, 6 Oct 2024 20:57:33 -0400 Subject: [PATCH] =?UTF-8?q?=E2=80=A2=20emcomposition.py=20=20=20-=20ALL=20?= =?UTF-8?q?->=20WEIGHTED=20=20=20-=20MAX=5FINDICATOR=20->=20ARG=5FMAX=20?= =?UTF-8?q?=20=20-=20more=20docstring=20edits=20=20=20-=20add=20check=20th?= =?UTF-8?q?at=20ARG=5FMAX=20is=20not=20used=20with=20enable=5Flearning=20?= =?UTF-8?q?=E2=80=A2=20test=5Femcomposition.py=20=20=20-=20add=20test=5Fso?= =?UTF-8?q?ftmax=5Fchoice?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/composition/test_emcomposition.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/tests/composition/test_emcomposition.py b/tests/composition/test_emcomposition.py index ceb94001de..73807aa442 100644 --- a/tests/composition/test_emcomposition.py +++ b/tests/composition/test_emcomposition.py @@ -235,14 +235,21 @@ def test_memory_fill(start, memory_fill): def test_softmax_choice(self): for softmax_choice in [pnl.WEIGHTED, pnl.ARG_MAX]: em = EMComposition(memory_template=[[[1,.1,.1]], [[.1,1,.1]], [[.1,.1,1]]], - softmax_choice=softmax_choice - ) + softmax_choice=softmax_choice, + enable_learning=False) result = em.run(inputs={em.query_input_nodes[0]:[[0,1,0]]}) if softmax_choice == pnl.WEIGHTED: np.testing.assert_allclose(result, [[0.21330295, 0.77339411, 0.21330295]]) if softmax_choice == pnl.ARG_MAX: np.testing.assert_allclose(result, [[.1, 1, .1]]) + with pytest.raises(pnl.ComponentError) as error_text: + em = EMComposition(memory_template=[[[1,.1,.1]], [[.1,1,.1]], [[.1,.1,1]]], + softmax_choice=pnl.ARG_MAX) + assert ("The ARG_MAX option for the 'softmax_choice' arg of 'EM_Composition-2' can not be used " + "when 'enable_learning' is set to True; use WEIGHTED or set 'enable_learning' to False." + in str(error_text.value)) + @pytest.mark.pytorch class TestExecution: