Skip to content

Commit

Permalink
Make unit tests work again (#3575)
Browse files Browse the repository at this point in the history
This PR is for adjusting the unit tests in the `tests` directory so that
they no longer throw errors.

I've removed two tests that were obsoleted by the shift to latent nodes,
but `test_graph_execution_state.py` and `test_invoker.py` are throwing
this validation error:

```
TypeError: InvocationServices.__init__() missing 2 required positional arguments: 'boards' and 'board_images'
```
  • Loading branch information
lstein authored Jul 3, 2023
2 parents c314b17 + 9c83a4e commit 78857bf
Show file tree
Hide file tree
Showing 7 changed files with 237 additions and 460 deletions.
57 changes: 29 additions & 28 deletions invokeai/app/services/invocation_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from invokeai.app.services.board_images import BoardImagesServiceABC
from invokeai.app.services.boards import BoardServiceABC
from invokeai.app.services.images import ImageServiceABC
from invokeai.backend import ModelManager
from invokeai.app.services.model_manager_service import ModelManagerServiceBase
from invokeai.app.services.events import EventServiceBase
from invokeai.app.services.latent_storage import LatentsStorageBase
from invokeai.app.services.restoration_services import RestorationServices
Expand All @@ -22,46 +22,47 @@ class InvocationServices:
"""Services that can be used by invocations"""

# TODO: Just forward-declared everything due to circular dependencies. Fix structure.
board_images: "BoardImagesServiceABC"
boards: "BoardServiceABC"
configuration: "InvokeAISettings"
events: "EventServiceBase"
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"]
graph_library: "ItemStorageABC"["LibraryGraph"]
images: "ImageServiceABC"
latents: "LatentsStorageBase"
logger: "Logger"
model_manager: "ModelManagerServiceBase"
processor: "InvocationProcessorABC"
queue: "InvocationQueueABC"
model_manager: "ModelManager"
restoration: "RestorationServices"
configuration: "InvokeAISettings"
images: "ImageServiceABC"
boards: "BoardServiceABC"
board_images: "BoardImagesServiceABC"
graph_library: "ItemStorageABC"["LibraryGraph"]
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"]
processor: "InvocationProcessorABC"

def __init__(
self,
model_manager: "ModelManager",
events: "EventServiceBase",
logger: "Logger",
latents: "LatentsStorageBase",
images: "ImageServiceABC",
boards: "BoardServiceABC",
board_images: "BoardImagesServiceABC",
queue: "InvocationQueueABC",
graph_library: "ItemStorageABC"["LibraryGraph"],
boards: "BoardServiceABC",
configuration: "InvokeAISettings",
events: "EventServiceBase",
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"],
graph_library: "ItemStorageABC"["LibraryGraph"],
images: "ImageServiceABC",
latents: "LatentsStorageBase",
logger: "Logger",
model_manager: "ModelManagerServiceBase",
processor: "InvocationProcessorABC",
queue: "InvocationQueueABC",
restoration: "RestorationServices",
configuration: "InvokeAISettings",
):
self.model_manager = model_manager
self.events = events
self.logger = logger
self.latents = latents
self.images = images
self.boards = boards
self.board_images = board_images
self.queue = queue
self.graph_library = graph_library
self.boards = boards
self.boards = boards
self.configuration = configuration
self.events = events
self.graph_execution_manager = graph_execution_manager
self.graph_library = graph_library
self.images = images
self.latents = latents
self.logger = logger
self.model_manager = model_manager
self.processor = processor
self.queue = queue
self.restoration = restoration
self.configuration = configuration
self.boards = boards
30 changes: 0 additions & 30 deletions tests/conftest.py

This file was deleted.

130 changes: 95 additions & 35 deletions tests/nodes/test_graph_execution_state.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,81 @@
import pytest

from invokeai.app.invocations.baseinvocation import (BaseInvocation,
BaseInvocationOutput,
InvocationContext)
from .test_invoker import create_edge
from .test_nodes import (
TestEventService,
TextToImageTestInvocation,
PromptTestInvocation,
PromptCollectionTestInvocation,
)
from invokeai.app.services.invocation_queue import MemoryInvocationQueue
from invokeai.app.services.processor import DefaultInvocationProcessor
from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
InvocationContext,
)
from invokeai.app.invocations.collections import RangeInvocation
from invokeai.app.invocations.math import AddInvocation, MultiplyInvocation
from invokeai.app.services.graph import (CollectInvocation, Graph,
GraphExecutionState,
IterateInvocation)
from invokeai.app.services.invocation_services import InvocationServices

from .test_invoker import create_edge
from .test_nodes import (ImageTestInvocation, PromptCollectionTestInvocation,
PromptTestInvocation)
from invokeai.app.services.graph import (
Graph,
CollectInvocation,
IterateInvocation,
GraphExecutionState,
LibraryGraph,
)
import pytest


@pytest.fixture
def simple_graph():
g = Graph()
g.add_node(PromptTestInvocation(id = "1", prompt = "Banana sushi"))
g.add_node(ImageTestInvocation(id = "2"))
g.add_node(PromptTestInvocation(id="1", prompt="Banana sushi"))
g.add_node(TextToImageTestInvocation(id="2"))
g.add_edge(create_edge("1", "prompt", "2", "prompt"))
return g

def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[BaseInvocation, BaseInvocationOutput]:

# This must be defined here to avoid issues with the dynamic creation of the union of all invocation types
# Defining it in a separate module will cause the union to be incomplete, and pydantic will not validate
# the test invocations.
@pytest.fixture
def mock_services() -> InvocationServices:
# NOTE: none of these are actually called by the test invocations
return InvocationServices(
model_manager = None, # type: ignore
events = TestEventService(),
logger = None, # type: ignore
images = None, # type: ignore
latents = None, # type: ignore
boards = None, # type: ignore
board_images = None, # type: ignore
queue = MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph](
filename=sqlite_memory, table_name="graphs"
),
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
processor = DefaultInvocationProcessor(),
restoration = None, # type: ignore
configuration = None, # type: ignore
)


def invoke_next(
g: GraphExecutionState, services: InvocationServices
) -> tuple[BaseInvocation, BaseInvocationOutput]:
n = g.next()
if n is None:
return (None, None)

print(f'invoking {n.id}: {type(n)}')
print(f"invoking {n.id}: {type(n)}")
o = n.invoke(InvocationContext(services, "1"))
g.complete(n.id, o)

return (n, o)


def test_graph_state_executes_in_order(simple_graph, mock_services):
g = GraphExecutionState(graph = simple_graph)
g = GraphExecutionState(graph=simple_graph)

n1 = invoke_next(g, mock_services)
n2 = invoke_next(g, mock_services)
Expand All @@ -47,38 +87,42 @@ def test_graph_state_executes_in_order(simple_graph, mock_services):
assert g.results[n1[0].id].prompt == n1[0].prompt
assert n2[0].prompt == n1[0].prompt


def test_graph_is_complete(simple_graph, mock_services):
g = GraphExecutionState(graph = simple_graph)
g = GraphExecutionState(graph=simple_graph)
n1 = invoke_next(g, mock_services)
n2 = invoke_next(g, mock_services)
n3 = g.next()

assert g.is_complete()


def test_graph_is_not_complete(simple_graph, mock_services):
g = GraphExecutionState(graph = simple_graph)
g = GraphExecutionState(graph=simple_graph)
n1 = invoke_next(g, mock_services)
n2 = g.next()

assert not g.is_complete()


# TODO: test completion with iterators/subgraphs


def test_graph_state_expands_iterator(mock_services):
graph = Graph()
graph.add_node(RangeInvocation(id = "0", start = 0, stop = 3, step = 1))
graph.add_node(IterateInvocation(id = "1"))
graph.add_node(MultiplyInvocation(id = "2", b = 10))
graph.add_node(AddInvocation(id = "3", b = 1))
graph.add_node(RangeInvocation(id="0", start=0, stop=3, step=1))
graph.add_node(IterateInvocation(id="1"))
graph.add_node(MultiplyInvocation(id="2", b=10))
graph.add_node(AddInvocation(id="3", b=1))
graph.add_edge(create_edge("0", "collection", "1", "collection"))
graph.add_edge(create_edge("1", "item", "2", "a"))
graph.add_edge(create_edge("2", "a", "3", "a"))

g = GraphExecutionState(graph = graph)
g = GraphExecutionState(graph=graph)
while not g.is_complete():
invoke_next(g, mock_services)

prepared_add_nodes = g.source_prepared_mapping['3']
prepared_add_nodes = g.source_prepared_mapping["3"]
results = set([g.results[n].a for n in prepared_add_nodes])
expected = set([1, 11, 21])
assert results == expected
Expand All @@ -87,15 +131,17 @@ def test_graph_state_expands_iterator(mock_services):
def test_graph_state_collects(mock_services):
graph = Graph()
test_prompts = ["Banana sushi", "Cat sushi"]
graph.add_node(PromptCollectionTestInvocation(id = "1", collection = list(test_prompts)))
graph.add_node(IterateInvocation(id = "2"))
graph.add_node(PromptTestInvocation(id = "3"))
graph.add_node(CollectInvocation(id = "4"))
graph.add_node(
PromptCollectionTestInvocation(id="1", collection=list(test_prompts))
)
graph.add_node(IterateInvocation(id="2"))
graph.add_node(PromptTestInvocation(id="3"))
graph.add_node(CollectInvocation(id="4"))
graph.add_edge(create_edge("1", "collection", "2", "collection"))
graph.add_edge(create_edge("2", "item", "3", "prompt"))
graph.add_edge(create_edge("3", "prompt", "4", "item"))

g = GraphExecutionState(graph = graph)
g = GraphExecutionState(graph=graph)
n1 = invoke_next(g, mock_services)
n2 = invoke_next(g, mock_services)
n3 = invoke_next(g, mock_services)
Expand All @@ -113,10 +159,16 @@ def test_graph_state_prepares_eagerly(mock_services):
graph = Graph()

test_prompts = ["Banana sushi", "Cat sushi"]
graph.add_node(PromptCollectionTestInvocation(id="prompt_collection", collection=list(test_prompts)))
graph.add_node(
PromptCollectionTestInvocation(
id="prompt_collection", collection=list(test_prompts)
)
)
graph.add_node(IterateInvocation(id="iterate"))
graph.add_node(PromptTestInvocation(id="prompt_iterated"))
graph.add_edge(create_edge("prompt_collection", "collection", "iterate", "collection"))
graph.add_edge(
create_edge("prompt_collection", "collection", "iterate", "collection")
)
graph.add_edge(create_edge("iterate", "item", "prompt_iterated", "prompt"))

# separated, fully-preparable chain of nodes
Expand All @@ -142,13 +194,21 @@ def test_graph_executes_depth_first(mock_services):
graph = Graph()

test_prompts = ["Banana sushi", "Cat sushi"]
graph.add_node(PromptCollectionTestInvocation(id="prompt_collection", collection=list(test_prompts)))
graph.add_node(
PromptCollectionTestInvocation(
id="prompt_collection", collection=list(test_prompts)
)
)
graph.add_node(IterateInvocation(id="iterate"))
graph.add_node(PromptTestInvocation(id="prompt_iterated"))
graph.add_node(PromptTestInvocation(id="prompt_successor"))
graph.add_edge(create_edge("prompt_collection", "collection", "iterate", "collection"))
graph.add_edge(
create_edge("prompt_collection", "collection", "iterate", "collection")
)
graph.add_edge(create_edge("iterate", "item", "prompt_iterated", "prompt"))
graph.add_edge(create_edge("prompt_iterated", "prompt", "prompt_successor", "prompt"))
graph.add_edge(
create_edge("prompt_iterated", "prompt", "prompt_successor", "prompt")
)

g = GraphExecutionState(graph=graph)
n1 = invoke_next(g, mock_services)
Expand Down
Loading

0 comments on commit 78857bf

Please sign in to comment.