diff --git a/malsim/agents/decision_agent.py b/malsim/agents/decision_agent.py index eeca5b1a..4c8ec939 100644 --- a/malsim/agents/decision_agent.py +++ b/malsim/agents/decision_agent.py @@ -1,11 +1,12 @@ """A decision agent is a heuristic agent""" from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional from abc import ABC, abstractmethod if TYPE_CHECKING: from ..sims import MalSimAgentView + from maltoolbox.attackgraph import AttackGraphNode class DecisionAgent(ABC): @@ -14,7 +15,7 @@ def get_next_action( self, agent: MalSimAgentView, **kwargs - ) -> tuple[int, int]: ... + ) -> Optional[AttackGraphNode]: ... class PassiveAgent(DecisionAgent): def __init__(self, info): @@ -24,5 +25,5 @@ def get_next_action( self, agent: MalSimAgentView, **kwargs - ) -> tuple[int, int]: - return (0, None) + ) -> Optional[AttackGraphNode]: + return None diff --git a/malsim/agents/keyboard_input.py b/malsim/agents/keyboard_input.py index 6e057ba7..4aec4c87 100644 --- a/malsim/agents/keyboard_input.py +++ b/malsim/agents/keyboard_input.py @@ -1,7 +1,13 @@ +from __future__ import annotations import logging +from typing import TYPE_CHECKING, Optional + from .decision_agent import DecisionAgent from ..sims import MalSimAgentView +if TYPE_CHECKING: + from maltoolbox.attackgraph import AttackGraphNode + logger = logging.getLogger(__name__) null_action = [] @@ -16,7 +22,7 @@ def get_next_action( self, agent: MalSimAgentView, **kwargs - ) -> tuple: + ) -> Optional[AttackGraphNode]: """Compute action from action_surface""" def valid_action(user_input: str) -> bool: @@ -59,4 +65,4 @@ def get_action_object(user_input: str) -> tuple: if index is not None else 'wait' ) - return (1, index_to_node[index]) if index is not None else (0, None) + return index_to_node[index] if index is not None else None