diff --git a/invokeai/app/services/events/events_base.py b/invokeai/app/services/events/events_base.py index a5ac2fafa97..e7311b0cba1 100644 --- a/invokeai/app/services/events/events_base.py +++ b/invokeai/app/services/events/events_base.py @@ -34,6 +34,7 @@ QueueClearedEvent, QueueItemStatusChangedEvent, ) +from invokeai.app.services.images.images_common import ImageDTO if TYPE_CHECKING: from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput @@ -214,9 +215,9 @@ def emit_bulk_upload_progress(self, bulk_upload_id: str, completed: int, total: """Emitted when a bulk image upload is started""" self.dispatch(BulkUploadProgressEvent.build(bulk_upload_id, completed, total)) - def emit_bulk_upload_complete(self, bulk_upload_id: str, total: int) -> None: + def emit_bulk_upload_complete(self, bulk_upload_id: str, total: int, image_DTO: ImageDTO) -> None: """Emitted when a bulk image upload is complete""" - self.dispatch(BulkUploadCompletedEvent.build(bulk_upload_id, total=total)) + self.dispatch(BulkUploadCompletedEvent.build(bulk_upload_id, total=total, image_DTO=image_DTO)) def emit_bulk_upload_error(self, bulk_upload_id: str, error: str) -> None: """Emitted when a bulk image upload has an error""" diff --git a/invokeai/app/services/events/events_common.py b/invokeai/app/services/events/events_common.py index 1b4e0748f86..b02cc2625e7 100644 --- a/invokeai/app/services/events/events_common.py +++ b/invokeai/app/services/events/events_common.py @@ -4,6 +4,7 @@ from fastapi_events.registry.payload_schema import registry as payload_schema from pydantic import BaseModel, ConfigDict, Field +from invokeai.app.services.images.images_common import ImageDTO from invokeai.app.services.session_processor.session_processor_common import ProgressImage from invokeai.app.services.session_queue.session_queue_common import ( QUEUE_ITEM_STATUS, @@ -656,14 +657,11 @@ class BulkUploadCompletedEvent(BulkUploadEventBase): __event_name__ = "bulk_upload_completed" total: int = Field(description="The total numberof images") + image_DTO: ImageDTO = Field(description="An image from the upload so client can refetch correctly") @classmethod - def build( - cls, - bulk_upload_id: str, - total: int, - ) -> "BulkUploadCompletedEvent": - return cls(bulk_upload_id=bulk_upload_id, total=total) + def build(cls, bulk_upload_id: str, total: int, image_DTO: ImageDTO) -> "BulkUploadCompletedEvent": + return cls(bulk_upload_id=bulk_upload_id, total=total, image_DTO=image_DTO) @payload_schema.register diff --git a/invokeai/app/services/images/images_default.py b/invokeai/app/services/images/images_default.py index 0b122e94202..c7d278416df 100644 --- a/invokeai/app/services/images/images_default.py +++ b/invokeai/app/services/images/images_default.py @@ -186,7 +186,9 @@ def process_and_save_image(image_data: ImageBulkUploadData): self.__invoker.services.events.emit_bulk_upload_error(bulk_upload_id=bulk_upload_id, error=str(e)) pbar.close() - self.__invoker.services.events.emit_bulk_upload_complete(bulk_upload_id=bulk_upload_id, total=len(images_DTOs)) + self.__invoker.services.events.emit_bulk_upload_complete( + bulk_upload_id=bulk_upload_id, total=len(images_DTOs), image_DTO=images_DTOs[0] + ) def update( self, diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index 1dbc015d85b..95a2d01a533 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -2600,6 +2600,8 @@ export type components = { * @description The total numberof images */ total: number; + /** @description An image from the upload so client can refetch correctly */ + image_DTO: components["schemas"]["ImageDTO"]; }; /** * BulkUploadErrorEvent diff --git a/invokeai/frontend/web/src/services/events/setEventListeners.tsx b/invokeai/frontend/web/src/services/events/setEventListeners.tsx index 93c686d9a36..83a1e731729 100644 --- a/invokeai/frontend/web/src/services/events/setEventListeners.tsx +++ b/invokeai/frontend/web/src/services/events/setEventListeners.tsx @@ -15,8 +15,11 @@ import { t } from 'i18next'; import { forEach, isNil, round } from 'lodash-es'; import type { ApiTagDescription } from 'services/api'; import { api, LIST_TAG } from 'services/api'; +import { boardsApi } from 'services/api/endpoints/boards'; +import { imagesApi } from 'services/api/endpoints/images'; import { modelsApi } from 'services/api/endpoints/models'; import { queueApi, queueItemsAdapter } from 'services/api/endpoints/queue'; +import { getCategories, getListImagesUrl } from 'services/api/util'; import { buildOnInvocationComplete } from 'services/events/onInvocationComplete'; import type { ClientToServerEvents, ServerToClientEvents } from 'services/events/types'; import type { Socket } from 'socket.io-client'; @@ -529,7 +532,7 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis socket.on('bulk_upload_completed', (data) => { log.debug({ data }, 'Bulk gallery upload ready'); - const { total } = data; + const { total, image_DTO: imageDTO } = data; toast({ id: 'BULK_UPLOAD', @@ -545,6 +548,26 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis ), duration: null, }); + + // update the total images for the board + dispatch( + boardsApi.util.updateQueryData('getBoardImagesTotal', imageDTO.board_id ?? 'none', (draft) => { + draft.total += 1; + }) + ); + + dispatch( + imagesApi.util.invalidateTags([ + { type: 'Board', id: imageDTO.board_id ?? 'none' }, + { + type: 'ImageList', + id: getListImagesUrl({ + board_id: imageDTO.board_id ?? 'none', + categories: getCategories(imageDTO), + }), + }, + ]) + ); }); socket.on('bulk_upload_error', (data) => {