Skip to content

Commit

Permalink
Download Models API. (#132)
Browse files Browse the repository at this point in the history
* Download Manager API.

* Add delete and cancel downloads.

* Improve logs.

* Track progress.

* Validate model filepath.

* Validate safetensors.
  • Loading branch information
robinjhuang authored Oct 29, 2024
1 parent 5332668 commit 60e24d9
Show file tree
Hide file tree
Showing 6 changed files with 342 additions and 3 deletions.
7 changes: 7 additions & 0 deletions src/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@ export const IPC_CHANNELS = {
GET_PRELOAD_SCRIPT: 'get-preload-script',
OPEN_DEVTOOLS: 'open-devtools',
OPEN_LOGS_FOLDER: 'open-logs-folder',
DOWNLOAD_PROGRESS: 'download-progress',
START_DOWNLOAD: 'start-download',
PAUSE_DOWNLOAD: 'pause-download',
RESUME_DOWNLOAD: 'resume-download',
CANCEL_DOWNLOAD: 'cancel-download',
DELETE_MODEL: 'delete-model',
GET_ALL_DOWNLOADS: 'get-all-downloads',
} as const;

export const COMFY_ERROR_MESSAGE =
Expand Down
5 changes: 4 additions & 1 deletion src/main.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ import { StoreType } from './store';
import { createReadStream, watchFile } from 'node:fs';
import todesktop from '@todesktop/runtime';
import { PythonEnvironment } from './pythonEnvironment';
import { DownloadManager } from './models/DownloadManager';
import { getModelsDirectory } from './utils';

let comfyServerProcess: ChildProcess | null = null;
const host = '127.0.0.1';
Expand All @@ -31,6 +33,7 @@ let mainWindow: BrowserWindow | null;
let wss: WebSocketServer | null;
let store: Store<StoreType> | null;
const messageQueue: Array<any> = []; // Stores mesaages before renderer is ready.
let downloadManager: DownloadManager;

log.initialize();

Expand Down Expand Up @@ -182,7 +185,7 @@ if (!gotTheLock) {
});
await handleFirstTimeSetup();
const { appResourcesPath, pythonInstallPath, modelConfigPath, basePath } = await determineResourcesPaths();

downloadManager = DownloadManager.getInstance(mainWindow, getModelsDirectory(basePath));
port = await findAvailablePort(8000, 9999).catch((err) => {
log.error(`ERROR: Failed to find available port: ${err}`);
throw err;
Expand Down
275 changes: 275 additions & 0 deletions src/models/DownloadManager.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,275 @@
import { BrowserWindow, session, DownloadItem, ipcMain } from 'electron';
import path from 'path';
import fs from 'fs';
import { IPC_CHANNELS } from '../constants';
import log from 'electron-log/main';

interface Download {
url: string;
filename: string;
tempPath: string; // Temporary filename until the download is complete.
savePath: string;
item: DownloadItem | null;
}

export enum DownloadStatus {
PENDING = 'pending',
IN_PROGRESS = 'in_progress',
COMPLETED = 'completed',
PAUSED = 'paused',
ERROR = 'error',
CANCELLED = 'cancelled',
}
interface DownloadState {
url: string;
filename: string;
state: DownloadStatus;
receivedBytes: number;
totalBytes: number;
isPaused: boolean;
}

/**
* Singleton class that manages downloading model checkpoints for ComfyUI.
*/
export class DownloadManager {
private static instance: DownloadManager;
private downloads: Map<string, Download>;
private mainWindow: BrowserWindow;
private modelsDirectory: string;
private constructor(mainWindow: BrowserWindow, modelsDirectory: string) {
this.downloads = new Map();
this.mainWindow = mainWindow;
this.modelsDirectory = modelsDirectory;

session.defaultSession.on('will-download', (event, item, webContents) => {
const url = item.getURLChain()[0]; // Get the original URL in case of redirects.
log.info('Will-download event ', url);
const download = this.downloads.get(url);

if (download) {
this.reportProgress(url, 0, DownloadStatus.PENDING);
item.setSavePath(download.tempPath);
download.item = item;
log.info(`Setting save path to ${item.getSavePath()}`);

item.on('updated', (event, state) => {
if (state === 'interrupted') {
log.info('Download is interrupted but can be resumed');
} else if (state === 'progressing') {
const progress = item.getReceivedBytes() / item.getTotalBytes();
if (item.isPaused()) {
log.info('Download is paused');
this.reportProgress(url, progress, DownloadStatus.PAUSED);
} else {
this.reportProgress(url, progress, DownloadStatus.IN_PROGRESS);
}
}
});

item.once('done', (event, state) => {
if (state === 'completed') {
try {
fs.renameSync(download.tempPath, download.savePath);
log.info(`Successfully renamed ${download.tempPath} to ${download.savePath}`);
} catch (error) {
log.error(`Failed to rename downloaded file: ${error}. Deleting temp file.`);
fs.unlinkSync(download.tempPath);
}
this.reportProgress(url, 1, DownloadStatus.COMPLETED);
this.downloads.delete(url);
} else {
log.info(`Download failed: ${state}`);
const progress = item.getReceivedBytes() / item.getTotalBytes();
this.reportProgress(url, progress, DownloadStatus.ERROR);
}
});
}
});
}

startDownload(url: string, savePath: string, filename: string): boolean {
const localSavePath = this.getLocalSavePath(filename, savePath);
if (!this.isPathInModelsDirectory(localSavePath)) {
log.error(`Save path ${localSavePath} is not in models directory ${this.modelsDirectory}`);
this.reportProgress(url, 0, DownloadStatus.ERROR, 'Save path is not in models directory');
return false;
}

const validationResult = this.validateSafetensorsFile(url, filename);
if (!validationResult.isValid) {
log.error(validationResult.error);
this.reportProgress(url, 0, DownloadStatus.ERROR, validationResult.error);
return false;
}

if (fs.existsSync(localSavePath)) {
log.info(`File ${filename} already exists, skipping download`);
return true;
}

const existingDownload = this.downloads.get(url);
if (existingDownload) {
log.info('Download already exists');
if (existingDownload.item && existingDownload.item.isPaused()) {
this.resumeDownload(url);
}
return true;
}

log.info(`Starting download ${url} to ${localSavePath}`);
const tempPath = this.getTempPath(filename, savePath);
this.downloads.set(url, { url, savePath: localSavePath, tempPath, filename, item: null });

// TODO(robinhuang): Add offset support for resuming downloads.
// Can use https://www.electronjs.org/docs/latest/api/session#sescreateinterrupteddownloadoptions
session.defaultSession.downloadURL(url);
return true;
}

cancelDownload(url: string): void {
const download = this.downloads.get(url);
if (download && download.item) {
log.info('Cancelling download');
download.item.cancel();
}
}

pauseDownload(url: string): void {
const download = this.downloads.get(url);
if (download && download.item) {
log.info('Pausing download');
download.item.pause();
}
}

resumeDownload(url: string): void {
const download = this.downloads.get(url);
if (download) {
if (download.item && download.item.canResume()) {
log.info('Resuming download');
download.item.resume();
} else {
this.startDownload(download.url, download.savePath, download.filename);
}
}
}

deleteModel(filename: string, savePath: string): boolean {
const localSavePath = this.getLocalSavePath(filename, savePath);
if (!this.isPathInModelsDirectory(localSavePath)) {
log.error(`Save path ${localSavePath} is not in models directory ${this.modelsDirectory}`);
return false;
}
const tempPath = this.getTempPath(filename, savePath);
try {
if (fs.existsSync(localSavePath)) {
log.info(`Deleting local file ${localSavePath}`);
fs.unlinkSync(localSavePath);
}
} catch (error) {
log.error(`Failed to delete file ${localSavePath}: ${error}`);
}

try {
if (fs.existsSync(tempPath)) {
log.info(`Deleting temp file ${tempPath}`);
fs.unlinkSync(tempPath);
}
} catch (error) {
log.error(`Failed to delete file ${tempPath}: ${error}`);
}
}

getAllDownloads(): DownloadState[] {
return Array.from(this.downloads.values())
.filter((download) => download.item !== null)
.map((download) => ({
url: download.url,
filename: download.filename,
tempPath: download.tempPath,
state: this.convertDownloadState(download.item?.getState()),
receivedBytes: download.item?.getReceivedBytes() || 0,
totalBytes: download.item?.getTotalBytes() || 0,
isPaused: download.item?.isPaused() || false,
}));
}

private convertDownloadState(state: 'progressing' | 'completed' | 'cancelled' | 'interrupted'): DownloadStatus {
switch (state) {
case 'progressing':
return DownloadStatus.IN_PROGRESS;
case 'completed':
return DownloadStatus.COMPLETED;
case 'cancelled':
return DownloadStatus.CANCELLED;
case 'interrupted':
return DownloadStatus.ERROR;
default:
return DownloadStatus.ERROR;
}
}

private getTempPath(filename: string, savePath: string): string {
return path.join(this.modelsDirectory, savePath, `Unconfirmed ${filename}.tmp`);
}

// Only allow .safetensors files to be downloaded.
private validateSafetensorsFile(url: string, filename: string): { isValid: boolean; error?: string } {
try {
const urlObj = new URL(url);
const pathname = urlObj.pathname.toLowerCase();
if (!pathname.endsWith('.safetensors') && !filename.toLowerCase().endsWith('.safetensors')) {
return {
isValid: false,
error: 'Invalid file type: must be a .safetensors file',
};
}
return { isValid: true };
} catch (error) {
return {
isValid: false,
error: `Invalid URL format: ${error}`,
};
}
}

private getLocalSavePath(filename: string, savePath: string): string {
return path.join(this.modelsDirectory, savePath, filename);
}

private isPathInModelsDirectory(filePath: string): boolean {
const absoluteFilePath = path.resolve(filePath);
const absoluteModelsDir = path.resolve(this.modelsDirectory);
return absoluteFilePath.startsWith(absoluteModelsDir);
}

private reportProgress(url: string, progress: number, status: DownloadStatus, message: string = ''): void {
log.info(`Download progress: ${progress}, status: ${status}, message: ${message}`);
this.mainWindow.webContents.send(IPC_CHANNELS.DOWNLOAD_PROGRESS, {
url,
progress,
status,
message,
});
}

public static getInstance(mainWindow: BrowserWindow, modelsDirectory: string): DownloadManager {
if (!DownloadManager.instance) {
DownloadManager.instance = new DownloadManager(mainWindow, modelsDirectory);
DownloadManager.instance.registerIpcHandlers();
}
return DownloadManager.instance;
}

private registerIpcHandlers() {
ipcMain.handle(IPC_CHANNELS.START_DOWNLOAD, (event, { url, path, filename }) =>
this.startDownload(url, path, filename)
);
ipcMain.handle(IPC_CHANNELS.PAUSE_DOWNLOAD, (event, url: string) => this.pauseDownload(url));
ipcMain.handle(IPC_CHANNELS.RESUME_DOWNLOAD, (event, url: string) => this.resumeDownload(url));
ipcMain.handle(IPC_CHANNELS.CANCEL_DOWNLOAD, (event, url: string) => this.cancelDownload(url));
ipcMain.handle(IPC_CHANNELS.GET_ALL_DOWNLOADS, (event) => this.getAllDownloads());
ipcMain.handle(IPC_CHANNELS.DELETE_MODEL, (event, { filename, path }) => this.deleteModel(filename, path));
}
}
50 changes: 49 additions & 1 deletion src/preload.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
// preload.ts

import { contextBridge, ipcRenderer } from 'electron';
import { contextBridge, DownloadItem, ipcRenderer } from 'electron';
import { IPC_CHANNELS, ELECTRON_BRIDGE_API } from './constants';
import { DownloadStatus } from './models/DownloadManager';

export interface ElectronAPI {
/**
Expand Down Expand Up @@ -33,6 +34,22 @@ export interface ElectronAPI {
* Open the logs folder in the system's default file explorer.
*/
openLogsFolder: () => void;
DownloadManager: {
onDownloadProgress: (
callback: (progress: {
url: string;
progress_percentage: number;
status: DownloadStatus;
message?: string;
}) => void
) => void;
startDownload: (url: string, path: string, filename: string) => Promise<boolean>;
cancelDownload: (url: string) => Promise<boolean>;
pauseDownload: (url: string) => Promise<boolean>;
resumeDownload: (url: string) => Promise<boolean>;
deleteModel: (filename: string, path: string) => Promise<boolean>;
getAllDownloads: () => Promise<DownloadItem[]>;
};
}

const electronAPI: ElectronAPI = {
Expand Down Expand Up @@ -98,6 +115,37 @@ const electronAPI: ElectronAPI = {
openLogsFolder: () => {
ipcRenderer.send(IPC_CHANNELS.OPEN_LOGS_FOLDER);
},
DownloadManager: {
onDownloadProgress: (
callback: (progress: {
url: string;
progress_percentage: number;
status: DownloadStatus;
message?: string;
}) => void
) => {
ipcRenderer.on(IPC_CHANNELS.DOWNLOAD_PROGRESS, (_event, progress) => callback(progress));
},
startDownload: (url: string, path: string, filename: string): Promise<boolean> => {
console.log(`Sending start download message to main process`, { url, path, filename });
return ipcRenderer.invoke(IPC_CHANNELS.START_DOWNLOAD, { url, path, filename });
},
cancelDownload: (url: string): Promise<boolean> => {
return ipcRenderer.invoke(IPC_CHANNELS.CANCEL_DOWNLOAD, url);
},
pauseDownload: (url: string): Promise<boolean> => {
return ipcRenderer.invoke(IPC_CHANNELS.PAUSE_DOWNLOAD, url);
},
resumeDownload: (url: string): Promise<boolean> => {
return ipcRenderer.invoke(IPC_CHANNELS.RESUME_DOWNLOAD, url);
},
deleteModel: (filename: string, path: string): Promise<boolean> => {
return ipcRenderer.invoke(IPC_CHANNELS.DELETE_DOWNLOAD, { filename, path });
},
getAllDownloads: (): Promise<DownloadItem[]> => {
return ipcRenderer.invoke(IPC_CHANNELS.GET_ALL_DOWNLOADS);
},
},
};

contextBridge.exposeInMainWorld(ELECTRON_BRIDGE_API, electronAPI);
Loading

0 comments on commit 60e24d9

Please sign in to comment.