Skip to content

Commit

Permalink
2056 additional deepgrow workflow events (#2305)
Browse files Browse the repository at this point in the history
* additional deepgrow workflow events

Signed-off-by: Wenqi Li <[email protected]>

* remove the attach method

Signed-off-by: Wenqi Li <[email protected]>
  • Loading branch information
wyli authored Jun 4, 2021
1 parent be87f31 commit c1e06f5
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 7 deletions.
12 changes: 6 additions & 6 deletions monai/apps/deepgrow/interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@
import torch

from monai.engines import SupervisedEvaluator, SupervisedTrainer
from monai.engines.workflow import Events
from monai.engines.utils import IterationEvents
from monai.transforms import Compose
from monai.utils.enums import CommonKeys


class Interaction:
"""
Ignite handler used to introduce interactions (simulation of clicks) for Deepgrow Training/Evaluation.
Ignite process_function used to introduce interactions (simulation of clicks) for Deepgrow Training/Evaluation.
This implementation is based on:
Sakinis et al., Interactive segmentation of medical images through
Expand Down Expand Up @@ -50,10 +50,6 @@ def __init__(
self.train = train
self.key_probability = key_probability

def attach(self, engine: Union[SupervisedTrainer, SupervisedEvaluator]) -> None:
if not engine.has_event_handler(self, Events.ITERATION_STARTED):
engine.add_event_handler(Events.ITERATION_STARTED, self)

def __call__(self, engine: Union[SupervisedTrainer, SupervisedEvaluator], batchdata: Dict[str, torch.Tensor]):
if batchdata is None:
raise ValueError("Must provide batch data for current iteration.")
Expand All @@ -62,6 +58,8 @@ def __call__(self, engine: Union[SupervisedTrainer, SupervisedEvaluator], batchd
inputs, _ = engine.prepare_batch(batchdata)
inputs = inputs.to(engine.state.device)

engine.fire_event(IterationEvents.INNER_ITERATION_STARTED)

engine.network.eval()
with torch.no_grad():
if engine.amp:
Expand All @@ -70,6 +68,8 @@ def __call__(self, engine: Union[SupervisedTrainer, SupervisedEvaluator], batchd
else:
predictions = engine.inferer(inputs, engine.network)

engine.fire_event(IterationEvents.INNER_ITERATION_COMPLETED)

batchdata.update({CommonKeys.PRED: predictions})
batchdata[self.key_probability] = torch.as_tensor(
([1.0 - ((1.0 / self.max_interactions) * j)] if self.train else [1.0]) * len(inputs)
Expand Down
5 changes: 4 additions & 1 deletion monai/engines/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,16 @@ class IterationEvents(EventEnum):
`LOSS_COMPLETED` is the Event when `loss(pred, label)` completed.
`BACKWARD_COMPLETED` is the Event when `loss.backward()` completed.
`MODEL_COMPLETED` is the Event when all the model related operations completed.
`INNER_ITERATION_STARTED` is the Event when the iteration has an inner loop and the loop is started.
`INNER_ITERATION_COMPLETED` is the Event when the iteration has an inner loop and the loop is completed.
"""

FORWARD_COMPLETED = "forward_completed"
LOSS_COMPLETED = "loss_completed"
BACKWARD_COMPLETED = "backward_completed"
MODEL_COMPLETED = "model_completed"
INNER_ITERATION_STARTED = "inner_iteration_started"
INNER_ITERATION_COMPLETED = "inner_iteration_completed"


class GanKeys:
Expand Down
11 changes: 11 additions & 0 deletions tests/test_deepgrow_interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,17 @@
from monai.apps.deepgrow.interaction import Interaction
from monai.data import Dataset
from monai.engines import SupervisedTrainer
from monai.engines.utils import IterationEvents
from monai.transforms import Activationsd, Compose, ToNumpyd


def add_one(engine):
if engine.state.best_metric is -1:
engine.state.best_metric = 0
else:
engine.state.best_metric = engine.state.best_metric + 1


class TestInteractions(unittest.TestCase):
def run_interaction(self, train, compose):
data = []
Expand Down Expand Up @@ -47,9 +55,12 @@ def run_interaction(self, train, compose):
loss_function=loss,
iteration_update=i,
)
engine.add_event_handler(IterationEvents.INNER_ITERATION_STARTED, add_one)
engine.add_event_handler(IterationEvents.INNER_ITERATION_COMPLETED, add_one)

engine.run()
self.assertIsNotNone(engine.state.batch.get("probability"), "Probability is missing")
self.assertEqual(engine.state.best_metric, 9)

def test_train_interaction(self):
self.run_interaction(train=True, compose=True)
Expand Down

0 comments on commit c1e06f5

Please sign in to comment.