Skip to content

Commit

Permalink
Fix return types decision agents, make KeyboardAgent return nodes, ad…
Browse files Browse the repository at this point in the history
…apt cli to that
  • Loading branch information
mrkickling committed Jan 17, 2025
1 parent 697dc5f commit 2fe03ed
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 8 deletions.
9 changes: 5 additions & 4 deletions malsim/agents/decision_agent.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
"""A decision agent is a heuristic agent"""

from __future__ import annotations
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional
from abc import ABC, abstractmethod

if TYPE_CHECKING:
from ..sims import MalSimAgentView
from maltoolbox.attackgraph import AttackGraphNode

class DecisionAgent(ABC):

Expand All @@ -14,7 +15,7 @@ def get_next_action(
self,
agent: MalSimAgentView,
**kwargs
) -> tuple[int, int]: ...
) -> Optional[AttackGraphNode]: ...

class PassiveAgent(DecisionAgent):
def __init__(self, info):
Expand All @@ -24,5 +25,5 @@ def get_next_action(
self,
agent: MalSimAgentView,
**kwargs
) -> tuple[int, int]:
return (0, None)
) -> Optional[AttackGraphNode]:
return None
10 changes: 8 additions & 2 deletions malsim/agents/keyboard_input.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Optional

from .decision_agent import DecisionAgent
from ..sims import MalSimAgentView

if TYPE_CHECKING:
from maltoolbox.attackgraph import AttackGraphNode

logger = logging.getLogger(__name__)
null_action = []

Expand All @@ -16,7 +22,7 @@ def get_next_action(
self,
agent: MalSimAgentView,
**kwargs
) -> tuple:
) -> Optional[AttackGraphNode]:
"""Compute action from action_surface"""

def valid_action(user_input: str) -> bool:
Expand Down Expand Up @@ -59,4 +65,4 @@ def get_action_object(user_input: str) -> tuple:
if index is not None else 'wait'
)

return (1, index_to_node[index]) if index is not None else (0, None)
return index_to_node[index] if index is not None else None
4 changes: 2 additions & 2 deletions malsim/agents/searchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def select_next_target(
previous_target: int,
targets: Union[List[int], Deque[int]],
attack_surface: Set[int],
) -> int:
) -> Optional[int]:
"""Select a target from attack surface
by going through the target queue"""

Expand Down Expand Up @@ -154,7 +154,7 @@ def select_next_target(
previous_target: int,
targets: Union[List[int], Deque[int]],
attack_surface: Set[int],
) -> int:
) -> Optional[int]:
if previous_target in attack_surface:
return previous_target

Expand Down

0 comments on commit 2fe03ed

Please sign in to comment.