Skip to content

Commit

Permalink
fix: Avoid gesture false positives by adding a processor stabilizatio…
Browse files Browse the repository at this point in the history
…n delay (#46)
  • Loading branch information
mirland authored Oct 7, 2024
1 parent 14615fa commit d3b4e74
Show file tree
Hide file tree
Showing 12 changed files with 63 additions and 12 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import 'dart:async';
import 'dart:io';

import 'package:dartx/dartx.dart';
import 'package:rxdart/rxdart.dart';
import 'package:simon_ai/core/common/logger.dart';
import 'package:simon_ai/core/model/hand_gesture_with_position.dart';
Expand All @@ -15,10 +16,11 @@ const _logVerbose = true;
/// the transformer will emit the gesture.
class GameGestureStabilizationTransformer extends StreamTransformerBase<
HandGestureWithPosition, HandGestureWithPosition> {
static const _confidenceForMinWindow = 0.7;
static final _defaultTimeSpan = Platform.isAndroid
? const Duration(seconds: 1)
: const Duration(milliseconds: 300);
static const _defaultWindowSize = 7;
static final _defaultWindowSize = Platform.isAndroid ? 6 : 7;
static final _defaultMinWindowSize = Platform.isAndroid ? 3 : 5;
static final _defaultMaxUnrecognizedGesturesInWindow =
Platform.isAndroid ? 2 : 3;
Expand Down Expand Up @@ -58,14 +60,22 @@ class GameGestureStabilizationTransformer extends StreamTransformerBase<

void _emitBuffer() {
if (_buffer.isNotEmpty) {
if (_buffer.length >= _minWindowSize) {
final confidenceAverage = _buffer.averageBy((it) => it.gestureConfidence);
if (_buffer.length >= _windowSize ||
(_buffer.length >= _minWindowSize &&
confidenceAverage >= _confidenceForMinWindow)) {
_requireEmmit = false;
_controller.add(List.unmodifiable(_buffer));
if (_logEnabled) {
final bufferConfidence = _buffer.joinToString(
transform: (it) => it.gestureConfidence.toString(),
);
Logger.i(
"Emit gesture ${_buffer.first.gesture}, "
"bufferSize: ${_buffer.length}, "
"time: ${_windowTime.elapsedMilliseconds} millis",
"time: ${_windowTime.elapsedMilliseconds} millis, "
"unrecognized count: $_currentUnrecognizedGestures, "
"confidence avg $confidenceAverage ($bufferConfidence)",
);
}
} else {
Expand All @@ -79,14 +89,13 @@ class GameGestureStabilizationTransformer extends StreamTransformerBase<

void _handleNewGesture(HandGestureWithPosition gestureWithPosition) {
if (gestureWithPosition.gesture == HandGesture.unrecognized) {
_currentUnrecognizedGestures++;
if (_currentUnrecognizedGestures >= _maxUnrecognizedGesturesInWindow) {
_resetBuffer();
if (_logVerbose) {
Logger.i("Max unrecognized gestures reached, reset window");
}
} else if (_logVerbose) {
Logger.i("Discard unrecognized gesture");
} else if (_buffer.isNotEmpty) {
_currentUnrecognizedGestures++;
}
return;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,23 @@ typedef Processor<T, R> = FutureOr<R> Function(T);

class ProcessWhileAvailableTransformer<T, R>
extends StreamTransformerBase<T, R> {
static const _defaultStabilizationTime = Duration(milliseconds: 5);

var _isClosed = false;
final List<Processor> _availableProcessors;
final Set<Processor> _processors;

final _mutex = Mutex();
final _stabilizationTimeMutex = Mutex();
final Queue<T> _unprocessedQueue = Queue<T>();

ProcessWhileAvailableTransformer(Iterable<Processor> processors)
: _processors = processors.toSet(),
_availableProcessors = processors.toList();

Future<void> _waitStabilizationTime() => _stabilizationTimeMutex
.protect(() => Future.delayed(_defaultStabilizationTime));

Future<void> _processValue(
Processor processor,
T value,
Expand All @@ -26,6 +32,8 @@ class ProcessWhileAvailableTransformer<T, R>
try {
T? currentValue = value;
do {
await _waitStabilizationTime();

final event = await processor(currentValue);
sink.add(event);
// ignore: avoid-redundant-async
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,18 @@ import 'package:simon_ai/core/common/logger.dart';
import 'package:simon_ai/core/interfaces/model_interface.dart';
import 'package:simon_ai/core/model/hand_classifier_model_data.dart';
import 'package:simon_ai/core/model/hand_gestures.dart';
import 'package:simon_ai/core/model/hand_landmarks_result_data.dart';
import 'package:simon_ai/gen/assets.gen.dart';
import 'package:tflite_flutter/tflite_flutter.dart';
import 'package:tflite_flutter_helper/tflite_flutter_helper.dart';

class HandCannedGestureClassifier
implements ModelHandler<TensorBufferFloat, HandGesture> {
implements ModelHandler<TensorBufferFloat, CannedGestureData> {
final ModelMetadata model =
(path: Assets.models.cannedGestureClassifier, inputSize: 128);

late Interpreter _interpreter;

@override
Interpreter get interpreter => _interpreter;

Expand Down Expand Up @@ -63,7 +65,7 @@ class HandCannedGestureClassifier
}

@override
Future<HandGesture> performOperations(
Future<CannedGestureData> performOperations(
TensorBufferFloat tensorBufferFloat,
) async {
stopwatch.start();
Expand Down Expand Up @@ -91,11 +93,14 @@ class HandCannedGestureClassifier
interpreter.runForMultipleInputs(inputs, outputs);
}

HandGesture _processGestureResultData() {
CannedGestureData _processGestureResultData() {
final gesturesScore =
handCannedGestureOutputLocations.first.getDoubleList();
final highestScore = gesturesScore.reduce(max);
final indexOfHighestScore = gesturesScore.indexOf(highestScore);
return HandGesture.values[indexOfHighestScore];
return (
gesture: HandGesture.values[indexOfHighestScore],
confidence: highestScore,
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,14 @@ class HandClassifier
.performOperations((image: input, cropData: cropData));
final gestureVector = await handGestureEmbedderClassifier
.performOperations(handLandmarksResult.tensors);
final gesture =
final gestureResult =
await handCannedGestureClassifier.performOperations(gestureVector);
return Future.value(
(
confidence: handLandmarksResult.confidence,
keyPoints: handLandmarksResult.keyPoints,
gesture: gesture,
gesture: gestureResult.gesture,
gestureConfidence: gestureResult.confidence,
cropData: cropData,
),
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ class GestureMobileProcessor implements GestureProcessor {
final processedKeyPoints = _processKeypoints(resultData.keyPoints);
return (
confidence: resultData.confidence,
gestureConfidence: resultData.gestureConfidence,
keyPoints: processedKeyPoints,
gesture: resultData.gesture,
cropData: resultData.cropData,
Expand Down
1 change: 1 addition & 0 deletions lib/core/model/hand_classifier_result_data.dart
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ typedef HandClassifierResultData = ({
List<double> keyPoints,
HandGesture gesture,
HandDetectorResultData cropData,
double gestureConfidence,
});
1 change: 1 addition & 0 deletions lib/core/model/hand_gesture_with_position.dart
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import 'package:simon_ai/core/model/hand_gestures.dart';

typedef HandGestureWithPosition = ({
HandGesture gesture,
double gestureConfidence,
Coordinates gesturePosition,
HandDetectorResultData boundingBox,
});
6 changes: 6 additions & 0 deletions lib/core/model/hand_landmarks_result_data.dart
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,13 @@ typedef HandLandmarksResultData = ({

typedef HandLandmarksData = ({
double confidence,
double gestureConfidence,
List<Coordinates> keyPoints,
HandGesture gesture,
HandDetectorResultData cropData,
});

typedef CannedGestureData = ({
double confidence,
HandGesture gesture,
});
1 change: 1 addition & 0 deletions lib/core/repository/game_manager.dart
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ class GameManager {
_gestureStreamController.add(result);
return (
gesture: result.gesture,
gestureConfidence: result.gestureConfidence,
gesturePosition: result.keyPoints.centerCoordinates,
boundingBox: result.cropData,
);
Expand Down
16 changes: 16 additions & 0 deletions lib/ui/game_screen/game_screen_cubit.freezed.dart

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions lib/ui/widgets/camera/camera_mobile_widget.dart
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ class _GestureSection extends StatelessWidget {
keypointsData: gestureData ??
(
confidence: 0.0,
gestureConfidence: 0.0,
keyPoints: [],
gesture: HandGesture.unrecognized,
cropData: (x: 0, y: 0, w: 0, h: 0, confidence: 0.0),
Expand Down
1 change: 1 addition & 0 deletions test/widget_test.dart
Original file line number Diff line number Diff line change
Expand Up @@ -83,5 +83,6 @@ HandGestureWithPosition _gestureToHandGestureWithPosition(
(
boundingBox: (confidence: 0.0, h: 0.0, w: 0.0, x: 0.0, y: 0.0),
gesture: gesture,
gestureConfidence: 0.1,
gesturePosition: (x: 0.0, y: 0.0)
);

0 comments on commit d3b4e74

Please sign in to comment.