From 6b8263469c1726eff3d221c649811accaaa9a279 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 20 Sep 2024 18:09:03 +1000 Subject: [PATCH 1/5] feat(ui): revised `getImageDTO` utils - Rename util to `getImageDTOSafe` - Update API to accept the same options as RTKQ's `initiate` - Add `getImageDTO`; while `getImageDTOSafe` returns null if the image is not found, the new util throws - Update usage of `getImageDTOSafe` --- .../konva/CanvasCompositorModule.ts | 6 ++-- .../CanvasEntityObjectRenderer.ts | 4 +-- .../konva/CanvasObject/CanvasObjectImage.ts | 4 +-- .../src/features/controlLayers/store/types.ts | 4 +-- .../web/src/features/metadata/util/parsers.ts | 4 +-- .../web/src/services/api/endpoints/images.ts | 36 +++++++++++++------ .../services/events/onInvocationComplete.tsx | 8 ++--- 7 files changed, 40 insertions(+), 26 deletions(-) diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasCompositorModule.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasCompositorModule.ts index cc457a4f15a..1ee23fdec50 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasCompositorModule.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasCompositorModule.ts @@ -14,7 +14,7 @@ import { selectAutoAddBoardId } from 'features/gallery/store/gallerySelectors'; import { atom, computed } from 'nanostores'; import type { Logger } from 'roarr'; import type { UploadOptions } from 'services/api/endpoints/images'; -import { getImageDTO, uploadImage } from 'services/api/endpoints/images'; +import { getImageDTOSafe, uploadImage } from 'services/api/endpoints/images'; import type { ImageDTO } from 'services/api/types'; import stableHash from 'stable-hash'; import { assert } from 'tsafe'; @@ -210,7 +210,7 @@ export class CanvasCompositorModule extends CanvasModuleBase { const cachedImageName = this.manager.cache.imageNameCache.get(hash); if (cachedImageName) { - imageDTO = await getImageDTO(cachedImageName); + imageDTO = await getImageDTOSafe(cachedImageName); if (imageDTO) { this.log.trace({ rect, imageName: cachedImageName, imageDTO }, 'Using cached composite raster layer image'); return imageDTO; @@ -374,7 +374,7 @@ export class CanvasCompositorModule extends CanvasModuleBase { const cachedImageName = this.manager.cache.imageNameCache.get(hash); if (cachedImageName) { - imageDTO = await getImageDTO(cachedImageName); + imageDTO = await getImageDTOSafe(cachedImageName); if (imageDTO) { this.log.trace({ rect, cachedImageName, imageDTO }, 'Using cached composite inpaint mask image'); return imageDTO; diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityObjectRenderer.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityObjectRenderer.ts index af1f5c35aaa..56964e5364f 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityObjectRenderer.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityObjectRenderer.ts @@ -27,7 +27,7 @@ import { debounce } from 'lodash-es'; import { atom } from 'nanostores'; import type { Logger } from 'roarr'; import { serializeError } from 'serialize-error'; -import { getImageDTO, uploadImage } from 'services/api/endpoints/images'; +import { getImageDTOSafe, uploadImage } from 'services/api/endpoints/images'; import type { ImageDTO } from 'services/api/types'; import { assert } from 'tsafe'; @@ -383,7 +383,7 @@ export class CanvasEntityObjectRenderer extends CanvasModuleBase { const cachedImageName = this.manager.cache.imageNameCache.get(hash); if (cachedImageName) { - imageDTO = await getImageDTO(cachedImageName); + imageDTO = await getImageDTOSafe(cachedImageName); if (imageDTO) { this.log.trace({ rect, cachedImageName, imageDTO }, 'Using cached rasterized image'); return imageDTO; diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasObject/CanvasObjectImage.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasObject/CanvasObjectImage.ts index f9b981bca41..eeacad2f49a 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasObject/CanvasObjectImage.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasObject/CanvasObjectImage.ts @@ -11,7 +11,7 @@ import type { CanvasImageState } from 'features/controlLayers/store/types'; import { t } from 'i18next'; import Konva from 'konva'; import type { Logger } from 'roarr'; -import { getImageDTO } from 'services/api/endpoints/images'; +import { getImageDTOSafe } from 'services/api/endpoints/images'; export class CanvasObjectImage extends CanvasModuleBase { readonly type = 'object_image'; @@ -100,7 +100,7 @@ export class CanvasObjectImage extends CanvasModuleBase { this.konva.placeholder.text.text(t('common.loadingImage', 'Loading Image')); } - const imageDTO = await getImageDTO(imageName); + const imageDTO = await getImageDTOSafe(imageName); if (imageDTO === null) { this.onFailedToLoadImage(); return; diff --git a/invokeai/frontend/web/src/features/controlLayers/store/types.ts b/invokeai/frontend/web/src/features/controlLayers/store/types.ts index 8befe563bb0..007e3567645 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/types.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/types.ts @@ -7,7 +7,7 @@ import { zParameterNegativePrompt, zParameterPositivePrompt, } from 'features/parameters/types/parameterSchemas'; -import { getImageDTO } from 'services/api/endpoints/images'; +import { getImageDTOSafe } from 'services/api/endpoints/images'; import type { ImageDTO } from 'services/api/types'; import { z } from 'zod'; @@ -31,7 +31,7 @@ const zImageWithDims = z }) .refine(async (v) => { const { image_name } = v; - const imageDTO = await getImageDTO(image_name, true); + const imageDTO = await getImageDTOSafe(image_name, { forceRefetch: true }); return imageDTO !== null; }); export type ImageWithDims = z.infer; diff --git a/invokeai/frontend/web/src/features/metadata/util/parsers.ts b/invokeai/frontend/web/src/features/metadata/util/parsers.ts index f47ec375199..49b50225cc8 100644 --- a/invokeai/frontend/web/src/features/metadata/util/parsers.ts +++ b/invokeai/frontend/web/src/features/metadata/util/parsers.ts @@ -67,7 +67,7 @@ import { isParameterWidth, } from 'features/parameters/types/parameterSchemas'; import { get, isArray, isString } from 'lodash-es'; -import { getImageDTO } from 'services/api/endpoints/images'; +import { getImageDTOSafe } from 'services/api/endpoints/images'; import { isControlNetModelConfig, isIPAdapterModelConfig, @@ -603,7 +603,7 @@ const parseIPAdapterToIPAdapterLayer: MetadataParseFunc => { - const options = { +export const getImageDTOSafe = async ( + image_name: string, + options?: StartQueryActionCreatorOptions +): Promise => { + const _options = { subscribe: false, - forceRefetch, + ...options, }; - const req = getStore().dispatch(imagesApi.endpoints.getImageDTO.initiate(image_name, options)); + const req = getStore().dispatch(imagesApi.endpoints.getImageDTO.initiate(image_name, _options)); try { - const imageDTO = await req.unwrap(); - req.unsubscribe(); - return imageDTO; + return await req.unwrap(); } catch { - req.unsubscribe(); return null; } }; +/** + * Imperative RTKQ helper to fetch an ImageDTO. + * @param image_name The name of the image to fetch + * @param options The options for the query. By default, the query will not subscribe to the store. + * @raises Error if the image is not found or there is an error fetching the image + */ +export const getImageDTO = (image_name: string, options?: StartQueryActionCreatorOptions): Promise => { + const _options = { + subscribe: false, + ...options, + }; + const req = getStore().dispatch(imagesApi.endpoints.getImageDTO.initiate(image_name, _options)); + return req.unwrap(); +}; + export type UploadOptions = { blob: Blob; fileName: string; diff --git a/invokeai/frontend/web/src/services/events/onInvocationComplete.tsx b/invokeai/frontend/web/src/services/events/onInvocationComplete.tsx index 7c7e511e48d..177d4599235 100644 --- a/invokeai/frontend/web/src/services/events/onInvocationComplete.tsx +++ b/invokeai/frontend/web/src/services/events/onInvocationComplete.tsx @@ -7,7 +7,7 @@ import { boardIdSelected, galleryViewChanged, imageSelected, offsetChanged } fro import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState'; import { zNodeStatus } from 'features/nodes/types/invocation'; import { boardsApi } from 'services/api/endpoints/boards'; -import { getImageDTO, imagesApi } from 'services/api/endpoints/images'; +import { getImageDTOSafe, imagesApi } from 'services/api/endpoints/images'; import type { ImageDTO, S } from 'services/api/types'; import { getCategories, getListImagesUrl } from 'services/api/util'; import { $lastProgressEvent } from 'services/events/stores'; @@ -87,10 +87,8 @@ export const buildOnInvocationComplete = (getState: () => RootState, dispatch: A const getResultImageDTO = (data: S['InvocationCompleteEvent']) => { const { result } = data; - if (result.type === 'image_output') { - return getImageDTO(result.image.image_name); - } else if (result.type === 'canvas_v2_mask_and_crop_output') { - return getImageDTO(result.image.image_name); + if (result.type === 'image_output' || result.type === 'canvas_v2_mask_and_crop_output') { + return getImageDTOSafe(result.image.image_name); } return null; }; From 3f96a8202bcfc009bd0e7e7fe87f01269dd3fab3 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 20 Sep 2024 18:12:28 +1000 Subject: [PATCH 2/5] feat(ui): add QueueError for queue item errors --- .../web/src/services/events/errors.ts | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) create mode 100644 invokeai/frontend/web/src/services/events/errors.ts diff --git a/invokeai/frontend/web/src/services/events/errors.ts b/invokeai/frontend/web/src/services/events/errors.ts new file mode 100644 index 00000000000..24100939e90 --- /dev/null +++ b/invokeai/frontend/web/src/services/events/errors.ts @@ -0,0 +1,23 @@ +/** + * A custom error class for queue event errors. These errors have a type, message and traceback. + */ + +export class QueueError extends Error { + type: string; + traceback: string; + + constructor(type: string, message: string, traceback: string) { + super(message); + this.name = 'QueueError'; + this.type = type; + this.traceback = traceback; + + if (Error.captureStackTrace) { + Error.captureStackTrace(this, QueueError); + } + } + + toString() { + return `${this.name} [${this.type}]: ${this.message}\nTraceback:\n${this.traceback}`; + } +} From 5978667397a7bed0de2a1f58d5fe6c7b2f745171 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 20 Sep 2024 18:50:58 +1000 Subject: [PATCH 3/5] feat(ui): expose id property on Graph class --- .../src/features/nodes/util/graph/generation/Graph.test.ts | 6 +++--- .../web/src/features/nodes/util/graph/generation/Graph.ts | 4 +++- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/Graph.test.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/Graph.test.ts index 6fd1431c8e4..8afd69daf5b 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/Graph.test.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/Graph.test.ts @@ -2,7 +2,6 @@ import { deepClone } from 'common/util/deepClone'; import { Graph } from 'features/nodes/util/graph/generation/Graph'; import type { AnyInvocation, Invocation } from 'services/api/types'; import { assert, AssertionError, is } from 'tsafe'; -import { validate } from 'uuid'; import { describe, expect, it } from 'vitest'; import { z } from 'zod'; @@ -11,11 +10,12 @@ describe('Graph', () => { it('should create a new graph with the correct id', () => { const g = new Graph('test-id'); expect(g._graph.id).toBe('test-id'); + expect(g.id).toBe('test-id'); }); - it('should create a new graph with a uuid id if none is provided', () => { + it('should create an id if none is provided', () => { const g = new Graph(); expect(g._graph.id).not.toBeUndefined(); - expect(validate(g._graph.id)).toBeTruthy(); + expect(g.id).not.toBeUndefined(); }); }); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/Graph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/Graph.ts index 950b25d71f6..9da719509b8 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/Graph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/Graph.ts @@ -32,10 +32,12 @@ export type GraphType = { id: string; nodes: Record; edge export class Graph { _graph: GraphType; _metadataNodeId = getPrefixedId('core_metadata'); + id: string; constructor(id?: string) { + this.id = id ?? Graph.getId('graph'); this._graph = { - id: id ?? uuidv4(), + id: this.id, nodes: {}, edges: [], }; From 794bfc63f65bd68edf67769c24e228ae64a461fc Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 20 Sep 2024 19:17:42 +1000 Subject: [PATCH 4/5] feat(ui): reworked filter execution & error handling Two main changes: - Add `runGraphAndReturnImageOutput` to `CanvasStateApiModule`. This method is a safe and convenient abstraction to execute a graph and retrieve the image output of one of its nodes. It supports cancellation (via an AbortSignal) and timeout. - Update filters to build whole graphs, as opposed to nodes. These changes allow: - Filter execution is resilient, with all error cases handled (afaik) - `CanvasEntityFilterer` class is much simpler - Stuck or long-running filters may be canceled - Filters may be arbitrarily complex - so long as there is one node that outputs an image, the filter will just work --- .../components/Filters/Filter.tsx | 1 - .../CanvasEntity/CanvasEntityFilterer.ts | 144 ++++----- .../konva/CanvasStateApiModule.ts | 200 +++++++++++- .../features/controlLayers/store/filters.ts | 289 ++++++++++++------ 4 files changed, 456 insertions(+), 178 deletions(-) diff --git a/invokeai/frontend/web/src/features/controlLayers/components/Filters/Filter.tsx b/invokeai/frontend/web/src/features/controlLayers/components/Filters/Filter.tsx index b05adbc097c..5bc449fac6f 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/Filters/Filter.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/Filters/Filter.tsx @@ -106,7 +106,6 @@ const FilterBox = memo(({ adapter }: { adapter: CanvasEntityAdapterRasterLayer | variant="ghost" leftIcon={} onClick={adapter.filterer.cancel} - isLoading={isProcessing} loadingText={t('controlLayers.filter.cancel')} > {t('controlLayers.filter.cancel')} diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityFilterer.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityFilterer.ts index 5b66a7c332e..3fd501a4682 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityFilterer.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityFilterer.ts @@ -1,5 +1,4 @@ -import type { SerializableObject } from 'common/types'; -import { withResultAsync } from 'common/util/result'; +import { withResult, withResultAsync } from 'common/util/result'; import type { CanvasEntityAdapterControlLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterControlLayer'; import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRasterLayer'; import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager'; @@ -13,9 +12,9 @@ import { imageDTOToImageObject } from 'features/controlLayers/store/util'; import { debounce } from 'lodash-es'; import { atom } from 'nanostores'; import type { Logger } from 'roarr'; -import { getImageDTO } from 'services/api/endpoints/images'; +import { serializeError } from 'serialize-error'; import { buildSelectModelConfig } from 'services/api/hooks/modelsByType'; -import { type BatchConfig, type ImageDTO, isControlNetOrT2IAdapterModelConfig, type S } from 'services/api/types'; +import { isControlNetOrT2IAdapterModelConfig } from 'services/api/types'; import { assert } from 'tsafe'; type CanvasEntityFiltererConfig = { @@ -38,6 +37,11 @@ export class CanvasEntityFilterer extends CanvasModuleBase { subscriptions = new Set<() => void>(); config: CanvasEntityFiltererConfig = DEFAULT_CONFIG; + /** + * The AbortController used to cancel the filter processing. + */ + abortController: AbortController | null = null; + $isFiltering = atom(false); $hasProcessed = atom(false); $isProcessing = atom(false); @@ -100,63 +104,82 @@ export class CanvasEntityFilterer extends CanvasModuleBase { processImmediate = async () => { const config = this.$filterConfig.get(); - const isValid = IMAGE_FILTERS[config.type].validateConfig?.(config as never) ?? true; + const filterData = IMAGE_FILTERS[config.type]; + + // Cannot get TS to be happy with `config`, thinks it should be `never`... eh... + const isValid = filterData.validateConfig?.(config as never) ?? true; if (!isValid) { + this.log.error({ config }, 'Invalid filter config'); return; } - this.log.trace({ config }, 'Previewing filter'); + this.log.trace({ config }, 'Processing filter'); const rect = this.parent.transformer.getRelativeRect(); - const imageDTO = await this.parent.renderer.rasterize({ rect, attrs: { filters: [], opacity: 1 } }); - const nodeId = getPrefixedId('filter_node'); - const batch = this.buildBatchConfig(imageDTO, config, nodeId); - - // Listen for the filter processing completion event - const completedListener = async (event: S['InvocationCompleteEvent']) => { - if (event.origin !== this.id || event.invocation_source_id !== nodeId) { - return; - } - this.manager.socket.off('invocation_complete', completedListener); - this.manager.socket.off('invocation_error', errorListener); - - this.log.trace({ event } as SerializableObject, 'Handling filter processing completion'); - const { result } = event; - assert(result.type === 'image_output', `Processor did not return an image output, got: ${result}`); - - const imageDTO = await getImageDTO(result.image.image_name); - assert(imageDTO, "Failed to fetch processor output's image DTO"); + const rasterizeResult = await withResultAsync(() => + this.parent.renderer.rasterize({ rect, attrs: { filters: [], opacity: 1 } }) + ); + if (rasterizeResult.isErr()) { + this.log.error({ error: serializeError(rasterizeResult.error) }, 'Error rasterizing entity'); + this.$isProcessing.set(false); + return; + } - this.imageState = imageDTOToImageObject(imageDTO); + this.$isProcessing.set(true); - await this.parent.bufferRenderer.setBuffer(this.imageState, true); + const imageDTO = rasterizeResult.value; + // Cannot get TS to be happy with `config`, thinks it should be `never`... eh... + const buildGraphResult = withResult(() => filterData.buildGraph(imageDTO, config as never)); + if (buildGraphResult.isErr()) { + this.log.error({ error: serializeError(buildGraphResult.error) }, 'Error building filter graph'); this.$isProcessing.set(false); - this.$hasProcessed.set(true); - }; - const errorListener = (event: S['InvocationErrorEvent']) => { - if (event.origin !== this.id || event.invocation_source_id !== nodeId) { - return; - } - this.manager.socket.off('invocation_complete', completedListener); - this.manager.socket.off('invocation_error', errorListener); - - this.log.error({ event } as SerializableObject, 'Error processing filter'); + return; + } + + const controller = new AbortController(); + this.abortController = controller; + + const { graph, outputNodeId } = buildGraphResult.value; + const filterResult = await withResultAsync(() => + this.manager.stateApi.runGraphAndReturnImageOutput({ + graph, + outputNodeId, + // The filter graph should always be prepended to the queue so it's processed ASAP. + prepend: true, + /** + * The filter node may need to download a large model. Currently, the models required by the filter nodes are + * downloaded just-in-time, as required by the filter. If we use a timeout here, we might get into a catch-22 + * where the filter node is waiting for the model to download, but the download gets canceled if the filter + * node times out. + * + * (I suspect the model download will actually _not_ be canceled if the graph is canceled, but let's not chance it!) + * + * TODO(psyche): Figure out a better way to handle this. Probably need to download the models ahead of time. + */ + // timeout: 5000, + /** + * The filter node should be able to cancel the request if it's taking too long. This will cancel the graph's + * queue item and clear any event listeners on the request. + */ + signal: controller.signal, + }) + ); + if (filterResult.isErr()) { + this.log.error({ error: serializeError(filterResult.error) }, 'Error processing filter'); this.$isProcessing.set(false); - }; + this.abortController = null; + return; + } - this.manager.socket.on('invocation_complete', completedListener); - this.manager.socket.on('invocation_error', errorListener); + this.log.trace({ imageDTO: filterResult.value }, 'Filter processed'); + this.imageState = imageDTOToImageObject(filterResult.value); - this.log.trace({ batch } as SerializableObject, 'Enqueuing filter batch'); + await this.parent.bufferRenderer.setBuffer(this.imageState, true); - this.$isProcessing.set(true); - const req = this.manager.stateApi.enqueueBatch(batch); - const result = await withResultAsync(req.unwrap); - if (result.isErr()) { - this.$isProcessing.set(false); - } - req.reset(); + this.$isProcessing.set(false); + this.$hasProcessed.set(true); + this.abortController = null; }; process = debounce(this.processImmediate, this.config.processDebounceMs); @@ -188,6 +211,8 @@ export class CanvasEntityFilterer extends CanvasModuleBase { reset = () => { this.log.trace('Resetting filter'); + this.abortController?.abort(); + this.abortController = null; this.parent.bufferRenderer.clearBuffer(); this.parent.transformer.updatePosition(); this.parent.renderer.syncCache(true); @@ -205,31 +230,6 @@ export class CanvasEntityFilterer extends CanvasModuleBase { this.manager.stateApi.$filteringAdapter.set(null); }; - buildBatchConfig = (imageDTO: ImageDTO, config: FilterConfig, id: string): BatchConfig => { - // TODO(psyche): I can't get TS to be happy, it thinkgs `config` is `never` but it should be inferred from the generic... I'll just cast it for now - const node = IMAGE_FILTERS[config.type].buildNode(imageDTO, config as never); - node.id = id; - const batch: BatchConfig = { - prepend: true, - batch: { - graph: { - nodes: { - [node.id]: { - ...node, - // filtered images are always intermediate - do not save to gallery - is_intermediate: true, - }, - }, - edges: [], - }, - origin: this.id, - runs: 1, - }, - }; - - return batch; - }; - repr = () => { return { id: this.id, diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasStateApiModule.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasStateApiModule.ts index 5e7cbbad0f4..7bfbbe9dd46 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasStateApiModule.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasStateApiModule.ts @@ -2,6 +2,7 @@ import { $alt, $ctrl, $meta, $shift } from '@invoke-ai/ui-library'; import type { Selector } from '@reduxjs/toolkit'; import { addAppListener } from 'app/store/middleware/listenerMiddleware'; import type { AppStore, RootState } from 'app/store/store'; +import { withResultAsync } from 'common/util/result'; import type { CanvasEntityAdapterControlLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterControlLayer'; import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRasterLayer'; import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager'; @@ -38,10 +39,13 @@ import type { RgbaColor, } from 'features/controlLayers/store/types'; import { RGBA_BLACK } from 'features/controlLayers/store/types'; +import type { Graph } from 'features/nodes/util/graph/generation/Graph'; import { atom, computed } from 'nanostores'; import type { Logger } from 'roarr'; +import { getImageDTO } from 'services/api/endpoints/images'; import { queueApi } from 'services/api/endpoints/queue'; -import type { BatchConfig } from 'services/api/types'; +import type { BatchConfig, ImageDTO, S } from 'services/api/types'; +import { QueueError } from 'services/events/errors'; import { assert } from 'tsafe'; import type { CanvasEntityAdapter } from './CanvasEntity/types'; @@ -187,14 +191,200 @@ export class CanvasStateApiModule extends CanvasModuleBase { }; /** - * Enqueues a batch, pushing state to redux. - */ - enqueueBatch = (batch: BatchConfig) => { - return this.store.dispatch( + * Run a graph and return an image output. The specified output node must return an image output, else the promise + * will reject with an error. + * + * @param arg The arguments for the function. + * @param arg.graph The graph to execute. + * @param arg.outputNodeId The id of the node whose output will be retrieved. + * @param arg.destination The destination to assign to the batch. If omitted, the destination is not set. + * @param arg.prepend Whether to prepend the graph to the front of the queue. If omitted, the graph is appended to the end of the queue. + * @param arg.timeout The timeout for the batch. If omitted, there is no timeout. + * @param arg.signal An optional signal to cancel the operation. If omitted, the operation cannot be canceled! + * + * @returns A promise that resolves to the image output or rejects with an error. + * + * @example + * + * ```ts + * const graph = new Graph(); + * const outputNode = graph.addNode({ id: 'my-resize-node', type: 'img_resize', image: { image_name: 'my-image.png' } }); + * const controller = new AbortController(); + * const imageDTO = await this.manager.stateApi.runGraphAndReturnImageOutput({ + * graph, + * outputNodeId: outputNode.id, + * prepend: true, + * signal: controller.signal, + * }); + * // To cancel the operation: + * controller.abort(); + * ``` + */ + runGraphAndReturnImageOutput = async (arg: { + graph: Graph; + outputNodeId: string; + destination?: string; + prepend?: boolean; + timeout?: number; + signal?: AbortSignal; + }): Promise => { + const { graph, outputNodeId, destination, prepend, timeout, signal } = arg; + + /** + * We will use the origin to handle events from the graph. Ideally we'd just use the queue item's id, but there's a + * race condition: + * - The queue item id is not available until the graph is enqueued + * - The graph may complete before we can set up the listeners to handle the completion event + * + * The origin is the only unique identifier we have that is guaranteed to be available before the graph is enqueued, + * so we will use that to filter events. + */ + const origin = getPrefixedId(graph.id); + + const batch: BatchConfig = { + prepend, + batch: { + graph: graph.getGraph(), + origin, + destination, + runs: 1, + }, + }; + + /** + * If a timeout is provided, we will cancel the graph if it takes too long - but we need a way to clear the timeout + * if the graph completes or errors before the timeout. + */ + let timeoutId: number | null = null; + const _clearTimeout = () => { + if (timeoutId !== null) { + window.clearTimeout(timeoutId); + timeoutId = null; + } + }; + + /** + * First, enqueue the graph - we need the `batch_id` to cancel the graph. But to get the `batch_id`, we need to + * `await` the request. You might be tempted to `await` the request inside the result promise, but we should not + * `await` inside a promise executor. + * + * See: https://eslint.org/docs/latest/rules/no-async-promise-executor + */ + const enqueueRequest = this.store.dispatch( queueApi.endpoints.enqueueBatch.initiate(batch, { + // Use the same cache key for all enqueueBatch requests, so that all consumers of this query get the same status + // updates. fixedCacheKey: 'enqueueBatch', + // We do not need RTK to track this request in the store + track: false, }) ); + + // The `batch_id` should _always_ be present - the OpenAPI schema from which the types are generated is incorrect. + // TODO(psyche): Fix the OpenAPI schema. + const { batch_id } = (await enqueueRequest.unwrap()).batch; + assert(batch_id, 'Enqueue result is missing batch_id'); + + const resultPromise = new Promise((resolve, reject) => { + const invocationCompleteHandler = async (event: S['InvocationCompleteEvent']) => { + // Ignore events that are not for this graph + if (event.origin !== origin) { + return; + } + // Ignore events that are not from the output node + if (event.invocation_source_id !== outputNodeId) { + return; + } + + // If we get here, the event is for the correct graph and output node. + + // Clear the timeout and socket listeners + _clearTimeout(); + clearListeners(); + + // The result must be an image output + const { result } = event; + if (result.type !== 'image_output') { + reject(new Error(`Graph output node did not return an image output, got: ${result}`)); + return; + } + + // Get the result image DTO + const getImageDTOResult = await withResultAsync(() => getImageDTO(result.image.image_name)); + if (getImageDTOResult.isErr()) { + reject(getImageDTOResult.error); + return; + } + + // Ok! + resolve(getImageDTOResult.value); + }; + + const queueItemStatusChangedHandler = (event: S['QueueItemStatusChangedEvent']) => { + // Ignore events that are not for this graph + if (event.origin !== origin) { + return; + } + + // Ignore events where the status is pending or in progress - no need to do anything for these + if (event.status === 'pending' || event.status === 'in_progress') { + return; + } + + // event.status is 'failed', 'canceled' or 'completed' - something has gone awry + _clearTimeout(); + clearListeners(); + + if (event.status === 'completed') { + // If we get a queue item completed event, that means we never got a completion event for the output node! + reject(new Error('Queue item completed without output node completion event')); + } else if (event.status === 'failed') { + // We expect the event to have error details, but technically it's possible that it doesn't + const { error_type, error_message, error_traceback } = event; + if (error_type && error_message && error_traceback) { + reject(new QueueError(error_type, error_message, error_traceback)); + } else { + reject(new Error('Queue item failed, but no error details were provided')); + } + } else { + // event.status is 'canceled' + reject(new Error('Graph canceled')); + } + }; + + this.manager.socket.on('invocation_complete', invocationCompleteHandler); + this.manager.socket.on('queue_item_status_changed', queueItemStatusChangedHandler); + + const clearListeners = () => { + this.manager.socket.off('invocation_complete', invocationCompleteHandler); + this.manager.socket.off('queue_item_status_changed', queueItemStatusChangedHandler); + }; + + const cancelGraph = () => { + this.store.dispatch(queueApi.endpoints.cancelByBatchIds.initiate({ batch_ids: [batch_id] }, { track: false })); + }; + + if (timeout) { + timeoutId = window.setTimeout(() => { + this.log.trace('Graph canceled by timeout'); + clearListeners(); + cancelGraph(); + reject(new Error('Graph timed out')); + }, timeout); + } + + if (signal) { + signal.addEventListener('abort', () => { + this.log.trace('Graph canceled by signal'); + _clearTimeout(); + clearListeners(); + cancelGraph(); + reject(new Error('Graph canceled')); + }); + } + }); + + return resultPromise; }; /** diff --git a/invokeai/frontend/web/src/features/controlLayers/store/filters.ts b/invokeai/frontend/web/src/features/controlLayers/store/filters.ts index bd138666b27..5cfec9ee9f2 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/filters.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/filters.ts @@ -1,7 +1,8 @@ import { getPrefixedId } from 'features/controlLayers/konva/util'; import type { ImageWithDims } from 'features/controlLayers/store/types'; import { zModelIdentifierField } from 'features/nodes/types/common'; -import type { AnyInvocation, ControlNetModelConfig, Invocation, T2IAdapterModelConfig } from 'services/api/types'; +import { Graph } from 'features/nodes/util/graph/generation/Graph'; +import type { ControlNetModelConfig, T2IAdapterModelConfig } from 'services/api/types'; import { assert } from 'tsafe'; import { z } from 'zod'; @@ -132,7 +133,10 @@ export const isFilterType = (v: unknown): v is FilterType => zFilterType.safePar type ImageFilterData = { type: T; buildDefaults(): Extract; - buildNode(imageDTO: ImageWithDims, config: Extract): AnyInvocation; + buildGraph( + imageDTO: ImageWithDims, + config: Extract + ): { graph: Graph; outputNodeId: string }; validateConfig?(config: Extract): boolean; }; @@ -144,13 +148,20 @@ export const IMAGE_FILTERS: { [key in FilterConfig['type']]: ImageFilterData => ({ - id: getPrefixedId('canny_edge_detection'), - type: 'canny_edge_detection', - image: { image_name }, - low_threshold, - high_threshold, - }), + buildGraph: ({ image_name }, { low_threshold, high_threshold }) => { + const graph = new Graph(getPrefixedId('canny_edge_detection_filter')); + const node = graph.addNode({ + id: getPrefixedId('canny_edge_detection'), + type: 'canny_edge_detection', + image: { image_name }, + low_threshold, + high_threshold, + }); + return { + graph, + outputNodeId: node.id, + }; + }, }, color_map: { type: 'color_map', @@ -158,12 +169,19 @@ export const IMAGE_FILTERS: { [key in FilterConfig['type']]: ImageFilterData => ({ - id: getPrefixedId('color_map'), - type: 'color_map', - image: { image_name }, - tile_size, - }), + buildGraph: ({ image_name }, { tile_size }) => { + const graph = new Graph(getPrefixedId('color_map_filter')); + const node = graph.addNode({ + id: getPrefixedId('color_map'), + type: 'color_map', + image: { image_name }, + tile_size, + }); + return { + graph, + outputNodeId: node.id, + }; + }, }, content_shuffle: { type: 'content_shuffle', @@ -171,12 +189,19 @@ export const IMAGE_FILTERS: { [key in FilterConfig['type']]: ImageFilterData => ({ - id: getPrefixedId('content_shuffle'), - type: 'content_shuffle', - image: { image_name }, - scale_factor, - }), + buildGraph: ({ image_name }, { scale_factor }) => { + const graph = new Graph(getPrefixedId('content_shuffle_filter')); + const node = graph.addNode({ + id: getPrefixedId('content_shuffle'), + type: 'content_shuffle', + image: { image_name }, + scale_factor, + }); + return { + graph, + outputNodeId: node.id, + }; + }, }, depth_anything_depth_estimation: { type: 'depth_anything_depth_estimation', @@ -184,12 +209,19 @@ export const IMAGE_FILTERS: { [key in FilterConfig['type']]: ImageFilterData => ({ - id: getPrefixedId('depth_anything_depth_estimation'), - type: 'depth_anything_depth_estimation', - image: { image_name }, - model_size, - }), + buildGraph: ({ image_name }, { model_size }) => { + const graph = new Graph(getPrefixedId('depth_anything_depth_estimation_filter')); + const node = graph.addNode({ + id: getPrefixedId('depth_anything_depth_estimation'), + type: 'depth_anything_depth_estimation', + image: { image_name }, + model_size, + }); + return { + graph, + outputNodeId: node.id, + }; + }, }, hed_edge_detection: { type: 'hed_edge_detection', @@ -197,23 +229,37 @@ export const IMAGE_FILTERS: { [key in FilterConfig['type']]: ImageFilterData => ({ - id: getPrefixedId('hed_edge_detection'), - type: 'hed_edge_detection', - image: { image_name }, - scribble, - }), + buildGraph: ({ image_name }, { scribble }) => { + const graph = new Graph(getPrefixedId('hed_edge_detection_filter')); + const node = graph.addNode({ + id: getPrefixedId('hed_edge_detection'), + type: 'hed_edge_detection', + image: { image_name }, + scribble, + }); + return { + graph, + outputNodeId: node.id, + }; + }, }, lineart_anime_edge_detection: { type: 'lineart_anime_edge_detection', buildDefaults: () => ({ type: 'lineart_anime_edge_detection', }), - buildNode: ({ image_name }): Invocation<'lineart_anime_edge_detection'> => ({ - id: getPrefixedId('lineart_anime_edge_detection'), - type: 'lineart_anime_edge_detection', - image: { image_name }, - }), + buildGraph: ({ image_name }) => { + const graph = new Graph(getPrefixedId('lineart_anime_edge_detection_filter')); + const node = graph.addNode({ + id: getPrefixedId('lineart_anime_edge_detection'), + type: 'lineart_anime_edge_detection', + image: { image_name }, + }); + return { + graph, + outputNodeId: node.id, + }; + }, }, lineart_edge_detection: { type: 'lineart_edge_detection', @@ -221,12 +267,19 @@ export const IMAGE_FILTERS: { [key in FilterConfig['type']]: ImageFilterData => ({ - id: getPrefixedId('lineart_edge_detection'), - type: 'lineart_edge_detection', - image: { image_name }, - coarse, - }), + buildGraph: ({ image_name }, { coarse }) => { + const graph = new Graph(getPrefixedId('lineart_edge_detection_filter')); + const node = graph.addNode({ + id: getPrefixedId('lineart_edge_detection'), + type: 'lineart_edge_detection', + image: { image_name }, + coarse, + }); + return { + graph, + outputNodeId: node.id, + }; + }, }, mediapipe_face_detection: { type: 'mediapipe_face_detection', @@ -235,13 +288,20 @@ export const IMAGE_FILTERS: { [key in FilterConfig['type']]: ImageFilterData => ({ - id: getPrefixedId('mediapipe_face_detection'), - type: 'mediapipe_face_detection', - image: { image_name }, - max_faces, - min_confidence, - }), + buildGraph: ({ image_name }, { max_faces, min_confidence }) => { + const graph = new Graph(getPrefixedId('mediapipe_face_detection_filter')); + const node = graph.addNode({ + id: getPrefixedId('mediapipe_face_detection'), + type: 'mediapipe_face_detection', + image: { image_name }, + max_faces, + min_confidence, + }); + return { + graph, + outputNodeId: node.id, + }; + }, }, mlsd_detection: { type: 'mlsd_detection', @@ -250,24 +310,38 @@ export const IMAGE_FILTERS: { [key in FilterConfig['type']]: ImageFilterData => ({ - id: getPrefixedId('mlsd_detection'), - type: 'mlsd_detection', - image: { image_name }, - score_threshold, - distance_threshold, - }), + buildGraph: ({ image_name }, { score_threshold, distance_threshold }) => { + const graph = new Graph(getPrefixedId('mlsd_detection_filter')); + const node = graph.addNode({ + id: getPrefixedId('mlsd_detection'), + type: 'mlsd_detection', + image: { image_name }, + score_threshold, + distance_threshold, + }); + return { + graph, + outputNodeId: node.id, + }; + }, }, normal_map: { type: 'normal_map', buildDefaults: () => ({ type: 'normal_map', }), - buildNode: ({ image_name }): Invocation<'normal_map'> => ({ - id: getPrefixedId('normal_map'), - type: 'normal_map', - image: { image_name }, - }), + buildGraph: ({ image_name }) => { + const graph = new Graph(getPrefixedId('normal_map_filter')); + const node = graph.addNode({ + id: getPrefixedId('normal_map'), + type: 'normal_map', + image: { image_name }, + }); + return { + graph, + outputNodeId: node.id, + }; + }, }, pidi_edge_detection: { type: 'pidi_edge_detection', @@ -276,13 +350,20 @@ export const IMAGE_FILTERS: { [key in FilterConfig['type']]: ImageFilterData => ({ - id: getPrefixedId('pidi_edge_detection'), - type: 'pidi_edge_detection', - image: { image_name }, - quantize_edges, - scribble, - }), + buildGraph: ({ image_name }, { quantize_edges, scribble }) => { + const graph = new Graph(getPrefixedId('pidi_edge_detection_filter')); + const node = graph.addNode({ + id: getPrefixedId('pidi_edge_detection'), + type: 'pidi_edge_detection', + image: { image_name }, + quantize_edges, + scribble, + }); + return { + graph, + outputNodeId: node.id, + }; + }, }, dw_openpose_detection: { type: 'dw_openpose_detection', @@ -292,14 +373,21 @@ export const IMAGE_FILTERS: { [key in FilterConfig['type']]: ImageFilterData => ({ - id: getPrefixedId('dw_openpose_detection'), - type: 'dw_openpose_detection', - image: { image_name }, - draw_body, - draw_face, - draw_hands, - }), + buildGraph: ({ image_name }, { draw_body, draw_face, draw_hands }) => { + const graph = new Graph(getPrefixedId('dw_openpose_detection_filter')); + const node = graph.addNode({ + id: getPrefixedId('dw_openpose_detection'), + type: 'dw_openpose_detection', + image: { image_name }, + draw_body, + draw_face, + draw_hands, + }); + return { + graph, + outputNodeId: node.id, + }; + }, }, spandrel_filter: { type: 'spandrel_filter', @@ -309,29 +397,30 @@ export const IMAGE_FILTERS: { [key in FilterConfig['type']]: ImageFilterData => { + buildGraph: ({ image_name }, { model, scale, autoScale }) => { assert(model !== null); - if (autoScale) { - const node: Invocation<'spandrel_image_to_image_autoscale'> = { - id: getPrefixedId('spandrel_image_to_image_autoscale'), - type: 'spandrel_image_to_image_autoscale', - image_to_image_model: model, - image: { image_name }, - scale, - }; - return node; - } else { - const node: Invocation<'spandrel_image_to_image'> = { - id: getPrefixedId('spandrel_image_to_image'), - type: 'spandrel_image_to_image', - image_to_image_model: model, - image: { image_name }, - }; - return node; - } + const graph = new Graph(getPrefixedId('spandrel_filter')); + const node = graph.addNode( + autoScale + ? { + id: getPrefixedId('spandrel_image_to_image_autoscale'), + type: 'spandrel_image_to_image_autoscale', + image_to_image_model: model, + image: { image_name }, + scale, + } + : { + id: getPrefixedId('spandrel_image_to_image'), + type: 'spandrel_image_to_image', + image_to_image_model: model, + image: { image_name }, + } + ); + + return { + graph, + outputNodeId: node.id, + }; }, validateConfig: (config): boolean => { if (!config.model) { From f6146b1a73fdeb0f200e6cb7a6c504297d2e8a2b Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 20 Sep 2024 19:48:03 +1000 Subject: [PATCH 5/5] feat(ui): error handling for all rasterization calls --- .../CanvasEntityObjectRenderer.ts | 62 +++++++++++++------ .../CanvasEntity/CanvasEntityTransformer.ts | 9 ++- .../graph/generation/addControlAdapters.ts | 33 ++++++++-- .../nodes/util/graph/generation/addRegions.ts | 20 +++++- .../web/src/services/api/endpoints/images.ts | 4 +- 5 files changed, 97 insertions(+), 31 deletions(-) diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityObjectRenderer.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityObjectRenderer.ts index 56964e5364f..da7bc9d7999 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityObjectRenderer.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityObjectRenderer.ts @@ -1,5 +1,6 @@ import { $authToken } from 'app/store/nanostores/authToken'; import { rgbColorToString } from 'common/util/colorCodeTransformers'; +import { withResult } from 'common/util/result'; import { SyncableMap } from 'common/util/SyncableMap/SyncableMap'; import type { CanvasEntityAdapter } from 'features/controlLayers/konva/CanvasEntity/types'; import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager'; @@ -356,14 +357,25 @@ export class CanvasEntityObjectRenderer extends CanvasModuleBase { }; /** - * Rasterizes the parent entity. If the entity has a rasterization cache for the given rect, the cached image is - * returned. Otherwise, the entity is rasterized and the image is uploaded to the server. + * Rasterizes the parent entity, returning a promise that resolves to the image DTO. + * + * If the entity has a rasterization cache for the given rect, the cached image is returned. Otherwise, the entity is + * rasterized and the image is uploaded to the server. * * The rasterization cache is reset when the entity's state changes. The buffer object is not considered part of the * entity state for this purpose as it is a temporary object. * - * @param rect The rect to rasterize. If omitted, the entity's full rect will be used. - * @returns A promise that resolves to the rasterized image DTO. + * If rasterization fails for any reason, the promise will reject. + * + * @param options The rasterization options. + * @param options.rect The region of the entity to rasterize. + * @param options.replaceObjects Whether to replace the entity's objects with the rasterized image. If you just want + * the entity's image, omit or set this to false. + * @param options.attrs The Konva node attributes to apply to the rasterized image group. For example, you might want + * to disable filters or set the opacity to the rasterized image. + * @param options.bg Draws the entity on a canvas with the given background color. If omitted, the entity is drawn on + * a transparent canvas. + * @returns A promise that resolves to the rasterized image DTO or rejects if rasterization fails. */ rasterize = async (options: { rect: Rect; @@ -423,26 +435,38 @@ export class CanvasEntityObjectRenderer extends CanvasModuleBase { if (this.parent.transformer.$isPendingRectCalculation.get()) { return; } + const pixelRect = this.parent.transformer.$pixelRect.get(); if (pixelRect.width === 0 || pixelRect.height === 0) { return; } - try { - // TODO(psyche): This is an internal Konva method, so it may break in the future. Can we make this API public? - const canvas = this.konva.objectGroup._getCachedSceneCanvas()._canvas as HTMLCanvasElement | undefined | null; - if (canvas) { - const nodeRect = this.parent.transformer.$nodeRect.get(); - const rect = { - x: pixelRect.x - nodeRect.x, - y: pixelRect.y - nodeRect.y, - width: pixelRect.width, - height: pixelRect.height, - }; - this.$canvasCache.set({ rect, canvas }); - } - } catch (error) { + + /** + * TODO(psyche): This is an internal Konva method, so it may break in the future. Can we make this API public? + * + * This method's API is unknown. It has been experimentally determined that it may throw, so we need to handle + * errors. + */ + const getCacheCanvasResult = withResult( + () => this.konva.objectGroup._getCachedSceneCanvas()._canvas as HTMLCanvasElement | undefined | null + ); + if (getCacheCanvasResult.isErr()) { // We are using an internal Konva method, so we need to catch any errors that may occur. - this.log.warn({ error: serializeError(error) }, 'Failed to update preview canvas'); + this.log.warn({ error: serializeError(getCacheCanvasResult.error) }, 'Failed to update preview canvas'); + return; + } + + const canvas = getCacheCanvasResult.value; + + if (canvas) { + const nodeRect = this.parent.transformer.$nodeRect.get(); + const rect = { + x: pixelRect.x - nodeRect.x, + y: pixelRect.y - nodeRect.y, + width: pixelRect.width, + height: pixelRect.height, + }; + this.$canvasCache.set({ rect, canvas }); } }, 300); diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityTransformer.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityTransformer.ts index d32ea51726b..7b0208bf41c 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityTransformer.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityTransformer.ts @@ -1,3 +1,4 @@ +import { withResultAsync } from 'common/util/result'; import { roundToMultiple } from 'common/util/roundDownToMultiple'; import type { CanvasEntityAdapter } from 'features/controlLayers/konva/CanvasEntity/types'; import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager'; @@ -15,6 +16,7 @@ import type { GroupConfig } from 'konva/lib/Group'; import { debounce, get } from 'lodash-es'; import { atom } from 'nanostores'; import type { Logger } from 'roarr'; +import { serializeError } from 'serialize-error'; import { assert } from 'tsafe'; type CanvasEntityTransformerConfig = { @@ -575,7 +577,12 @@ export class CanvasEntityTransformer extends CanvasModuleBase { this.log.debug('Applying transform'); this.$isProcessing.set(true); const rect = this.getRelativeRect(); - await this.parent.renderer.rasterize({ rect, replaceObjects: true, attrs: { opacity: 1, filters: [] } }); + const rasterizeResult = await withResultAsync(() => + this.parent.renderer.rasterize({ rect, replaceObjects: true, attrs: { opacity: 1, filters: [] } }) + ); + if (rasterizeResult.isErr()) { + this.log.error({ error: serializeError(rasterizeResult.error) }, 'Failed to rasterize entity'); + } this.requestRectCalculation(); this.stopTransform(); }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addControlAdapters.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addControlAdapters.ts index deaddab38f6..347ca4fba4b 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addControlAdapters.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addControlAdapters.ts @@ -1,3 +1,5 @@ +import { logger } from 'app/logging/logger'; +import { withResultAsync } from 'common/util/result'; import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager'; import type { CanvasControlLayerState, @@ -6,9 +8,12 @@ import type { T2IAdapterConfig, } from 'features/controlLayers/store/types'; import type { Graph } from 'features/nodes/util/graph/generation/Graph'; +import { serializeError } from 'serialize-error'; import type { BaseModelType, ImageDTO, Invocation } from 'services/api/types'; import { assert } from 'tsafe'; +const log = logger('system'); + type AddControlNetsResult = { addedControlNets: number; }; @@ -33,9 +38,17 @@ export const addControlNets = async ( for (const layer of validControlLayers) { result.addedControlNets++; - const adapter = manager.adapters.controlLayers.get(layer.id); - assert(adapter, 'Adapter not found'); - const imageDTO = await adapter.renderer.rasterize({ rect, attrs: { opacity: 1, filters: [] }, bg: 'black' }); + const getImageDTOResult = await withResultAsync(() => { + const adapter = manager.adapters.controlLayers.get(layer.id); + assert(adapter, 'Adapter not found'); + return adapter.renderer.rasterize({ rect, attrs: { opacity: 1, filters: [] }, bg: 'black' }); + }); + if (getImageDTOResult.isErr()) { + log.warn({ error: serializeError(getImageDTOResult.error) }, 'Error rasterizing control layer'); + continue; + } + + const imageDTO = getImageDTOResult.value; addControlNetToGraph(g, layer, imageDTO, collector); } @@ -66,9 +79,17 @@ export const addT2IAdapters = async ( for (const layer of validControlLayers) { result.addedT2IAdapters++; - const adapter = manager.adapters.controlLayers.get(layer.id); - assert(adapter, 'Adapter not found'); - const imageDTO = await adapter.renderer.rasterize({ rect, attrs: { opacity: 1, filters: [], bg: 'black' } }); + const getImageDTOResult = await withResultAsync(() => { + const adapter = manager.adapters.controlLayers.get(layer.id); + assert(adapter, 'Adapter not found'); + return adapter.renderer.rasterize({ rect, attrs: { opacity: 1, filters: [] }, bg: 'black' }); + }); + if (getImageDTOResult.isErr()) { + log.warn({ error: serializeError(getImageDTOResult.error) }, 'Error rasterizing control layer'); + continue; + } + + const imageDTO = getImageDTOResult.value; addT2IAdapterToGraph(g, layer, imageDTO, collector); } diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addRegions.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addRegions.ts index d85e862d508..dcce2046daa 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addRegions.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addRegions.ts @@ -1,4 +1,6 @@ +import { logger } from 'app/logging/logger'; import { deepClone } from 'common/util/deepClone'; +import { withResultAsync } from 'common/util/result'; import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager'; import { getPrefixedId } from 'features/controlLayers/konva/util'; import type { @@ -8,9 +10,12 @@ import type { RegionalGuidanceReferenceImageState, } from 'features/controlLayers/store/types'; import type { Graph } from 'features/nodes/util/graph/generation/Graph'; +import { serializeError } from 'serialize-error'; import type { BaseModelType, Invocation } from 'services/api/types'; import { assert } from 'tsafe'; +const log = logger('system'); + type AddedRegionResult = { addedPositivePrompt: boolean; addedNegativePrompt: boolean; @@ -64,9 +69,18 @@ export const addRegions = async ( addedAutoNegativePositivePrompt: false, addedIPAdapters: 0, }; - const adapter = manager.adapters.regionMasks.get(region.id); - assert(adapter, 'Adapter not found'); - const imageDTO = await adapter.renderer.rasterize({ rect: bbox }); + + const getImageDTOResult = await withResultAsync(() => { + const adapter = manager.adapters.regionMasks.get(region.id); + assert(adapter, 'Adapter not found'); + return adapter.renderer.rasterize({ rect: bbox, attrs: { opacity: 1, filters: [] } }); + }); + if (getImageDTOResult.isErr()) { + log.warn({ error: serializeError(getImageDTOResult.error) }, 'Error rasterizing region mask'); + continue; + } + + const imageDTO = getImageDTOResult.value; // The main mask-to-tensor node const maskToTensor = g.addNode({ diff --git a/invokeai/frontend/web/src/services/api/endpoints/images.ts b/invokeai/frontend/web/src/services/api/endpoints/images.ts index cbd6908c512..0d92302e031 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/images.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/images.ts @@ -612,7 +612,7 @@ export type UploadOptions = { board_id?: BoardId; metadata?: SerializableObject; }; -export const uploadImage = async (arg: UploadOptions): Promise => { +export const uploadImage = (arg: UploadOptions): Promise => { const { blob, fileName, image_category, is_intermediate, crop_visible = false, board_id, metadata } = arg; const { dispatch } = getStore(); @@ -628,5 +628,5 @@ export const uploadImage = async (arg: UploadOptions): Promise => { }) ); req.reset(); - return await req.unwrap(); + return req.unwrap(); };