diff --git a/ai_diffusion/client.py b/ai_diffusion/client.py
index cbe92927f..b8c3a0199 100644
--- a/ai_diffusion/client.py
+++ b/ai_diffusion/client.py
@@ -147,6 +147,7 @@ async def connect(url=default_url):
]
if len(missing) > 0:
raise MissingResource(ResourceKind.node, missing)
+ client.nodes_inputs = {name: nodes[name]["input"]["required"] for name in nodes}
# Retrieve list of checkpoints
client._refresh_models(nodes, await client.try_inspect_checkpoints())
@@ -164,9 +165,6 @@ async def connect(url=default_url):
client.ip_adapter_model = {
ver: _find_ip_adapter(ip, ver) for ver in [SDVersion.sd15, SDVersion.sdxl]
}
- client.nodes_inputs["IPAdapterApplyEncoded"] = nodes["IPAdapterApplyEncoded"]["input"][
- "required"
- ]
# Retrieve upscale models
client.upscalers = nodes["UpscaleModelLoader"]["input"]["required"]["model_name"][0]
diff --git a/ai_diffusion/comfyworkflow.py b/ai_diffusion/comfyworkflow.py
index 266f4f352..3bd013792 100644
--- a/ai_diffusion/comfyworkflow.py
+++ b/ai_diffusion/comfyworkflow.py
@@ -30,24 +30,19 @@ def __init__(self, node_inputs: dict | None = None) -> None:
self._cache = {}
self._nodes_required_inputs = node_inputs or {}
- def get_node_optional_values(self, node_name: str, args: dict):
- node_inputs = self._nodes_required_inputs.get(node_name, None)
-
- if node_inputs is None:
- return args
-
- values = {}
- for k, v in node_inputs.items():
- if args is not None and k in args.keys():
- values[k] = args[k]
- elif len(v) == 1:
- if isinstance(v[0], list) and len(v[0]) > 0:
- values[k] = v[0][0]
- elif len(k) >= 1 and isinstance(v[1], dict):
- if v[1].get("default", None) is not None:
- values[k] = v[1]["default"]
-
- return values
+ def add_default_values(self, node_name: str, args: dict):
+ if node_inputs := self._nodes_required_inputs.get(node_name, None):
+ for k, v in node_inputs.items():
+ if k not in args:
+ if len(v) == 1 and isinstance(v[0], list) and len(v[0]) > 0:
+ # enum type, use first value in list of possible values
+ args[k] = v[0][0]
+ elif len(v) > 1 and isinstance(v[1], dict):
+ # other type, try to access default value
+ default = v[1].get("default", None)
+ if default is not None:
+ args[k] = default
+ return args
def dump(self, filepath: str):
with open(filepath, "w") as f:
@@ -63,7 +58,7 @@ def add(self, class_type: str, output_count: Literal[2], **inputs) -> Output2: .
def add(self, class_type: str, output_count: Literal[3], **inputs) -> Output3: ...
def add(self, class_type: str, output_count: int, **inputs):
- inputs = self.get_node_optional_values(class_type, inputs)
+ inputs = self.add_default_values(class_type, inputs)
normalize = lambda x: [str(x.node), x.output] if isinstance(x, Output) else x
self.node_count += 1
self.root[str(self.node_count)] = {
diff --git a/ai_diffusion/document.py b/ai_diffusion/document.py
index 231f7c3fc..371f7f5f0 100644
--- a/ai_diffusion/document.py
+++ b/ai_diffusion/document.py
@@ -39,7 +39,7 @@ def create_mask_from_layer(self, padding: float, is_inpaint: bool) -> tuple[Mask
def get_image(
self, bounds: Bounds | None = None, exclude_layers: list[krita.Node] | None = None
- ):
+ ) -> Image:
raise NotImplementedError
def get_layer_image(self, layer: krita.Node, bounds: Bounds | None) -> Image:
diff --git a/ai_diffusion/model.py b/ai_diffusion/model.py
index eccbd4a16..bcab7e39e 100644
--- a/ai_diffusion/model.py
+++ b/ai_diffusion/model.py
@@ -213,7 +213,7 @@ def generate_control_layer(self, control: ControlLayer):
async def _generate_control_layer(self, job: Job, image: Image, mode: ControlMode):
client = self._connection.client
- work = workflow.create_control_image(image, mode)
+ work = workflow.create_control_image(client, image, mode)
job.id = await client.enqueue(work)
def cancel(self, active=False, queued=False):
diff --git a/ai_diffusion/workflow.py b/ai_diffusion/workflow.py
index fc4543e14..c69f406e6 100644
--- a/ai_diffusion/workflow.py
+++ b/ai_diffusion/workflow.py
@@ -621,10 +621,10 @@ def refine_region(
return w
-def create_control_image(image: Image, mode: ControlMode):
+def create_control_image(comfy: Client, image: Image, mode: ControlMode):
assert mode not in [ControlMode.image, ControlMode.inpaint]
- w = ComfyWorkflow()
+ w = ComfyWorkflow(comfy.nodes_inputs)
input = w.load_image(image)
result = None
@@ -663,7 +663,7 @@ def create_control_image(image: Image, mode: ControlMode):
def upscale_simple(comfy: Client, image: Image, model: str, factor: float):
- w = ComfyWorkflow()
+ w = ComfyWorkflow(comfy.nodes_inputs)
upscale_model = w.load_upscale_model(model)
img = w.load_image(image)
img = w.upscale_image(upscale_model, img)
diff --git a/tests/references/test_create_control_image_pose.png b/tests/references/test_create_control_image_pose.png
index c286047a5..7d360507b 100644
--- a/tests/references/test_create_control_image_pose.png
+++ b/tests/references/test_create_control_image_pose.png
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
-oid sha256:414a2bc7c844d08dec79eddf0a379b3d9c02a0fc272a5e4704b155c9362fc0c7
+oid sha256:36599bf877dc8d3adba2de0a42030275db57cb05bc22c6812e793249d4f194c6
size 21278
diff --git a/tests/references/test_create_open_pose_vector.svg b/tests/references/test_create_open_pose_vector.svg
index d4ed1b77f..010697f32 100644
--- a/tests/references/test_create_open_pose_vector.svg
+++ b/tests/references/test_create_open_pose_vector.svg
@@ -1 +1 @@
-
\ No newline at end of file
+
\ No newline at end of file
diff --git a/tests/test_workflow.py b/tests/test_workflow.py
index 7374f2693..fd232f0f2 100644
--- a/tests/test_workflow.py
+++ b/tests/test_workflow.py
@@ -398,7 +398,7 @@ async def main():
def test_create_control_image(qtapp, comfy, mode):
image_name = f"test_create_control_image_{mode.name}.png"
image = Image.load(image_dir / "adobe_stock.jpg")
- job = workflow.create_control_image(image, mode)
+ job = workflow.create_control_image(comfy, image, mode)
async def main():
result = await run_and_save(comfy, job, image_name)
@@ -411,7 +411,7 @@ async def main():
def test_create_open_pose_vector(qtapp, comfy):
image_name = f"test_create_open_pose_vector.svg"
image = Image.load(image_dir / "adobe_stock.jpg")
- job = workflow.create_control_image(image, ControlMode.pose)
+ job = workflow.create_control_image(comfy, image, ControlMode.pose)
async def main():
job_id = None