diff --git a/packages/upscalerjs/src/args.browser.ts b/packages/upscalerjs/src/args.browser.ts index 0231f949a..1f988bfc3 100644 --- a/packages/upscalerjs/src/args.browser.ts +++ b/packages/upscalerjs/src/args.browser.ts @@ -1,4 +1,4 @@ -import { BASE64, UpscaleArgs, TENSOR, PrivateUpscaleArgs, } from "./types"; +import { BASE64, TENSOR, GetUpscaleOptions, } from "./types"; const getOutputOption = (output?: unknown): TENSOR | BASE64 => { if (output === 'tensor') { @@ -7,14 +7,14 @@ const getOutputOption = (output?: unknown): TENSOR | BASE64 => { return 'base64'; }; -export function getUpscaleOptions({ +export const getUpscaleOptions: GetUpscaleOptions = ({ output, progressOutput, ...options -}: Omit & { output?: unknown; progressOutput?: unknown } = {}): PrivateUpscaleArgs { +} = {}) => { return { ...options, output: getOutputOption(output), progressOutput: getOutputOption(progressOutput || output), }; -} +}; diff --git a/packages/upscalerjs/src/args.node.ts b/packages/upscalerjs/src/args.node.ts index 02fb7d83d..75fb95226 100644 --- a/packages/upscalerjs/src/args.node.ts +++ b/packages/upscalerjs/src/args.node.ts @@ -1,4 +1,4 @@ -import { BASE64, UpscaleArgs, TENSOR, PrivateUpscaleArgs, } from "./types"; +import { BASE64, TENSOR, GetUpscaleOptions, } from "./types"; const getOutputOption = (output?: unknown): TENSOR | BASE64 => { if (output === 'base64') { @@ -7,14 +7,14 @@ const getOutputOption = (output?: unknown): TENSOR | BASE64 => { return 'tensor'; }; -export function getUpscaleOptions({ +export const getUpscaleOptions: GetUpscaleOptions = ({ output, progressOutput, ...options -}: Omit & { output?: unknown; progressOutput?: unknown } = {}): PrivateUpscaleArgs { +} = {}) => { return { ...options, output: getOutputOption(output), progressOutput: getOutputOption(progressOutput || output), }; -} +}; diff --git a/packages/upscalerjs/src/image.browser.ts b/packages/upscalerjs/src/image.browser.ts index b9d55487d..f2e152ebe 100644 --- a/packages/upscalerjs/src/image.browser.ts +++ b/packages/upscalerjs/src/image.browser.ts @@ -88,7 +88,7 @@ export const isHTMLImageElement = (pixels: Input): pixels is HTMLImageElement => } }; -export const tensorAsBase64: TensorAsBase64 = (tf, tensor) => { +export const tensorAsBase64: TensorAsBase64 = (tf, tensor) => { const arr = tensorAsClampedArray(tf, tensor); const [height, width, ] = tensor.shape; const imageData = new ImageData(width, height); diff --git a/packages/upscalerjs/src/image.node.ts b/packages/upscalerjs/src/image.node.ts index d77303c92..cc4872cd1 100644 --- a/packages/upscalerjs/src/image.node.ts +++ b/packages/upscalerjs/src/image.node.ts @@ -85,7 +85,7 @@ export const getImageAsTensor: GetImageAsTensor = async ( // skipcq: throw getInvalidTensorError(tensor); }; -export const tensorAsBase64: TensorAsBase64 = (tf, tensor) => { +export const tensorAsBase64: TensorAsBase64 = (tf, tensor) => { const arr = tensorAsClampedArray(tf, tensor); return Buffer.from(arr).toString('base64'); }; diff --git a/packages/upscalerjs/src/loadModel.browser.ts b/packages/upscalerjs/src/loadModel.browser.ts index 191002a25..fc727df32 100644 --- a/packages/upscalerjs/src/loadModel.browser.ts +++ b/packages/upscalerjs/src/loadModel.browser.ts @@ -1,6 +1,6 @@ import type { LayersModel, } from '@tensorflow/tfjs-layers'; -import type { ModelDefinition, ModelType, ModelConfigurationInternals, GraphModel, } from '@upscalerjs/core'; -import type { ParsedModelDefinition, ModelPackage, } from './types'; +import type { ModelType, ModelConfigurationInternals, GraphModel, } from '@upscalerjs/core'; +import type { ParsedModelDefinition, LoadModel, } from './types'; import { loadTfModel, parseModelDefinition, @@ -64,10 +64,7 @@ export async function fetchModel, -): Promise => { +export const loadModel: LoadModel = async (tf, _modelDefinition) => { const modelDefinition = await _modelDefinition; try { diff --git a/packages/upscalerjs/src/loadModel.node.ts b/packages/upscalerjs/src/loadModel.node.ts index d96030a90..b68ee0a48 100644 --- a/packages/upscalerjs/src/loadModel.node.ts +++ b/packages/upscalerjs/src/loadModel.node.ts @@ -1,8 +1,7 @@ import path from 'path'; -import type { ModelDefinition, } from "@upscalerjs/core"; import { loadTfModel, parseModelDefinition, } from './model-utils'; import { resolver, } from './resolver'; -import { ParsedModelDefinition, ModelPackage, } from './types'; +import { ParsedModelDefinition, LoadModel, } from './types'; import { isValidModelDefinition, TF, @@ -42,10 +41,7 @@ export const getModelPath = (modelConfiguration: ParsedModelDefinition): string return `file://${path.resolve(moduleFolder, _internals.path)}`; }; -export const loadModel = async ( - tf: TF, - _modelDefinition: Promise, -): Promise => { +export const loadModel: LoadModel = async (tf, _modelDefinition) => { const modelDefinition = await _modelDefinition; try { isValidModelDefinition(modelDefinition); diff --git a/packages/upscalerjs/src/types.ts b/packages/upscalerjs/src/types.ts index e0dc700f6..9e2ffb792 100644 --- a/packages/upscalerjs/src/types.ts +++ b/packages/upscalerjs/src/types.ts @@ -90,7 +90,12 @@ export type CheckValidEnvironment = (input: T, opts: { progressOutput?: ResultFormat; }) => void; export type GetImageAsTensor = (tf: T, input: I) => Promise; -export type TensorAsBase64 = (tf: TF, tensor: Tensor3D) => string; +export type TensorAsBase64 = (tf: T, tensor: Tensor3D) => string; +export type LoadModel = (tf: T, _modelDefinition: Promise) => Promise; +export type GetUpscaleOptions = (args?: Omit & { + output?: unknown; + progressOutput?: unknown +}) => PrivateUpscaleArgs; export type Coordinate = [number, number]; diff --git a/packages/upscalerjs/src/upscale.ts b/packages/upscalerjs/src/upscale.ts index e63fb9f60..c0c768de4 100644 --- a/packages/upscalerjs/src/upscale.ts +++ b/packages/upscalerjs/src/upscale.ts @@ -300,7 +300,7 @@ export async function* upscale( interface InternalConfig { checkValidEnvironment: CheckValidEnvironment; getImageAsTensor: GetImageAsTensor, - tensorAsBase64: TensorAsBase64, + tensorAsBase64: TensorAsBase64; } export function cancellableUpscale( diff --git a/packages/upscalerjs/src/upscaler.ts b/packages/upscalerjs/src/upscaler.ts index 5a89dbb9a..ef441f7bc 100644 --- a/packages/upscalerjs/src/upscaler.ts +++ b/packages/upscalerjs/src/upscaler.ts @@ -91,10 +91,10 @@ export class Upscaler { this._opts = { ...opts, }; - this._model = loadModel(tf, getModel(tf, this._opts.model || DEFAULT_MODEL)); + this._model = loadModel(this.tf, getModel(this.tf, this._opts.model || DEFAULT_MODEL)); this.ready = new Promise((resolve, reject) => { this._model.then(() => cancellableWarmup( - tf, + this.tf, this._model, (this._opts.warmupSizes || []), undefined, @@ -161,7 +161,7 @@ export class Upscaler { ) { await this.ready; const modelPackage = await this._model; - return cancellableUpscale(tf, image, getUpscaleOptions(options), { + return cancellableUpscale(this.tf, image, getUpscaleOptions(options), { ...modelPackage, signal: this._abortController.signal, }, { @@ -195,7 +195,7 @@ export class Upscaler { warmup = async (warmupSizes: WarmupSizes = [], options?: WarmupArgs): Promise => { await this.ready; return cancellableWarmup( - tf, + this.tf, this._model, warmupSizes, options, {