-
Notifications
You must be signed in to change notification settings - Fork 2.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add unit tests for ExtensionsManager and ExtensionBase.
- Loading branch information
Showing
3 changed files
with
161 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
from unittest import mock | ||
|
||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext | ||
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType | ||
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback | ||
|
||
|
||
class MockExtension(ExtensionBase): | ||
"""A mock ExtensionBase subclass for testing purposes.""" | ||
|
||
def __init__(self, x: int): | ||
super().__init__() | ||
self._x = x | ||
|
||
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP) | ||
def set_step_index(self, ctx: DenoiseContext): | ||
ctx.step_index = self._x | ||
|
||
|
||
def test_extension_base_callback_registration(): | ||
"""Test that a callback can be successfully registered with an extension.""" | ||
val = 5 | ||
mock_extension = MockExtension(val) | ||
|
||
mock_ctx = mock.MagicMock() | ||
|
||
callbacks = mock_extension.get_callbacks() | ||
pre_denoise_loop_cbs = callbacks.get(ExtensionCallbackType.PRE_DENOISE_LOOP, []) | ||
assert len(pre_denoise_loop_cbs) == 1 | ||
|
||
# Call the mock callback. | ||
pre_denoise_loop_cbs[0].function(mock_ctx) | ||
|
||
# Confirm that the callback ran. | ||
assert mock_ctx.step_index == val | ||
|
||
|
||
def test_extension_base_empty_callback_type(): | ||
"""Test that an empty list is returned when no callbacks are registered for a given callback type.""" | ||
mock_extension = MockExtension(5) | ||
|
||
# There should be no callbacks registered for POST_DENOISE_LOOP. | ||
callbacks = mock_extension.get_callbacks() | ||
|
||
post_denoise_loop_cbs = callbacks.get(ExtensionCallbackType.POST_DENOISE_LOOP, []) | ||
assert len(post_denoise_loop_cbs) == 0 |
112 changes: 112 additions & 0 deletions
112
tests/backend/stable_diffusion/test_extension_manager.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
from unittest import mock | ||
|
||
import pytest | ||
|
||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext | ||
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType | ||
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback | ||
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager | ||
|
||
|
||
class MockExtension(ExtensionBase): | ||
"""A mock ExtensionBase subclass for testing purposes.""" | ||
|
||
def __init__(self, x: int): | ||
super().__init__() | ||
self._x = x | ||
|
||
# Note that order is not specified. It should default to 0. | ||
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP) | ||
def set_step_index(self, ctx: DenoiseContext): | ||
ctx.step_index = self._x | ||
|
||
|
||
class MockExtensionLate(ExtensionBase): | ||
"""A mock ExtensionBase subclass with a high order value on its PRE_DENOISE_LOOP callback.""" | ||
|
||
def __init__(self, x: int): | ||
super().__init__() | ||
self._x = x | ||
|
||
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP, order=1000) | ||
def set_step_index(self, ctx: DenoiseContext): | ||
ctx.step_index = self._x | ||
|
||
|
||
def test_extension_manager_run_callback(): | ||
"""Test that run_callback runs all callbacks for the given callback type.""" | ||
|
||
em = ExtensionsManager() | ||
mock_extension_1 = MockExtension(1) | ||
em.add_extension(mock_extension_1) | ||
|
||
mock_ctx = mock.MagicMock() | ||
em.run_callback(ExtensionCallbackType.PRE_DENOISE_LOOP, mock_ctx) | ||
|
||
assert mock_ctx.step_index == 1 | ||
|
||
|
||
def test_extension_manager_run_callback_no_callbacks(): | ||
"""Test that run_callback does not raise an error when there are no callbacks for the given callback type.""" | ||
em = ExtensionsManager() | ||
mock_ctx = mock.MagicMock() | ||
em.run_callback(ExtensionCallbackType.PRE_DENOISE_LOOP, mock_ctx) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
["extension_1", "extension_2"], | ||
# Regardless of initialization order, we expect MockExtensionLate to run last. | ||
[(MockExtension(1), MockExtensionLate(2)), (MockExtensionLate(2), MockExtension(1))], | ||
) | ||
def test_extension_manager_order_callbacks(extension_1: ExtensionBase, extension_2: ExtensionBase): | ||
"""Test that run_callback runs callbacks in the correct order.""" | ||
em = ExtensionsManager() | ||
em.add_extension(extension_1) | ||
em.add_extension(extension_2) | ||
|
||
mock_ctx = mock.MagicMock() | ||
em.run_callback(ExtensionCallbackType.PRE_DENOISE_LOOP, mock_ctx) | ||
|
||
assert mock_ctx.step_index == 2 | ||
|
||
|
||
class MockExtensionStableSort(ExtensionBase): | ||
"""A mock extension with three PRE_DENOISE_LOOP callbacks, each with a different order value.""" | ||
|
||
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP, order=-1000) | ||
def early(self, ctx: DenoiseContext): | ||
pass | ||
|
||
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP) | ||
def middle(self, ctx: DenoiseContext): | ||
pass | ||
|
||
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP, order=1000) | ||
def late(self, ctx: DenoiseContext): | ||
pass | ||
|
||
|
||
def test_extension_manager_stable_sort(): | ||
"""Test that when two callbacks have the same 'order' value, they are sorted based on the order they were added to | ||
the ExtensionsManager.""" | ||
|
||
em = ExtensionsManager() | ||
|
||
mock_extension_1 = MockExtensionStableSort() | ||
mock_extension_2 = MockExtensionStableSort() | ||
|
||
em.add_extension(mock_extension_1) | ||
em.add_extension(mock_extension_2) | ||
|
||
expected_order = [ | ||
mock_extension_1.early, | ||
mock_extension_2.early, | ||
mock_extension_1.middle, | ||
mock_extension_2.middle, | ||
mock_extension_1.late, | ||
mock_extension_2.late, | ||
] | ||
|
||
# It's not ideal that we are accessing a private attribute here, but this was the most direct way to assert the | ||
# desired behaviour. | ||
assert [cb.function for cb in em._ordered_callbacks[ExtensionCallbackType.PRE_DENOISE_LOOP]] == expected_order |