diff --git a/src/constants.ts b/src/constants.ts index 09f068f9..0e58976c 100644 --- a/src/constants.ts +++ b/src/constants.ts @@ -21,6 +21,7 @@ export const IPC_CHANNELS = { PAUSE_DOWNLOAD: 'pause-download', RESUME_DOWNLOAD: 'resume-download', CANCEL_DOWNLOAD: 'cancel-download', + DELETE_DOWNLOAD: 'delete-download', GET_ALL_DOWNLOADS: 'get-all-downloads', } as const; diff --git a/src/main.ts b/src/main.ts index e27b3561..c5f0e281 100644 --- a/src/main.ts +++ b/src/main.ts @@ -24,6 +24,7 @@ import { createReadStream, watchFile } from 'node:fs'; import todesktop from '@todesktop/runtime'; import { PythonEnvironment } from './pythonEnvironment'; import { DownloadManager } from './models/DownloadManager'; +import { getModelsDirectory } from './utils'; let comfyServerProcess: ChildProcess | null = null; const host = '127.0.0.1'; @@ -156,8 +157,6 @@ if (!gotTheLock) { try { await createWindow(); - downloadManager = DownloadManager.getInstance(mainWindow); - startWebSocketServer(); mainWindow.on('close', () => { mainWindow = null; @@ -186,7 +185,7 @@ if (!gotTheLock) { }); await handleFirstTimeSetup(); const { appResourcesPath, pythonInstallPath, modelConfigPath, basePath } = await determineResourcesPaths(); - + downloadManager = DownloadManager.getInstance(mainWindow, getModelsDirectory(basePath)); port = await findAvailablePort(8000, 9999).catch((err) => { log.error(`ERROR: Failed to find available port: ${err}`); throw err; diff --git a/src/models/DownloadManager.ts b/src/models/DownloadManager.ts index 185545ba..d916096f 100644 --- a/src/models/DownloadManager.ts +++ b/src/models/DownloadManager.ts @@ -1,30 +1,45 @@ import { BrowserWindow, session, DownloadItem, ipcMain } from 'electron'; import path from 'path'; +import fs from 'fs'; import { IPC_CHANNELS } from '../constants'; import log from 'electron-log/main'; interface Download { url: string; filename: string; + tempPath: string; // Temporary filename until the download is complete. savePath: string; item: DownloadItem | null; } +interface DownloadStatus { + url: string; + filename: string; + state: string; + receivedBytes: number; + totalBytes: number; + isPaused: boolean; +} + export class DownloadManager { private static instance: DownloadManager; private downloads: Map; private mainWindow: BrowserWindow; - constructor(mainWindow: BrowserWindow) { + private modelsDirectory: string; + constructor(mainWindow: BrowserWindow, modelsDirectory: string) { this.downloads = new Map(); this.mainWindow = mainWindow; + this.modelsDirectory = modelsDirectory; session.defaultSession.on('will-download', (event, item, webContents) => { - const url = item.getURL(); + const url = item.getURLChain()[0]; // Get the original URL in case of redirects. + log.info('Will-download event ', url); const download = this.downloads.get(url); if (download) { - item.setSavePath(download.savePath); + item.setSavePath(download.tempPath); download.item = item; + log.info(`Setting save path to ${item.getSavePath()}`); item.on('updated', (event, state) => { if (state === 'interrupted') { @@ -34,6 +49,7 @@ export class DownloadManager { log.info('Download is paused'); } else { const progress = item.getReceivedBytes() / item.getTotalBytes(); + log.info(`Download progress: ${progress}`); this.reportProgress(url, progress); } } @@ -41,29 +57,50 @@ export class DownloadManager { item.once('done', (event, state) => { if (state === 'completed') { + try { + fs.renameSync(download.tempPath, download.savePath); + log.info(`Successfully renamed ${download.tempPath} to ${download.savePath}`); + } catch (error) { + fs.unlinkSync(download.tempPath); + log.error(`Failed to rename downloaded file: ${error}`); + } this.reportProgress(url, 1, true); + this.downloads.delete(url); } else { log.info(`Download failed: ${state}`); - this.reportProgress(url, 0, false, true); + const progress = item.getReceivedBytes() / item.getTotalBytes(); + this.reportProgress(url, progress, false, true); } - this.downloads.delete(url); }); } }); } - startDownload(url: string, savePath: string): void { + startDownload(url: string, savePath: string, filename: string): boolean { + const localSavePath = this.getLocalSavePath(filename, savePath); + + if (fs.existsSync(localSavePath)) { + log.info(`File ${filename} already exists, skipping download`); + return true; + } + const existingDownload = this.downloads.get(url); if (existingDownload) { log.info('Download already exists'); if (existingDownload.item && existingDownload.item.isPaused()) { this.resumeDownload(url); } - return; + return true; } - const filename = path.basename(savePath); - this.downloads.set(url, { url, filename, savePath, item: null }); + + log.info(`Starting download ${url} to ${localSavePath}`); + const tempPath = this.getTempPath(filename, savePath); + this.downloads.set(url, { url, savePath: localSavePath, tempPath, filename, item: null }); + + // TODO(robinhuang): Add offset support for resuming downloads. + // Can use https://www.electronjs.org/docs/latest/api/session#sescreateinterrupteddownloadoptions session.defaultSession.downloadURL(url); + return true; } cancelDownload(url: string): void { @@ -89,15 +126,52 @@ export class DownloadManager { log.info('Resuming download'); download.item.resume(); } else { - this.startDownload(download.url, download.savePath); + this.startDownload(download.url, download.savePath, download.filename); } } } - getAllDownloads(): DownloadItem[] { + deleteDownload(url: string, filename: string, savePath: string): void { + this.downloads.delete(url); + const localSavePath = this.getLocalSavePath(filename, savePath); + const tempPath = this.getTempPath(filename, savePath); + try { + if (fs.existsSync(localSavePath)) { + fs.unlinkSync(localSavePath); + } + } catch (error) { + log.error(`Failed to delete file ${localSavePath}: ${error}`); + } + + try { + if (fs.existsSync(tempPath)) { + fs.unlinkSync(tempPath); + } + } catch (error) { + log.error(`Failed to delete file ${tempPath}: ${error}`); + } + } + + getAllDownloads(): DownloadStatus[] { return Array.from(this.downloads.values()) - .map((download) => download.item) - .filter((item): item is DownloadItem => item !== null); + .filter((download) => download.item !== null) + .map((download) => ({ + url: download.url, + filename: download.filename, + tempPath: download.tempPath, + state: download.item?.getState() || 'interrupted', + receivedBytes: download.item?.getReceivedBytes() || 0, + totalBytes: download.item?.getTotalBytes() || 0, + isPaused: download.item?.isPaused() || false, + })); + } + + private getTempPath(filename: string, savePath: string): string { + return path.join(this.modelsDirectory, savePath, `Unconfirmed ${filename}.tmp`); + } + + private getLocalSavePath(filename: string, savePath: string): string { + return path.join(this.modelsDirectory, savePath, filename); } private reportProgress( @@ -114,19 +188,24 @@ export class DownloadManager { }); } - public static getInstance(mainWindow: BrowserWindow): DownloadManager { + public static getInstance(mainWindow: BrowserWindow, modelsDirectory: string): DownloadManager { if (!DownloadManager.instance) { - DownloadManager.instance = new DownloadManager(mainWindow); + DownloadManager.instance = new DownloadManager(mainWindow, modelsDirectory); DownloadManager.instance.registerIpcHandlers(); } return DownloadManager.instance; } private registerIpcHandlers() { - ipcMain.handle(IPC_CHANNELS.START_DOWNLOAD, (event, { url, path }) => this.startDownload(url, path)); + ipcMain.handle(IPC_CHANNELS.START_DOWNLOAD, (event, { url, path, filename }) => + this.startDownload(url, path, filename) + ); ipcMain.handle(IPC_CHANNELS.PAUSE_DOWNLOAD, (event, url: string) => this.pauseDownload(url)); ipcMain.handle(IPC_CHANNELS.RESUME_DOWNLOAD, (event, url: string) => this.resumeDownload(url)); ipcMain.handle(IPC_CHANNELS.CANCEL_DOWNLOAD, (event, url: string) => this.cancelDownload(url)); ipcMain.handle(IPC_CHANNELS.GET_ALL_DOWNLOADS, (event) => this.getAllDownloads()); + ipcMain.handle(IPC_CHANNELS.DELETE_DOWNLOAD, (event, { url, filename, path }) => + this.deleteDownload(url, filename, path) + ); } } diff --git a/src/preload.ts b/src/preload.ts index 2da0027e..b38b3bd7 100644 --- a/src/preload.ts +++ b/src/preload.ts @@ -37,10 +37,11 @@ export interface ElectronAPI { onDownloadProgress: ( callback: (progress: { url: string; progress: number; isComplete: boolean; isCancelled: boolean }) => void ) => void; - startDownload: (url: string, path: string) => Promise; + startDownload: (url: string, path: string, filename: string) => Promise; cancelDownload: (url: string) => Promise; pauseDownload: (url: string) => Promise; resumeDownload: (url: string) => Promise; + deleteDownload: (url: string, filename: string, path: string) => Promise; getAllDownloads: () => Promise; }; } @@ -114,8 +115,9 @@ const electronAPI: ElectronAPI = { ) => { ipcRenderer.on(IPC_CHANNELS.DOWNLOAD_PROGRESS, (_event, progress) => callback(progress)); }, - startDownload: (url: string, path: string): Promise => { - return ipcRenderer.invoke(IPC_CHANNELS.START_DOWNLOAD, { url, path }); + startDownload: (url: string, path: string, filename: string): Promise => { + console.log(`Sending start download message to main process`, { url, path, filename }); + return ipcRenderer.invoke(IPC_CHANNELS.START_DOWNLOAD, { url, path, filename }); }, cancelDownload: (url: string): Promise => { return ipcRenderer.invoke(IPC_CHANNELS.CANCEL_DOWNLOAD, url); @@ -126,6 +128,9 @@ const electronAPI: ElectronAPI = { resumeDownload: (url: string): Promise => { return ipcRenderer.invoke(IPC_CHANNELS.RESUME_DOWNLOAD, url); }, + deleteDownload: (url: string, filename: string, path: string): Promise => { + return ipcRenderer.invoke(IPC_CHANNELS.DELETE_DOWNLOAD, { url, filename, path }); + }, getAllDownloads: (): Promise => { return ipcRenderer.invoke(IPC_CHANNELS.GET_ALL_DOWNLOADS); }, diff --git a/src/tray.ts b/src/tray.ts index ff36228a..3266b90e 100644 --- a/src/tray.ts +++ b/src/tray.ts @@ -4,6 +4,7 @@ import { IPC_CHANNELS } from './constants'; import { exec } from 'child_process'; import log from 'electron-log/main'; import { PythonEnvironment } from './pythonEnvironment'; +import { getModelsDirectory } from './utils'; export function SetupTray( mainView: BrowserWindow, @@ -65,7 +66,7 @@ export function SetupTray( { type: 'separator' }, { label: 'Open Models Folder', - click: () => shell.openPath(path.join(basePath, 'models')), + click: () => shell.openPath(getModelsDirectory(basePath)), }, { label: 'Open Outputs Folder', diff --git a/src/utils.ts b/src/utils.ts index 38809143..525f92cc 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -1,4 +1,5 @@ import * as fsPromises from 'node:fs/promises'; +import path from 'node:path'; export async function pathAccessible(path: string): Promise { try { @@ -8,3 +9,7 @@ export async function pathAccessible(path: string): Promise { return false; } } + +export function getModelsDirectory(comfyUIBasePath: string): string { + return path.join(comfyUIBasePath, 'models'); +}