Skip to content

Commit

Permalink
Add default required inputs for all node types
Browse files Browse the repository at this point in the history
  • Loading branch information
Acly committed Dec 21, 2023
1 parent 948bfc7 commit 8774699
Show file tree
Hide file tree
Showing 8 changed files with 24 additions and 31 deletions.
4 changes: 1 addition & 3 deletions ai_diffusion/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ async def connect(url=default_url):
]
if len(missing) > 0:
raise MissingResource(ResourceKind.node, missing)
client.nodes_inputs = {name: nodes[name]["input"]["required"] for name in nodes}

# Retrieve list of checkpoints
client._refresh_models(nodes, await client.try_inspect_checkpoints())
Expand All @@ -164,9 +165,6 @@ async def connect(url=default_url):
client.ip_adapter_model = {
ver: _find_ip_adapter(ip, ver) for ver in [SDVersion.sd15, SDVersion.sdxl]
}
client.nodes_inputs["IPAdapterApplyEncoded"] = nodes["IPAdapterApplyEncoded"]["input"][
"required"
]

# Retrieve upscale models
client.upscalers = nodes["UpscaleModelLoader"]["input"]["required"]["model_name"][0]
Expand Down
33 changes: 14 additions & 19 deletions ai_diffusion/comfyworkflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,24 +30,19 @@ def __init__(self, node_inputs: dict | None = None) -> None:
self._cache = {}
self._nodes_required_inputs = node_inputs or {}

def get_node_optional_values(self, node_name: str, args: dict):
node_inputs = self._nodes_required_inputs.get(node_name, None)

if node_inputs is None:
return args

values = {}
for k, v in node_inputs.items():
if args is not None and k in args.keys():
values[k] = args[k]
elif len(v) == 1:
if isinstance(v[0], list) and len(v[0]) > 0:
values[k] = v[0][0]
elif len(k) >= 1 and isinstance(v[1], dict):
if v[1].get("default", None) is not None:
values[k] = v[1]["default"]

return values
def add_default_values(self, node_name: str, args: dict):
if node_inputs := self._nodes_required_inputs.get(node_name, None):
for k, v in node_inputs.items():
if k not in args:
if len(v) == 1 and isinstance(v[0], list) and len(v[0]) > 0:
# enum type, use first value in list of possible values
args[k] = v[0][0]
elif len(v) > 1 and isinstance(v[1], dict):
# other type, try to access default value
default = v[1].get("default", None)
if default is not None:
args[k] = default
return args

def dump(self, filepath: str):
with open(filepath, "w") as f:
Expand All @@ -63,7 +58,7 @@ def add(self, class_type: str, output_count: Literal[2], **inputs) -> Output2: .
def add(self, class_type: str, output_count: Literal[3], **inputs) -> Output3: ...

def add(self, class_type: str, output_count: int, **inputs):
inputs = self.get_node_optional_values(class_type, inputs)
inputs = self.add_default_values(class_type, inputs)
normalize = lambda x: [str(x.node), x.output] if isinstance(x, Output) else x
self.node_count += 1
self.root[str(self.node_count)] = {
Expand Down
2 changes: 1 addition & 1 deletion ai_diffusion/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def create_mask_from_layer(self, padding: float, is_inpaint: bool) -> tuple[Mask

def get_image(
self, bounds: Bounds | None = None, exclude_layers: list[krita.Node] | None = None
):
) -> Image:
raise NotImplementedError

def get_layer_image(self, layer: krita.Node, bounds: Bounds | None) -> Image:
Expand Down
2 changes: 1 addition & 1 deletion ai_diffusion/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def generate_control_layer(self, control: ControlLayer):

async def _generate_control_layer(self, job: Job, image: Image, mode: ControlMode):
client = self._connection.client
work = workflow.create_control_image(image, mode)
work = workflow.create_control_image(client, image, mode)
job.id = await client.enqueue(work)

def cancel(self, active=False, queued=False):
Expand Down
6 changes: 3 additions & 3 deletions ai_diffusion/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,10 +621,10 @@ def refine_region(
return w


def create_control_image(image: Image, mode: ControlMode):
def create_control_image(comfy: Client, image: Image, mode: ControlMode):
assert mode not in [ControlMode.image, ControlMode.inpaint]

w = ComfyWorkflow()
w = ComfyWorkflow(comfy.nodes_inputs)
input = w.load_image(image)
result = None

Expand Down Expand Up @@ -663,7 +663,7 @@ def create_control_image(image: Image, mode: ControlMode):


def upscale_simple(comfy: Client, image: Image, model: str, factor: float):
w = ComfyWorkflow()
w = ComfyWorkflow(comfy.nodes_inputs)
upscale_model = w.load_upscale_model(model)
img = w.load_image(image)
img = w.upscale_image(upscale_model, img)
Expand Down
2 changes: 1 addition & 1 deletion tests/references/test_create_control_image_pose.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion tests/references/test_create_open_pose_vector.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 2 additions & 2 deletions tests/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ async def main():
def test_create_control_image(qtapp, comfy, mode):
image_name = f"test_create_control_image_{mode.name}.png"
image = Image.load(image_dir / "adobe_stock.jpg")
job = workflow.create_control_image(image, mode)
job = workflow.create_control_image(comfy, image, mode)

async def main():
result = await run_and_save(comfy, job, image_name)
Expand All @@ -411,7 +411,7 @@ async def main():
def test_create_open_pose_vector(qtapp, comfy):
image_name = f"test_create_open_pose_vector.svg"
image = Image.load(image_dir / "adobe_stock.jpg")
job = workflow.create_control_image(image, ControlMode.pose)
job = workflow.create_control_image(comfy, image, ControlMode.pose)

async def main():
job_id = None
Expand Down

0 comments on commit 8774699

Please sign in to comment.