Skip to content

Commit

Permalink
Add workflow schema with zod (#22)
Browse files Browse the repository at this point in the history
* Safe parse with zod

* Fix zod issue

* Fix all validation errors

* nit

* Add tests

* Add color fields

* Passthrough
  • Loading branch information
huchenlei authored Jun 17, 2024
1 parent cc7ee23 commit b11a12d
Show file tree
Hide file tree
Showing 9 changed files with 1,383 additions and 14 deletions.
23 changes: 23 additions & 0 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 5 additions & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"dev": "vite",
"build": "npm run typecheck && vite build",
"typecheck": "tsc --noEmit",
"test": "jest",
"test": "npm run build && jest",
"test:generate": "npx tsx tests-ui/setup",
"preview": "vite preview"
},
Expand All @@ -24,5 +24,9 @@
"typescript": "^5.4.5",
"vite": "^5.2.0",
"vite-plugin-static-copy": "^1.0.5"
},
"dependencies": {
"zod": "^3.23.8",
"zod-validation-error": "^3.3.0"
}
}
30 changes: 18 additions & 12 deletions src/scripts/app.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import { DraggableList } from "./ui/draggableList";
import { applyTextReplacements, addStylesheet } from "./utils";
import type { ComfyExtension } from "/types/comfy";
import type { LGraph, LGraphCanvas, LGraphNode } from "/types/litegraph";
import { type ComfyWorkflow, parseComfyWorkflow } from "../types/comfyWorkflow";

export const ANIM_PREVIEW_WIDGET = "$$comfy_animation_preview"

Expand Down Expand Up @@ -973,16 +974,18 @@ export class ComfyApp {

// No image found. Look for node data
data = data.getData("text/plain");
let workflow;
let workflow: ComfyWorkflow;
try {
data = data.slice(data.indexOf("{"));
workflow = JSON.parse(data);
workflow = await parseComfyWorkflow(data);
} catch (err) {
try {
data = data.slice(data.indexOf("workflow\n"));
data = data.slice(data.indexOf("{"));
workflow = JSON.parse(data);
} catch (error) {}
workflow = await parseComfyWorkflow(data);
} catch (error) {
console.error(error);
}
}

if (workflow && workflow.version && workflow.nodes && workflow.extra) {
Expand Down Expand Up @@ -1652,8 +1655,7 @@ export class ComfyApp {
try {
const loadWorkflow = async (json) => {
if (json) {
const workflow = JSON.parse(json);
await this.loadGraphData(workflow);
await this.loadGraphData(await parseComfyWorkflow(json));
return true;
}
};
Expand Down Expand Up @@ -1895,7 +1897,7 @@ export class ComfyApp {
* @param {*} graphData A serialized graph object
* @param { boolean } clean If the graph state, e.g. images, should be cleared
*/
async loadGraphData(graphData?, clean: boolean = true, restore_view: boolean = true) {
async loadGraphData(graphData?: ComfyWorkflow, clean: boolean = true, restore_view: boolean = true) {
if (clean !== false) {
this.clean();
}
Expand Down Expand Up @@ -1932,6 +1934,9 @@ export class ComfyApp {
try {
this.graph.configure(graphData);
if (restore_view && this.enableWorkflowViewRestore.value && graphData.extra?.ds) {
// @ts-ignore
// Need to set strict: true for zod to match the type [number, number]
// https://github.com/colinhacks/zod/issues/3056
this.canvas.ds.offset = graphData.extra.ds.offset;
this.canvas.ds.scale = graphData.extra.ds.scale;
}
Expand Down Expand Up @@ -2273,7 +2278,7 @@ export class ComfyApp {
if (file.type === "image/png") {
const pngInfo = await getPngMetadata(file);
if (pngInfo?.workflow) {
await this.loadGraphData(JSON.parse(pngInfo.workflow));
await this.loadGraphData(await parseComfyWorkflow(pngInfo.workflow));
} else if (pngInfo?.prompt) {
this.loadApiJson(JSON.parse(pngInfo.prompt));
} else if (pngInfo?.parameters) {
Expand All @@ -2288,7 +2293,7 @@ export class ComfyApp {
const prompt = pngInfo?.prompt || pngInfo?.Prompt;

if (workflow) {
this.loadGraphData(JSON.parse(workflow));
this.loadGraphData(await parseComfyWorkflow(workflow));
} else if (prompt) {
this.loadApiJson(JSON.parse(prompt));
} else {
Expand All @@ -2297,13 +2302,14 @@ export class ComfyApp {
} else if (file.type === "application/json" || file.name?.endsWith(".json")) {
const reader = new FileReader();
reader.onload = async () => {
const jsonContent = JSON.parse(reader.result as string);
const readerResult = reader.result as string;
const jsonContent = JSON.parse(readerResult);
if (jsonContent?.templates) {
this.loadTemplateData(jsonContent);
} else if(this.isApiJson(jsonContent)) {
this.loadApiJson(jsonContent);
} else {
await this.loadGraphData(jsonContent);
await this.loadGraphData(await parseComfyWorkflow(readerResult));
}
};
reader.readAsText(file);
Expand All @@ -2313,7 +2319,7 @@ export class ComfyApp {
// @ts-ignore
if (info.workflow) {
// @ts-ignore
await this.loadGraphData(JSON.parse(info.workflow));
await this.loadGraphData(await parseComfyWorkflow(info.workflow));
// @ts-ignore
} else if (info.prompt) {
// @ts-ignore
Expand Down
4 changes: 3 additions & 1 deletion src/scripts/defaultGraph.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
export const defaultGraph = {
import type { ComfyWorkflow } from "/types/comfyWorkflow";

export const defaultGraph: ComfyWorkflow = {
last_node_id: 9,
last_link_id: 9,
nodes: [
Expand Down
118 changes: 118 additions & 0 deletions src/types/comfyWorkflow.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import { z } from 'zod';
import { fromZodError } from 'zod-validation-error';

const zComfyLink = z.tuple([
z.number(), // Link id
z.number(), // Node id of source node
z.number(), // Output slot# of source node
z.number(), // Node id of destination node
z.number(), // Input slot# of destination node
z.string(), // Data type
]);

const zNodeOutput = z.object({
name: z.string(),
type: z.string(),
links: z.array(z.number()).nullable(),
slot_index: z.number().optional(),
}).passthrough();

const zNodeInput = z.object({
name: z.string(),
type: z.string(),
link: z.number().nullable(),
slot_index: z.number().optional(),
}).passthrough();

const zFlags = z.object({
collapsed: z.boolean().optional(),
pinned: z.boolean().optional(),
allow_interaction: z.boolean().optional(),
horizontal: z.boolean().optional(),
skip_repeated_outputs: z.boolean().optional(),
}).passthrough();

const zProperties = z.object({
["Node name for S&R"]: z.string().optional(),
}).passthrough();

const zVector2 = z.union([
z.object({ 0: z.number(), 1: z.number() }),
z.tuple([z.number(), z.number()]),
]);

const zComfyNode = z.object({
id: z.number(),
type: z.string(),
pos: z.tuple([z.number(), z.number()]),
size: zVector2,
flags: zFlags,
order: z.number(),
mode: z.number(),
inputs: z.array(zNodeInput).optional(),
outputs: z.array(zNodeOutput).optional(),
properties: zProperties,
widgets_values: z.array(z.any()).optional(), // This could contain mixed types
color: z.string().optional(),
bgcolor: z.string().optional(),
}).passthrough();

const zGroup = z.object({
title: z.string(),
bounding: z.tuple([z.number(), z.number(), z.number(), z.number()]),
color: z.string(),
font_size: z.number(),
locked: z.boolean(),
}).passthrough();

const zInfo = z.object({
name: z.string(),
author: z.string(),
description: z.string(),
version: z.string(),
created: z.string(),
modified: z.string(),
software: z.string(),
}).passthrough();

const zDS = z.object({
scale: z.number(),
offset: zVector2,
}).passthrough();

const zConfig = z.object({
links_ontop: z.boolean().optional(),
align_to_grid: z.boolean().optional(),
}).passthrough();

const zExtra = z.object({
ds: zDS.optional(),
info: zInfo.optional(),
}).passthrough();

const zComfyWorkflow = z.object({
last_node_id: z.number(),
last_link_id: z.number(),
nodes: z.array(zComfyNode),
links: z.array(zComfyLink),
groups: z.array(zGroup).optional(),
config: zConfig.optional().nullable(),
extra: zExtra.optional().nullable(),
version: z.number(),
}).passthrough();

export type NodeInput = z.infer<typeof zNodeInput>;
export type NodeOutput = z.infer<typeof zNodeOutput>;
export type ComfyLink = z.infer<typeof zComfyLink>;
export type ComfyNode = z.infer<typeof zComfyNode>;
export type ComfyWorkflow = z.infer<typeof zComfyWorkflow>;


export async function parseComfyWorkflow(data: string): Promise<ComfyWorkflow> {
// Validate
const result = await zComfyWorkflow.safeParseAsync(JSON.parse(data));
if (!result.success) {
throw fromZodError(result.error);
}
return result.data;
}
55 changes: 55 additions & 0 deletions tests-ui/tests/comfyWorkflow.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import { parseComfyWorkflow } from "../../src/types/comfyWorkflow";
import { defaultGraph } from "../../src/scripts/defaultGraph";
import fs from "fs";

const WORKFLOW_DIR = "tests-ui/workflows";

describe("parseComfyWorkflow", () => {
it("parses valid workflow", async () => {
fs.readdirSync(WORKFLOW_DIR).forEach(async (file) => {
if (file.endsWith(".json")) {
const data = fs.readFileSync(`${WORKFLOW_DIR}/${file}`, "utf-8");
await expect(parseComfyWorkflow(data)).resolves.not.toThrow();
}
});
});

it("workflow.nodes", async () => {
const workflow = JSON.parse(JSON.stringify(defaultGraph));
workflow.nodes = undefined;
await expect(parseComfyWorkflow(JSON.stringify(workflow))).rejects.toThrow();

workflow.nodes = null;
await expect(parseComfyWorkflow(JSON.stringify(workflow))).rejects.toThrow();

workflow.nodes = [];
await expect(parseComfyWorkflow(JSON.stringify(workflow))).resolves.not.toThrow();
});

it("workflow.version", async () => {
const workflow = JSON.parse(JSON.stringify(defaultGraph));
workflow.version = undefined;
await expect(parseComfyWorkflow(JSON.stringify(workflow))).rejects.toThrow();

workflow.version = "1.0.1"; // Invalid format.
await expect(parseComfyWorkflow(JSON.stringify(workflow))).rejects.toThrow();

workflow.version = 1;
await expect(parseComfyWorkflow(JSON.stringify(workflow))).resolves.not.toThrow();
});

it("workflow.extra", async () => {
const workflow = JSON.parse(JSON.stringify(defaultGraph));
workflow.extra = undefined;
await expect(parseComfyWorkflow(JSON.stringify(workflow))).resolves.not.toThrow();

workflow.extra = null;
await expect(parseComfyWorkflow(JSON.stringify(workflow))).resolves.not.toThrow();

workflow.extra = {};
await expect(parseComfyWorkflow(JSON.stringify(workflow))).resolves.not.toThrow();

workflow.extra = { foo: "bar" }; // Should accept extra fields.
await expect(parseComfyWorkflow(JSON.stringify(workflow))).resolves.not.toThrow();
});
});
Loading

0 comments on commit b11a12d

Please sign in to comment.