diff --git a/tfjs-node/src/kernels/TopK.ts b/tfjs-node/src/kernels/TopK.ts index 9b581ffbd40..9f77169e1f0 100644 --- a/tfjs-node/src/kernels/TopK.ts +++ b/tfjs-node/src/kernels/TopK.ts @@ -16,7 +16,6 @@ */ import {KernelConfig, scalar, TopK, TopKAttrs, TopKInputs} from '@tensorflow/tfjs'; -import {isNullOrUndefined} from 'util'; import {createTensorsTypeOpAttr, NodeJSKernelBackend} from '../nodejs_kernel_backend'; @@ -28,8 +27,8 @@ export const topKConfig: KernelConfig = { const backend = args.backend as NodeJSKernelBackend; const {k, sorted} = args.attrs as unknown as TopKAttrs; - const kCount = isNullOrUndefined(k) ? 1 : k; - const isSorted = isNullOrUndefined(sorted) ? true : sorted; + const kCount = k ?? 1; + const isSorted = sorted ?? true; const opAttrs = [ {name: 'sorted', type: backend.binding.TF_ATTR_BOOL, value: isSorted}, createTensorsTypeOpAttr('T', x.dtype), diff --git a/tfjs-node/src/nodejs_kernel_backend.ts b/tfjs-node/src/nodejs_kernel_backend.ts index a399d31bfa0..a192f313854 100644 --- a/tfjs-node/src/nodejs_kernel_backend.ts +++ b/tfjs-node/src/nodejs_kernel_backend.ts @@ -17,7 +17,6 @@ import * as tf from '@tensorflow/tfjs'; import {backend_util, BackendTimingInfo, DataId, DataType, KernelBackend, ModelTensorInfo, Rank, Scalar, scalar, ScalarLike, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, TensorInfo, tidy, util} from '@tensorflow/tfjs'; -import {isArray, isNullOrUndefined} from 'util'; import {encodeInt32ArrayAsInt64, Int64Scalar} from './int64_tensors'; import {TensorMetadata, TFEOpAttr, TFJSBinding} from './tfjs_binding'; @@ -740,7 +739,7 @@ export function getTFDType(dataType: tf.DataType): number { export function createTensorsTypeOpAttr( attrName: string, tensorsOrDtype: tf.Tensor|tf.Tensor[]|tf.DataType): TFEOpAttr { - if (isNullOrUndefined(tensorsOrDtype)) { + if (tensorsOrDtype === null || tensorsOrDtype === undefined) { throw new Error('Invalid input tensors value.'); } return { @@ -757,7 +756,7 @@ export function createTensorsTypeOpAttr( export function createOpAttr( attrName: string, tensorsOrDtype: tf.Tensor|tf.Tensor[]|tf.DataType, value: ScalarLike): TFEOpAttr { - if (isNullOrUndefined(tensorsOrDtype)) { + if (tensorsOrDtype === null || tensorsOrDtype === undefined) { throw new Error('Invalid input tensors value.'); } return {name: attrName, type: nodeBackend().binding.TF_BOOL, value}; @@ -765,10 +764,10 @@ export function createOpAttr( /** Returns the dtype number for a single or list of input Tensors. */ function getTFDTypeForInputs(tensors: tf.Tensor|tf.Tensor[]): number { - if (isNullOrUndefined(tensors)) { + if (tensors === null || tensors === undefined) { throw new Error('Invalid input tensors value.'); } - if (isArray(tensors)) { + if (Array.isArray(tensors)) { for (let i = 0; i < tensors.length; i++) { return getTFDType(tensors[i].dtype); }