Skip to content

Commit

Permalink
feat(trpc/subscription): change to use generator
Browse files Browse the repository at this point in the history
tctien342 committed Jan 15, 2025
1 parent 3597cbf commit 1dc8a2b
Showing 6 changed files with 354 additions and 372 deletions.
6 changes: 4 additions & 2 deletions components/dialogs/AddWorkflowDialog/steps/Finalize.tsx
Original file line number Diff line number Diff line change
@@ -30,6 +30,7 @@ import { dispatchGlobalEvent, EGlobalEvent } from '@/hooks/useGlobalEvent'
const SelectionSchema = z.nativeEnum(EValueSelectionType)

export const FinalizeStep: IComponent = () => {
const idRef = useRef(Math.random().toString(36).substring(7))
const [loading, setLoading] = useState(false)
const [progressEv, setProgressEv] = useState<TWorkflowProgressMessage>()
const [previewBlob, setPreviewBlob] = useState<Blob>()
@@ -44,15 +45,14 @@ export const FinalizeStep: IComponent = () => {

const isEnd = progressEv?.key === 'finished' || progressEv?.key === 'failed'

trpc.workflow.testWorkflow.useSubscription(undefined, {
trpc.workflow.testWorkflow.useSubscription(idRef.current, {
onData: (ev) => {
if (ev.key === 'preview') {
const base64 = ev.data.blob64
const blob = base64 ? new Blob([Buffer.from(base64, 'base64')]) : undefined
setPreviewBlob(blob)
return
}

setProgressEv(ev)
if (ev.key === 'failed' || ev.key === 'finished') {
setLoading(false)
@@ -69,6 +69,7 @@ export const FinalizeStep: IComponent = () => {

const handlePressTest = async () => {
if (!workflow) return
setProgressEv(undefined)
const wfObj = cloneDeep(inputWorkflowTest.current)
// Check if there are files to upload
const inputKeys = Object.keys(workflow?.mapInput || {})
@@ -87,6 +88,7 @@ export const FinalizeStep: IComponent = () => {
}
setLoading(true)
mutateAsync({
id: idRef.current,
workflow,
input: wfObj
})
311 changes: 151 additions & 160 deletions server/routers/client.ts
Original file line number Diff line number Diff line change
@@ -46,15 +46,12 @@ export const clientRouter = router({
await cacher.set('CLIENT_STATUS', input, EClientStatus.Offline)
return true
}),
monitorSystem: adminProcedure.input(z.string()).subscription(async ({ input, ctx }) => {
return observable<TMonitorEvent>((subscriber) => {
const off = cacher.on('SYSTEM_MONITOR', input, (ev) => {
subscriber.next(ev.detail)
})
return off
})
monitorSystem: adminProcedure.input(z.string()).subscription(async function* ({ input, signal }) {
for await (const data of cacher.onGenerator('SYSTEM_MONITOR', input, signal)) {
yield data.detail
}
}),
monitorStatus: adminProcedure.input(z.string()).subscription(async ({ input, ctx }) => {
monitorStatus: adminProcedure.input(z.string()).subscription(async function* ({ input, ctx, signal }) {
const latestEvent = await ctx.em.findOne(
ClientStatusEvent,
{
@@ -66,14 +63,12 @@ export const clientRouter = router({
)
if (!latestEvent) {
throw new Error('Client not found')
} else {
yield latestEvent.status
}
for await (const data of cacher.onGenerator('CLIENT_STATUS', input, signal)) {
yield data.detail
}
return observable<EClientStatus>((subscriber) => {
subscriber.next(latestEvent.status)
const off = cacher.on('CLIENT_STATUS', input, (ev) => {
subscriber.next(ev.detail)
})
return off
})
}),
control: adminProcedure
.input(
@@ -134,7 +129,7 @@ export const clientRouter = router({
return false
}
}),
overview: adminProcedure.subscription(async ({ ctx }) => {
overview: adminProcedure.subscription(async function* ({ ctx, signal }) {
const cacher = CachingService.getInstance()

const getStatues = async () => {
@@ -146,14 +141,10 @@ export const clientRouter = router({
error: data.filter((e) => e === EClientStatus.Error).length
}
}

return observable<Awaited<ReturnType<typeof getStatues>>>((subscriber) => {
getStatues().then((data) => subscriber.next(data))
const off = cacher.onCategory('CLIENT_STATUS', async (ev) => {
getStatues().then((data) => subscriber.next(data))
})
return off
})
yield await getStatues()
for await (const _ of cacher.onCategoryGenerator('CLIENT_STATUS', signal)) {
yield await getStatues()
}
}),
testNewClient: adminProcedure.input(ClientSchema).mutation(async ({ input, ctx }) => {
if (input.auth && (!input.username || !input.password)) {
@@ -214,153 +205,153 @@ export const clientRouter = router({
}),
addNewClient: adminProcedure
.input(z.intersection(ClientSchema, z.object({ displayName: z.string().optional() })))
.subscription(async ({ input, ctx }) => {
.subscription(async function* ({ input, ctx, signal }) {
if (input.auth && (!input.username || !input.password)) {
throw new Error('Username or password is required')
}
const api = new ComfyApi(input.host, 'test', {
credentials: input.auth ? { type: 'basic', username: input.username!, password: input.password! } : undefined
})
return observable<EImportingClient>((subscriber) => {
const run = async () => {
const test = await api.ping()
if (test.status) {
subscriber.next(EImportingClient.PING_OK)
} else {
subscriber.next(EImportingClient.FAILED)
return
}
let client = await ctx.em.findOne(Client, { host: input.host })
if (!client) {
client = ctx.em.create(Client, {
host: input.host,
auth: input.auth ? EAuthMode.Basic : EAuthMode.None,
username: input.username,
password: input.password,
name: input.displayName ?? input.host
})
await ctx.em.persistAndFlush(client)
}
subscriber.next(EImportingClient.CLIENT_CREATED)

const importCkpt = async () => {
const ckpts = await api.getCheckpoints()
for (const ckpt of ckpts) {
let resource = await ctx.em.findOne(Resource, { name: ckpt, type: EResourceType.Checkpoint })
if (!resource) {
resource = ctx.em.create(
Resource,
{
name: ckpt,
type: EResourceType.Checkpoint
},
{ partial: true }
)
}
client.resources.add(resource)
}
await ctx.em.persistAndFlush(client)
subscriber.next(EImportingClient.IMPORTED_CHECKPOINT)
const test = await api.ping()
if (test.status) {
yield EImportingClient.PING_OK
} else {
yield EImportingClient.FAILED
return
}
let client = await ctx.em.findOne(Client, { host: input.host })
if (!client) {
client = ctx.em.create(Client, {
host: input.host,
auth: input.auth ? EAuthMode.Basic : EAuthMode.None,
username: input.username,
password: input.password,
name: input.displayName ?? input.host
})
await ctx.em.persistAndFlush(client)
}
yield EImportingClient.CLIENT_CREATED

const importCkpt = async () => {
const ckpts = await api.getCheckpoints()
for (const ckpt of ckpts) {
let resource = await ctx.em.findOne(Resource, { name: ckpt, type: EResourceType.Checkpoint })
if (!resource) {
resource = ctx.em.create(
Resource,
{
name: ckpt,
type: EResourceType.Checkpoint
},
{ partial: true }
)
}
const importLora = async () => {
const loras = await api.getLoras()
for (const lora of loras) {
let resource = await ctx.em.findOne(Resource, { name: lora, type: EResourceType.Lora })
if (!resource) {
resource = ctx.em.create(
Resource,
{
name: lora,
type: EResourceType.Lora
},
{ partial: true }
)
}
client.resources.add(resource)
}
await ctx.em.persistAndFlush(client)
subscriber.next(EImportingClient.IMPORTED_LORA)
client.resources.add(resource)
}
await ctx.em.persistAndFlush(client)
return EImportingClient.IMPORTED_CHECKPOINT
}
const importLora = async () => {
const loras = await api.getLoras()
for (const lora of loras) {
let resource = await ctx.em.findOne(Resource, { name: lora, type: EResourceType.Lora })
if (!resource) {
resource = ctx.em.create(
Resource,
{
name: lora,
type: EResourceType.Lora
},
{ partial: true }
)
}
const importSamplerScheduler = async () => {
const samplerInfo = await api.getSamplerInfo()
const samplers = samplerInfo.sampler?.[0] as string[]
const schedulers = samplerInfo.scheduler?.[0] as string[]
for (const sampler of samplers) {
let resource = await ctx.em.findOne(Resource, { name: sampler, type: EResourceType.Sampler })
if (!resource) {
resource = ctx.em.create(
Resource,
{
name: sampler,
type: EResourceType.Sampler
},
{ partial: true }
)
}
client.resources.add(resource)
}
for (const scheduler of schedulers) {
let resource = await ctx.em.findOne(Resource, { name: scheduler, type: EResourceType.Scheduler })
if (!resource) {
resource = ctx.em.create(
Resource,
{
name: scheduler,
type: EResourceType.Scheduler
},
{ partial: true }
)
}
client.resources.add(resource)
}
await ctx.em.persistAndFlush(client)
subscriber.next(EImportingClient.IMPORTED_SAMPLER_SCHEDULER)
client.resources.add(resource)
}
await ctx.em.persistAndFlush(client)
return EImportingClient.IMPORTED_LORA
}
const importSamplerScheduler = async () => {
const samplerInfo = await api.getSamplerInfo()
const samplers = samplerInfo.sampler?.[0] as string[]
const schedulers = samplerInfo.scheduler?.[0] as string[]
for (const sampler of samplers) {
let resource = await ctx.em.findOne(Resource, { name: sampler, type: EResourceType.Sampler })
if (!resource) {
resource = ctx.em.create(
Resource,
{
name: sampler,
type: EResourceType.Sampler
},
{ partial: true }
)
}
const importExtension = async () => {
const extensions = (await api.getNodeDefs()) ?? []
const promises = Object.values(extensions).map(async (ext) => {
let resource = await ctx.em.findOne(Extension, { pythonModule: ext.python_module, name: ext.name })
if (!resource) {
resource = ctx.em.create(
Extension,
{
name: ext.name,
displayName: ext.display_name,
pythonModule: ext.python_module,
category: ext.category,
outputNode: ext.output_node,
inputConf: ext.input.required,
description: ext.description,
outputConf: ext.output?.map((o, idx) => ({
name: ext.output_name?.[idx] ?? '',
isList: ext.output_is_list?.[idx] ?? false,
type: o,
tooltip: ext.output_tooltips?.[idx] ?? ''
}))
},
{ partial: true }
)
}
client.extensions.add(resource)
})
await Promise.all(promises)
await ctx.em.persistAndFlush(client)
subscriber.next(EImportingClient.IMPORTED_EXTENSION)
client.resources.add(resource)
}
for (const scheduler of schedulers) {
let resource = await ctx.em.findOne(Resource, { name: scheduler, type: EResourceType.Scheduler })
if (!resource) {
resource = ctx.em.create(
Resource,
{
name: scheduler,
type: EResourceType.Scheduler
},
{ partial: true }
)
}
await Promise.all([importCkpt(), importLora(), importSamplerScheduler(), importExtension()]).catch((e) => {
console.error(e)
subscriber.next(EImportingClient.FAILED)
})
subscriber.next(EImportingClient.DONE)
ComfyPoolInstance.getInstance().pool.addClient(
new ComfyApi(input.host, client.id, {
credentials: input.auth
? { type: 'basic', username: input.username!, password: input.password! }
: undefined
})
)
client.resources.add(resource)
}
run()
})
await ctx.em.persistAndFlush(client)
return EImportingClient.IMPORTED_SAMPLER_SCHEDULER
}
const importExtension = async () => {
const extensions = (await api.getNodeDefs()) ?? []
const promises = Object.values(extensions).map(async (ext) => {
let resource = await ctx.em.findOne(Extension, { pythonModule: ext.python_module, name: ext.name })
if (!resource) {
resource = ctx.em.create(
Extension,
{
name: ext.name,
displayName: ext.display_name,
pythonModule: ext.python_module,
category: ext.category,
outputNode: ext.output_node,
inputConf: ext.input.required,
description: ext.description,
outputConf: ext.output?.map((o, idx) => ({
name: ext.output_name?.[idx] ?? '',
isList: ext.output_is_list?.[idx] ?? false,
type: o,
tooltip: ext.output_tooltips?.[idx] ?? ''
}))
},
{ partial: true }
)
}
client.extensions.add(resource)
})
await Promise.all(promises)
await ctx.em.persistAndFlush(client)
return EImportingClient.IMPORTED_EXTENSION
}
try {
for await (const status of [importCkpt(), importLora(), importSamplerScheduler(), importExtension()]) {
yield status
}
yield EImportingClient.DONE
ComfyPoolInstance.getInstance().pool.addClient(
new ComfyApi(input.host, client.id, {
credentials: input.auth
? { type: 'basic', username: input.username!, password: input.password! }
: undefined
})
)
} catch (e) {
console.error(e)
yield EImportingClient.FAILED
}
})
})
41 changes: 18 additions & 23 deletions server/routers/task.ts
Original file line number Diff line number Diff line change
@@ -16,7 +16,7 @@ export const taskRouter = router({
clientId: z.string().optional()
})
)
.subscription(async ({ ctx, input }) => {
.subscription(async function* ({ ctx, input, signal }) {
let trigger: any = {}
const fn = async () => {
if (ctx.session.user!.role !== EUserRole.Admin) {
@@ -49,19 +49,18 @@ export const taskRouter = router({
{ limit: input.limit, orderBy: { createdAt: 'DESC' }, populate: ['trigger', 'trigger.user'] }
)
}
return observable<Awaited<ReturnType<typeof fn>>>((subscriber) => {
fn().then((data) => subscriber.next(data))
const off = !input.clientId
? cacher.onCategory('LAST_TASK_CLIENT', (ev) => {
fn().then((data) => subscriber.next(data))
})
: cacher.on('LAST_TASK_CLIENT', input.clientId, (ev) => {
fn().then((data) => subscriber.next(data))
})
return off
})
yield await fn()
if (!input.clientId) {
for await (const _ of cacher.onCategoryGenerator('LAST_TASK_CLIENT', signal)) {
yield await fn()
}
} else {
for await (const _ of cacher.onGenerator('LAST_TASK_CLIENT', input.clientId, signal)) {
yield await fn()
}
}
}),
countStats: privateProcedure.subscription(async ({ ctx }) => {
countStats: privateProcedure.subscription(async function* ({ ctx }) {
if (!ctx.session?.user) {
throw new Error('Unauthorized')
}
@@ -98,15 +97,11 @@ export const taskRouter = router({
executed
}
}

return observable<Awaited<ReturnType<typeof getStats>>>((subscriber) => {
getStats().then((data) => subscriber.next(data))
const off = cacher.onCategory('CLIENT_STATUS', (ev) => {
if ([EClientStatus.Executing, EClientStatus.Online].includes(ev.detail.value)) {
getStats().then((data) => subscriber.next(data))
}
})
return off
})
yield await getStats()
for await (const data of cacher.onCategoryGenerator('CLIENT_STATUS')) {
if ([EClientStatus.Executing, EClientStatus.Online].includes(data.detail.value)) {
yield await getStats()
}
}
})
})
125 changes: 57 additions & 68 deletions server/routers/watch.ts
Original file line number Diff line number Diff line change
@@ -8,21 +8,19 @@ import CachingService from '@/services/caching.service'
import { WorkflowTask } from '@/entities/workflow_task'

export const watchRouter = router({
historyList: privateProcedure.subscription(async ({ ctx }) => {
historyList: privateProcedure.subscription(async function* ({ ctx, signal }) {
const cacher = CachingService.getInstance()
return observable<number>((subscriber) => {
if (ctx.session.user!.role === EUserRole.Admin) {
return cacher.onCategory('HISTORY_LIST', (ev) => {
subscriber.next(ev.detail.value)
})
} else {
return cacher.on('HISTORY_LIST', ctx.session.user!.id, (ev) => {
subscriber.next(ev.detail)
})
if (ctx.session.user!.role === EUserRole.Admin) {
for await (const data of cacher.onCategoryGenerator('HISTORY_LIST', signal)) {
yield data.detail.value
}
})
} else {
for await (const data of cacher.onGenerator('HISTORY_LIST', ctx.session.user!.id, signal)) {
yield data.detail
}
}
}),
historyItem: privateProcedure.input(z.string()).subscription(async ({ input, ctx }) => {
historyItem: privateProcedure.input(z.string()).subscription(async function* ({ input, ctx, signal }) {
const cacher = CachingService.getInstance()
if (ctx.session.user!.role !== EUserRole.Admin) {
const taskInfo = await ctx.em.findOneOrFail(
@@ -40,73 +38,64 @@ export const watchRouter = router({
throw new Error('Unauthorized')
}
}
return observable<number>((subscriber) => {
return cacher.onCategory('HISTORY_ITEM', (ev) => {
if (ev.detail.id === input) subscriber.next(ev.detail.value)
})
})
for await (const data of cacher.onGenerator('HISTORY_ITEM', input, signal)) {
yield data.detail
}
}),
workflow: privateProcedure.input(z.string()).subscription(async ({ input, ctx }) => {
workflow: privateProcedure.input(z.string()).subscription(async function* ({ input, signal }) {
const cacher = CachingService.getInstance()
return observable<number>((subscriber) => {
return cacher.on('WORKFLOW', input, (ev) => {
subscriber.next(ev.detail)
})
})
for await (const data of cacher.onGenerator('WORKFLOW', input, signal)) {
yield data.detail
}
}),
balance: privateProcedure.subscription(async ({ ctx }) => {
balance: privateProcedure.subscription(async function* ({ ctx, signal }) {
const cacher = CachingService.getInstance()
return observable<number>((subscriber) => {
subscriber.next(ctx.session.user!.balance)
return cacher.on('USER_BALANCE', ctx.session.user!.id, (ev) => {
subscriber.next(ev.detail)
})
})
yield ctx.session.user!.balance
for await (const data of cacher.onGenerator('USER_BALANCE', ctx.session.user!.id, signal)) {
yield data.detail
}
}),
notification: privateProcedure.subscription(async ({ ctx }) => {
notification: privateProcedure.subscription(async function* ({ ctx, signal }) {
const cacher = CachingService.getInstance()
return observable<number>((subscriber) => {
return cacher.on('USER_NOTIFICATION', ctx.session.user!.id, (ev) => {
subscriber.next(ev.detail)
})
})

for await (const data of cacher.onGenerator('USER_NOTIFICATION', ctx.session.user!.id, signal)) {
yield data.detail
}
}),
executing: privateProcedure.subscription(async ({ ctx }) => {
executing: privateProcedure.subscription(async function* ({ ctx, signal }) {
const cacher = CachingService.getInstance()
return observable<boolean>((subscriber) => {
return cacher.on('USER_EXECUTING_TASK', ctx.session.user!.id, async (ev) => {
const task = await ctx.em.findOne(WorkflowTask, {
trigger: {
$or: [
{
user: {
id: ctx.session.user?.id
}
},
{
token: {
createdBy: ctx.session.user?.id
}

for await (const _ of cacher.onGenerator('USER_EXECUTING_TASK', ctx.session.user!.id, signal)) {
const task = await ctx.em.findOne(WorkflowTask, {
trigger: {
$or: [
{
user: {
id: ctx.session.user?.id
}
},
{
token: {
createdBy: ctx.session.user?.id
}
]
},
status: {
$nin: [ETaskStatus.Failed, ETaskStatus.Parent]
},
outputValues: null,
attachments: null,
executionTime: null
})
subscriber.next(!!task)
}
]
},
status: {
$nin: [ETaskStatus.Failed, ETaskStatus.Parent]
},
outputValues: null,
attachments: null,
executionTime: null
})
})
yield !!task
}
}),
preview: privateProcedure.input(z.object({ taskId: z.string() })).subscription(async ({ input, ctx }) => {
preview: privateProcedure.input(z.object({ taskId: z.string() })).subscription(async function* ({ input, signal }) {
const cacher = CachingService.getInstance()
return observable<string>((subscriber) => {
return cacher.on('PREVIEW', input.taskId, (ev) => {
subscriber.next(ev.detail.blob64)
})
})

for await (const data of cacher.onGenerator('PREVIEW', input.taskId, signal)) {
yield data.detail.blob64
}
})
})
241 changes: 123 additions & 118 deletions server/routers/workflow.ts
Original file line number Diff line number Diff line change
@@ -2,7 +2,7 @@ import { adminProcedure, editorProcedure, privateProcedure } from '../procedure'
import { router } from '../trpc'
import { z } from 'zod'
import { IMapperOutput, Workflow } from '@/entities/workflow'
import { EventEmitter } from 'node:events'
import { EventEmitter, on } from 'node:events'
import { observable } from '@trpc/server/observable'
import { ComfyPoolInstance } from '@/services/comfyui.service'
import { CallWrapper } from '@saintno/comfyui-sdk'
@@ -17,6 +17,10 @@ import { getBuilder, parseOutput } from '@/utils/workflow'

const ee = new EventEmitter()

const emitAction = (id: string, data: TWorkflowProgressMessage) => {
ee.emit(`workflow:${id}`, data)
}

const BaseSchema = z.object({
key: z.string(),
type: z.union([z.nativeEnum(EValueType), z.nativeEnum(EValueSelectionType), z.nativeEnum(EValueUtilityType)]),
@@ -236,132 +240,133 @@ export const workflowRouter = router({
await ctx.em.flush()
return true
}),
testWorkflow: editorProcedure.subscription(async ({ input, ctx }) => {
return observable<TWorkflowProgressMessage>((subscriber) => {
const handle = (data: { input: Record<string, any>; workflow: Workflow }) => {
subscriber.next({ key: 'init' })
const builder = getBuilder(data.workflow)
const pool = ComfyPoolInstance.getInstance().pool
pool.run(async (api) => {
for (const key in data.input) {
const inputData = data.input[key] || data.workflow.mapInput?.[key].default
if (!inputData) {
continue
}
switch (data.workflow.mapInput?.[key].type) {
case EValueType.Number:
case EValueUtilityType.Seed:
builder.input(key, Number(inputData))
break
case EValueType.String:
builder.input(key, String(inputData))
break
case EValueType.File:
case EValueType.Video:
case EValueType.Image:
const file = inputData as Attachment
const fileBlob = await AttachmentService.getInstance().getFileBlob(file.fileName)
if (!fileBlob) {
return subscriber.next({ key: 'failed', detail: 'missing file' })
}
const uploadedImg = await api.uploadImage(fileBlob, file.fileName)
if (!uploadedImg) {
subscriber.next({ key: 'failed', detail: 'failed to upload file' })
return
}
builder.input(key, uploadedImg.info.filename)
break
default:
builder.input(key, inputData)
break
}
}
return new CallWrapper(api, builder)
.onPending(() => {
subscriber.next({ key: 'loading' })
})
.onProgress((e) => {
subscriber.next({
key: 'progress',
data: { node: Number(e.node), max: Number(e.max), value: Number(e.value) }
})
})
.onPreview(async (e) => {
const arrayBuffer = await e.arrayBuffer()
const base64String = Buffer.from(arrayBuffer).toString('base64')
subscriber.next({ key: 'preview', data: { blob64: base64String } })
})
.onStart(() => {
subscriber.next({ key: 'start' })
})
.onFinished(async (outData) => {
subscriber.next({ key: 'downloading_output' })
const attachment = AttachmentService.getInstance()
const output = await parseOutput(api, data.workflow, outData)
subscriber.next({ key: 'uploading_output' })
const tmpOutput = cloneDeep(output) as Record<string, any>
// If key is array of Blob, convert it to base64
for (const key in tmpOutput) {
if (Array.isArray(tmpOutput[key])) {
tmpOutput[key] = (await Promise.all(
tmpOutput[key].map(async (v, idx) => {
if (v instanceof Blob) {
const imgUtil = new ImageUtil(Buffer.from(await v.arrayBuffer()))
const jpg = await imgUtil.intoJPG()
const tmpName = `${uniqueId()}_${key}_${idx}.jpg`
const uploaded = await attachment.uploadFile(jpg, `${tmpName}`)
if (uploaded) {
return await attachment.getFileURL(tmpName)
}
}
return v
})
)) as string[]
}
}
const outputConfig = data.workflow.mapOutput
const outputData = Object.keys(outputConfig || {}).reduce(
(acc, val) => {
if (tmpOutput[val] && outputConfig?.[val]) {
acc[val] = {
info: outputConfig[val],
data: tmpOutput[val]
}
}
return acc
},
{} as Record<
string,
{
info: IMapperOutput
data: number | boolean | string | Array<{ type: EAttachmentType; url: string }>
}
>
)
subscriber.next({ key: 'finished', data: { output: outputData } })
})
.onFailed((e) => {
console.warn(e)
subscriber.next({ key: 'failed', detail: (e.cause as any)?.error?.message || e.message })
})
.run()
})
}
ee.on('start', handle)
return () => {
ee.off('start', handle)
}
})
testWorkflow: editorProcedure.input(z.string()).subscription(async function* ({ ctx, input, signal }) {
for await (const [ev] of on(ee, `workflow:${input}`, {
// Passing the AbortSignal from the request automatically cancels the event emitter when the request is aborted
signal: signal
})) {
const data = ev as TWorkflowProgressMessage
yield data
}
}),
startTestWorkflow: editorProcedure
.input(
z.object({
id: z.string(),
input: z.record(z.string(), z.any()),
workflow: z.any()
})
)
.mutation(async ({ input, ctx }) => {
ee.emit('start', input)
const data = input
const builder = getBuilder(data.workflow)
const pool = ComfyPoolInstance.getInstance().pool
emitAction(input.id, { key: 'init' })
pool.run(async (api) => {
for (const key in data.input) {
const inputData = data.input[key] || data.workflow.mapInput?.[key].default
if (!inputData) {
continue
}
switch (data.workflow.mapInput?.[key].type) {
case EValueType.Number:
case EValueUtilityType.Seed:
builder.input(key, Number(inputData))
break
case EValueType.String:
builder.input(key, String(inputData))
break
case EValueType.File:
case EValueType.Video:
case EValueType.Image:
const file = inputData as Attachment
const fileBlob = await AttachmentService.getInstance().getFileBlob(file.fileName)
if (!fileBlob) {
emitAction(input.id, { key: 'failed', detail: 'missing file' })
return
}
const uploadedImg = await api.uploadImage(fileBlob, file.fileName)
if (!uploadedImg) {
emitAction(input.id, { key: 'failed', detail: 'failed to upload file' })
return
}
builder.input(key, uploadedImg.info.filename)
break
default:
builder.input(key, inputData)
break
}
}
return new CallWrapper(api, builder)
.onPending(() => {
emitAction(input.id, { key: 'loading' })
})
.onProgress((e) => {
emitAction(input.id, {
key: 'progress',
data: { node: Number(e.node), max: Number(e.max), value: Number(e.value) }
})
})
.onPreview(async (e) => {
const arrayBuffer = await e.arrayBuffer()
const base64String = Buffer.from(arrayBuffer).toString('base64')
emitAction(input.id, { key: 'preview', data: { blob64: base64String } })
})
.onStart(() => {
emitAction(input.id, { key: 'start' })
})
.onFinished(async (outData) => {
emitAction(input.id, { key: 'downloading_output' })
const attachment = AttachmentService.getInstance()
const output = await parseOutput(api, data.workflow, outData)
emitAction(input.id, { key: 'uploading_output' })
const tmpOutput = cloneDeep(output) as Record<string, any>
// If key is array of Blob, convert it to base64
for (const key in tmpOutput) {
if (Array.isArray(tmpOutput[key])) {
tmpOutput[key] = (await Promise.all(
tmpOutput[key].map(async (v, idx) => {
if (v instanceof Blob) {
const imgUtil = new ImageUtil(Buffer.from(await v.arrayBuffer()))
const jpg = await imgUtil.intoJPG()
const tmpName = `${uniqueId()}_${key}_${idx}.jpg`
const uploaded = await attachment.uploadFile(jpg, `${tmpName}`)
if (uploaded) {
return await attachment.getFileURL(tmpName, undefined, ctx.baseUrl)
}
}
return v
})
)) as string[]
}
}
const outputConfig = data.workflow.mapOutput
const outputData = Object.keys(outputConfig || {}).reduce(
(acc, val) => {
if (tmpOutput[val] && outputConfig?.[val]) {
acc[val] = {
info: outputConfig[val],
data: tmpOutput[val]
}
}
return acc
},
{} as Record<
string,
{
info: IMapperOutput
data: number | boolean | string | Array<{ type: EAttachmentType; url: string }>
}
>
)
emitAction(input.id, { key: 'finished', data: { output: outputData } })
})
.onFailed((e) => {
console.warn(e)
emitAction(input.id, { key: 'failed', detail: (e.cause as any)?.error?.message || e.message })
})
.run()
})
return true
}),
importWorkflow: editorProcedure
2 changes: 1 addition & 1 deletion services/caching.service.ts
Original file line number Diff line number Diff line change
@@ -212,7 +212,7 @@ class CachingService extends EventTarget {
}
}

onCategoryGenerator = async function* x<T extends keyof TCachingKeyMap>(
onCategoryGenerator = async function* <T extends keyof TCachingKeyMap>(
key: T,
signal?: AbortSignal
): AsyncGenerator<

0 comments on commit 1dc8a2b

Please sign in to comment.