Skip to content

Commit

Permalink
Added CPU and GPU delegates on models
Browse files Browse the repository at this point in the history
  • Loading branch information
simonguest committed Nov 6, 2023
1 parent a27f0bc commit 4ae88ec
Show file tree
Hide file tree
Showing 8 changed files with 53 additions and 27 deletions.
10 changes: 9 additions & 1 deletion src/blocks/models/image-segmentation/segment.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import Blockly from "blockly";

export let segment = {
init: function () {
this.appendValueInput("IMAGE")
Expand All @@ -6,6 +8,12 @@ export let segment = {
this.appendValueInput("MODEL")
.setCheck("ImageSegmentationModel")
.appendField("using model");
this.appendDummyInput()
.appendField("running on")
.appendField(
new Blockly.FieldDropdown([["GPU", "GPU"], ["CPU", "CPU"]]),
"DELEGATE"
)
this.setInputsInline(false);
this.setOutput(true, "Segment");
this.setPreviousStatement(false, null);
Expand All @@ -20,6 +28,6 @@ export let segment = {
if (image === "") return "";
if (model === "") return "";

return [`await cv.segment(${image}, ${model})`, generator.ORDER_NONE];
return [`await cv.segment(${image}, ${model}, "${block.getFieldValue("DELEGATE")}")`, generator.ORDER_NONE];
}
};
10 changes: 9 additions & 1 deletion src/blocks/models/object-detection/detectObjects.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import Blockly from "blockly";

export let detectObjects = {
init: function () {
this.appendValueInput("IMAGE")
Expand All @@ -6,6 +8,12 @@ export let detectObjects = {
this.appendValueInput("MODEL")
.setCheck("ObjectDetectionModel")
.appendField("using model");
this.appendDummyInput()
.appendField("running on")
.appendField(
new Blockly.FieldDropdown([["GPU", "GPU"], ["CPU", "CPU"]]),
"DELEGATE"
)
this.setInputsInline(false);
this.setOutput(true, "Objects");
this.setPreviousStatement(false, null);
Expand All @@ -20,6 +28,6 @@ export let detectObjects = {
if (image === "") return "";
if (model === "") return "";

return [`await cv.detectObjects(${image}, ${model})`, generator.ORDER_NONE];
return [`await cv.detectObjects(${image}, ${model}, "${block.getFieldValue("DELEGATE")}")`, generator.ORDER_NONE];
}
};
10 changes: 9 additions & 1 deletion src/blocks/models/pose-estimation/detectPose.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import Blockly from "blockly";

export let detectPose = {
init: function () {
this.appendValueInput("IMAGE")
Expand All @@ -6,6 +8,12 @@ export let detectPose = {
this.appendValueInput("MODEL")
.setCheck("PoseEstimationModel")
.appendField("using model");
this.appendDummyInput()
.appendField("running on")
.appendField(
new Blockly.FieldDropdown([["GPU", "GPU"], ["CPU", "CPU"]]),
"DELEGATE"
)
this.setInputsInline(false);
this.setOutput(true, "Pose");
this.setPreviousStatement(false, null);
Expand All @@ -20,6 +28,6 @@ export let detectPose = {
if (image === "") return "";
if (model === "") return "";

return [`await cv.detectPose(${image}, ${model})`, generator.ORDER_NONE];
return [`await cv.detectPose(${image}, ${model}, "${block.getFieldValue("DELEGATE")}")`, generator.ORDER_NONE];
}
};
4 changes: 2 additions & 2 deletions src/cv/imageSegmentation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ export class ImageSegmentation {

}

public async segment(mp: MediaPipe, image: ImageData, model: ModelData) {
public async segment(mp: MediaPipe, image: ImageData, model: ModelData, delegate: string) {
Debug.write("Segment Image");
return await mp.segment(image, model);
return await mp.segment(image, model, delegate);
}
}
12 changes: 6 additions & 6 deletions src/cv/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,8 @@ export class CV {
Object Detection
*************************/

public async detectObjects(image: ImageData, model: ModelData) {
return this.objectDetection.detectObjects(this.mp, image, model);
public async detectObjects(image: ImageData, model: ModelData, delegate: string) {
return this.objectDetection.detectObjects(this.mp, image, model, delegate);
}

public async drawBoundingBoxes(result: ObjectDetectorResult) {
Expand All @@ -136,8 +136,8 @@ export class CV {
Image Segmentation
*************************/

public async segment(image: ImageData, model: ModelData) {
return this.imageSegmentation.segment(this.mp, image, model);
public async segment(image: ImageData, model: ModelData, delegate: string) {
return this.imageSegmentation.segment(this.mp, image, model, delegate);
}

public async colorSegment(data: { result: ImageSegmenterResult, category: number }, rgb: number[]) {
Expand All @@ -155,8 +155,8 @@ export class CV {
Pose Detection
*************************/

public async detectPose(image: ImageData, model: ModelData) {
return await this.poseEstimation.detectPose(this.mp, image, model);
public async detectPose(image: ImageData, model: ModelData, delegate: string) {
return await this.poseEstimation.detectPose(this.mp, image, model, delegate);
}

public async drawPose(pose: PoseLandmarkerResult) {
Expand Down
26 changes: 14 additions & 12 deletions src/cv/mediapipe.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,53 +29,55 @@ export class MediaPipe {
this.models[path] = model;
}

public async detectObjects(image: ImageData, model: ModelData) {
let detector = this.getModel(model.path);
public async detectObjects(image: ImageData, model: ModelData, delegate) {
let detector = this.getModel(`${model.path}_${delegate}`);
if (!detector) {
detector = await ObjectDetector.createFromOptions(this.vision, {
baseOptions: {
modelAssetPath: model.path,
delegate: "GPU",
delegate: delegate,
},
scoreThreshold: 0.5,
runningMode: "IMAGE",
});
this.cacheModel(model.path, detector);
this.cacheModel(`${model.path}_${delegate}`, detector);
}

return detector.detect(image);
}

public async segment(image: ImageData, model: ModelData) {
let segmenter = this.getModel(model.path);
public async segment(image: ImageData, model: ModelData, delegate: string) {
if (delegate !== 'GPU' && delegate !== 'CPU') return;
let segmenter = this.getModel(`${model.path}_${delegate}`);
if (!segmenter) {
segmenter = await ImageSegmenter.createFromOptions(this.vision, {
baseOptions: {
modelAssetPath: model.path,
delegate: 'GPU',
delegate: delegate,
},
runningMode: 'IMAGE',
outputCategoryMask: true,
outputConfidenceMasks: false
});
this.cacheModel(model.path, segmenter);
this.cacheModel(`${model.path}_${delegate}`, segmenter);
}

return { result: segmenter.segment(image), category: model.category };
}

public async detectPose(image: ImageData, model: ModelData) {
let poseLandmarker = this.getModel(model.path);
public async detectPose(image: ImageData, model: ModelData, delegate: string) {
if (delegate !== 'GPU' && delegate !== 'CPU') return;
let poseLandmarker = this.getModel(`${model.path}_${delegate}`);
if (!poseLandmarker) {
poseLandmarker = await PoseLandmarker.createFromOptions(this.vision, {
baseOptions: {
modelAssetPath: model.path,
delegate: 'GPU',
delegate: delegate,
},
runningMode: 'IMAGE',
numPoses: 1,
});
this.cacheModel(model.path, poseLandmarker);
this.cacheModel(`${model.path}_${delegate}`, poseLandmarker);
}
return poseLandmarker.detect(image);
}
Expand Down
4 changes: 2 additions & 2 deletions src/cv/objectDetection.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ export class ObjectDetection {
private BOUNDING_BOX_FONT : string = "14px Arial";
private BOUNDING_BOX_FONT_COLOR : string = "#ffffff";

public async detectObjects(mediapipe: MediaPipe, image: ImageData, model: ModelData) {
public async detectObjects(mediapipe: MediaPipe, image: ImageData, model: ModelData, delegate: string) {
Debug.write("Detecting objects");
return await mediapipe.detectObjects(image, model);
return await mediapipe.detectObjects(image, model, delegate);
}

public async displayBoundingBoxes(canvas: HTMLCanvasElement, result: ObjectDetectorResult) {
Expand Down
4 changes: 2 additions & 2 deletions src/cv/poseEstimation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ import { PoseLandmarkerResult } from "@mediapipe/tasks-vision";

export class PoseEstimation {

public async detectPose(mp: MediaPipe, image: ImageData, model: ModelData) {
public async detectPose(mp: MediaPipe, image: ImageData, model: ModelData, delegate: string) {
Debug.write("Detecting pose");
return await mp.detectPose(image, model);
return await mp.detectPose(image, model, delegate);
}

public async displayPose(canvas: HTMLCanvasElement, pose: PoseLandmarkerResult, width: number, height: number) {
Expand Down

0 comments on commit 4ae88ec

Please sign in to comment.