Skip to content

Commit

Permalink
add missing check for coverage of all combinations in a derived factor
Browse files Browse the repository at this point in the history
  • Loading branch information
mflatt committed Apr 20, 2024
1 parent 39c8fc7 commit 0353c89
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 2 deletions.
2 changes: 1 addition & 1 deletion example_programs/Paper.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def one_diff(colors, words):
return words[0] == words[-1]

def both_diff(colors, words):
return not one_diff(colors, words)
return (colors[0] != colors[-1]) and (words[0] != words[-1])

one = DerivedLevel("one", Transition(one_diff, [color, word]))
both = DerivedLevel("both", Transition(both_diff, [color, word]))
Expand Down
13 changes: 12 additions & 1 deletion sweetpea/_internal/derivation_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,9 @@ def generate_derivations(block: Block) -> List[Derivation]:
accum = []
for factor in derived_factors:
according_level: Dict[Tuple[Any, ...], DerivedLevel] = {}
# every level must have the same cross product, so we can get it once:
cross_product: List[Tuple[Level, ...]] = factor.levels[0].get_dependent_cross_product()
for level in factor.levels:
cross_product: List[Tuple[Level, ...]] = level.get_dependent_cross_product()
valid_tuples: List[Tuple[Level, ...]] = []
for level_tuple in cross_product:
args = [(level.name if not isinstance(level, BeforeStart) else None) for level in level_tuple]
Expand Down Expand Up @@ -98,6 +99,16 @@ def generate_derivations(block: Block) -> List[Derivation]:
block.variables_per_trial())
level_index = block.first_variable_for_level(factor, level)
accum.append(Derivation(level_index, shifted_indices, factor))
# check that everything in the cross product is covered by some level
for level_tuple in cross_product:
if level_tuple not in according_level:
in_crossing = block.factor_in_crossing(factor)
maybe_crossing = "crossed " if in_crossing else ""
args = [(level.name if not isinstance(level, BeforeStart) else None) for level in level_tuple]
if level.window.width != 1:
args = list(chunk_dict(args, level.window.width)) # type: ignore
block.errors.add(f"No level in {maybe_crossing}factor"
f" '{factor.name}' has a precicate that matches '{args}'.")
return accum

@staticmethod
Expand Down
22 changes: 22 additions & 0 deletions sweetpea/tests/test_derivation_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,28 @@ def test_generate_derivations_should_produce_warning_if_some_level_is_unreachabl
block.show_errors()
assert capsys.readouterr().out == "WARNING: No matches to the factor 'congruent?' predicate for level\n 'dum'.\n"

def test_generate_derivations_should_produce_error_if_some_combination_is_uncovered(capsys):
local_con_factor = Factor("congruent?", [
DerivedLevel("con", WithinTrial(op.eq, [color, text])),
DerivedLevel("dum", WithinTrial(lambda c, t: c=='red' and t=='blue', [color, text]))
])
block = CrossBlock([color, text, local_con_factor],
[color, text],
[Reify(local_con_factor)])
block.show_errors()
assert capsys.readouterr().out == "No level in factor 'congruent?' has a precicate that matches '['blue', 'red']'.\n"

def test_generate_derivations_should_produce_error_if_some_transition_is_uncovered(capsys):
local_repeats_factor = Factor("repeats?", [
DerivedLevel("yes", Transition(lambda colors, texts: colors[0] == colors[-1] and texts[-1] == texts[0], [color, text])),
DerivedLevel("no", Transition(lambda colors, texts: colors[0] != colors[-1] and texts[-1] != texts[0], [color, text]))
])
block = CrossBlock([color, text, local_repeats_factor],
[color, text],
[Reify(local_repeats_factor)])
block.show_errors()
assert "No level in factor 'repeats?' has a precicate that matches '[{-1: 'red', 0: 'blue'}, {-1: 'red', 0: 'red'}]'.\n" in capsys.readouterr().out

def test_generate_derivations_within_trial():
assert DerivationProcessor.generate_derivations(blk) == [
Derivation(4, [[0, 2], [1, 3]], con_factor),
Expand Down

0 comments on commit 0353c89

Please sign in to comment.