diff --git a/docs/_source/api/derivations.rst b/docs/_source/api/derivations.rst index 6637daf9..4d6f5b2f 100644 --- a/docs/_source/api/derivations.rst +++ b/docs/_source/api/derivations.rst @@ -10,6 +10,10 @@ means that they must depend on the same factors in the same order, have the same window width, same window stride, and the same starting trial. However, :class:`.ElseLevel` is compatible with any derivation, as long as other levels in the same factor have compatible derivations. +For every combination of levels in the factors that a derived factor +depends on, there must be exactly one matching derived level, taking +into account that a level created with :class:`.ElseLevel` matches a +combination that is not matched by other levels. .. class:: sweetpea.Derivation() diff --git a/example_programs/Paper.py b/example_programs/Paper.py index 87f9e002..8c076263 100644 --- a/example_programs/Paper.py +++ b/example_programs/Paper.py @@ -32,7 +32,7 @@ def one_diff(colors, words): return words[0] == words[-1] def both_diff(colors, words): - return not one_diff(colors, words) + return (colors[0] != colors[-1]) and (words[0] != words[-1]) one = DerivedLevel("one", Transition(one_diff, [color, word])) both = DerivedLevel("both", Transition(both_diff, [color, word])) diff --git a/sweetpea/_internal/derivation_processor.py b/sweetpea/_internal/derivation_processor.py index 35542267..6578a4ae 100644 --- a/sweetpea/_internal/derivation_processor.py +++ b/sweetpea/_internal/derivation_processor.py @@ -57,8 +57,9 @@ def generate_derivations(block: Block) -> List[Derivation]: accum = [] for factor in derived_factors: according_level: Dict[Tuple[Any, ...], DerivedLevel] = {} + # every level must have the same cross product, so we can get it once: + cross_product: List[Tuple[Level, ...]] = factor.levels[0].get_dependent_cross_product() for level in factor.levels: - cross_product: List[Tuple[Level, ...]] = level.get_dependent_cross_product() valid_tuples: List[Tuple[Level, ...]] = [] for level_tuple in cross_product: args = [(level.name if not isinstance(level, BeforeStart) else None) for level in level_tuple] @@ -98,6 +99,16 @@ def generate_derivations(block: Block) -> List[Derivation]: block.variables_per_trial()) level_index = block.first_variable_for_level(factor, level) accum.append(Derivation(level_index, shifted_indices, factor)) + # check that everything in the cross product is covered by some level + for level_tuple in cross_product: + if level_tuple not in according_level: + in_crossing = block.factor_in_crossing(factor) + maybe_crossing = "crossed " if in_crossing else "" + args = [(level.name if not isinstance(level, BeforeStart) else None) for level in level_tuple] + if level.window.width != 1: + args = list(chunk_dict(args, level.window.width)) # type: ignore + block.errors.add(f"No level in {maybe_crossing}factor" + f" '{factor.name}' has a precicate that matches '{args}'.") return accum @staticmethod diff --git a/sweetpea/tests/test_derivation_processor.py b/sweetpea/tests/test_derivation_processor.py index ad582661..952fb668 100644 --- a/sweetpea/tests/test_derivation_processor.py +++ b/sweetpea/tests/test_derivation_processor.py @@ -86,6 +86,28 @@ def test_generate_derivations_should_produce_warning_if_some_level_is_unreachabl block.show_errors() assert capsys.readouterr().out == "WARNING: No matches to the factor 'congruent?' predicate for level\n 'dum'.\n" +def test_generate_derivations_should_produce_error_if_some_combination_is_uncovered(capsys): + local_con_factor = Factor("congruent?", [ + DerivedLevel("con", WithinTrial(op.eq, [color, text])), + DerivedLevel("dum", WithinTrial(lambda c, t: c=='red' and t=='blue', [color, text])) + ]) + block = CrossBlock([color, text, local_con_factor], + [color, text], + [Reify(local_con_factor)]) + block.show_errors() + assert capsys.readouterr().out == "No level in factor 'congruent?' has a precicate that matches '['blue', 'red']'.\n" + +def test_generate_derivations_should_produce_error_if_some_transition_is_uncovered(capsys): + local_repeats_factor = Factor("repeats?", [ + DerivedLevel("yes", Transition(lambda colors, texts: colors[0] == colors[-1] and texts[-1] == texts[0], [color, text])), + DerivedLevel("no", Transition(lambda colors, texts: colors[0] != colors[-1] and texts[-1] != texts[0], [color, text])) + ]) + block = CrossBlock([color, text, local_repeats_factor], + [color, text], + [Reify(local_repeats_factor)]) + block.show_errors() + assert "No level in factor 'repeats?' has a precicate that matches '[{-1: 'red', 0: 'blue'}, {-1: 'red', 0: 'red'}]'.\n" in capsys.readouterr().out + def test_generate_derivations_within_trial(): assert DerivationProcessor.generate_derivations(blk) == [ Derivation(4, [[0, 2], [1, 3]], con_factor),