From 486e5230e8f75dd54e9dd69731116260d4d08355 Mon Sep 17 00:00:00 2001 From: Joakim Loxdal Date: Thu, 17 Oct 2024 14:09:46 +0200 Subject: [PATCH] Add test to test deterministic simulation when seed is given --- tests/test_mal_simulator.py | 64 ++++++++++++++++++++++++++++++++++++- 1 file changed, 63 insertions(+), 1 deletion(-) diff --git a/tests/test_mal_simulator.py b/tests/test_mal_simulator.py index 82ce23f6..4e1a94b6 100644 --- a/tests/test_mal_simulator.py +++ b/tests/test_mal_simulator.py @@ -1,11 +1,12 @@ """Test MalSimulator class""" +import pytest + from maltoolbox.attackgraph import AttackGraph, Attacker from malsim.sims.mal_simulator import MalSimulator from malsim.scenario import load_scenario, create_simulator_from_scenario from malsim.sims import MalSimulatorSettings - def test_malsimulator(corelang_lang_graph, model): attack_graph = AttackGraph(corelang_lang_graph, model) MalSimulator(corelang_lang_graph, model, attack_graph) @@ -990,6 +991,67 @@ def test_default_settings_defender_observation_false_negatives(): assert not defender_obs_state[node_index] +def test_default_settings_defender_observation_false_negatives_seed(): + """Test default MalSimulator with false negative rates""" + + def run_simulation_with_seed( + seed, + expected_attacker_obs, + expected_defender_obs + ): + sim.reset(seed=seed) + attacker_agent_id = next(iter(sim.get_attacker_agents())) + defender_agent_id = next(iter(sim.get_defender_agents())) + + # Get an uncompromised step + user_3_compromise = sim.attack_graph.get_node_by_full_name( + 'User:3:compromise') + + # Let the attacker compromise User:3:compromise + actions = { + attacker_agent_id: (1, sim._id_to_index[user_3_compromise.id]), + defender_agent_id: (0, None)} + + obs, _, _, _, _ = sim.step(actions) + + actual_attacker_obs_state = obs[attacker_agent_id]['observed_state'] + for index, state in enumerate(actual_attacker_obs_state): + assert state == expected_attacker_obs[index] + + actual_defender_obs_state = obs[defender_agent_id]['observed_state'] + for index, state in enumerate(actual_defender_obs_state): + assert state == expected_defender_obs[index] + + sim, _ = create_simulator_from_scenario( + 'tests/testdata/scenarios/traininglang_fp_fn_scenario.yml', + ) + + expected_attacker_obs_100 = \ + [-1, 0, 1, 0, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, -1] + expected_defender_obs_100 = \ + [0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0] + seed = 100 + + # Make sure the seed makes the obs state deterministic + # in terms of false positives + for _ in range(100): + # running it many times with same seed + # yields same observations + run_simulation_with_seed( + seed, + expected_attacker_obs_100, + expected_defender_obs_100 + ) + + with pytest.raises(AssertionError): + # running it once with different seed + # yields assertion error (different observations) + run_simulation_with_seed( + 1337, + expected_attacker_obs_100, + expected_defender_obs_100 + ) + def test_defender_observation_false_positives_negatives(): """Test default MalSimulator with false negative and positive rates"""