From 8774699092a443b851f019ae991e8352163d9056 Mon Sep 17 00:00:00 2001 From: Acly Date: Thu, 21 Dec 2023 13:57:27 +0100 Subject: [PATCH] Add default required inputs for all node types --- ai_diffusion/client.py | 4 +-- ai_diffusion/comfyworkflow.py | 33 ++++++++----------- ai_diffusion/document.py | 2 +- ai_diffusion/model.py | 2 +- ai_diffusion/workflow.py | 6 ++-- .../test_create_control_image_pose.png | 2 +- .../test_create_open_pose_vector.svg | 2 +- tests/test_workflow.py | 4 +-- 8 files changed, 24 insertions(+), 31 deletions(-) diff --git a/ai_diffusion/client.py b/ai_diffusion/client.py index cbe92927f..b8c3a0199 100644 --- a/ai_diffusion/client.py +++ b/ai_diffusion/client.py @@ -147,6 +147,7 @@ async def connect(url=default_url): ] if len(missing) > 0: raise MissingResource(ResourceKind.node, missing) + client.nodes_inputs = {name: nodes[name]["input"]["required"] for name in nodes} # Retrieve list of checkpoints client._refresh_models(nodes, await client.try_inspect_checkpoints()) @@ -164,9 +165,6 @@ async def connect(url=default_url): client.ip_adapter_model = { ver: _find_ip_adapter(ip, ver) for ver in [SDVersion.sd15, SDVersion.sdxl] } - client.nodes_inputs["IPAdapterApplyEncoded"] = nodes["IPAdapterApplyEncoded"]["input"][ - "required" - ] # Retrieve upscale models client.upscalers = nodes["UpscaleModelLoader"]["input"]["required"]["model_name"][0] diff --git a/ai_diffusion/comfyworkflow.py b/ai_diffusion/comfyworkflow.py index 266f4f352..3bd013792 100644 --- a/ai_diffusion/comfyworkflow.py +++ b/ai_diffusion/comfyworkflow.py @@ -30,24 +30,19 @@ def __init__(self, node_inputs: dict | None = None) -> None: self._cache = {} self._nodes_required_inputs = node_inputs or {} - def get_node_optional_values(self, node_name: str, args: dict): - node_inputs = self._nodes_required_inputs.get(node_name, None) - - if node_inputs is None: - return args - - values = {} - for k, v in node_inputs.items(): - if args is not None and k in args.keys(): - values[k] = args[k] - elif len(v) == 1: - if isinstance(v[0], list) and len(v[0]) > 0: - values[k] = v[0][0] - elif len(k) >= 1 and isinstance(v[1], dict): - if v[1].get("default", None) is not None: - values[k] = v[1]["default"] - - return values + def add_default_values(self, node_name: str, args: dict): + if node_inputs := self._nodes_required_inputs.get(node_name, None): + for k, v in node_inputs.items(): + if k not in args: + if len(v) == 1 and isinstance(v[0], list) and len(v[0]) > 0: + # enum type, use first value in list of possible values + args[k] = v[0][0] + elif len(v) > 1 and isinstance(v[1], dict): + # other type, try to access default value + default = v[1].get("default", None) + if default is not None: + args[k] = default + return args def dump(self, filepath: str): with open(filepath, "w") as f: @@ -63,7 +58,7 @@ def add(self, class_type: str, output_count: Literal[2], **inputs) -> Output2: . def add(self, class_type: str, output_count: Literal[3], **inputs) -> Output3: ... def add(self, class_type: str, output_count: int, **inputs): - inputs = self.get_node_optional_values(class_type, inputs) + inputs = self.add_default_values(class_type, inputs) normalize = lambda x: [str(x.node), x.output] if isinstance(x, Output) else x self.node_count += 1 self.root[str(self.node_count)] = { diff --git a/ai_diffusion/document.py b/ai_diffusion/document.py index 231f7c3fc..371f7f5f0 100644 --- a/ai_diffusion/document.py +++ b/ai_diffusion/document.py @@ -39,7 +39,7 @@ def create_mask_from_layer(self, padding: float, is_inpaint: bool) -> tuple[Mask def get_image( self, bounds: Bounds | None = None, exclude_layers: list[krita.Node] | None = None - ): + ) -> Image: raise NotImplementedError def get_layer_image(self, layer: krita.Node, bounds: Bounds | None) -> Image: diff --git a/ai_diffusion/model.py b/ai_diffusion/model.py index eccbd4a16..bcab7e39e 100644 --- a/ai_diffusion/model.py +++ b/ai_diffusion/model.py @@ -213,7 +213,7 @@ def generate_control_layer(self, control: ControlLayer): async def _generate_control_layer(self, job: Job, image: Image, mode: ControlMode): client = self._connection.client - work = workflow.create_control_image(image, mode) + work = workflow.create_control_image(client, image, mode) job.id = await client.enqueue(work) def cancel(self, active=False, queued=False): diff --git a/ai_diffusion/workflow.py b/ai_diffusion/workflow.py index fc4543e14..c69f406e6 100644 --- a/ai_diffusion/workflow.py +++ b/ai_diffusion/workflow.py @@ -621,10 +621,10 @@ def refine_region( return w -def create_control_image(image: Image, mode: ControlMode): +def create_control_image(comfy: Client, image: Image, mode: ControlMode): assert mode not in [ControlMode.image, ControlMode.inpaint] - w = ComfyWorkflow() + w = ComfyWorkflow(comfy.nodes_inputs) input = w.load_image(image) result = None @@ -663,7 +663,7 @@ def create_control_image(image: Image, mode: ControlMode): def upscale_simple(comfy: Client, image: Image, model: str, factor: float): - w = ComfyWorkflow() + w = ComfyWorkflow(comfy.nodes_inputs) upscale_model = w.load_upscale_model(model) img = w.load_image(image) img = w.upscale_image(upscale_model, img) diff --git a/tests/references/test_create_control_image_pose.png b/tests/references/test_create_control_image_pose.png index c286047a5..7d360507b 100644 --- a/tests/references/test_create_control_image_pose.png +++ b/tests/references/test_create_control_image_pose.png @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:414a2bc7c844d08dec79eddf0a379b3d9c02a0fc272a5e4704b155c9362fc0c7 +oid sha256:36599bf877dc8d3adba2de0a42030275db57cb05bc22c6812e793249d4f194c6 size 21278 diff --git a/tests/references/test_create_open_pose_vector.svg b/tests/references/test_create_open_pose_vector.svg index d4ed1b77f..010697f32 100644 --- a/tests/references/test_create_open_pose_vector.svg +++ b/tests/references/test_create_open_pose_vector.svg @@ -1 +1 @@ - \ No newline at end of file + \ No newline at end of file diff --git a/tests/test_workflow.py b/tests/test_workflow.py index 7374f2693..fd232f0f2 100644 --- a/tests/test_workflow.py +++ b/tests/test_workflow.py @@ -398,7 +398,7 @@ async def main(): def test_create_control_image(qtapp, comfy, mode): image_name = f"test_create_control_image_{mode.name}.png" image = Image.load(image_dir / "adobe_stock.jpg") - job = workflow.create_control_image(image, mode) + job = workflow.create_control_image(comfy, image, mode) async def main(): result = await run_and_save(comfy, job, image_name) @@ -411,7 +411,7 @@ async def main(): def test_create_open_pose_vector(qtapp, comfy): image_name = f"test_create_open_pose_vector.svg" image = Image.load(image_dir / "adobe_stock.jpg") - job = workflow.create_control_image(image, ControlMode.pose) + job = workflow.create_control_image(comfy, image, ControlMode.pose) async def main(): job_id = None