Skip to content

Commit

Permalink
send one imageDTO back with complete event so the UI can refresh the …
Browse files Browse the repository at this point in the history
…correct data
  • Loading branch information
Mary Hipp authored and Mary Hipp committed Oct 22, 2024
1 parent 90ad720 commit e6f80ca
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 10 deletions.
5 changes: 3 additions & 2 deletions invokeai/app/services/events/events_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
QueueClearedEvent,
QueueItemStatusChangedEvent,
)
from invokeai.app.services.images.images_common import ImageDTO

if TYPE_CHECKING:
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
Expand Down Expand Up @@ -214,9 +215,9 @@ def emit_bulk_upload_progress(self, bulk_upload_id: str, completed: int, total:
"""Emitted when a bulk image upload is started"""
self.dispatch(BulkUploadProgressEvent.build(bulk_upload_id, completed, total))

def emit_bulk_upload_complete(self, bulk_upload_id: str, total: int) -> None:
def emit_bulk_upload_complete(self, bulk_upload_id: str, total: int, image_DTO: ImageDTO) -> None:
"""Emitted when a bulk image upload is complete"""
self.dispatch(BulkUploadCompletedEvent.build(bulk_upload_id, total=total))
self.dispatch(BulkUploadCompletedEvent.build(bulk_upload_id, total=total, image_DTO=image_DTO))

def emit_bulk_upload_error(self, bulk_upload_id: str, error: str) -> None:
"""Emitted when a bulk image upload has an error"""
Expand Down
10 changes: 4 additions & 6 deletions invokeai/app/services/events/events_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from fastapi_events.registry.payload_schema import registry as payload_schema
from pydantic import BaseModel, ConfigDict, Field

from invokeai.app.services.images.images_common import ImageDTO
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.app.services.session_queue.session_queue_common import (
QUEUE_ITEM_STATUS,
Expand Down Expand Up @@ -656,14 +657,11 @@ class BulkUploadCompletedEvent(BulkUploadEventBase):
__event_name__ = "bulk_upload_completed"

total: int = Field(description="The total numberof images")
image_DTO: ImageDTO = Field(description="An image from the upload so client can refetch correctly")

@classmethod
def build(
cls,
bulk_upload_id: str,
total: int,
) -> "BulkUploadCompletedEvent":
return cls(bulk_upload_id=bulk_upload_id, total=total)
def build(cls, bulk_upload_id: str, total: int, image_DTO: ImageDTO) -> "BulkUploadCompletedEvent":
return cls(bulk_upload_id=bulk_upload_id, total=total, image_DTO=image_DTO)


@payload_schema.register
Expand Down
4 changes: 3 additions & 1 deletion invokeai/app/services/images/images_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,9 @@ def process_and_save_image(image_data: ImageBulkUploadData):
self.__invoker.services.events.emit_bulk_upload_error(bulk_upload_id=bulk_upload_id, error=str(e))

pbar.close()
self.__invoker.services.events.emit_bulk_upload_complete(bulk_upload_id=bulk_upload_id, total=len(images_DTOs))
self.__invoker.services.events.emit_bulk_upload_complete(
bulk_upload_id=bulk_upload_id, total=len(images_DTOs), image_DTO=images_DTOs[0]
)

def update(
self,
Expand Down
2 changes: 2 additions & 0 deletions invokeai/frontend/web/src/services/api/schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2600,6 +2600,8 @@ export type components = {
* @description The total numberof images
*/
total: number;
/** @description An image from the upload so client can refetch correctly */
image_DTO: components["schemas"]["ImageDTO"];
};
/**
* BulkUploadErrorEvent
Expand Down
25 changes: 24 additions & 1 deletion invokeai/frontend/web/src/services/events/setEventListeners.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@ import { t } from 'i18next';
import { forEach, isNil, round } from 'lodash-es';
import type { ApiTagDescription } from 'services/api';
import { api, LIST_TAG } from 'services/api';
import { boardsApi } from 'services/api/endpoints/boards';
import { imagesApi } from 'services/api/endpoints/images';
import { modelsApi } from 'services/api/endpoints/models';
import { queueApi, queueItemsAdapter } from 'services/api/endpoints/queue';
import { getCategories, getListImagesUrl } from 'services/api/util';
import { buildOnInvocationComplete } from 'services/events/onInvocationComplete';
import type { ClientToServerEvents, ServerToClientEvents } from 'services/events/types';
import type { Socket } from 'socket.io-client';
Expand Down Expand Up @@ -529,7 +532,7 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis

socket.on('bulk_upload_completed', (data) => {
log.debug({ data }, 'Bulk gallery upload ready');
const { total } = data;
const { total, image_DTO: imageDTO } = data;

toast({
id: 'BULK_UPLOAD',
Expand All @@ -545,6 +548,26 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis
),
duration: null,
});

// update the total images for the board
dispatch(
boardsApi.util.updateQueryData('getBoardImagesTotal', imageDTO.board_id ?? 'none', (draft) => {
draft.total += 1;
})
);

dispatch(
imagesApi.util.invalidateTags([
{ type: 'Board', id: imageDTO.board_id ?? 'none' },
{
type: 'ImageList',
id: getListImagesUrl({
board_id: imageDTO.board_id ?? 'none',
categories: getCategories(imageDTO),
}),
},
])
);
});

socket.on('bulk_upload_error', (data) => {
Expand Down

0 comments on commit e6f80ca

Please sign in to comment.