From 87176227224e5c0f9bfcb484b714b6ff02645364 Mon Sep 17 00:00:00 2001 From: Ard Oerlemans Date: Fri, 2 Apr 2021 19:35:49 -0700 Subject: [PATCH] Adding MoveNet to Pose Detection API (#627) * Adding MoveNet skeleton and support files * Adding MoveNet implementation and supporting files * Lint MoveNet files (except for line length) Line lengths will be updated in an upcoming commit. * Use correct formatting, which includes line lengths * Remove test URL * Addressing code review comments * Fix MoveNet visualization in live demo * Updates after review comments MoveNet determineCropRegion refactored and moved OneEuroFilter to common filters directory * Remove configurable keypoint threshold This threshold was mostly used for internal cropping logic. * Merge Model and KeypointModel, plus additional updates after review * Move one euro filter back to MoveNet with TODO * Add test for MoveNet * Use MoveNet as default model in pose demo * Add MoveNet model type selection to pose demo * Run models at highest possible speed in pose demo * Fix error in cropping code * Remove comment * Updates to resolve review comments * Fix CI test * Simplify pose array creation --- pose-detection/demo/src/camera.js | 4 - pose-detection/demo/src/index.js | 37 +- pose-detection/demo/src/option_panel.js | 31 +- pose-detection/demo/src/params.js | 6 +- pose-detection/src/constants.ts | 19 + pose-detection/src/create_detector.ts | 8 +- pose-detection/src/index.ts | 12 + pose-detection/src/movenet/constants.ts | 42 ++ pose-detection/src/movenet/detector.ts | 414 ++++++++++++++++++ pose-detection/src/movenet/detector_utils.ts | 51 +++ pose-detection/src/movenet/index.ts | 18 + pose-detection/src/movenet/movenet_test.ts | 44 ++ .../src/movenet/robust_one_euro_filter.ts | 123 ++++++ pose-detection/src/movenet/types.ts | 41 ++ pose-detection/src/types.ts | 5 +- pose-detection/src/util.ts | 2 + 16 files changed, 826 insertions(+), 31 deletions(-) create mode 100644 pose-detection/src/movenet/constants.ts create mode 100644 pose-detection/src/movenet/detector.ts create mode 100644 pose-detection/src/movenet/detector_utils.ts create mode 100644 pose-detection/src/movenet/index.ts create mode 100644 pose-detection/src/movenet/movenet_test.ts create mode 100644 pose-detection/src/movenet/robust_one_euro_filter.ts create mode 100644 pose-detection/src/movenet/types.ts diff --git a/pose-detection/demo/src/camera.js b/pose-detection/demo/src/camera.js index 47805f8beb..5fbc16fdc0 100644 --- a/pose-detection/demo/src/camera.js +++ b/pose-detection/demo/src/camera.js @@ -24,10 +24,6 @@ export class Camera { this.video = document.getElementById('video'); this.canvas = document.getElementById('output'); this.ctx = this.canvas.getContext('2d'); - - // The video frame rate may be lower than the browser animate frame - // rate. We use this to avoid processing the same frame twice. - this.lastVideoTime = 0; } /** diff --git a/pose-detection/demo/src/index.js b/pose-detection/demo/src/index.js index ef81766912..d60269724d 100644 --- a/pose-detection/demo/src/index.js +++ b/pose-detection/demo/src/index.js @@ -40,6 +40,13 @@ async function createDetector(model) { case posedetection.SupportedModels.MediapipeBlazepose: return posedetection.createDetector( STATE.model.model, {quantBytes: 4, upperBodyOnly: false}); + case posedetection.SupportedModels.MoveNet: + const modelType = + STATE.model[STATE.model.model].modelType == 'Lightning' ? + posedetection.movenet.modelType.SINGLEPOSE_LIGHTNING : + posedetection.movenet.modelType.SINGLEPOSE_THUNDER; + return posedetection.createDetector( + STATE.model.model, {modelType: modelType}); } } @@ -58,28 +65,24 @@ async function checkGuiUpdate() { } async function renderResult() { - if (camera.video.currentTime !== camera.lastVideoTime) { - camera.lastVideoTime = camera.video.currentTime; - - // FPS only counts the time it takes to finish estimatePoses. - stats.begin(); + // FPS only counts the time it takes to finish estimatePoses. + stats.begin(); - const poses = await detector.estimatePoses( - camera.video, {maxPoses: 1, flipHorizontal: false}); + const poses = await detector.estimatePoses( + camera.video, {maxPoses: 1, flipHorizontal: false}); - stats.end(); + stats.end(); - camera.drawCtx(); + camera.drawCtx(); - // The null check makes sure the UI is not in the middle of changing to a - // different model. If changeToModel is non-null, the result is from an - // old model, which shouldn't be rendered. - if (poses.length > 0 && STATE.changeToModel == null) { - const shouldScale = STATE.model.model === - posedetection.SupportedModels.MediapipeBlazepose; + // The null check makes sure the UI is not in the middle of changing to a + // different model. If changeToModel is non-null, the result is from an + // old model, which shouldn't be rendered. + if (poses.length > 0 && STATE.changeToModel == null) { + const shouldScale = + STATE.model.model === posedetection.SupportedModels.MediapipeBlazepose; - camera.drawResult(poses[0], shouldScale); - } + camera.drawResult(poses[0], shouldScale); } } diff --git a/pose-detection/demo/src/option_panel.js b/pose-detection/demo/src/option_panel.js index 7f1310521f..b81139dada 100644 --- a/pose-detection/demo/src/option_panel.js +++ b/pose-detection/demo/src/option_panel.js @@ -44,10 +44,17 @@ export function setupDatGui() { case posedetection.SupportedModels.PoseNet: poseNetFolder.open(); blazePoseFolder.close(); + moveNetFolder.close(); break; case posedetection.SupportedModels.MediapipeBlazepose: blazePoseFolder.open(); poseNetFolder.close(); + moveNetFolder.close(); + break; + case posedetection.SupportedModels.MoveNet: + blazePoseFolder.close(); + poseNetFolder.close(); + moveNetFolder.open(); break; default: throw new Error(`${model} is not supported.`); @@ -55,13 +62,18 @@ export function setupDatGui() { }); modelFolder.open(); - // The PoseNet model config folder contains options for PoseNet config + // The MoveNet model config folder contains options for MoveNet config // settings. - const poseNetFolder = gui.addFolder('PoseNet Config'); - poseNetFolder.add( - STATE.model[posedetection.SupportedModels.PoseNet], 'scoreThreshold', 0, + const moveNetFolder = gui.addFolder('MoveNet Config'); + const moveNetTypeController = moveNetFolder.add( + STATE.model[posedetection.SupportedModels.MoveNet], 'modelType', + ['Thunder', 'Lightning']); + moveNetTypeController.onChange(type => { + STATE.changeToModel = type; + }); + moveNetFolder.add( + STATE.model[posedetection.SupportedModels.MoveNet], 'scoreThreshold', 0, 1); - poseNetFolder.open(); // The Blazepose model config folder contains options for Blazepose config // settings. @@ -70,5 +82,14 @@ export function setupDatGui() { STATE.model[posedetection.SupportedModels.MediapipeBlazepose], 'scoreThreshold', 0, 1); + // The PoseNet model config folder contains options for PoseNet config + // settings. + const poseNetFolder = gui.addFolder('PoseNet Config'); + poseNetFolder.add( + STATE.model[posedetection.SupportedModels.PoseNet], 'scoreThreshold', 0, + 1); + + moveNetFolder.open(); + return gui; } diff --git a/pose-detection/demo/src/params.js b/pose-detection/demo/src/params.js index dc0f1f40db..c14cb09b16 100644 --- a/pose-detection/demo/src/params.js +++ b/pose-detection/demo/src/params.js @@ -26,7 +26,7 @@ export const VIDEO_SIZE = { export const STATE = { camera: {targetFPS: 60, sizeOption: '640 X 480'}, model: { - model: posedetection.SupportedModels.PoseNet, + model: posedetection.SupportedModels.MoveNet, } }; STATE.model[posedetection.SupportedModels.MediapipeBlazepose] = { @@ -35,3 +35,7 @@ STATE.model[posedetection.SupportedModels.MediapipeBlazepose] = { STATE.model[posedetection.SupportedModels.PoseNet] = { scoreThreshold: 0.5 }; +STATE.model[posedetection.SupportedModels.MoveNet] = { + modelType: 'Thunder', + scoreThreshold: 0.3 +}; diff --git a/pose-detection/src/constants.ts b/pose-detection/src/constants.ts index c406a4c783..cd3368226e 100644 --- a/pose-detection/src/constants.ts +++ b/pose-detection/src/constants.ts @@ -53,3 +53,22 @@ export const BLAZEPOSE_CONNECTED_KEYPOINTS_PAIRS = [ [18, 20], [23, 25], [23, 24], [24, 26], [25, 27], [26, 28], [27, 29], [28, 30], [27, 31], [28, 32], [29, 31], [30, 32] ]; +export const COCO_KEYPOINTS_NAMED_MAP: {[index: string]: number} = { + nose: 0, + left_eye: 1, + right_eye: 2, + left_ear: 3, + right_ear: 4, + left_shoulder: 5, + right_shoulder: 6, + left_elbow: 7, + right_elbow: 8, + left_wrist: 9, + right_wrist: 10, + left_hip: 11, + right_hip: 12, + left_knee: 13, + right_knee: 14, + left_ankle: 15, + right_ankle: 16 +}; diff --git a/pose-detection/src/create_detector.ts b/pose-detection/src/create_detector.ts index 3832efcb0a..0e42b45334 100644 --- a/pose-detection/src/create_detector.ts +++ b/pose-detection/src/create_detector.ts @@ -17,6 +17,8 @@ import {BlazeposeDetector} from './blazepose/detector'; import {BlazeposeModelConfig} from './blazepose/types'; +import {MoveNetDetector} from './movenet/detector'; +import {MoveNetModelConfig} from './movenet/types'; import {PoseDetector} from './pose_detector'; import {PosenetDetector} from './posenet/detector'; import {PosenetModelConfig} from './posenet/types'; @@ -29,13 +31,15 @@ import {SupportedModels} from './types'; */ export async function createDetector( model: SupportedModels, - modelConfig: PosenetModelConfig| - BlazeposeModelConfig): Promise { + modelConfig: PosenetModelConfig|BlazeposeModelConfig| + MoveNetModelConfig): Promise { switch (model) { case SupportedModels.PoseNet: return PosenetDetector.load(modelConfig as PosenetModelConfig); case SupportedModels.MediapipeBlazepose: return BlazeposeDetector.load(modelConfig as BlazeposeModelConfig); + case SupportedModels.MoveNet: + return MoveNetDetector.load(modelConfig as MoveNetModelConfig); default: throw new Error(`${model} is not a supported model name.`); } diff --git a/pose-detection/src/index.ts b/pose-detection/src/index.ts index 44f820a821..17d456a185 100644 --- a/pose-detection/src/index.ts +++ b/pose-detection/src/index.ts @@ -18,9 +18,11 @@ // Entry point to create a new detector instance. export {BlazeposeEstimationConfig, BlazeposeModelConfig} from './blazepose/types'; export {createDetector} from './create_detector'; +export {MoveNetEstimationConfig, MoveNetModelConfig} from './movenet/types'; // PoseDetector class. export {PoseDetector} from './pose_detector'; export {PoseNetEstimationConfig, PosenetModelConfig} from './posenet/types'; + // Supported models enum. export * from './types'; @@ -28,3 +30,13 @@ export * from './types'; // Utils for rendering. import * as util from './util'; export {util}; + +// MoveNet model types. +import {SINGLEPOSE_LIGHTNING, SINGLEPOSE_THUNDER} from './movenet/constants'; +const movenet = { + modelType: { + 'SINGLEPOSE_LIGHTNING': SINGLEPOSE_LIGHTNING, + 'SINGLEPOSE_THUNDER': SINGLEPOSE_THUNDER + } +}; +export {movenet}; diff --git a/pose-detection/src/movenet/constants.ts b/pose-detection/src/movenet/constants.ts new file mode 100644 index 0000000000..9d61971b8a --- /dev/null +++ b/pose-detection/src/movenet/constants.ts @@ -0,0 +1,42 @@ +/** + * @license + * Copyright 2021 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {MoveNetEstimationConfig, MoveNetModelConfig} from './types'; + +export const SINGLEPOSE_LIGHTNING = 'SinglePose.Lightning'; +export const SINGLEPOSE_THUNDER = 'SinglePose.Thunder'; + +export const VALID_MODELS = [SINGLEPOSE_LIGHTNING, SINGLEPOSE_THUNDER]; + +export const MOVENET_SINGLEPOSE_LIGHTNING_URL = + 'https://tfhub.dev/google/tfjs-model/movenet/singlepose/lightning/1'; +export const MOVENET_SINGLEPOSE_THUNDER_URL = + 'https://tfhub.dev/google/tfjs-model/movenet/singlepose/thunder/1'; + +export const MOVENET_SINGLEPOSE_LIGHTNING_RESOLUTION = 192; +export const MOVENET_SINGLEPOSE_THUNDER_RESOLUTION = 256; + +// The default configuration for loading MoveNet. +export const MOVENET_CONFIG: MoveNetModelConfig = { + modelType: SINGLEPOSE_LIGHTNING +}; + +export const MOVENET_SINGLE_POSE_ESTIMATION_CONFIG: MoveNetEstimationConfig = { + maxPoses: 1 +}; + +export const MIN_CROP_KEYPOINT_SCORE = 0.3; diff --git a/pose-detection/src/movenet/detector.ts b/pose-detection/src/movenet/detector.ts new file mode 100644 index 0000000000..bf81db2941 --- /dev/null +++ b/pose-detection/src/movenet/detector.ts @@ -0,0 +1,414 @@ +/** + * @license + * Copyright 2021 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tfc from '@tensorflow/tfjs-converter'; +import * as tf from '@tensorflow/tfjs-core'; + +import {getImageSize, toImageTensor} from '../calculators/image_utils'; +import {COCO_KEYPOINTS_NAMED_MAP} from '../constants'; +import {BasePoseDetector, PoseDetector} from '../pose_detector'; +import {InputResolution, Keypoint, Pose, PoseDetectorInput} from '../types'; + +import {MIN_CROP_KEYPOINT_SCORE, MOVENET_CONFIG, MOVENET_SINGLE_POSE_ESTIMATION_CONFIG, MOVENET_SINGLEPOSE_LIGHTNING_RESOLUTION, MOVENET_SINGLEPOSE_LIGHTNING_URL, MOVENET_SINGLEPOSE_THUNDER_RESOLUTION, MOVENET_SINGLEPOSE_THUNDER_URL, SINGLEPOSE_LIGHTNING, SINGLEPOSE_THUNDER} from './constants'; +import {validateEstimationConfig, validateModelConfig} from './detector_utils'; +import {RobustOneEuroFilter} from './robust_one_euro_filter'; +import {MoveNetEstimationConfig, MoveNetModelConfig} from './types'; + +/** + * MoveNet detector class. + */ +export class MoveNetDetector extends BasePoseDetector { + private modelInputResolution: InputResolution = {height: 0, width: 0}; + private cropRegion: number[]; + private filter: RobustOneEuroFilter; + // This will be used to calculate the actual camera fps. Starts with 30 fps + // as an assumption. + private previousFrameTime = 0; + private frameTimeDiff = 0.0333; + + // Should not be called outside. + private constructor( + private readonly moveNetModel: tfc.GraphModel, + config: MoveNetModelConfig) { + super(); + + if (config.modelType === SINGLEPOSE_LIGHTNING) { + this.modelInputResolution.width = MOVENET_SINGLEPOSE_LIGHTNING_RESOLUTION; + this.modelInputResolution.height = + MOVENET_SINGLEPOSE_LIGHTNING_RESOLUTION; + } else if (config.modelType === SINGLEPOSE_THUNDER) { + this.modelInputResolution.width = MOVENET_SINGLEPOSE_THUNDER_RESOLUTION; + this.modelInputResolution.height = MOVENET_SINGLEPOSE_THUNDER_RESOLUTION; + } + + this.filter = new RobustOneEuroFilter(); + } + + /** + * Loads the MoveNet model instance from a checkpoint. The model to be loaded + * is configurable using the config dictionary `ModelConfig`. Please find more + * details in the documentation of the `ModelConfig`. + * + * @param config `ModelConfig` dictionary that contains parameters for + * the MoveNet loading process. Please find more details of each parameter + * in the documentation of the `ModelConfig` interface. + */ + static async load(modelConfig: MoveNetModelConfig = MOVENET_CONFIG): + Promise { + const config = validateModelConfig(modelConfig); + let model: tfc.GraphModel; + if (config.modelUrl) { + model = await tfc.loadGraphModel(config.modelUrl); + } else { + let modelUrl; + if (config.modelType === SINGLEPOSE_LIGHTNING) { + modelUrl = MOVENET_SINGLEPOSE_LIGHTNING_URL; + } else if (config.modelType === SINGLEPOSE_THUNDER) { + modelUrl = MOVENET_SINGLEPOSE_THUNDER_URL; + } + model = await tfc.loadGraphModel(modelUrl, {fromTFHub: true}); + } + return new MoveNetDetector(model, config); + } + + /** + * Runs inference on an image using a model that is assumed to be a person + * keypoint model that outputs 17 keypoints. + * @param inputImage 4D tensor containing the input image. Should be of size + * [1, modelHeight, modelWidth, 3]. + * @param executeSync Whether to execute the model synchronously. + * @return An InferenceResult with keypoints and scores, or null if the + * inference call could not be executed (for example when the model was + * not initialized yet) or if it produced an unexpected tensor size. + */ + async detectKeypoints(inputImage: tf.Tensor4D, executeSync = true): + Promise { + if (!this.moveNetModel) { + return null; + } + + const numKeypoints = 17; + + let outputTensor; + if (executeSync) { + outputTensor = this.moveNetModel.execute(inputImage) as tf.Tensor; + } else { + outputTensor = + await this.moveNetModel.executeAsync(inputImage) as tf.Tensor; + } + + // We expect an output array of shape [1, 1, 17, 3] (batch, person, + // keypoint, coordinate + score). + if (!outputTensor || outputTensor.shape.length !== 4 || + outputTensor.shape[0] !== 1 || outputTensor.shape[1] !== 1 || + outputTensor.shape[2] !== numKeypoints || outputTensor.shape[3] !== 3) { + outputTensor.dispose(); + return null; + } + + const inferenceResult = outputTensor.dataSync(); + outputTensor.dispose(); + + const keypoints: Keypoint[] = []; + + for (let i = 0; i < numKeypoints; ++i) { + keypoints[i] = { + y: inferenceResult[i * 3], + x: inferenceResult[i * 3 + 1], + score: inferenceResult[i * 3 + 2] + }; + } + + return keypoints; + } + + /** + * Estimates poses for an image or video frame. + * + * This does standard ImageNet pre-processing before inferring through the + * model. The image should pixels should have values [0-255]. It returns a + * single pose. + * + * @param image + * ImageData|HTMLImageElement|HTMLCanvasElement|HTMLVideoElement The input + * image to feed through the network. + * + * @param config + * + * @return An array of `Pose`s. + */ + async estimatePoses( + image: PoseDetectorInput, + estimationConfig: + MoveNetEstimationConfig = MOVENET_SINGLE_POSE_ESTIMATION_CONFIG): + Promise { + // We only validate that maxPoses is 1. + validateEstimationConfig(estimationConfig); + + if (image == null) { + return []; + } + + const now = tf.util.now(); + if (this.previousFrameTime !== 0) { + const newSampleWeight = 0.02; + this.frameTimeDiff = (1.0 - newSampleWeight) * this.frameTimeDiff + + newSampleWeight * (now - this.previousFrameTime); + } + this.previousFrameTime = now; + + const imageTensor3D = toImageTensor(image); + const imageSize = getImageSize(imageTensor3D); + const imageTensor4D: tf.Tensor4D = tf.expandDims(imageTensor3D, 0); + + // Make sure we don't dispose the input image if it's already a tensor. + if (!(image instanceof tf.Tensor)) { + imageTensor3D.dispose(); + } + + let keypoints: Keypoint[] = null; + + // If we have a cropRegion from a previous run, try to run the model on an + // image crop first. + if (this.cropRegion != null) { + const croppedImage = tf.tidy(() => { + // Crop region is a [batch, 4] size tensor. + const cropRegionTensor = tf.tensor2d([this.cropRegion]); + // The batch index that the crop should operate on. A [batch] size + // tensor. + const boxInd: tf.Tensor1D = tf.zeros([1], 'int32'); + // Target size of each crop. + const cropSize: [number, number] = + [this.modelInputResolution.height, this.modelInputResolution.width]; + return tf.cast( + tf.image.cropAndResize( + imageTensor4D, cropRegionTensor, boxInd, cropSize, 'bilinear', + 0), + 'int32'); + }); + + // Run cropModel. Model will dispose croppedImage. + keypoints = await this.detectKeypoints(croppedImage); + croppedImage.dispose(); + + // Convert keypoints to image coordinates. cropRegion is stored as + // top-left and bottom-right coordinates: [y1, x1, y2, x2]. + const cropHeight = this.cropRegion[2] - this.cropRegion[0]; + const cropWidth = this.cropRegion[3] - this.cropRegion[1]; + for (let i = 0; i < keypoints.length; ++i) { + keypoints[i].y = this.cropRegion[0] + keypoints[i].y * cropHeight; + keypoints[i].x = this.cropRegion[1] + keypoints[i].x * cropWidth; + } + + // Apply the sequential filter before estimating the cropping area + // to make it more stable. + this.arrayToKeypoints( + this.filter.insert( + this.keypointsToArray(keypoints), 1.0 / this.frameTimeDiff), + keypoints); + + // Determine next crop region based on detected keypoints and if a crop + // region is not detected, this will trigger the model to run on the full + // image. + let newCropRegion = this.determineCropRegion( + keypoints, imageTensor4D.shape[1], imageTensor4D.shape[2]); + + // Use exponential filter on the cropping region to make it less jittery. + if (newCropRegion != null) { + // TODO(ardoerlemans): Use existing low pass filter from shared + // calculators. + const oldCropRegionWeight = 0.1; + newCropRegion = newCropRegion.map(x => x * (1 - oldCropRegionWeight)); + this.cropRegion = this.cropRegion.map(x => x * oldCropRegionWeight); + this.cropRegion = this.cropRegion.map((e, i) => e + newCropRegion[i]); + } else { + this.cropRegion = null; + } + } else { + // No cropRegion was available from a previous run, so run the model on + // the full image. + const resizedImage: tf.Tensor = tf.image.resizeBilinear( + imageTensor4D, + [this.modelInputResolution.height, this.modelInputResolution.width]); + const resizedImageInt = tf.cast(resizedImage, 'int32') as tf.Tensor4D; + resizedImage.dispose(); + + // Model will dispose resizedImageInt. + keypoints = await this.detectKeypoints(resizedImageInt, true); + resizedImageInt.dispose(); + + this.arrayToKeypoints( + this.filter.insert( + this.keypointsToArray(keypoints), 1.0 / this.frameTimeDiff), + keypoints); + + // Determine crop region based on detected keypoints. + this.cropRegion = this.determineCropRegion( + keypoints, imageSize.height, imageSize.width); + } + + imageTensor4D.dispose(); + + // Convert keypoint coordinates from normalized coordinates to image space. + for (let i = 0; i < keypoints.length; ++i) { + keypoints[i].y *= imageSize.height; + keypoints[i].x *= imageSize.width; + } + + const poses: Pose[] = []; + poses[0] = {keypoints}; + + return poses; + } + + torsoVisible(keypoints: Keypoint[]): boolean { + return ( + keypoints[COCO_KEYPOINTS_NAMED_MAP['left_hip']].score > + MIN_CROP_KEYPOINT_SCORE && + keypoints[COCO_KEYPOINTS_NAMED_MAP['right_hip']].score > + MIN_CROP_KEYPOINT_SCORE && + keypoints[COCO_KEYPOINTS_NAMED_MAP['left_shoulder']].score > + MIN_CROP_KEYPOINT_SCORE && + keypoints[COCO_KEYPOINTS_NAMED_MAP['right_shoulder']].score > + MIN_CROP_KEYPOINT_SCORE); + } + + /** + * Calculates the maximum distance from each keypoints to the center location. + * The function returns the maximum distances from the two sets of keypoints: + * full 17 keypoints and 4 torso keypoints. The returned information will be + * used to determine the crop size. See determineCropRegion for more detail. + * + * @param targetKeypoints Maps from joint names to coordinates. + */ + determineTorsoAndBodyRange( + keypoints: Keypoint[], targetKeypoints: {[index: string]: number[]}, + centerY: number, centerX: number): number[] { + const torsoJoints = + ['left_shoulder', 'right_shoulder', 'left_hip', 'right_hip']; + let maxTorsoYrange = 0.0; + let maxTorsoXrange = 0.0; + for (let i = 0; i < torsoJoints.length; i++) { + const distY = Math.abs(centerY - targetKeypoints[torsoJoints[i]][0]); + const distX = Math.abs(centerX - targetKeypoints[torsoJoints[i]][1]); + if (distY > maxTorsoYrange) { + maxTorsoYrange = distY; + } + if (distX > maxTorsoXrange) { + maxTorsoXrange = distX; + } + } + let maxBodyYrange = 0.0; + let maxBodyXrange = 0.0; + for (const key of Object.keys(targetKeypoints)) { + if (keypoints[COCO_KEYPOINTS_NAMED_MAP[key]].score < + MIN_CROP_KEYPOINT_SCORE) { + continue; + } + const distY = Math.abs(centerY - targetKeypoints[key][0]); + const distX = Math.abs(centerX - targetKeypoints[key][1]); + if (distY > maxBodyYrange) { + maxBodyYrange = distY; + } + if (distX > maxBodyXrange) { + maxBodyXrange = distX; + } + } + + return [maxTorsoYrange, maxTorsoXrange, maxBodyYrange, maxBodyXrange]; + } + + /** + * Determines the region to crop the image for the model to run inference on. + * The algorithm uses the detected joints from the previous frame to estimate + * the square region that encloses the full body of the target person and + * centers at the midpoint of two hip joints. The crop size is determined by + * the distances between each joints and the center point. + * When the model is not confident with the four torso joint predictions, the + * function returns a default crop which is the full image padded to square. + */ + determineCropRegion( + keypoints: Keypoint[], imageHeight: number, imageWidth: number) { + const targetKeypoints: {[index: string]: number[]} = {}; + + for (const key of Object.keys(COCO_KEYPOINTS_NAMED_MAP)) { + targetKeypoints[key] = [ + keypoints[COCO_KEYPOINTS_NAMED_MAP[key]].y * imageHeight, + keypoints[COCO_KEYPOINTS_NAMED_MAP[key]].x * imageWidth + ]; + } + + if (this.torsoVisible(keypoints)) { + const centerY = + (targetKeypoints['left_hip'][0] + targetKeypoints['right_hip'][0]) / + 2; + const centerX = + (targetKeypoints['left_hip'][1] + targetKeypoints['right_hip'][1]) / + 2; + + const [maxTorsoYrange, maxTorsoXrange, maxBodyYrange, maxBodyXrange] = + this.determineTorsoAndBodyRange( + keypoints, targetKeypoints, centerY, centerX); + + let cropLengthHalf = Math.max( + maxTorsoXrange * 2.0, maxTorsoYrange * 2.0, maxBodyYrange * 1.2, + maxBodyXrange * 1.2); + + cropLengthHalf = Math.min( + cropLengthHalf, + Math.max( + centerX, imageWidth - centerX, centerY, imageHeight - centerY)); + + let cropCorner = [centerY - cropLengthHalf, centerX - cropLengthHalf]; + + if (cropLengthHalf > Math.max(imageWidth, imageHeight) / 2) { + cropLengthHalf = Math.max(imageWidth, imageHeight) / 2; + cropCorner = [0.0, 0.0]; + } + + const cropLength = cropLengthHalf * 2; + const cropRegion = [ + cropCorner[0] / imageHeight, cropCorner[1] / imageWidth, + (cropCorner[0] + cropLength) / imageHeight, + (cropCorner[1] + cropLength) / imageWidth + ]; + return cropRegion; + } else { + return null; + } + } + + keypointsToArray(keypoints: Keypoint[]) { + const values: number[] = []; + for (let i = 0; i < 17; ++i) { + values[i * 2] = keypoints[i].y; + values[i * 2 + 1] = keypoints[i].x; + } + return values; + } + + arrayToKeypoints(values: number[], keypoints: Keypoint[]) { + for (let i = 0; i < 17; ++i) { + keypoints[i].y = values[i * 2]; + keypoints[i].x = values[i * 2 + 1]; + } + } + + dispose() { + this.moveNetModel.dispose(); + } +} diff --git a/pose-detection/src/movenet/detector_utils.ts b/pose-detection/src/movenet/detector_utils.ts new file mode 100644 index 0000000000..9cf47c7a69 --- /dev/null +++ b/pose-detection/src/movenet/detector_utils.ts @@ -0,0 +1,51 @@ +/** + * @license + * Copyright 2021 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {MOVENET_CONFIG, MOVENET_SINGLE_POSE_ESTIMATION_CONFIG, VALID_MODELS} from './constants'; +import {MoveNetEstimationConfig, MoveNetModelConfig} from './types'; + +export function validateModelConfig(modelConfig: MoveNetModelConfig): + MoveNetModelConfig { + const config = modelConfig == null ? MOVENET_CONFIG : {...modelConfig}; + + if (!modelConfig.modelType) { + modelConfig.modelType = 'SinglePose.Lightning'; + } else if (VALID_MODELS.indexOf(config.modelType) < 0) { + throw new Error( + `Invalid architecture ${config.modelType}. ` + + `Should be one of ${VALID_MODELS}`); + } + + return config; +} + +export function validateEstimationConfig( + estimationConfig: MoveNetEstimationConfig): MoveNetEstimationConfig { + const config = estimationConfig == null ? + MOVENET_SINGLE_POSE_ESTIMATION_CONFIG : + {...estimationConfig}; + + if (!config.maxPoses) { + config.maxPoses = 1; + } + + if (config.maxPoses <= 0 || config.maxPoses > 1) { + throw new Error(`Invalid maxPoses ${config.maxPoses}. Should be 1.`); + } + + return config; +} diff --git a/pose-detection/src/movenet/index.ts b/pose-detection/src/movenet/index.ts new file mode 100644 index 0000000000..b83db64e54 --- /dev/null +++ b/pose-detection/src/movenet/index.ts @@ -0,0 +1,18 @@ +/** + * @license + * Copyright 2021 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +export {MoveNetDetector} from './detector'; diff --git a/pose-detection/src/movenet/movenet_test.ts b/pose-detection/src/movenet/movenet_test.ts new file mode 100644 index 0000000000..aa72da3e5d --- /dev/null +++ b/pose-detection/src/movenet/movenet_test.ts @@ -0,0 +1,44 @@ +/** + * @license + * Copyright 2021 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tf from '@tensorflow/tfjs-core'; +// tslint:disable-next-line: no-imports-from-dist +import {ALL_ENVS, describeWithFlags} from '@tensorflow/tfjs-core/dist/jasmine_util'; + +import * as poseDetection from '../index'; + +import {SINGLEPOSE_LIGHTNING} from './constants'; + +describeWithFlags('MoveNet', ALL_ENVS, () => { + let detector: poseDetection.PoseDetector; + beforeEach(async () => { + // Note: this makes a network request for model assets. + detector = await poseDetection.createDetector( + poseDetection.SupportedModels.MoveNet, + {modelType: SINGLEPOSE_LIGHTNING}); + }); + + it('estimatePoses does not leak memory', async () => { + const input: tf.Tensor3D = tf.zeros([128, 128, 3]); + + const beforeTensors = tf.memory().numTensors; + + await detector.estimatePoses(input); + + expect(tf.memory().numTensors).toEqual(beforeTensors); + }); +}); diff --git a/pose-detection/src/movenet/robust_one_euro_filter.ts b/pose-detection/src/movenet/robust_one_euro_filter.ts new file mode 100644 index 0000000000..46d73c426e --- /dev/null +++ b/pose-detection/src/movenet/robust_one_euro_filter.ts @@ -0,0 +1,123 @@ +/** + * @license + * Copyright 2021 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +// TODO(ardoerlemans): This filter should be merged with the existing +// one euro filter in pose_detection/src/calculators once PR #632 has been +// merged. + +/** + * Exponentially weighted moving average filter. + * https://en.wikipedia.org/wiki/Moving_average + */ +class EWMA { + private updateRate: number; + private state: number[]; + + constructor(updateRate: number) { + this.updateRate = updateRate; + this.state = []; + } + + insert(observations: number[]) { + if (!this.state.length) { + this.state = observations.slice(); + } else { + for (let obsIdx = 0; obsIdx < observations.length; obsIdx += 1) { + this.state[obsIdx] = (1 - this.updateRate) * this.state[obsIdx] + + this.updateRate * observations[obsIdx]; + } + } + return this.state; + } +} + +/** + * One-euro filter that works on arrays of numbers. + * https://hal.inria.fr/hal-00670496/document + */ +export class RobustOneEuroFilter { + private updateRateOffset: number; + private updateRateSlope: number; + private fps: number; + private thresholdOffset: number; + private thresholdSlope: number; + private prevObs: number[]; + private speed: number[]; + private speedEwma: EWMA; + private threshold: number[]; + private state: number[]; + + constructor( + updateRateOffset = 40.0, updateRateSlope = 4e3, speedUpdateRate = 0.5, + fps = 30, thresholdOffset = 0.5, thresholdSlope = 5.0) { + this.updateRateOffset = updateRateOffset; + this.updateRateSlope = updateRateSlope; + this.fps = fps; + this.thresholdOffset = thresholdOffset; + this.thresholdSlope = thresholdSlope; + + this.prevObs = []; + this.speed = []; + this.speedEwma = new EWMA(speedUpdateRate); + this.threshold = []; + + this.state = []; + } + + insert(observations: number[], fps = 30) { + if (fps > 0) { + this.fps = fps; + } + this.threshold = observations.slice(); + for (let obsIdx = 0; obsIdx < observations.length; obsIdx += 1) { + this.threshold[obsIdx] = this.thresholdOffset; + } + + if (this.prevObs.length) { + const rawDiff = observations.slice(); + for (let obsIdx = 0; obsIdx < observations.length; obsIdx += 1) { + rawDiff[obsIdx] -= this.prevObs[obsIdx]; + } + if (this.speed.length) { + for (let obsIdx = 0; obsIdx < observations.length; obsIdx += 1) { + this.threshold[obsIdx] = this.thresholdOffset + + this.thresholdSlope * Math.abs(this.speed[obsIdx]); + } + } + this.speed = this.speedEwma.insert(rawDiff); + } + this.prevObs = observations.slice(); + + if (!this.state.length) { + this.state = observations.slice(); + } else { + for (let obsIdx = 0; obsIdx < observations.length; obsIdx += 1) { + const updateRate = 1.0 / + (1.0 + + this.fps / + (this.updateRateOffset + + this.updateRateSlope * Math.abs(this.speed[obsIdx]))); + this.state[obsIdx] = this.state[obsIdx] + + updateRate * this.threshold[obsIdx] * + Math.asinh( + (observations[obsIdx] - this.state[obsIdx]) / + this.threshold[obsIdx]); + } + } + return this.state; + } +} diff --git a/pose-detection/src/movenet/types.ts b/pose-detection/src/movenet/types.ts new file mode 100644 index 0000000000..972e567787 --- /dev/null +++ b/pose-detection/src/movenet/types.ts @@ -0,0 +1,41 @@ +/** + * @license + * Copyright 2021 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {EstimationConfig, ModelConfig} from '../types'; + +/** + * Additional MoveNet model loading config. + * + * 'modelType': Optional. The type of MoveNet model to load, Lighting or + * Thunder. Defaults to Lightning. Lightning is a lower capacity model that can + * run >50FPS on most modern laptops while achieving good performance. Thunder + * is A higher capacity model that performs better prediction quality while + * still achieving real-time (>30FPS) speed. Thunder will lag behind the + * lightning, but it will pack a punch. + * + * `modelUrl`: Optional. An optional string that specifies custom url of the + * model. This is useful for area/countries that don't have access to the model + * hosted on TF Hub. + */ +export interface MoveNetModelConfig extends ModelConfig { + modelType?: string; + modelUrl?: string; +} + +/** + * MoveNet Specific Inference Config. + */ +export interface MoveNetEstimationConfig extends EstimationConfig {} diff --git a/pose-detection/src/types.ts b/pose-detection/src/types.ts index 4494f3db6a..bd65f9c859 100644 --- a/pose-detection/src/types.ts +++ b/pose-detection/src/types.ts @@ -17,8 +17,9 @@ import * as tf from '@tensorflow/tfjs-core'; export enum SupportedModels { - PoseNet = 'PoseNet', - MediapipeBlazepose = 'MediapipeBlazepose' + MoveNet = 'MoveNet', + MediapipeBlazepose = 'MediapipeBlazepose', + PoseNet = 'PoseNet' } export type QuantBytes = 1|2|4; diff --git a/pose-detection/src/util.ts b/pose-detection/src/util.ts index 67d789738d..c2653e4508 100644 --- a/pose-detection/src/util.ts +++ b/pose-detection/src/util.ts @@ -23,6 +23,7 @@ export function getKeypointIndexBySide(model: SupportedModels): case SupportedModels.MediapipeBlazepose: return constants.BLAZEPOSE_KEYPOINTS_BY_SIDE; case SupportedModels.PoseNet: + case SupportedModels.MoveNet: return constants.COCO_KEYPOINTS_BY_SIDE; default: throw new Error(`Model ${model} is not supported.`); @@ -33,6 +34,7 @@ export function getAdjacentPairs(model: SupportedModels): number[][] { case SupportedModels.MediapipeBlazepose: return constants.BLAZEPOSE_CONNECTED_KEYPOINTS_PAIRS; case SupportedModels.PoseNet: + case SupportedModels.MoveNet: return constants.COCO_CONNECTED_KEYPOINTS_PAIRS; default: throw new Error(`Model ${model} is not supported.`);