Skip to content

Commit

Permalink
Adding MoveNet to Pose Detection API (#627)
Browse files Browse the repository at this point in the history
* Adding MoveNet skeleton and support files

* Adding MoveNet implementation and supporting files

* Lint MoveNet files (except for line length)

Line lengths will be updated in an upcoming commit.

* Use correct formatting, which includes line lengths

* Remove test URL

* Addressing code review comments

* Fix MoveNet visualization in live demo

* Updates after review comments

MoveNet determineCropRegion refactored and moved OneEuroFilter to common
filters directory

* Remove configurable keypoint threshold

This threshold was mostly used for internal cropping logic.

* Merge Model and KeypointModel, plus additional updates after review

* Move one euro filter back to MoveNet with TODO

* Add test for MoveNet

* Use MoveNet as default model in pose demo

* Add MoveNet model type selection to pose demo

* Run models at highest possible speed in pose demo

* Fix error in cropping code

* Remove comment

* Updates to resolve review comments

* Fix CI test

* Simplify pose array creation
  • Loading branch information
ardoerlemans authored Apr 3, 2021
1 parent 2b3484c commit 8717622
Show file tree
Hide file tree
Showing 16 changed files with 826 additions and 31 deletions.
4 changes: 0 additions & 4 deletions pose-detection/demo/src/camera.js
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,6 @@ export class Camera {
this.video = document.getElementById('video');
this.canvas = document.getElementById('output');
this.ctx = this.canvas.getContext('2d');

// The video frame rate may be lower than the browser animate frame
// rate. We use this to avoid processing the same frame twice.
this.lastVideoTime = 0;
}

/**
Expand Down
37 changes: 20 additions & 17 deletions pose-detection/demo/src/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,13 @@ async function createDetector(model) {
case posedetection.SupportedModels.MediapipeBlazepose:
return posedetection.createDetector(
STATE.model.model, {quantBytes: 4, upperBodyOnly: false});
case posedetection.SupportedModels.MoveNet:
const modelType =
STATE.model[STATE.model.model].modelType == 'Lightning' ?
posedetection.movenet.modelType.SINGLEPOSE_LIGHTNING :
posedetection.movenet.modelType.SINGLEPOSE_THUNDER;
return posedetection.createDetector(
STATE.model.model, {modelType: modelType});
}
}

Expand All @@ -58,28 +65,24 @@ async function checkGuiUpdate() {
}

async function renderResult() {
if (camera.video.currentTime !== camera.lastVideoTime) {
camera.lastVideoTime = camera.video.currentTime;

// FPS only counts the time it takes to finish estimatePoses.
stats.begin();
// FPS only counts the time it takes to finish estimatePoses.
stats.begin();

const poses = await detector.estimatePoses(
camera.video, {maxPoses: 1, flipHorizontal: false});
const poses = await detector.estimatePoses(
camera.video, {maxPoses: 1, flipHorizontal: false});

stats.end();
stats.end();

camera.drawCtx();
camera.drawCtx();

// The null check makes sure the UI is not in the middle of changing to a
// different model. If changeToModel is non-null, the result is from an
// old model, which shouldn't be rendered.
if (poses.length > 0 && STATE.changeToModel == null) {
const shouldScale = STATE.model.model ===
posedetection.SupportedModels.MediapipeBlazepose;
// The null check makes sure the UI is not in the middle of changing to a
// different model. If changeToModel is non-null, the result is from an
// old model, which shouldn't be rendered.
if (poses.length > 0 && STATE.changeToModel == null) {
const shouldScale =
STATE.model.model === posedetection.SupportedModels.MediapipeBlazepose;

camera.drawResult(poses[0], shouldScale);
}
camera.drawResult(poses[0], shouldScale);
}
}

Expand Down
31 changes: 26 additions & 5 deletions pose-detection/demo/src/option_panel.js
Original file line number Diff line number Diff line change
Expand Up @@ -44,24 +44,36 @@ export function setupDatGui() {
case posedetection.SupportedModels.PoseNet:
poseNetFolder.open();
blazePoseFolder.close();
moveNetFolder.close();
break;
case posedetection.SupportedModels.MediapipeBlazepose:
blazePoseFolder.open();
poseNetFolder.close();
moveNetFolder.close();
break;
case posedetection.SupportedModels.MoveNet:
blazePoseFolder.close();
poseNetFolder.close();
moveNetFolder.open();
break;
default:
throw new Error(`${model} is not supported.`);
}
});
modelFolder.open();

// The PoseNet model config folder contains options for PoseNet config
// The MoveNet model config folder contains options for MoveNet config
// settings.
const poseNetFolder = gui.addFolder('PoseNet Config');
poseNetFolder.add(
STATE.model[posedetection.SupportedModels.PoseNet], 'scoreThreshold', 0,
const moveNetFolder = gui.addFolder('MoveNet Config');
const moveNetTypeController = moveNetFolder.add(
STATE.model[posedetection.SupportedModels.MoveNet], 'modelType',
['Thunder', 'Lightning']);
moveNetTypeController.onChange(type => {
STATE.changeToModel = type;
});
moveNetFolder.add(
STATE.model[posedetection.SupportedModels.MoveNet], 'scoreThreshold', 0,
1);
poseNetFolder.open();

// The Blazepose model config folder contains options for Blazepose config
// settings.
Expand All @@ -70,5 +82,14 @@ export function setupDatGui() {
STATE.model[posedetection.SupportedModels.MediapipeBlazepose],
'scoreThreshold', 0, 1);

// The PoseNet model config folder contains options for PoseNet config
// settings.
const poseNetFolder = gui.addFolder('PoseNet Config');
poseNetFolder.add(
STATE.model[posedetection.SupportedModels.PoseNet], 'scoreThreshold', 0,
1);

moveNetFolder.open();

return gui;
}
6 changes: 5 additions & 1 deletion pose-detection/demo/src/params.js
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ export const VIDEO_SIZE = {
export const STATE = {
camera: {targetFPS: 60, sizeOption: '640 X 480'},
model: {
model: posedetection.SupportedModels.PoseNet,
model: posedetection.SupportedModels.MoveNet,
}
};
STATE.model[posedetection.SupportedModels.MediapipeBlazepose] = {
Expand All @@ -35,3 +35,7 @@ STATE.model[posedetection.SupportedModels.MediapipeBlazepose] = {
STATE.model[posedetection.SupportedModels.PoseNet] = {
scoreThreshold: 0.5
};
STATE.model[posedetection.SupportedModels.MoveNet] = {
modelType: 'Thunder',
scoreThreshold: 0.3
};
19 changes: 19 additions & 0 deletions pose-detection/src/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,22 @@ export const BLAZEPOSE_CONNECTED_KEYPOINTS_PAIRS = [
[18, 20], [23, 25], [23, 24], [24, 26], [25, 27], [26, 28], [27, 29],
[28, 30], [27, 31], [28, 32], [29, 31], [30, 32]
];
export const COCO_KEYPOINTS_NAMED_MAP: {[index: string]: number} = {
nose: 0,
left_eye: 1,
right_eye: 2,
left_ear: 3,
right_ear: 4,
left_shoulder: 5,
right_shoulder: 6,
left_elbow: 7,
right_elbow: 8,
left_wrist: 9,
right_wrist: 10,
left_hip: 11,
right_hip: 12,
left_knee: 13,
right_knee: 14,
left_ankle: 15,
right_ankle: 16
};
8 changes: 6 additions & 2 deletions pose-detection/src/create_detector.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

import {BlazeposeDetector} from './blazepose/detector';
import {BlazeposeModelConfig} from './blazepose/types';
import {MoveNetDetector} from './movenet/detector';
import {MoveNetModelConfig} from './movenet/types';
import {PoseDetector} from './pose_detector';
import {PosenetDetector} from './posenet/detector';
import {PosenetModelConfig} from './posenet/types';
Expand All @@ -29,13 +31,15 @@ import {SupportedModels} from './types';
*/
export async function createDetector(
model: SupportedModels,
modelConfig: PosenetModelConfig|
BlazeposeModelConfig): Promise<PoseDetector> {
modelConfig: PosenetModelConfig|BlazeposeModelConfig|
MoveNetModelConfig): Promise<PoseDetector> {
switch (model) {
case SupportedModels.PoseNet:
return PosenetDetector.load(modelConfig as PosenetModelConfig);
case SupportedModels.MediapipeBlazepose:
return BlazeposeDetector.load(modelConfig as BlazeposeModelConfig);
case SupportedModels.MoveNet:
return MoveNetDetector.load(modelConfig as MoveNetModelConfig);
default:
throw new Error(`${model} is not a supported model name.`);
}
Expand Down
12 changes: 12 additions & 0 deletions pose-detection/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,25 @@
// Entry point to create a new detector instance.
export {BlazeposeEstimationConfig, BlazeposeModelConfig} from './blazepose/types';
export {createDetector} from './create_detector';
export {MoveNetEstimationConfig, MoveNetModelConfig} from './movenet/types';
// PoseDetector class.
export {PoseDetector} from './pose_detector';
export {PoseNetEstimationConfig, PosenetModelConfig} from './posenet/types';

// Supported models enum.
export * from './types';

// Second level exports.
// Utils for rendering.
import * as util from './util';
export {util};

// MoveNet model types.
import {SINGLEPOSE_LIGHTNING, SINGLEPOSE_THUNDER} from './movenet/constants';
const movenet = {
modelType: {
'SINGLEPOSE_LIGHTNING': SINGLEPOSE_LIGHTNING,
'SINGLEPOSE_THUNDER': SINGLEPOSE_THUNDER
}
};
export {movenet};
42 changes: 42 additions & 0 deletions pose-detection/src/movenet/constants.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/**
* @license
* Copyright 2021 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {MoveNetEstimationConfig, MoveNetModelConfig} from './types';

export const SINGLEPOSE_LIGHTNING = 'SinglePose.Lightning';
export const SINGLEPOSE_THUNDER = 'SinglePose.Thunder';

export const VALID_MODELS = [SINGLEPOSE_LIGHTNING, SINGLEPOSE_THUNDER];

export const MOVENET_SINGLEPOSE_LIGHTNING_URL =
'https://tfhub.dev/google/tfjs-model/movenet/singlepose/lightning/1';
export const MOVENET_SINGLEPOSE_THUNDER_URL =
'https://tfhub.dev/google/tfjs-model/movenet/singlepose/thunder/1';

export const MOVENET_SINGLEPOSE_LIGHTNING_RESOLUTION = 192;
export const MOVENET_SINGLEPOSE_THUNDER_RESOLUTION = 256;

// The default configuration for loading MoveNet.
export const MOVENET_CONFIG: MoveNetModelConfig = {
modelType: SINGLEPOSE_LIGHTNING
};

export const MOVENET_SINGLE_POSE_ESTIMATION_CONFIG: MoveNetEstimationConfig = {
maxPoses: 1
};

export const MIN_CROP_KEYPOINT_SCORE = 0.3;
Loading

0 comments on commit 8717622

Please sign in to comment.