diff --git a/src/blocks/models/image-segmentation/segment.ts b/src/blocks/models/image-segmentation/segment.ts index c3742c9..65ec041 100644 --- a/src/blocks/models/image-segmentation/segment.ts +++ b/src/blocks/models/image-segmentation/segment.ts @@ -1,3 +1,5 @@ +import Blockly from "blockly"; + export let segment = { init: function () { this.appendValueInput("IMAGE") @@ -6,6 +8,12 @@ export let segment = { this.appendValueInput("MODEL") .setCheck("ImageSegmentationModel") .appendField("using model"); + this.appendDummyInput() + .appendField("running on") + .appendField( + new Blockly.FieldDropdown([["GPU", "GPU"], ["CPU", "CPU"]]), + "DELEGATE" + ) this.setInputsInline(false); this.setOutput(true, "Segment"); this.setPreviousStatement(false, null); @@ -20,6 +28,6 @@ export let segment = { if (image === "") return ""; if (model === "") return ""; - return [`await cv.segment(${image}, ${model})`, generator.ORDER_NONE]; + return [`await cv.segment(${image}, ${model}, "${block.getFieldValue("DELEGATE")}")`, generator.ORDER_NONE]; } }; \ No newline at end of file diff --git a/src/blocks/models/object-detection/detectObjects.ts b/src/blocks/models/object-detection/detectObjects.ts index 850e87f..97c67ae 100644 --- a/src/blocks/models/object-detection/detectObjects.ts +++ b/src/blocks/models/object-detection/detectObjects.ts @@ -1,3 +1,5 @@ +import Blockly from "blockly"; + export let detectObjects = { init: function () { this.appendValueInput("IMAGE") @@ -6,6 +8,12 @@ export let detectObjects = { this.appendValueInput("MODEL") .setCheck("ObjectDetectionModel") .appendField("using model"); + this.appendDummyInput() + .appendField("running on") + .appendField( + new Blockly.FieldDropdown([["GPU", "GPU"], ["CPU", "CPU"]]), + "DELEGATE" + ) this.setInputsInline(false); this.setOutput(true, "Objects"); this.setPreviousStatement(false, null); @@ -20,6 +28,6 @@ export let detectObjects = { if (image === "") return ""; if (model === "") return ""; - return [`await cv.detectObjects(${image}, ${model})`, generator.ORDER_NONE]; + return [`await cv.detectObjects(${image}, ${model}, "${block.getFieldValue("DELEGATE")}")`, generator.ORDER_NONE]; } }; \ No newline at end of file diff --git a/src/blocks/models/pose-estimation/detectPose.ts b/src/blocks/models/pose-estimation/detectPose.ts index b3b2dee..1fb151f 100644 --- a/src/blocks/models/pose-estimation/detectPose.ts +++ b/src/blocks/models/pose-estimation/detectPose.ts @@ -1,3 +1,5 @@ +import Blockly from "blockly"; + export let detectPose = { init: function () { this.appendValueInput("IMAGE") @@ -6,6 +8,12 @@ export let detectPose = { this.appendValueInput("MODEL") .setCheck("PoseEstimationModel") .appendField("using model"); + this.appendDummyInput() + .appendField("running on") + .appendField( + new Blockly.FieldDropdown([["GPU", "GPU"], ["CPU", "CPU"]]), + "DELEGATE" + ) this.setInputsInline(false); this.setOutput(true, "Pose"); this.setPreviousStatement(false, null); @@ -20,6 +28,6 @@ export let detectPose = { if (image === "") return ""; if (model === "") return ""; - return [`await cv.detectPose(${image}, ${model})`, generator.ORDER_NONE]; + return [`await cv.detectPose(${image}, ${model}, "${block.getFieldValue("DELEGATE")}")`, generator.ORDER_NONE]; } }; \ No newline at end of file diff --git a/src/cv/imageSegmentation.ts b/src/cv/imageSegmentation.ts index 1ca1ba0..32b0dab 100644 --- a/src/cv/imageSegmentation.ts +++ b/src/cv/imageSegmentation.ts @@ -59,8 +59,8 @@ export class ImageSegmentation { } - public async segment(mp: MediaPipe, image: ImageData, model: ModelData) { + public async segment(mp: MediaPipe, image: ImageData, model: ModelData, delegate: string) { Debug.write("Segment Image"); - return await mp.segment(image, model); + return await mp.segment(image, model, delegate); } } \ No newline at end of file diff --git a/src/cv/index.ts b/src/cv/index.ts index 8d3c1a4..0d86d44 100644 --- a/src/cv/index.ts +++ b/src/cv/index.ts @@ -120,8 +120,8 @@ export class CV { Object Detection *************************/ - public async detectObjects(image: ImageData, model: ModelData) { - return this.objectDetection.detectObjects(this.mp, image, model); + public async detectObjects(image: ImageData, model: ModelData, delegate: string) { + return this.objectDetection.detectObjects(this.mp, image, model, delegate); } public async drawBoundingBoxes(result: ObjectDetectorResult) { @@ -136,8 +136,8 @@ export class CV { Image Segmentation *************************/ - public async segment(image: ImageData, model: ModelData) { - return this.imageSegmentation.segment(this.mp, image, model); + public async segment(image: ImageData, model: ModelData, delegate: string) { + return this.imageSegmentation.segment(this.mp, image, model, delegate); } public async colorSegment(data: { result: ImageSegmenterResult, category: number }, rgb: number[]) { @@ -155,8 +155,8 @@ export class CV { Pose Detection *************************/ - public async detectPose(image: ImageData, model: ModelData) { - return await this.poseEstimation.detectPose(this.mp, image, model); + public async detectPose(image: ImageData, model: ModelData, delegate: string) { + return await this.poseEstimation.detectPose(this.mp, image, model, delegate); } public async drawPose(pose: PoseLandmarkerResult) { diff --git a/src/cv/mediapipe.ts b/src/cv/mediapipe.ts index c4fd600..959a3a5 100644 --- a/src/cv/mediapipe.ts +++ b/src/cv/mediapipe.ts @@ -29,53 +29,55 @@ export class MediaPipe { this.models[path] = model; } - public async detectObjects(image: ImageData, model: ModelData) { - let detector = this.getModel(model.path); + public async detectObjects(image: ImageData, model: ModelData, delegate) { + let detector = this.getModel(`${model.path}_${delegate}`); if (!detector) { detector = await ObjectDetector.createFromOptions(this.vision, { baseOptions: { modelAssetPath: model.path, - delegate: "GPU", + delegate: delegate, }, scoreThreshold: 0.5, runningMode: "IMAGE", }); - this.cacheModel(model.path, detector); + this.cacheModel(`${model.path}_${delegate}`, detector); } return detector.detect(image); } - public async segment(image: ImageData, model: ModelData) { - let segmenter = this.getModel(model.path); + public async segment(image: ImageData, model: ModelData, delegate: string) { + if (delegate !== 'GPU' && delegate !== 'CPU') return; + let segmenter = this.getModel(`${model.path}_${delegate}`); if (!segmenter) { segmenter = await ImageSegmenter.createFromOptions(this.vision, { baseOptions: { modelAssetPath: model.path, - delegate: 'GPU', + delegate: delegate, }, runningMode: 'IMAGE', outputCategoryMask: true, outputConfidenceMasks: false }); - this.cacheModel(model.path, segmenter); + this.cacheModel(`${model.path}_${delegate}`, segmenter); } return { result: segmenter.segment(image), category: model.category }; } - public async detectPose(image: ImageData, model: ModelData) { - let poseLandmarker = this.getModel(model.path); + public async detectPose(image: ImageData, model: ModelData, delegate: string) { + if (delegate !== 'GPU' && delegate !== 'CPU') return; + let poseLandmarker = this.getModel(`${model.path}_${delegate}`); if (!poseLandmarker) { poseLandmarker = await PoseLandmarker.createFromOptions(this.vision, { baseOptions: { modelAssetPath: model.path, - delegate: 'GPU', + delegate: delegate, }, runningMode: 'IMAGE', numPoses: 1, }); - this.cacheModel(model.path, poseLandmarker); + this.cacheModel(`${model.path}_${delegate}`, poseLandmarker); } return poseLandmarker.detect(image); } diff --git a/src/cv/objectDetection.ts b/src/cv/objectDetection.ts index 14ad2a9..7527961 100644 --- a/src/cv/objectDetection.ts +++ b/src/cv/objectDetection.ts @@ -8,9 +8,9 @@ export class ObjectDetection { private BOUNDING_BOX_FONT : string = "14px Arial"; private BOUNDING_BOX_FONT_COLOR : string = "#ffffff"; - public async detectObjects(mediapipe: MediaPipe, image: ImageData, model: ModelData) { + public async detectObjects(mediapipe: MediaPipe, image: ImageData, model: ModelData, delegate: string) { Debug.write("Detecting objects"); - return await mediapipe.detectObjects(image, model); + return await mediapipe.detectObjects(image, model, delegate); } public async displayBoundingBoxes(canvas: HTMLCanvasElement, result: ObjectDetectorResult) { diff --git a/src/cv/poseEstimation.ts b/src/cv/poseEstimation.ts index fd62927..cde16ba 100644 --- a/src/cv/poseEstimation.ts +++ b/src/cv/poseEstimation.ts @@ -4,9 +4,9 @@ import { PoseLandmarkerResult } from "@mediapipe/tasks-vision"; export class PoseEstimation { - public async detectPose(mp: MediaPipe, image: ImageData, model: ModelData) { + public async detectPose(mp: MediaPipe, image: ImageData, model: ModelData, delegate: string) { Debug.write("Detecting pose"); - return await mp.detectPose(image, model); + return await mp.detectPose(image, model, delegate); } public async displayPose(canvas: HTMLCanvasElement, pose: PoseLandmarkerResult, width: number, height: number) {