Skip to content

Commit

Permalink
feat(ui): rework progress event handling
Browse files Browse the repository at this point in the history
- Canvas manages its own progress socket event listeners and progress event data.
- Remove cancellations listener jank.
- Dip into low-level redux subscription API to watch for queue status changes, clearing the last "global" progress event when the queue has nothing in progress. Could also do this in a useEffect I guess.
- Had to shuffle some things around to prevent circular imports, so there are a lot of tiny changes here.
  • Loading branch information
psychedelicious authored and hipsterusername committed Sep 18, 2024
1 parent b08a66e commit 7db4d26
Show file tree
Hide file tree
Showing 40 changed files with 161 additions and 305 deletions.
2 changes: 1 addition & 1 deletion invokeai/frontend/web/src/app/components/App.tsx
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import { Box, useGlobalModifiersInit } from '@invoke-ai/ui-library';
import { useSocketIO } from 'app/hooks/useSocketIO';
import { useSyncQueueStatus } from 'app/hooks/useSyncQueueStatus';
import { useLogger } from 'app/logging/useLogger';
import { appStarted } from 'app/store/middleware/listenerMiddleware/listeners/appStarted';
Expand Down Expand Up @@ -31,6 +30,7 @@ import { size } from 'lodash-es';
import { memo, useCallback, useEffect } from 'react';
import { ErrorBoundary } from 'react-error-boundary';
import { useGetOpenAPISchemaQuery } from 'services/api/endpoints/appInfo';
import { useSocketIO } from 'services/events/useSocketIO';

import AppErrorBoundaryFallback from './AppErrorBoundaryFallback';
import PreselectedImage from './PreselectedImage';
Expand Down
2 changes: 1 addition & 1 deletion invokeai/frontend/web/src/app/components/InvokeAIUI.tsx
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import 'i18n';

import type { Middleware } from '@reduxjs/toolkit';
import { $socketOptions } from 'app/hooks/useSocketIO';
import { $authToken } from 'app/store/nanostores/authToken';
import { $baseUrl } from 'app/store/nanostores/baseUrl';
import { $customNavComponent } from 'app/store/nanostores/customNavComponent';
Expand All @@ -24,6 +23,7 @@ import type { PropsWithChildren, ReactNode } from 'react';
import React, { lazy, memo, useEffect, useMemo } from 'react';
import { Provider } from 'react-redux';
import { addMiddleware, resetMiddlewares } from 'redux-dynamic-middlewares';
import { $socketOptions } from 'services/events/stores';
import type { ManagerOptions, SocketOptions } from 'socket.io-client';

const App = lazy(() => import('./App'));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import { addBatchEnqueuedListener } from 'app/store/middleware/listenerMiddlewar
import { addDeleteBoardAndImagesFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/boardAndImagesDeleted';
import { addBoardIdSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/boardIdSelected';
import { addBulkDownloadListeners } from 'app/store/middleware/listenerMiddleware/listeners/bulkDownload';
import { addCancellationsListeners } from 'app/store/middleware/listenerMiddleware/listeners/cancellationsListeners';
import { addEnqueueRequestedLinear } from 'app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear';
import { addEnqueueRequestedNodes } from 'app/store/middleware/listenerMiddleware/listeners/enqueueRequestedNodes';
import { addGalleryImageClickedListener } from 'app/store/middleware/listenerMiddleware/listeners/galleryImageClicked';
Expand Down Expand Up @@ -73,15 +72,6 @@ addAnyEnqueuedListener(startAppListening);
addBatchEnqueuedListener(startAppListening);

// Canvas actions
// addCanvasSavedToGalleryListener(startAppListening);
// addCanvasMaskSavedToGalleryListener(startAppListening);
// addCanvasImageToControlNetListener(startAppListening);
// addCanvasMaskToControlNetListener(startAppListening);
// addCanvasDownloadedAsImageListener(startAppListening);
// addCanvasCopiedToClipboardListener(startAppListening);
// addCanvasMergedListener(startAppListening);
// addStagingAreaImageSavedListener(startAppListening);
// addCommitStagingAreaImageListener(startAppListening);
addStagingListeners(startAppListening);

// Socket.IO
Expand Down Expand Up @@ -121,6 +111,3 @@ addAdHocPostProcessingRequestedListener(startAppListening);
addDynamicPromptsListener(startAppListening);

addSetDefaultSettingsListener(startAppListening);
// addControlAdapterPreprocessor(startAppListening);

addCancellationsListeners(startAppListening);

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
const destination = state.canvasSettings.sendToCanvas ? 'canvas' : 'gallery';

const prepareBatchResult = withResult(() =>
prepareLinearUIBatch(state, g, prepend, noise, posCond, 'generation', destination)
prepareLinearUIBatch(state, g, prepend, noise, posCond, 'canvas', destination)
);

if (prepareBatchResult.isErr()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ import { getPresetModifiedPrompts } from 'features/nodes/util/graph/graphBuilder
import { activeStylePresetIdChanged } from 'features/stylePresets/store/stylePresetSlice';
import { stylePresetsApi } from 'services/api/endpoints/stylePresets';
import { utilitiesApi } from 'services/api/endpoints/utilities';
import { socketConnected } from 'services/events/setEventListeners';

import { socketConnected } from './socketConnected';

const matcher = isAnyOf(
positivePromptChanged,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { createAction } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { $baseUrl } from 'app/store/nanostores/baseUrl';
Expand All @@ -6,11 +7,11 @@ import { atom } from 'nanostores';
import { api } from 'services/api';
import { modelsApi } from 'services/api/endpoints/models';
import { queueApi, selectQueueStatus } from 'services/api/endpoints/queue';
import { socketConnected } from 'services/events/setEventListeners';

const log = logger('events');

const $isFirstConnection = atom(true);
export const socketConnected = createAction('socket/connected');

export const addSocketConnectedEventListener = (startAppListening: AppStartListening) => {
startAppListening({
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import { useStore } from '@nanostores/react';
import { $isConnected } from 'app/hooks/useSocketIO';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { $true } from 'app/store/nanostores/util';
import { useAppSelector } from 'app/store/storeHooks';
Expand All @@ -21,6 +20,7 @@ import i18n from 'i18next';
import { forEach, upperFirst } from 'lodash-es';
import { useMemo } from 'react';
import { getConnectedEdges } from 'reactflow';
import { $isConnected } from 'services/events/stores';

const LAYER_TYPE_TO_TKEY = {
reference_image: 'controlLayers.referenceImage',
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import { Flex, useShiftModifier } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { skipToken } from '@reduxjs/toolkit/query';
import { $isConnected } from 'app/hooks/useSocketIO';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIDndImage from 'common/components/IAIDndImage';
import IAIDndImageIcon from 'common/components/IAIDndImageIcon';
Expand All @@ -16,6 +15,7 @@ import { useTranslation } from 'react-i18next';
import { PiArrowCounterClockwiseBold, PiRulerBold } from 'react-icons/pi';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import type { ImageDTO, PostUploadAction } from 'services/api/types';
import { $isConnected } from 'services/events/stores';

type Props = {
image: ImageWithDims | null;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import { useStore } from '@nanostores/react';
import { $socket } from 'app/hooks/useSocketIO';
import { logger } from 'app/logging/logger';
import { useAppStore } from 'app/store/nanostores/store';
import { useAssertSingleton } from 'common/hooks/useAssertSingleton';
import { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { $canvasManager } from 'features/controlLayers/store/canvasSlice';
import Konva from 'konva';
import { useLayoutEffect, useState } from 'react';
import { $socket } from 'services/events/stores';
import { useDevicePixelRatio } from 'use-device-pixel-ratio';

const log = logger('canvas');
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import type { AppSocket } from 'app/hooks/useSocketIO';
import { logger } from 'app/logging/logger';
import type { AppStore } from 'app/store/store';
import type { SerializableObject } from 'common/types';
Expand Down Expand Up @@ -31,6 +30,7 @@ import Konva from 'konva';
import type { Atom } from 'nanostores';
import { computed } from 'nanostores';
import type { Logger } from 'roarr';
import type { AppSocket } from 'services/events/types';
import { assert } from 'tsafe';

import { CanvasBackgroundModule } from './CanvasBackgroundModule';
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase'
import { getPrefixedId, loadImage } from 'features/controlLayers/konva/util';
import { selectShowProgressOnCanvas } from 'features/controlLayers/store/canvasSettingsSlice';
import Konva from 'konva';
import { atom } from 'nanostores';
import type { Logger } from 'roarr';
import { selectCanvasQueueCounts } from 'services/api/endpoints/queue';
import type { S } from 'services/api/types';

export class CanvasProgressImageModule extends CanvasModuleBase {
readonly type = 'progress_image';
Expand All @@ -23,7 +26,8 @@ export class CanvasProgressImageModule extends CanvasModuleBase {
imageElement: HTMLImageElement | null = null;

subscriptions = new Set<() => void>();

$lastProgressEvent = atom<S['InvocationDenoiseProgressEvent'] | null>(null);
hasActiveGeneration: boolean = false;
mutex: Mutex = new Mutex();

constructor(manager: CanvasManager) {
Expand All @@ -41,19 +45,58 @@ export class CanvasProgressImageModule extends CanvasModuleBase {
image: null,
};

this.subscriptions.add(this.manager.stateApi.$lastCanvasProgressEvent.listen(this.render));
this.subscriptions.add(this.manager.stagingArea.$shouldShowStagedImage.listen(this.render));
this.subscriptions.add(this.manager.stateApi.createStoreSubscription(selectShowProgressOnCanvas, this.render));
this.subscriptions.add(this.setSocketEventListeners());
this.subscriptions.add(
this.manager.stateApi.createStoreSubscription(selectCanvasQueueCounts, ({ data }) => {
if (data && (data.in_progress > 0 || data.pending > 0)) {
this.hasActiveGeneration = true;
} else {
this.hasActiveGeneration = false;
this.$lastProgressEvent.set(null);
}
})
);
this.subscriptions.add(this.$lastProgressEvent.listen(this.render));
}

setSocketEventListeners = (): (() => void) => {
const progressListener = (data: S['InvocationDenoiseProgressEvent']) => {
if (data.destination !== 'canvas') {
return;
}
if (!this.hasActiveGeneration) {
return;
}
this.$lastProgressEvent.set(data);
};

const clearProgress = () => {
this.$lastProgressEvent.set(null);
};

this.manager.socket.on('invocation_denoise_progress', progressListener);
this.manager.socket.on('connect', clearProgress);
this.manager.socket.on('connect_error', clearProgress);
this.manager.socket.on('disconnect', clearProgress);

return () => {
this.manager.socket.off('invocation_denoise_progress', progressListener);
this.manager.socket.off('connect', clearProgress);
this.manager.socket.off('connect_error', clearProgress);
this.manager.socket.off('disconnect', clearProgress);
};
};

getNodes = () => {
return [this.konva.group];
};

render = async () => {
const release = await this.mutex.acquire();

const event = this.manager.stateApi.$lastCanvasProgressEvent.get();
const event = this.$lastProgressEvent.get();
const showProgressOnCanvas = this.manager.stateApi.runSelector(selectShowProgressOnCanvas);

if (!event || !showProgressOnCanvas) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ export class CanvasStagingAreaModule extends CanvasModuleBase {

if (!this.image.isLoading && !this.image.isError) {
await this.image.update({ ...this.image.state, image: imageDTOToImageWithDims(imageDTO) }, true);
this.manager.stateApi.$lastCanvasProgressEvent.set(null);
this.manager.progressImage.$lastProgressEvent.set(null);
}
this.image.konva.group.visible(shouldShowStagedImage);
} else {
Expand Down
Loading

0 comments on commit 7db4d26

Please sign in to comment.