Skip to content

Commit

Permalink
Enforce generic types with function definitions (#1172)
Browse files Browse the repository at this point in the history
* Enforce generic types with function definitions
  • Loading branch information
thekevinscott authored Oct 2, 2023
1 parent 6eb0436 commit 228f5d8
Show file tree
Hide file tree
Showing 9 changed files with 26 additions and 28 deletions.
8 changes: 4 additions & 4 deletions packages/upscalerjs/src/args.browser.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { BASE64, UpscaleArgs, TENSOR, PrivateUpscaleArgs, } from "./types";
import { BASE64, TENSOR, GetUpscaleOptions, } from "./types";

const getOutputOption = (output?: unknown): TENSOR | BASE64 => {
if (output === 'tensor') {
Expand All @@ -7,14 +7,14 @@ const getOutputOption = (output?: unknown): TENSOR | BASE64 => {
return 'base64';
};

export function getUpscaleOptions({
export const getUpscaleOptions: GetUpscaleOptions = ({
output,
progressOutput,
...options
}: Omit<UpscaleArgs, 'output' | 'progressOutput'> & { output?: unknown; progressOutput?: unknown } = {}): PrivateUpscaleArgs {
} = {}) => {
return {
...options,
output: getOutputOption(output),
progressOutput: getOutputOption(progressOutput || output),
};
}
};
8 changes: 4 additions & 4 deletions packages/upscalerjs/src/args.node.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { BASE64, UpscaleArgs, TENSOR, PrivateUpscaleArgs, } from "./types";
import { BASE64, TENSOR, GetUpscaleOptions, } from "./types";

const getOutputOption = (output?: unknown): TENSOR | BASE64 => {
if (output === 'base64') {
Expand All @@ -7,14 +7,14 @@ const getOutputOption = (output?: unknown): TENSOR | BASE64 => {
return 'tensor';
};

export function getUpscaleOptions({
export const getUpscaleOptions: GetUpscaleOptions = ({
output,
progressOutput,
...options
}: Omit<UpscaleArgs, 'output' | 'progressOutput'> & { output?: unknown; progressOutput?: unknown } = {}): PrivateUpscaleArgs {
} = {}) => {
return {
...options,
output: getOutputOption(output),
progressOutput: getOutputOption(progressOutput || output),
};
}
};
2 changes: 1 addition & 1 deletion packages/upscalerjs/src/image.browser.ts
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ export const isHTMLImageElement = (pixels: Input): pixels is HTMLImageElement =>
}
};

export const tensorAsBase64: TensorAsBase64 = (tf, tensor) => {
export const tensorAsBase64: TensorAsBase64<typeof tf> = (tf, tensor) => {
const arr = tensorAsClampedArray(tf, tensor);
const [height, width, ] = tensor.shape;
const imageData = new ImageData(width, height);
Expand Down
2 changes: 1 addition & 1 deletion packages/upscalerjs/src/image.node.ts
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ export const getImageAsTensor: GetImageAsTensor<TFN, Input> = async ( // skipcq:
throw getInvalidTensorError(tensor);
};

export const tensorAsBase64: TensorAsBase64 = (tf, tensor) => {
export const tensorAsBase64: TensorAsBase64<TFN> = (tf, tensor) => {
const arr = tensorAsClampedArray(tf, tensor);
return Buffer.from(arr).toString('base64');
};
Expand Down
9 changes: 3 additions & 6 deletions packages/upscalerjs/src/loadModel.browser.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import type { LayersModel, } from '@tensorflow/tfjs-layers';
import type { ModelDefinition, ModelType, ModelConfigurationInternals, GraphModel, } from '@upscalerjs/core';
import type { ParsedModelDefinition, ModelPackage, } from './types';
import type { ModelType, ModelConfigurationInternals, GraphModel, } from '@upscalerjs/core';
import type { ParsedModelDefinition, LoadModel, } from './types';
import {
loadTfModel,
parseModelDefinition,
Expand Down Expand Up @@ -64,10 +64,7 @@ export async function fetchModel<M extends ModelType, R = M extends 'graph' ? Gr
throw getLoadModelErrorMessage(errs, modelPath || _internals.path, _internals);
}

export const loadModel = async (
tf: TF,
_modelDefinition: Promise<ModelDefinition>,
): Promise<ModelPackage> => {
export const loadModel: LoadModel<TF> = async (tf, _modelDefinition) => {
const modelDefinition = await _modelDefinition;

try {
Expand Down
8 changes: 2 additions & 6 deletions packages/upscalerjs/src/loadModel.node.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import path from 'path';
import type { ModelDefinition, } from "@upscalerjs/core";
import { loadTfModel, parseModelDefinition, } from './model-utils';
import { resolver, } from './resolver';
import { ParsedModelDefinition, ModelPackage, } from './types';
import { ParsedModelDefinition, LoadModel, } from './types';
import {
isValidModelDefinition,
TF,
Expand Down Expand Up @@ -42,10 +41,7 @@ export const getModelPath = (modelConfiguration: ParsedModelDefinition): string
return `file://${path.resolve(moduleFolder, _internals.path)}`;
};

export const loadModel = async (
tf: TF,
_modelDefinition: Promise<ModelDefinition>,
): Promise<ModelPackage> => {
export const loadModel: LoadModel<TF> = async (tf, _modelDefinition) => {
const modelDefinition = await _modelDefinition;
try {
isValidModelDefinition(modelDefinition);
Expand Down
7 changes: 6 additions & 1 deletion packages/upscalerjs/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,12 @@ export type CheckValidEnvironment<T> = (input: T, opts: {
progressOutput?: ResultFormat;
}) => void;
export type GetImageAsTensor<T extends TF, I> = (tf: T, input: I) => Promise<Tensor4D>;
export type TensorAsBase64 = (tf: TF, tensor: Tensor3D) => string;
export type TensorAsBase64<T extends TF> = (tf: T, tensor: Tensor3D) => string;
export type LoadModel<T extends TF> = (tf: T, _modelDefinition: Promise<ModelDefinition>) => Promise<ModelPackage>;
export type GetUpscaleOptions = (args?: Omit<UpscaleArgs, 'output' | 'progressOutput'> & {
output?: unknown;
progressOutput?: unknown
}) => PrivateUpscaleArgs;

export type Coordinate = [number, number];

Expand Down
2 changes: 1 addition & 1 deletion packages/upscalerjs/src/upscale.ts
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ export async function* upscale<T extends TF, I>(
interface InternalConfig<T extends TF, I> {
checkValidEnvironment: CheckValidEnvironment<I>;
getImageAsTensor: GetImageAsTensor<T, I>,
tensorAsBase64: TensorAsBase64,
tensorAsBase64: TensorAsBase64<T>;
}

export function cancellableUpscale<T extends TF, I>(
Expand Down
8 changes: 4 additions & 4 deletions packages/upscalerjs/src/upscaler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,10 @@ export class Upscaler {
this._opts = {
...opts,
};
this._model = loadModel(tf, getModel(tf, this._opts.model || DEFAULT_MODEL));
this._model = loadModel(this.tf, getModel(this.tf, this._opts.model || DEFAULT_MODEL));
this.ready = new Promise((resolve, reject) => {
this._model.then(() => cancellableWarmup(
tf,
this.tf,
this._model,
(this._opts.warmupSizes || []),
undefined,
Expand Down Expand Up @@ -161,7 +161,7 @@ export class Upscaler {
) {
await this.ready;
const modelPackage = await this._model;
return cancellableUpscale(tf, image, getUpscaleOptions(options), {
return cancellableUpscale(this.tf, image, getUpscaleOptions(options), {
...modelPackage,
signal: this._abortController.signal,
}, {
Expand Down Expand Up @@ -195,7 +195,7 @@ export class Upscaler {
warmup = async (warmupSizes: WarmupSizes = [], options?: WarmupArgs): Promise<void> => {
await this.ready;
return cancellableWarmup(
tf,
this.tf,
this._model,
warmupSizes,
options, {
Expand Down

0 comments on commit 228f5d8

Please sign in to comment.