diff --git a/invokeai/frontend/web/src/app/components/App.tsx b/invokeai/frontend/web/src/app/components/App.tsx index 300f13fcd0c..97dbe00e28a 100644 --- a/invokeai/frontend/web/src/app/components/App.tsx +++ b/invokeai/frontend/web/src/app/components/App.tsx @@ -1,5 +1,4 @@ import { Box, useGlobalModifiersInit } from '@invoke-ai/ui-library'; -import { useSocketIO } from 'app/hooks/useSocketIO'; import { useSyncQueueStatus } from 'app/hooks/useSyncQueueStatus'; import { useLogger } from 'app/logging/useLogger'; import { appStarted } from 'app/store/middleware/listenerMiddleware/listeners/appStarted'; @@ -31,6 +30,7 @@ import { size } from 'lodash-es'; import { memo, useCallback, useEffect } from 'react'; import { ErrorBoundary } from 'react-error-boundary'; import { useGetOpenAPISchemaQuery } from 'services/api/endpoints/appInfo'; +import { useSocketIO } from 'services/events/useSocketIO'; import AppErrorBoundaryFallback from './AppErrorBoundaryFallback'; import PreselectedImage from './PreselectedImage'; diff --git a/invokeai/frontend/web/src/app/components/InvokeAIUI.tsx b/invokeai/frontend/web/src/app/components/InvokeAIUI.tsx index 47d4359e31c..61252aaf1bf 100644 --- a/invokeai/frontend/web/src/app/components/InvokeAIUI.tsx +++ b/invokeai/frontend/web/src/app/components/InvokeAIUI.tsx @@ -1,7 +1,6 @@ import 'i18n'; import type { Middleware } from '@reduxjs/toolkit'; -import { $socketOptions } from 'app/hooks/useSocketIO'; import { $authToken } from 'app/store/nanostores/authToken'; import { $baseUrl } from 'app/store/nanostores/baseUrl'; import { $customNavComponent } from 'app/store/nanostores/customNavComponent'; @@ -24,6 +23,7 @@ import type { PropsWithChildren, ReactNode } from 'react'; import React, { lazy, memo, useEffect, useMemo } from 'react'; import { Provider } from 'react-redux'; import { addMiddleware, resetMiddlewares } from 'redux-dynamic-middlewares'; +import { $socketOptions } from 'services/events/stores'; import type { ManagerOptions, SocketOptions } from 'socket.io-client'; const App = lazy(() => import('./App')); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts index d924f3ea160..a0a6604ddf8 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts @@ -9,7 +9,6 @@ import { addBatchEnqueuedListener } from 'app/store/middleware/listenerMiddlewar import { addDeleteBoardAndImagesFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/boardAndImagesDeleted'; import { addBoardIdSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/boardIdSelected'; import { addBulkDownloadListeners } from 'app/store/middleware/listenerMiddleware/listeners/bulkDownload'; -import { addCancellationsListeners } from 'app/store/middleware/listenerMiddleware/listeners/cancellationsListeners'; import { addEnqueueRequestedLinear } from 'app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear'; import { addEnqueueRequestedNodes } from 'app/store/middleware/listenerMiddleware/listeners/enqueueRequestedNodes'; import { addGalleryImageClickedListener } from 'app/store/middleware/listenerMiddleware/listeners/galleryImageClicked'; @@ -73,15 +72,6 @@ addAnyEnqueuedListener(startAppListening); addBatchEnqueuedListener(startAppListening); // Canvas actions -// addCanvasSavedToGalleryListener(startAppListening); -// addCanvasMaskSavedToGalleryListener(startAppListening); -// addCanvasImageToControlNetListener(startAppListening); -// addCanvasMaskToControlNetListener(startAppListening); -// addCanvasDownloadedAsImageListener(startAppListening); -// addCanvasCopiedToClipboardListener(startAppListening); -// addCanvasMergedListener(startAppListening); -// addStagingAreaImageSavedListener(startAppListening); -// addCommitStagingAreaImageListener(startAppListening); addStagingListeners(startAppListening); // Socket.IO @@ -121,6 +111,3 @@ addAdHocPostProcessingRequestedListener(startAppListening); addDynamicPromptsListener(startAppListening); addSetDefaultSettingsListener(startAppListening); -// addControlAdapterPreprocessor(startAppListening); - -addCancellationsListeners(startAppListening); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/cancellationsListeners.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/cancellationsListeners.ts deleted file mode 100644 index d5021fe7638..00000000000 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/cancellationsListeners.ts +++ /dev/null @@ -1,137 +0,0 @@ -import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; -import { $lastCanvasProgressEvent } from 'features/controlLayers/store/canvasSlice'; -import { queueApi } from 'services/api/endpoints/queue'; - -/** - * To prevent a race condition where a progress event arrives after a successful cancellation, we need to keep track of - * cancellations: - * - In the route handlers above, we track and update the cancellations object - * - When the user queues a, we should reset the cancellations, also handled int he route handlers above - * - When we get a progress event, we should check if the event is cancelled before setting the event - * - * We have a few ways that cancellations are effected, so we need to track them all: - * - by queue item id (in this case, we will compare the session_id and not the item_id) - * - by batch id - * - by destination - * - by clearing the queue - */ -type Cancellations = { - sessionIds: Set; - batchIds: Set; - destinations: Set; - clearQueue: boolean; -}; - -const resetCancellations = (): void => { - cancellations.clearQueue = false; - cancellations.sessionIds.clear(); - cancellations.batchIds.clear(); - cancellations.destinations.clear(); -}; - -const cancellations: Cancellations = { - sessionIds: new Set(), - batchIds: new Set(), - destinations: new Set(), - clearQueue: false, -} as Readonly; - -/** - * Checks if an item is cancelled, used to prevent race conditions with event handling. - * - * To use this, provide the session_id, batch_id and destination from the event payload. - */ -export const getIsCancelled = (item: { - session_id: string; - batch_id: string; - destination?: string | null; -}): boolean => { - if (cancellations.clearQueue) { - return true; - } - if (cancellations.sessionIds.has(item.session_id)) { - return true; - } - if (cancellations.batchIds.has(item.batch_id)) { - return true; - } - if (item.destination && cancellations.destinations.has(item.destination)) { - return true; - } - return false; -}; - -export const addCancellationsListeners = (startAppListening: AppStartListening) => { - // When we get a cancellation, we may need to clear the last progress event - next few listeners handle those cases. - // Maybe we could use the `getIsCancelled` util here, but I think that could introduce _another_ race condition... - startAppListening({ - matcher: queueApi.endpoints.enqueueBatch.matchFulfilled, - effect: () => { - resetCancellations(); - }, - }); - - startAppListening({ - matcher: queueApi.endpoints.cancelByBatchDestination.matchFulfilled, - effect: (action) => { - cancellations.destinations.add(action.meta.arg.originalArgs.destination); - - const event = $lastCanvasProgressEvent.get(); - if (!event) { - return; - } - const { session_id, batch_id, destination } = event; - if (getIsCancelled({ session_id, batch_id, destination })) { - $lastCanvasProgressEvent.set(null); - } - }, - }); - - startAppListening({ - matcher: queueApi.endpoints.cancelQueueItem.matchFulfilled, - effect: (action) => { - cancellations.sessionIds.add(action.payload.session_id); - - const event = $lastCanvasProgressEvent.get(); - if (!event) { - return; - } - const { session_id, batch_id, destination } = event; - if (getIsCancelled({ session_id, batch_id, destination })) { - $lastCanvasProgressEvent.set(null); - } - }, - }); - - startAppListening({ - matcher: queueApi.endpoints.cancelByBatchIds.matchFulfilled, - effect: (action) => { - for (const batch_id of action.meta.arg.originalArgs.batch_ids) { - cancellations.batchIds.add(batch_id); - } - const event = $lastCanvasProgressEvent.get(); - if (!event) { - return; - } - const { session_id, batch_id, destination } = event; - if (getIsCancelled({ session_id, batch_id, destination })) { - $lastCanvasProgressEvent.set(null); - } - }, - }); - - startAppListening({ - matcher: queueApi.endpoints.clearQueue.matchFulfilled, - effect: () => { - cancellations.clearQueue = true; - const event = $lastCanvasProgressEvent.get(); - if (!event) { - return; - } - const { session_id, batch_id, destination } = event; - if (getIsCancelled({ session_id, batch_id, destination })) { - $lastCanvasProgressEvent.set(null); - } - }, - }); -}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear.ts index 760deccbb7a..7da7ac99de6 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear.ts @@ -66,7 +66,7 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening) const destination = state.canvasSettings.sendToCanvas ? 'canvas' : 'gallery'; const prepareBatchResult = withResult(() => - prepareLinearUIBatch(state, g, prepend, noise, posCond, 'generation', destination) + prepareLinearUIBatch(state, g, prepend, noise, posCond, 'canvas', destination) ); if (prepareBatchResult.isErr()) { diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/promptChanged.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/promptChanged.ts index d8abe4c66c8..0be242f49d3 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/promptChanged.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/promptChanged.ts @@ -15,7 +15,8 @@ import { getPresetModifiedPrompts } from 'features/nodes/util/graph/graphBuilder import { activeStylePresetIdChanged } from 'features/stylePresets/store/stylePresetSlice'; import { stylePresetsApi } from 'services/api/endpoints/stylePresets'; import { utilitiesApi } from 'services/api/endpoints/utilities'; -import { socketConnected } from 'services/events/setEventListeners'; + +import { socketConnected } from './socketConnected'; const matcher = isAnyOf( positivePromptChanged, diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketConnected.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketConnected.ts index b73ab7e0fa0..0398aacb6e3 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketConnected.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketConnected.ts @@ -1,3 +1,4 @@ +import { createAction } from '@reduxjs/toolkit'; import { logger } from 'app/logging/logger'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import { $baseUrl } from 'app/store/nanostores/baseUrl'; @@ -6,11 +7,11 @@ import { atom } from 'nanostores'; import { api } from 'services/api'; import { modelsApi } from 'services/api/endpoints/models'; import { queueApi, selectQueueStatus } from 'services/api/endpoints/queue'; -import { socketConnected } from 'services/events/setEventListeners'; const log = logger('events'); const $isFirstConnection = atom(true); +export const socketConnected = createAction('socket/connected'); export const addSocketConnectedEventListener = (startAppListening: AppStartListening) => { startAppListening({ diff --git a/invokeai/frontend/web/src/common/hooks/useIsReadyToEnqueue.ts b/invokeai/frontend/web/src/common/hooks/useIsReadyToEnqueue.ts index 24b19bfdceb..a46dd3c2df8 100644 --- a/invokeai/frontend/web/src/common/hooks/useIsReadyToEnqueue.ts +++ b/invokeai/frontend/web/src/common/hooks/useIsReadyToEnqueue.ts @@ -1,5 +1,4 @@ import { useStore } from '@nanostores/react'; -import { $isConnected } from 'app/hooks/useSocketIO'; import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { $true } from 'app/store/nanostores/util'; import { useAppSelector } from 'app/store/storeHooks'; @@ -21,6 +20,7 @@ import i18n from 'i18next'; import { forEach, upperFirst } from 'lodash-es'; import { useMemo } from 'react'; import { getConnectedEdges } from 'reactflow'; +import { $isConnected } from 'services/events/stores'; const LAYER_TYPE_TO_TKEY = { reference_image: 'controlLayers.referenceImage', diff --git a/invokeai/frontend/web/src/features/controlLayers/components/IPAdapter/IPAdapterImagePreview.tsx b/invokeai/frontend/web/src/features/controlLayers/components/IPAdapter/IPAdapterImagePreview.tsx index b7707bff441..575e1839900 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/IPAdapter/IPAdapterImagePreview.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/IPAdapter/IPAdapterImagePreview.tsx @@ -1,7 +1,6 @@ import { Flex, useShiftModifier } from '@invoke-ai/ui-library'; import { useStore } from '@nanostores/react'; import { skipToken } from '@reduxjs/toolkit/query'; -import { $isConnected } from 'app/hooks/useSocketIO'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import IAIDndImage from 'common/components/IAIDndImage'; import IAIDndImageIcon from 'common/components/IAIDndImageIcon'; @@ -16,6 +15,7 @@ import { useTranslation } from 'react-i18next'; import { PiArrowCounterClockwiseBold, PiRulerBold } from 'react-icons/pi'; import { useGetImageDTOQuery } from 'services/api/endpoints/images'; import type { ImageDTO, PostUploadAction } from 'services/api/types'; +import { $isConnected } from 'services/events/stores'; type Props = { image: ImageWithDims | null; diff --git a/invokeai/frontend/web/src/features/controlLayers/hooks/useInvokeCanvas.ts b/invokeai/frontend/web/src/features/controlLayers/hooks/useInvokeCanvas.ts index 286c32125d2..7a1beec8829 100644 --- a/invokeai/frontend/web/src/features/controlLayers/hooks/useInvokeCanvas.ts +++ b/invokeai/frontend/web/src/features/controlLayers/hooks/useInvokeCanvas.ts @@ -1,5 +1,4 @@ import { useStore } from '@nanostores/react'; -import { $socket } from 'app/hooks/useSocketIO'; import { logger } from 'app/logging/logger'; import { useAppStore } from 'app/store/nanostores/store'; import { useAssertSingleton } from 'common/hooks/useAssertSingleton'; @@ -7,6 +6,7 @@ import { CanvasManager } from 'features/controlLayers/konva/CanvasManager'; import { $canvasManager } from 'features/controlLayers/store/canvasSlice'; import Konva from 'konva'; import { useLayoutEffect, useState } from 'react'; +import { $socket } from 'services/events/stores'; import { useDevicePixelRatio } from 'use-device-pixel-ratio'; const log = logger('canvas'); diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasManager.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasManager.ts index ebac82581fe..eca6f1a5bbf 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasManager.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasManager.ts @@ -1,4 +1,3 @@ -import type { AppSocket } from 'app/hooks/useSocketIO'; import { logger } from 'app/logging/logger'; import type { AppStore } from 'app/store/store'; import type { SerializableObject } from 'common/types'; @@ -31,6 +30,7 @@ import Konva from 'konva'; import type { Atom } from 'nanostores'; import { computed } from 'nanostores'; import type { Logger } from 'roarr'; +import type { AppSocket } from 'services/events/types'; import { assert } from 'tsafe'; import { CanvasBackgroundModule } from './CanvasBackgroundModule'; diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasProgressImageModule.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasProgressImageModule.ts index 79106853cc0..efd3069dffa 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasProgressImageModule.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasProgressImageModule.ts @@ -4,7 +4,10 @@ import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase' import { getPrefixedId, loadImage } from 'features/controlLayers/konva/util'; import { selectShowProgressOnCanvas } from 'features/controlLayers/store/canvasSettingsSlice'; import Konva from 'konva'; +import { atom } from 'nanostores'; import type { Logger } from 'roarr'; +import { selectCanvasQueueCounts } from 'services/api/endpoints/queue'; +import type { S } from 'services/api/types'; export class CanvasProgressImageModule extends CanvasModuleBase { readonly type = 'progress_image'; @@ -23,7 +26,8 @@ export class CanvasProgressImageModule extends CanvasModuleBase { imageElement: HTMLImageElement | null = null; subscriptions = new Set<() => void>(); - + $lastProgressEvent = atom(null); + hasActiveGeneration: boolean = false; mutex: Mutex = new Mutex(); constructor(manager: CanvasManager) { @@ -41,11 +45,50 @@ export class CanvasProgressImageModule extends CanvasModuleBase { image: null, }; - this.subscriptions.add(this.manager.stateApi.$lastCanvasProgressEvent.listen(this.render)); this.subscriptions.add(this.manager.stagingArea.$shouldShowStagedImage.listen(this.render)); this.subscriptions.add(this.manager.stateApi.createStoreSubscription(selectShowProgressOnCanvas, this.render)); + this.subscriptions.add(this.setSocketEventListeners()); + this.subscriptions.add( + this.manager.stateApi.createStoreSubscription(selectCanvasQueueCounts, ({ data }) => { + if (data && (data.in_progress > 0 || data.pending > 0)) { + this.hasActiveGeneration = true; + } else { + this.hasActiveGeneration = false; + this.$lastProgressEvent.set(null); + } + }) + ); + this.subscriptions.add(this.$lastProgressEvent.listen(this.render)); } + setSocketEventListeners = (): (() => void) => { + const progressListener = (data: S['InvocationDenoiseProgressEvent']) => { + if (data.destination !== 'canvas') { + return; + } + if (!this.hasActiveGeneration) { + return; + } + this.$lastProgressEvent.set(data); + }; + + const clearProgress = () => { + this.$lastProgressEvent.set(null); + }; + + this.manager.socket.on('invocation_denoise_progress', progressListener); + this.manager.socket.on('connect', clearProgress); + this.manager.socket.on('connect_error', clearProgress); + this.manager.socket.on('disconnect', clearProgress); + + return () => { + this.manager.socket.off('invocation_denoise_progress', progressListener); + this.manager.socket.off('connect', clearProgress); + this.manager.socket.off('connect_error', clearProgress); + this.manager.socket.off('disconnect', clearProgress); + }; + }; + getNodes = () => { return [this.konva.group]; }; @@ -53,7 +96,7 @@ export class CanvasProgressImageModule extends CanvasModuleBase { render = async () => { const release = await this.mutex.acquire(); - const event = this.manager.stateApi.$lastCanvasProgressEvent.get(); + const event = this.$lastProgressEvent.get(); const showProgressOnCanvas = this.manager.stateApi.runSelector(selectShowProgressOnCanvas); if (!event || !showProgressOnCanvas) { diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasStagingAreaModule.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasStagingAreaModule.ts index cbef8a20670..703d63c9a46 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasStagingAreaModule.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasStagingAreaModule.ts @@ -96,7 +96,7 @@ export class CanvasStagingAreaModule extends CanvasModuleBase { if (!this.image.isLoading && !this.image.isError) { await this.image.update({ ...this.image.state, image: imageDTOToImageWithDims(imageDTO) }, true); - this.manager.stateApi.$lastCanvasProgressEvent.set(null); + this.manager.progressImage.$lastProgressEvent.set(null); } this.image.konva.group.visible(shouldShowStagedImage); } else { diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasStateApiModule.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasStateApiModule.ts index 30b2ab5f5a6..5e7cbbad0f4 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasStateApiModule.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasStateApiModule.ts @@ -15,7 +15,6 @@ import { settingsEraserWidthChanged, } from 'features/controlLayers/store/canvasSettingsSlice'; import { - $lastCanvasProgressEvent, bboxChangedFromCanvas, entityBrushLineAdded, entityEraserLineAdded, @@ -382,12 +381,6 @@ export class CanvasStateApiModule extends CanvasModuleBase { */ $isRasterizing = computed(this.$rasterizingAdapter, (rasterizingAdapter) => Boolean(rasterizingAdapter)); - /** - * The last canvas progress event. This is set in a global event listener. The staging area may set it to null when it - * consumes the event. - */ - $lastCanvasProgressEvent = $lastCanvasProgressEvent; - /** * Whether the space key is currently pressed. */ diff --git a/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts b/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts index eb90f8d1d72..1106dc21aa7 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts @@ -30,13 +30,7 @@ import type { IRect } from 'konva/lib/types'; import { merge, omit } from 'lodash-es'; import { atom } from 'nanostores'; import type { UndoableOptions } from 'redux-undo'; -import type { - ControlNetModelConfig, - ImageDTO, - IPAdapterModelConfig, - S, - T2IAdapterModelConfig, -} from 'services/api/types'; +import type { ControlNetModelConfig, ImageDTO, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types'; import { assert } from 'tsafe'; import type { @@ -1236,7 +1230,6 @@ function actionsThrottlingFilter(action: UnknownAction) { return true; } -export const $lastCanvasProgressEvent = atom(null); /** * The global canvas manager instance. */ diff --git a/invokeai/frontend/web/src/features/deleteImageModal/components/DeleteImageButton.tsx b/invokeai/frontend/web/src/features/deleteImageModal/components/DeleteImageButton.tsx index a4408e0a5fa..ff790f82c69 100644 --- a/invokeai/frontend/web/src/features/deleteImageModal/components/DeleteImageButton.tsx +++ b/invokeai/frontend/web/src/features/deleteImageModal/components/DeleteImageButton.tsx @@ -1,12 +1,12 @@ import type { IconButtonProps } from '@invoke-ai/ui-library'; import { IconButton } from '@invoke-ai/ui-library'; import { useStore } from '@nanostores/react'; -import { $isConnected } from 'app/hooks/useSocketIO'; import { useAppSelector } from 'app/store/storeHooks'; import { selectSelectionCount } from 'features/gallery/store/gallerySelectors'; import { memo } from 'react'; import { useTranslation } from 'react-i18next'; import { PiTrashSimpleBold } from 'react-icons/pi'; +import { $isConnected } from 'services/events/stores'; type DeleteImageButtonProps = Omit & { onClick: () => void; diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/CurrentImageButtons.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/CurrentImageButtons.tsx index cfa3a3a256a..82a3d041141 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/CurrentImageButtons.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/CurrentImageButtons.tsx @@ -1,7 +1,6 @@ import { ButtonGroup, IconButton, Menu, MenuButton, MenuList } from '@invoke-ai/ui-library'; import { useStore } from '@nanostores/react'; import { skipToken } from '@reduxjs/toolkit/query'; -import { $isConnected } from 'app/hooks/useSocketIO'; import { adHocPostProcessingRequested } from 'app/store/middleware/listenerMiddleware/listeners/addAdHocPostProcessingRequestedListener'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { INTERACTION_SCOPES } from 'common/hooks/interactionScopes'; @@ -30,7 +29,7 @@ import { PiRulerBold, } from 'react-icons/pi'; import { useGetImageDTOQuery } from 'services/api/endpoints/images'; -import { $progressImage } from 'services/events/setEventListeners'; +import { $isConnected, $progressImage } from 'services/events/stores'; const CurrentImageButtons = () => { const dispatch = useAppDispatch(); diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/CurrentImagePreview.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/CurrentImagePreview.tsx index 60c4db26e86..e0875e41971 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/CurrentImagePreview.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/CurrentImagePreview.tsx @@ -15,7 +15,7 @@ import { memo, useCallback, useMemo, useRef, useState } from 'react'; import { useTranslation } from 'react-i18next'; import { PiImageBold } from 'react-icons/pi'; import { useGetImageDTOQuery } from 'services/api/endpoints/images'; -import { $hasProgress, $isProgressFromCanvas } from 'services/events/setEventListeners'; +import { $hasProgress, $isProgressFromCanvas } from 'services/events/stores'; import ProgressImage from './ProgressImage'; diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ProgressImage.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ProgressImage.tsx index 2b8013f22ea..d876ada4156 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ProgressImage.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ProgressImage.tsx @@ -5,7 +5,7 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import { selectSystemSlice } from 'features/system/store/systemSlice'; import { memo, useMemo } from 'react'; -import { $isProgressFromCanvas, $progressImage } from 'services/events/setEventListeners'; +import { $isProgressFromCanvas, $progressImage } from 'services/events/stores'; const selectShouldAntialiasProgressImage = createSelector( selectSystemSlice, diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/CurrentImage/CurrentImageNode.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/CurrentImage/CurrentImageNode.tsx index 0363726bd11..555bcbf1230 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/CurrentImage/CurrentImageNode.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/CurrentImage/CurrentImageNode.tsx @@ -13,7 +13,7 @@ import type { CSSProperties, PropsWithChildren } from 'react'; import { memo, useCallback, useState } from 'react'; import { useTranslation } from 'react-i18next'; import type { NodeProps } from 'reactflow'; -import { $lastProgressEvent } from 'services/events/setEventListeners'; +import { $lastProgressEvent } from 'services/events/stores'; const CurrentImageNode = (props: NodeProps) => { const imageDTO = useAppSelector(selectLastSelectedImage); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageFieldInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageFieldInputComponent.tsx index cd77eace774..b6c264abf11 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageFieldInputComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageFieldInputComponent.tsx @@ -1,7 +1,6 @@ import { Flex, Text } from '@invoke-ai/ui-library'; import { useStore } from '@nanostores/react'; import { skipToken } from '@reduxjs/toolkit/query'; -import { $isConnected } from 'app/hooks/useSocketIO'; import { useAppDispatch } from 'app/store/storeHooks'; import IAIDndImage from 'common/components/IAIDndImage'; import IAIDndImageIcon from 'common/components/IAIDndImageIcon'; @@ -13,6 +12,7 @@ import { useTranslation } from 'react-i18next'; import { PiArrowCounterClockwiseBold } from 'react-icons/pi'; import { useGetImageDTOQuery } from 'services/api/endpoints/images'; import type { PostUploadAction } from 'services/api/types'; +import { $isConnected } from 'services/events/stores'; import type { FieldComponentProps } from './types'; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearBatchConfig.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearBatchConfig.ts index a8aeed3850b..988a7492525 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearBatchConfig.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearBatchConfig.ts @@ -11,7 +11,7 @@ export const prepareLinearUIBatch = ( prepend: boolean, noise: Invocation<'noise' | 'flux_denoise'>, posCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'>, - origin: 'generation' | 'workflows' | 'upscaling', + origin: 'canvas' | 'workflows' | 'upscaling', destination: 'canvas' | 'gallery' ): BatchConfig => { const { iterations, model, shouldRandomizeSeed, seed, shouldConcatPrompts } = state.params; diff --git a/invokeai/frontend/web/src/features/queue/components/ClearQueueConfirmationAlertDialog.tsx b/invokeai/frontend/web/src/features/queue/components/ClearQueueConfirmationAlertDialog.tsx index 1fbd8053885..00dcf92ad44 100644 --- a/invokeai/frontend/web/src/features/queue/components/ClearQueueConfirmationAlertDialog.tsx +++ b/invokeai/frontend/web/src/features/queue/components/ClearQueueConfirmationAlertDialog.tsx @@ -1,6 +1,5 @@ import { ConfirmationAlertDialog, Text } from '@invoke-ai/ui-library'; import { useStore } from '@nanostores/react'; -import { $isConnected } from 'app/hooks/useSocketIO'; import { useAppDispatch } from 'app/store/storeHooks'; import { buildUseBoolean } from 'common/hooks/useBoolean'; import { listCursorChanged, listPriorityChanged } from 'features/queue/store/queueSlice'; @@ -8,6 +7,7 @@ import { toast } from 'features/toast/toast'; import { memo, useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import { useClearQueueMutation, useGetQueueStatusQuery } from 'services/api/endpoints/queue'; +import { $isConnected } from 'services/events/stores'; const [useClearQueueConfirmationAlertDialog] = buildUseBoolean(false); diff --git a/invokeai/frontend/web/src/features/queue/hooks/useCancelBatch.ts b/invokeai/frontend/web/src/features/queue/hooks/useCancelBatch.ts index d9ad1a736f5..92d0cbb5a6d 100644 --- a/invokeai/frontend/web/src/features/queue/hooks/useCancelBatch.ts +++ b/invokeai/frontend/web/src/features/queue/hooks/useCancelBatch.ts @@ -1,9 +1,9 @@ import { useStore } from '@nanostores/react'; -import { $isConnected } from 'app/hooks/useSocketIO'; import { toast } from 'features/toast/toast'; import { useCallback } from 'react'; import { useTranslation } from 'react-i18next'; import { useCancelByBatchIdsMutation, useGetBatchStatusQuery } from 'services/api/endpoints/queue'; +import { $isConnected } from 'services/events/stores'; export const useCancelBatch = (batch_id: string) => { const isConnected = useStore($isConnected); diff --git a/invokeai/frontend/web/src/features/queue/hooks/useCancelCurrentQueueItem.ts b/invokeai/frontend/web/src/features/queue/hooks/useCancelCurrentQueueItem.ts index 9ae8e2dd2e7..12173aa7aa6 100644 --- a/invokeai/frontend/web/src/features/queue/hooks/useCancelCurrentQueueItem.ts +++ b/invokeai/frontend/web/src/features/queue/hooks/useCancelCurrentQueueItem.ts @@ -1,10 +1,10 @@ import { useStore } from '@nanostores/react'; -import { $isConnected } from 'app/hooks/useSocketIO'; import { toast } from 'features/toast/toast'; import { isNil } from 'lodash-es'; import { useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import { useCancelQueueItemMutation, useGetQueueStatusQuery } from 'services/api/endpoints/queue'; +import { $isConnected } from 'services/events/stores'; export const useCancelCurrentQueueItem = () => { const isConnected = useStore($isConnected); diff --git a/invokeai/frontend/web/src/features/queue/hooks/useCancelQueueItem.ts b/invokeai/frontend/web/src/features/queue/hooks/useCancelQueueItem.ts index bf0af41605a..4ac87a1aad7 100644 --- a/invokeai/frontend/web/src/features/queue/hooks/useCancelQueueItem.ts +++ b/invokeai/frontend/web/src/features/queue/hooks/useCancelQueueItem.ts @@ -1,9 +1,9 @@ import { useStore } from '@nanostores/react'; -import { $isConnected } from 'app/hooks/useSocketIO'; import { toast } from 'features/toast/toast'; import { useCallback } from 'react'; import { useTranslation } from 'react-i18next'; import { useCancelQueueItemMutation } from 'services/api/endpoints/queue'; +import { $isConnected } from 'services/events/stores'; export const useCancelQueueItem = (item_id: number) => { const isConnected = useStore($isConnected); diff --git a/invokeai/frontend/web/src/features/queue/hooks/useClearInvocationCache.ts b/invokeai/frontend/web/src/features/queue/hooks/useClearInvocationCache.ts index d177a72f5f4..19ab3cf45fc 100644 --- a/invokeai/frontend/web/src/features/queue/hooks/useClearInvocationCache.ts +++ b/invokeai/frontend/web/src/features/queue/hooks/useClearInvocationCache.ts @@ -1,9 +1,9 @@ import { useStore } from '@nanostores/react'; -import { $isConnected } from 'app/hooks/useSocketIO'; import { toast } from 'features/toast/toast'; import { useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import { useClearInvocationCacheMutation, useGetInvocationCacheStatusQuery } from 'services/api/endpoints/appInfo'; +import { $isConnected } from 'services/events/stores'; export const useClearInvocationCache = () => { const { t } = useTranslation(); diff --git a/invokeai/frontend/web/src/features/queue/hooks/useDisableInvocationCache.ts b/invokeai/frontend/web/src/features/queue/hooks/useDisableInvocationCache.ts index cf71e4bd4ba..703cf8d4cb3 100644 --- a/invokeai/frontend/web/src/features/queue/hooks/useDisableInvocationCache.ts +++ b/invokeai/frontend/web/src/features/queue/hooks/useDisableInvocationCache.ts @@ -1,9 +1,9 @@ import { useStore } from '@nanostores/react'; -import { $isConnected } from 'app/hooks/useSocketIO'; import { toast } from 'features/toast/toast'; import { useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import { useDisableInvocationCacheMutation, useGetInvocationCacheStatusQuery } from 'services/api/endpoints/appInfo'; +import { $isConnected } from 'services/events/stores'; export const useDisableInvocationCache = () => { const { t } = useTranslation(); diff --git a/invokeai/frontend/web/src/features/queue/hooks/useEnableInvocationCache.ts b/invokeai/frontend/web/src/features/queue/hooks/useEnableInvocationCache.ts index 7f28bddd787..2589d507170 100644 --- a/invokeai/frontend/web/src/features/queue/hooks/useEnableInvocationCache.ts +++ b/invokeai/frontend/web/src/features/queue/hooks/useEnableInvocationCache.ts @@ -1,9 +1,9 @@ import { useStore } from '@nanostores/react'; -import { $isConnected } from 'app/hooks/useSocketIO'; import { toast } from 'features/toast/toast'; import { useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import { useEnableInvocationCacheMutation, useGetInvocationCacheStatusQuery } from 'services/api/endpoints/appInfo'; +import { $isConnected } from 'services/events/stores'; export const useEnableInvocationCache = () => { const { t } = useTranslation(); diff --git a/invokeai/frontend/web/src/features/queue/hooks/usePauseProcessor.ts b/invokeai/frontend/web/src/features/queue/hooks/usePauseProcessor.ts index d25c8051e50..d4712ad2b87 100644 --- a/invokeai/frontend/web/src/features/queue/hooks/usePauseProcessor.ts +++ b/invokeai/frontend/web/src/features/queue/hooks/usePauseProcessor.ts @@ -1,9 +1,9 @@ import { useStore } from '@nanostores/react'; -import { $isConnected } from 'app/hooks/useSocketIO'; import { toast } from 'features/toast/toast'; import { useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import { useGetQueueStatusQuery, usePauseProcessorMutation } from 'services/api/endpoints/queue'; +import { $isConnected } from 'services/events/stores'; export const usePauseProcessor = () => { const { t } = useTranslation(); diff --git a/invokeai/frontend/web/src/features/queue/hooks/usePruneQueue.ts b/invokeai/frontend/web/src/features/queue/hooks/usePruneQueue.ts index f9426291bea..09e77e23d6a 100644 --- a/invokeai/frontend/web/src/features/queue/hooks/usePruneQueue.ts +++ b/invokeai/frontend/web/src/features/queue/hooks/usePruneQueue.ts @@ -1,11 +1,11 @@ import { useStore } from '@nanostores/react'; -import { $isConnected } from 'app/hooks/useSocketIO'; import { useAppDispatch } from 'app/store/storeHooks'; import { listCursorChanged, listPriorityChanged } from 'features/queue/store/queueSlice'; import { toast } from 'features/toast/toast'; import { useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import { useGetQueueStatusQuery, usePruneQueueMutation } from 'services/api/endpoints/queue'; +import { $isConnected } from 'services/events/stores'; export const usePruneQueue = () => { const dispatch = useAppDispatch(); diff --git a/invokeai/frontend/web/src/features/queue/hooks/useResumeProcessor.ts b/invokeai/frontend/web/src/features/queue/hooks/useResumeProcessor.ts index 72d787103b3..058a3b2b3e3 100644 --- a/invokeai/frontend/web/src/features/queue/hooks/useResumeProcessor.ts +++ b/invokeai/frontend/web/src/features/queue/hooks/useResumeProcessor.ts @@ -1,9 +1,9 @@ import { useStore } from '@nanostores/react'; -import { $isConnected } from 'app/hooks/useSocketIO'; import { toast } from 'features/toast/toast'; import { useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import { useGetQueueStatusQuery, useResumeProcessorMutation } from 'services/api/endpoints/queue'; +import { $isConnected } from 'services/events/stores'; export const useResumeProcessor = () => { const isConnected = useStore($isConnected); diff --git a/invokeai/frontend/web/src/features/system/components/ProgressBar.tsx b/invokeai/frontend/web/src/features/system/components/ProgressBar.tsx index 6cc834c2f89..dd09e402a44 100644 --- a/invokeai/frontend/web/src/features/system/components/ProgressBar.tsx +++ b/invokeai/frontend/web/src/features/system/components/ProgressBar.tsx @@ -1,11 +1,10 @@ import { Progress } from '@invoke-ai/ui-library'; import { useStore } from '@nanostores/react'; -import { $isConnected } from 'app/hooks/useSocketIO'; import { useCurrentDestination } from 'features/queue/hooks/useCurrentDestination'; import { memo, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import { useGetQueueStatusQuery } from 'services/api/endpoints/queue'; -import { $lastProgressEvent } from 'services/events/setEventListeners'; +import { $isConnected, $lastProgressEvent } from 'services/events/stores'; const ProgressBar = () => { const { t } = useTranslation(); diff --git a/invokeai/frontend/web/src/features/system/components/StatusIndicator.tsx b/invokeai/frontend/web/src/features/system/components/StatusIndicator.tsx index c0c21331890..a02b9199b19 100644 --- a/invokeai/frontend/web/src/features/system/components/StatusIndicator.tsx +++ b/invokeai/frontend/web/src/features/system/components/StatusIndicator.tsx @@ -1,9 +1,9 @@ import { Icon, Tooltip } from '@invoke-ai/ui-library'; import { useStore } from '@nanostores/react'; -import { $isConnected } from 'app/hooks/useSocketIO'; import { memo } from 'react'; import { useTranslation } from 'react-i18next'; import { PiWarningBold } from 'react-icons/pi'; +import { $isConnected } from 'services/events/stores'; const StatusIndicator = () => { const isConnected = useStore($isConnected); diff --git a/invokeai/frontend/web/src/services/api/endpoints/queue.ts b/invokeai/frontend/web/src/services/api/endpoints/queue.ts index 7c0271e0437..de4a0614548 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/queue.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/queue.ts @@ -78,15 +78,17 @@ export const queueApi = api.injectEndpoints({ resetListQueryData(dispatch); /** * When a batch is enqueued, we need to update the queue status. While it might be templting to invalidate the - * `SessionQueueStatus` tag here, this can introduce a race condition: + * `SessionQueueStatus` tag here, this can introduce a race condition when the queue item executes quickly: * - * - Enqueue batch via this query + * - Enqueue via this query * - On success, we invalidate `SessionQueueStatus` tag - network request sent to server - * - Network request received, response preparing/sending - * - A queue item status changes and we receive a socket event w/ updated status - * - Update status optimistically in socket handler - * - Tag invalidation response received, but by now its payload has stale data - * - Stale data is written to the cache + * - The server gets the queue status request and responds, but this takes some time... in the meantime: + * - The new queue item starts executing, and we receive a socket queue item status changed event + * - We optimistically update the queue status in the queue item status changed socket handler + * - At this point, the queue status is correct + * - Finally, we get the queue status from the tag invalidation request - but it's reporting the queue status + * from _before_ the last queue event + * - The queue status is now incorrect! * * Ok, what if we just never did optimistic updates and invalidated the tag in the queue event handlers instead? * It's much simpler that way, but it causes a lot of network requests - 3 per queue item, as it moves from @@ -94,7 +96,18 @@ export const queueApi = api.injectEndpoints({ * * We can do a bit of extra work here, incrementing the pending and total counts in the queue status, and do * similar optimistic updates in the socket handler. Because this optimistic update runs immediately after the - * enqueue network request, it should always occur _before_ the next queue event, so no race condition. + * enqueue network request, it should always occur _before_ the next queue event, so no race condition: + * + * - Enqueue batch via this query + * - On success, optimistically update - this happens immediately on the HTTP OK - before the next queue event + * - At this point, the queue status is correct + * - A queue item status changes and we receive a socket event w/ updated status + * - Update status optimistically in socket handler + * - Queue status is still correct + * + * This problem occurs most commonly with canvas filters like Canny edge detection, which are single-node + * graphs that execute very quickly. Image generation graphs take long enough to not trigger this race + * condition - even when all nodes are cached on the server. */ dispatch( queueApi.util.updateQueryData('getQueueStatus', undefined, (draft) => { diff --git a/invokeai/frontend/web/src/services/events/onInvocationComplete.tsx b/invokeai/frontend/web/src/services/events/onInvocationComplete.tsx index 32720911d5d..c1a40934c44 100644 --- a/invokeai/frontend/web/src/services/events/onInvocationComplete.tsx +++ b/invokeai/frontend/web/src/services/events/onInvocationComplete.tsx @@ -17,13 +17,9 @@ const isCanvasOutputNode = (data: S['InvocationCompleteEvent']) => { return data.invocation_source_id.split(':')[0] === 'canvas_output'; }; -export const buildOnInvocationComplete = ( - getState: () => RootState, - dispatch: AppDispatch, - nodeTypeDenylist: string[], - setLastProgressEvent: (event: S['InvocationDenoiseProgressEvent'] | null) => void, - setLastCanvasProgressEvent: (event: S['InvocationDenoiseProgressEvent'] | null) => void -) => { +const nodeTypeDenylist = ['load_image', 'image']; + +export const buildOnInvocationComplete = (getState: () => RootState, dispatch: AppDispatch) => { const addImageToGallery = (imageDTO: ImageDTO) => { if (imageDTO.is_intermediate) { return; @@ -113,7 +109,7 @@ export const buildOnInvocationComplete = ( } }; - const handleOriginGeneration = async (data: S['InvocationCompleteEvent']) => { + const handleOriginCanvas = async (data: S['InvocationCompleteEvent']) => { const imageDTO = await getResultImageDTO(data); if (!imageDTO) { @@ -121,6 +117,7 @@ export const buildOnInvocationComplete = ( } if (data.destination === 'canvas') { + // TODO(psyche): Can/should we let canvas handle this itself? if (isCanvasOutputNode(data)) { if (data.result.type === 'canvas_v2_mask_and_crop_output') { const { offset_x, offset_y } = data.result; @@ -131,8 +128,7 @@ export const buildOnInvocationComplete = ( addImageToGallery(imageDTO); } } else if (!imageDTO.is_intermediate) { - // session.mode === 'generate' - setLastCanvasProgressEvent(null); + // Desintaion is gallery addImageToGallery(imageDTO); } }; @@ -151,15 +147,17 @@ export const buildOnInvocationComplete = ( `Invocation complete (${data.invocation.type}, ${data.invocation_source_id})` ); - // Update the node execution states - the image output is handled below + if (nodeTypeDenylist.includes(data.invocation.type)) { + log.trace('Skipping node type denylisted'); + return; + } + if (data.origin === 'workflows') { await handleOriginWorkflows(data); - } else if (data.origin === 'generation') { - await handleOriginGeneration(data); + } else if (data.origin === 'canvas') { + await handleOriginCanvas(data); } else { await handleOriginOther(data); } - - setLastProgressEvent(null); }; }; diff --git a/invokeai/frontend/web/src/services/events/setEventListeners.tsx b/invokeai/frontend/web/src/services/events/setEventListeners.tsx index f99365da967..8a2f645bf91 100644 --- a/invokeai/frontend/web/src/services/events/setEventListeners.tsx +++ b/invokeai/frontend/web/src/services/events/setEventListeners.tsx @@ -1,48 +1,40 @@ import { ExternalLink } from '@invoke-ai/ui-library'; -import { createAction } from '@reduxjs/toolkit'; import { logger } from 'app/logging/logger'; -import { getIsCancelled } from 'app/store/middleware/listenerMiddleware/listeners/cancellationsListeners'; +import { socketConnected } from 'app/store/middleware/listenerMiddleware/listeners/socketConnected'; import { $baseUrl } from 'app/store/nanostores/baseUrl'; import { $bulkDownloadId } from 'app/store/nanostores/bulkDownloadId'; import { $queueId } from 'app/store/nanostores/queueId'; -import type { AppDispatch, RootState } from 'app/store/store'; +import type { AppStore } from 'app/store/store'; import type { SerializableObject } from 'common/types'; import { deepClone } from 'common/util/deepClone'; -import { $lastCanvasProgressEvent } from 'features/controlLayers/store/canvasSlice'; import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState'; import { zNodeStatus } from 'features/nodes/types/invocation'; import ErrorToastDescription, { getTitleFromErrorType } from 'features/toast/ErrorToastDescription'; import { toast } from 'features/toast/toast'; import { t } from 'i18next'; import { forEach } from 'lodash-es'; -import { atom, computed } from 'nanostores'; import { api, LIST_TAG } from 'services/api'; import { modelsApi } from 'services/api/endpoints/models'; import { queueApi, queueItemsAdapter } from 'services/api/endpoints/queue'; -import type { S } from 'services/api/types'; import { buildOnInvocationComplete } from 'services/events/onInvocationComplete'; import type { ClientToServerEvents, ServerToClientEvents } from 'services/events/types'; import type { Socket } from 'socket.io-client'; -export const socketConnected = createAction('socket/connected'); +import { $lastProgressEvent } from './stores'; const log = logger('events'); type SetEventListenersArg = { socket: Socket; - dispatch: AppDispatch; - getState: () => RootState; + store: AppStore; setIsConnected: (isConnected: boolean) => void; }; const selectModelInstalls = modelsApi.endpoints.listModelInstalls.select(); -const nodeTypeDenylist = ['load_image', 'image']; -export const $lastProgressEvent = atom(null); -export const $hasProgress = computed($lastProgressEvent, (val) => Boolean(val)); -export const $progressImage = computed($lastProgressEvent, (val) => val?.progress_image ?? null); -export const $isProgressFromCanvas = computed($lastProgressEvent, (val) => val?.destination === 'canvas'); -export const setEventListeners = ({ socket, dispatch, getState, setIsConnected }: SetEventListenersArg) => { +export const setEventListeners = ({ socket, store, setIsConnected }: SetEventListenersArg) => { + const { dispatch, getState } = store; + socket.on('connect', () => { log.debug('Connected'); setIsConnected(true); @@ -54,14 +46,12 @@ export const setEventListeners = ({ socket, dispatch, getState, setIsConnected } socket.emit('subscribe_bulk_download', { bulk_download_id }); } $lastProgressEvent.set(null); - $lastCanvasProgressEvent.set(null); }); socket.on('connect_error', (error) => { log.debug('Connect error'); setIsConnected(false); $lastProgressEvent.set(null); - $lastCanvasProgressEvent.set(null); if (error && error.message) { const data: string | undefined = (error as unknown as { data: string | undefined }).data; if (data === 'ERR_UNAUTHENTICATED') { @@ -78,7 +68,6 @@ export const setEventListeners = ({ socket, dispatch, getState, setIsConnected } socket.on('disconnect', () => { log.debug('Disconnected'); $lastProgressEvent.set(null); - $lastCanvasProgressEvent.set(null); setIsConnected(false); }); @@ -93,24 +82,7 @@ export const setEventListeners = ({ socket, dispatch, getState, setIsConnected } }); socket.on('invocation_denoise_progress', (data) => { - const { - invocation_source_id, - invocation, - step, - total_steps, - progress_image, - origin, - destination, - percentage, - session_id, - batch_id, - } = data; - - if (getIsCancelled({ session_id, batch_id, destination })) { - // Do not update the progress if this session has been cancelled. This prevents a race condition where we get a - // progress update after the session has been cancelled. - return; - } + const { invocation_source_id, invocation, step, total_steps, progress_image, origin, percentage } = data; log.trace( { data } as SerializableObject, @@ -128,11 +100,6 @@ export const setEventListeners = ({ socket, dispatch, getState, setIsConnected } upsertExecutionState(nes.nodeId, nes); } } - - // This event is only relevant for the canvas - if (destination === 'canvas') { - $lastCanvasProgressEvent.set(data); - } }); socket.on('invocation_error', (data) => { @@ -152,13 +119,7 @@ export const setEventListeners = ({ socket, dispatch, getState, setIsConnected } } }); - const onInvocationComplete = buildOnInvocationComplete( - getState, - dispatch, - nodeTypeDenylist, - $lastProgressEvent.set, - $lastCanvasProgressEvent.set - ); + const onInvocationComplete = buildOnInvocationComplete(getState, dispatch); socket.on('invocation_complete', onInvocationComplete); socket.on('model_load_started', (data) => { @@ -379,7 +340,6 @@ export const setEventListeners = ({ socket, dispatch, getState, setIsConnected } error_type, error_message, error_traceback, - origin, } = data; log.debug({ data }, `Queue item ${item_id} status updated: ${status}`); @@ -402,12 +362,17 @@ export const setEventListeners = ({ socket, dispatch, getState, setIsConnected } }) ); - // Update the queue status (we do not get the processor status here) + // Optimistic update of the queue status. We prefer to do an optimistic update over tag invalidation due to the + // frequency of `queue_item_status_changed` events. dispatch( queueApi.util.updateQueryData('getQueueStatus', undefined, (draft) => { if (!draft) { return; } + /** + * Update the queue status - though the getQueueStatus query response contains the processor status (i.e. running + * or paused), that data is not provided in the event we are handling. So we can only update `draft.queue` here. + */ Object.assign(draft.queue, queue_status); }) ); @@ -442,11 +407,6 @@ export const setEventListeners = ({ socket, dispatch, getState, setIsConnected } } else if (status === 'failed' && error_type) { const isLocal = getState().config.isLocal ?? true; const sessionId = session_id; - $lastProgressEvent.set(null); - - if (origin === 'canvas') { - $lastCanvasProgressEvent.set(null); - } toast({ id: `INVOCATION_ERROR_${error_type}`, @@ -463,13 +423,6 @@ export const setEventListeners = ({ socket, dispatch, getState, setIsConnected } /> ), }); - } else if (status === 'canceled') { - $lastProgressEvent.set(null); - if (origin === 'canvas') { - $lastCanvasProgressEvent.set(null); - } - } else if (status === 'completed') { - $lastProgressEvent.set(null); } }); diff --git a/invokeai/frontend/web/src/services/events/stores.ts b/invokeai/frontend/web/src/services/events/stores.ts new file mode 100644 index 00000000000..9144cb957ef --- /dev/null +++ b/invokeai/frontend/web/src/services/events/stores.ts @@ -0,0 +1,12 @@ +import { atom, computed, map } from 'nanostores'; +import type { S } from 'services/api/types'; +import type { AppSocket } from 'services/events/types'; +import type { ManagerOptions, SocketOptions } from 'socket.io-client'; + +export const $socket = atom(null); +export const $socketOptions = map>({}); +export const $isConnected = atom(false); +export const $lastProgressEvent = atom(null); +export const $hasProgress = computed($lastProgressEvent, (val) => Boolean(val)); +export const $progressImage = computed($lastProgressEvent, (val) => val?.progress_image ?? null); +export const $isProgressFromCanvas = computed($lastProgressEvent, (val) => val?.destination === 'canvas'); diff --git a/invokeai/frontend/web/src/services/events/types.ts b/invokeai/frontend/web/src/services/events/types.ts index 4d11490a8a6..714d94eb77f 100644 --- a/invokeai/frontend/web/src/services/events/types.ts +++ b/invokeai/frontend/web/src/services/events/types.ts @@ -1,4 +1,5 @@ import type { S } from 'services/api/types'; +import type { Socket } from 'socket.io-client'; type ClientEmitSubscribeQueue = { queue_id: string }; type ClientEmitUnsubscribeQueue = ClientEmitSubscribeQueue; @@ -40,3 +41,5 @@ export type ClientToServerEvents = { subscribe_bulk_download: (payload: ClientEmitSubscribeBulkDownload) => void; unsubscribe_bulk_download: (payload: ClientEmitUnsubscribeBulkDownload) => void; }; + +export type AppSocket = Socket; diff --git a/invokeai/frontend/web/src/app/hooks/useSocketIO.ts b/invokeai/frontend/web/src/services/events/useSocketIO.ts similarity index 74% rename from invokeai/frontend/web/src/app/hooks/useSocketIO.ts rename to invokeai/frontend/web/src/services/events/useSocketIO.ts index 2b39f24cc50..9ba33695647 100644 --- a/invokeai/frontend/web/src/app/hooks/useSocketIO.ts +++ b/invokeai/frontend/web/src/services/events/useSocketIO.ts @@ -3,14 +3,17 @@ import { $authToken } from 'app/store/nanostores/authToken'; import { $baseUrl } from 'app/store/nanostores/baseUrl'; import { $isDebugging } from 'app/store/nanostores/isDebugging'; import { useAppStore } from 'app/store/nanostores/store'; +import { useAssertSingleton } from 'common/hooks/useAssertSingleton'; import type { MapStore } from 'nanostores'; -import { atom, map } from 'nanostores'; import { useEffect, useMemo } from 'react'; +import { selectQueueStatus } from 'services/api/endpoints/queue'; import { setEventListeners } from 'services/events/setEventListeners'; -import type { ClientToServerEvents, ServerToClientEvents } from 'services/events/types'; -import type { ManagerOptions, Socket, SocketOptions } from 'socket.io-client'; +import type { AppSocket } from 'services/events/types'; +import type { ManagerOptions, SocketOptions } from 'socket.io-client'; import { io } from 'socket.io-client'; +import { $isConnected, $lastProgressEvent, $socket, $socketOptions } from './stores'; + // Inject socket options and url into window for debugging declare global { interface Window { @@ -18,19 +21,12 @@ declare global { } } -export type AppSocket = Socket; - -export const $socket = atom(null); -export const $socketOptions = map>({}); - -const $isSocketInitialized = atom(false); -export const $isConnected = atom(false); - /** * Initializes the socket.io connection and sets up event listeners. */ export const useSocketIO = () => { - const { dispatch, getState } = useAppStore(); + useAssertSingleton('useSocketIO'); + const store = useAppStore(); const baseUrl = useStore($baseUrl); const authToken = useStore($authToken); const addlSocketOptions = useStore($socketOptions); @@ -61,14 +57,11 @@ export const useSocketIO = () => { }, [authToken, addlSocketOptions, baseUrl]); useEffect(() => { - if ($isSocketInitialized.get()) { - // Singleton! - return; - } - const socket: AppSocket = io(socketUrl, socketOptions); $socket.set(socket); - setEventListeners({ socket, dispatch, getState, setIsConnected: $isConnected.set }); + + setEventListeners({ socket, store, setIsConnected: $isConnected.set }); + socket.connect(); if ($isDebugging.get() || import.meta.env.MODE === 'development') { @@ -78,7 +71,12 @@ export const useSocketIO = () => { console.log('Socket initialized', socket); } - $isSocketInitialized.set(true); + const unsubscribeQueueStatusListener = store.subscribe(() => { + const queueStatusData = selectQueueStatus(store.getState()).data; + if (!queueStatusData || queueStatusData.queue.in_progress === 0) { + $lastProgressEvent.set(null); + } + }); return () => { if ($isDebugging.get() || import.meta.env.MODE === 'development') { @@ -87,8 +85,8 @@ export const useSocketIO = () => { /* eslint-disable-next-line no-console */ console.log('Socket teardown', socket); } + unsubscribeQueueStatusListener(); socket.disconnect(); - $isSocketInitialized.set(false); }; - }, [dispatch, getState, socketOptions, socketUrl]); + }, [socketOptions, socketUrl, store]); };