diff --git a/src/constants.ts b/src/constants.ts index 05bddd7f..231220aa 100644 --- a/src/constants.ts +++ b/src/constants.ts @@ -16,6 +16,13 @@ export const IPC_CHANNELS = { GET_PRELOAD_SCRIPT: 'get-preload-script', OPEN_DEVTOOLS: 'open-devtools', OPEN_LOGS_FOLDER: 'open-logs-folder', + DOWNLOAD_PROGRESS: 'download-progress', + START_DOWNLOAD: 'start-download', + PAUSE_DOWNLOAD: 'pause-download', + RESUME_DOWNLOAD: 'resume-download', + CANCEL_DOWNLOAD: 'cancel-download', + DELETE_MODEL: 'delete-model', + GET_ALL_DOWNLOADS: 'get-all-downloads', } as const; export const COMFY_ERROR_MESSAGE = diff --git a/src/main.ts b/src/main.ts index 745048e1..08f723d4 100644 --- a/src/main.ts +++ b/src/main.ts @@ -23,6 +23,8 @@ import { StoreType } from './store'; import { createReadStream, watchFile } from 'node:fs'; import todesktop from '@todesktop/runtime'; import { PythonEnvironment } from './pythonEnvironment'; +import { DownloadManager } from './models/DownloadManager'; +import { getModelsDirectory } from './utils'; let comfyServerProcess: ChildProcess | null = null; const host = '127.0.0.1'; @@ -31,6 +33,7 @@ let mainWindow: BrowserWindow | null; let wss: WebSocketServer | null; let store: Store | null; const messageQueue: Array = []; // Stores mesaages before renderer is ready. +let downloadManager: DownloadManager; log.initialize(); @@ -182,7 +185,7 @@ if (!gotTheLock) { }); await handleFirstTimeSetup(); const { appResourcesPath, pythonInstallPath, modelConfigPath, basePath } = await determineResourcesPaths(); - + downloadManager = DownloadManager.getInstance(mainWindow, getModelsDirectory(basePath)); port = await findAvailablePort(8000, 9999).catch((err) => { log.error(`ERROR: Failed to find available port: ${err}`); throw err; diff --git a/src/models/DownloadManager.ts b/src/models/DownloadManager.ts new file mode 100644 index 00000000..85b35568 --- /dev/null +++ b/src/models/DownloadManager.ts @@ -0,0 +1,275 @@ +import { BrowserWindow, session, DownloadItem, ipcMain } from 'electron'; +import path from 'path'; +import fs from 'fs'; +import { IPC_CHANNELS } from '../constants'; +import log from 'electron-log/main'; + +interface Download { + url: string; + filename: string; + tempPath: string; // Temporary filename until the download is complete. + savePath: string; + item: DownloadItem | null; +} + +export enum DownloadStatus { + PENDING = 'pending', + IN_PROGRESS = 'in_progress', + COMPLETED = 'completed', + PAUSED = 'paused', + ERROR = 'error', + CANCELLED = 'cancelled', +} +interface DownloadState { + url: string; + filename: string; + state: DownloadStatus; + receivedBytes: number; + totalBytes: number; + isPaused: boolean; +} + +/** + * Singleton class that manages downloading model checkpoints for ComfyUI. + */ +export class DownloadManager { + private static instance: DownloadManager; + private downloads: Map; + private mainWindow: BrowserWindow; + private modelsDirectory: string; + private constructor(mainWindow: BrowserWindow, modelsDirectory: string) { + this.downloads = new Map(); + this.mainWindow = mainWindow; + this.modelsDirectory = modelsDirectory; + + session.defaultSession.on('will-download', (event, item, webContents) => { + const url = item.getURLChain()[0]; // Get the original URL in case of redirects. + log.info('Will-download event ', url); + const download = this.downloads.get(url); + + if (download) { + this.reportProgress(url, 0, DownloadStatus.PENDING); + item.setSavePath(download.tempPath); + download.item = item; + log.info(`Setting save path to ${item.getSavePath()}`); + + item.on('updated', (event, state) => { + if (state === 'interrupted') { + log.info('Download is interrupted but can be resumed'); + } else if (state === 'progressing') { + const progress = item.getReceivedBytes() / item.getTotalBytes(); + if (item.isPaused()) { + log.info('Download is paused'); + this.reportProgress(url, progress, DownloadStatus.PAUSED); + } else { + this.reportProgress(url, progress, DownloadStatus.IN_PROGRESS); + } + } + }); + + item.once('done', (event, state) => { + if (state === 'completed') { + try { + fs.renameSync(download.tempPath, download.savePath); + log.info(`Successfully renamed ${download.tempPath} to ${download.savePath}`); + } catch (error) { + log.error(`Failed to rename downloaded file: ${error}. Deleting temp file.`); + fs.unlinkSync(download.tempPath); + } + this.reportProgress(url, 1, DownloadStatus.COMPLETED); + this.downloads.delete(url); + } else { + log.info(`Download failed: ${state}`); + const progress = item.getReceivedBytes() / item.getTotalBytes(); + this.reportProgress(url, progress, DownloadStatus.ERROR); + } + }); + } + }); + } + + startDownload(url: string, savePath: string, filename: string): boolean { + const localSavePath = this.getLocalSavePath(filename, savePath); + if (!this.isPathInModelsDirectory(localSavePath)) { + log.error(`Save path ${localSavePath} is not in models directory ${this.modelsDirectory}`); + this.reportProgress(url, 0, DownloadStatus.ERROR, 'Save path is not in models directory'); + return false; + } + + const validationResult = this.validateSafetensorsFile(url, filename); + if (!validationResult.isValid) { + log.error(validationResult.error); + this.reportProgress(url, 0, DownloadStatus.ERROR, validationResult.error); + return false; + } + + if (fs.existsSync(localSavePath)) { + log.info(`File ${filename} already exists, skipping download`); + return true; + } + + const existingDownload = this.downloads.get(url); + if (existingDownload) { + log.info('Download already exists'); + if (existingDownload.item && existingDownload.item.isPaused()) { + this.resumeDownload(url); + } + return true; + } + + log.info(`Starting download ${url} to ${localSavePath}`); + const tempPath = this.getTempPath(filename, savePath); + this.downloads.set(url, { url, savePath: localSavePath, tempPath, filename, item: null }); + + // TODO(robinhuang): Add offset support for resuming downloads. + // Can use https://www.electronjs.org/docs/latest/api/session#sescreateinterrupteddownloadoptions + session.defaultSession.downloadURL(url); + return true; + } + + cancelDownload(url: string): void { + const download = this.downloads.get(url); + if (download && download.item) { + log.info('Cancelling download'); + download.item.cancel(); + } + } + + pauseDownload(url: string): void { + const download = this.downloads.get(url); + if (download && download.item) { + log.info('Pausing download'); + download.item.pause(); + } + } + + resumeDownload(url: string): void { + const download = this.downloads.get(url); + if (download) { + if (download.item && download.item.canResume()) { + log.info('Resuming download'); + download.item.resume(); + } else { + this.startDownload(download.url, download.savePath, download.filename); + } + } + } + + deleteModel(filename: string, savePath: string): boolean { + const localSavePath = this.getLocalSavePath(filename, savePath); + if (!this.isPathInModelsDirectory(localSavePath)) { + log.error(`Save path ${localSavePath} is not in models directory ${this.modelsDirectory}`); + return false; + } + const tempPath = this.getTempPath(filename, savePath); + try { + if (fs.existsSync(localSavePath)) { + log.info(`Deleting local file ${localSavePath}`); + fs.unlinkSync(localSavePath); + } + } catch (error) { + log.error(`Failed to delete file ${localSavePath}: ${error}`); + } + + try { + if (fs.existsSync(tempPath)) { + log.info(`Deleting temp file ${tempPath}`); + fs.unlinkSync(tempPath); + } + } catch (error) { + log.error(`Failed to delete file ${tempPath}: ${error}`); + } + } + + getAllDownloads(): DownloadState[] { + return Array.from(this.downloads.values()) + .filter((download) => download.item !== null) + .map((download) => ({ + url: download.url, + filename: download.filename, + tempPath: download.tempPath, + state: this.convertDownloadState(download.item?.getState()), + receivedBytes: download.item?.getReceivedBytes() || 0, + totalBytes: download.item?.getTotalBytes() || 0, + isPaused: download.item?.isPaused() || false, + })); + } + + private convertDownloadState(state: 'progressing' | 'completed' | 'cancelled' | 'interrupted'): DownloadStatus { + switch (state) { + case 'progressing': + return DownloadStatus.IN_PROGRESS; + case 'completed': + return DownloadStatus.COMPLETED; + case 'cancelled': + return DownloadStatus.CANCELLED; + case 'interrupted': + return DownloadStatus.ERROR; + default: + return DownloadStatus.ERROR; + } + } + + private getTempPath(filename: string, savePath: string): string { + return path.join(this.modelsDirectory, savePath, `Unconfirmed ${filename}.tmp`); + } + + // Only allow .safetensors files to be downloaded. + private validateSafetensorsFile(url: string, filename: string): { isValid: boolean; error?: string } { + try { + const urlObj = new URL(url); + const pathname = urlObj.pathname.toLowerCase(); + if (!pathname.endsWith('.safetensors') && !filename.toLowerCase().endsWith('.safetensors')) { + return { + isValid: false, + error: 'Invalid file type: must be a .safetensors file', + }; + } + return { isValid: true }; + } catch (error) { + return { + isValid: false, + error: `Invalid URL format: ${error}`, + }; + } + } + + private getLocalSavePath(filename: string, savePath: string): string { + return path.join(this.modelsDirectory, savePath, filename); + } + + private isPathInModelsDirectory(filePath: string): boolean { + const absoluteFilePath = path.resolve(filePath); + const absoluteModelsDir = path.resolve(this.modelsDirectory); + return absoluteFilePath.startsWith(absoluteModelsDir); + } + + private reportProgress(url: string, progress: number, status: DownloadStatus, message: string = ''): void { + log.info(`Download progress: ${progress}, status: ${status}, message: ${message}`); + this.mainWindow.webContents.send(IPC_CHANNELS.DOWNLOAD_PROGRESS, { + url, + progress, + status, + message, + }); + } + + public static getInstance(mainWindow: BrowserWindow, modelsDirectory: string): DownloadManager { + if (!DownloadManager.instance) { + DownloadManager.instance = new DownloadManager(mainWindow, modelsDirectory); + DownloadManager.instance.registerIpcHandlers(); + } + return DownloadManager.instance; + } + + private registerIpcHandlers() { + ipcMain.handle(IPC_CHANNELS.START_DOWNLOAD, (event, { url, path, filename }) => + this.startDownload(url, path, filename) + ); + ipcMain.handle(IPC_CHANNELS.PAUSE_DOWNLOAD, (event, url: string) => this.pauseDownload(url)); + ipcMain.handle(IPC_CHANNELS.RESUME_DOWNLOAD, (event, url: string) => this.resumeDownload(url)); + ipcMain.handle(IPC_CHANNELS.CANCEL_DOWNLOAD, (event, url: string) => this.cancelDownload(url)); + ipcMain.handle(IPC_CHANNELS.GET_ALL_DOWNLOADS, (event) => this.getAllDownloads()); + ipcMain.handle(IPC_CHANNELS.DELETE_MODEL, (event, { filename, path }) => this.deleteModel(filename, path)); + } +} diff --git a/src/preload.ts b/src/preload.ts index 6c194bf8..1ad071ad 100644 --- a/src/preload.ts +++ b/src/preload.ts @@ -1,7 +1,8 @@ // preload.ts -import { contextBridge, ipcRenderer } from 'electron'; +import { contextBridge, DownloadItem, ipcRenderer } from 'electron'; import { IPC_CHANNELS, ELECTRON_BRIDGE_API } from './constants'; +import { DownloadStatus } from './models/DownloadManager'; export interface ElectronAPI { /** @@ -33,6 +34,22 @@ export interface ElectronAPI { * Open the logs folder in the system's default file explorer. */ openLogsFolder: () => void; + DownloadManager: { + onDownloadProgress: ( + callback: (progress: { + url: string; + progress_percentage: number; + status: DownloadStatus; + message?: string; + }) => void + ) => void; + startDownload: (url: string, path: string, filename: string) => Promise; + cancelDownload: (url: string) => Promise; + pauseDownload: (url: string) => Promise; + resumeDownload: (url: string) => Promise; + deleteModel: (filename: string, path: string) => Promise; + getAllDownloads: () => Promise; + }; } const electronAPI: ElectronAPI = { @@ -98,6 +115,37 @@ const electronAPI: ElectronAPI = { openLogsFolder: () => { ipcRenderer.send(IPC_CHANNELS.OPEN_LOGS_FOLDER); }, + DownloadManager: { + onDownloadProgress: ( + callback: (progress: { + url: string; + progress_percentage: number; + status: DownloadStatus; + message?: string; + }) => void + ) => { + ipcRenderer.on(IPC_CHANNELS.DOWNLOAD_PROGRESS, (_event, progress) => callback(progress)); + }, + startDownload: (url: string, path: string, filename: string): Promise => { + console.log(`Sending start download message to main process`, { url, path, filename }); + return ipcRenderer.invoke(IPC_CHANNELS.START_DOWNLOAD, { url, path, filename }); + }, + cancelDownload: (url: string): Promise => { + return ipcRenderer.invoke(IPC_CHANNELS.CANCEL_DOWNLOAD, url); + }, + pauseDownload: (url: string): Promise => { + return ipcRenderer.invoke(IPC_CHANNELS.PAUSE_DOWNLOAD, url); + }, + resumeDownload: (url: string): Promise => { + return ipcRenderer.invoke(IPC_CHANNELS.RESUME_DOWNLOAD, url); + }, + deleteModel: (filename: string, path: string): Promise => { + return ipcRenderer.invoke(IPC_CHANNELS.DELETE_DOWNLOAD, { filename, path }); + }, + getAllDownloads: (): Promise => { + return ipcRenderer.invoke(IPC_CHANNELS.GET_ALL_DOWNLOADS); + }, + }, }; contextBridge.exposeInMainWorld(ELECTRON_BRIDGE_API, electronAPI); diff --git a/src/tray.ts b/src/tray.ts index ff36228a..3266b90e 100644 --- a/src/tray.ts +++ b/src/tray.ts @@ -4,6 +4,7 @@ import { IPC_CHANNELS } from './constants'; import { exec } from 'child_process'; import log from 'electron-log/main'; import { PythonEnvironment } from './pythonEnvironment'; +import { getModelsDirectory } from './utils'; export function SetupTray( mainView: BrowserWindow, @@ -65,7 +66,7 @@ export function SetupTray( { type: 'separator' }, { label: 'Open Models Folder', - click: () => shell.openPath(path.join(basePath, 'models')), + click: () => shell.openPath(getModelsDirectory(basePath)), }, { label: 'Open Outputs Folder', diff --git a/src/utils.ts b/src/utils.ts index 38809143..525f92cc 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -1,4 +1,5 @@ import * as fsPromises from 'node:fs/promises'; +import path from 'node:path'; export async function pathAccessible(path: string): Promise { try { @@ -8,3 +9,7 @@ export async function pathAccessible(path: string): Promise { return false; } } + +export function getModelsDirectory(comfyUIBasePath: string): string { + return path.join(comfyUIBasePath, 'models'); +}