Skip to content

Commit

Permalink
Add Background Blur and Background Replacement processors (aws#463)
Browse files Browse the repository at this point in the history
  • Loading branch information
richhx authored May 14, 2022
1 parent 95f1cde commit 9096237
Show file tree
Hide file tree
Showing 31 changed files with 1,698 additions and 32 deletions.
14 changes: 14 additions & 0 deletions .github/workflows/codecov.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,20 @@ jobs:
tar -xzf AmazonChimeSDKMedia.tar.gz
cp -R ./AmazonChimeSDKMedia.framework ./AmazonChimeSDK
# Download amazon-chime-sdk-machine-learning binary from AWS S3
- name: Get AmazonChimeSDKMachineLearning from AWS S3
run: |
aws configure set aws_access_key_id ${{ secrets.AWS_ACCESS_KEY_ID }} --profile jenkins-automated-test
aws configure set aws_secret_access_key ${{ secrets.AWS_SECRET_ACCESS_KEY }} --profile jenkins-automated-test
aws \
--profile jenkins-automated-test \
s3api get-object \
--bucket amazon-chime-sdk-ios-internal \
--key master/machine-learning/latest/AmazonChimeSDKMachineLearning.tar.gz \
AmazonChimeSDKMachineLearning.tar.gz
tar -xzf AmazonChimeSDKMachineLearning.tar.gz
cp -R ./AmazonChimeSDKMachineLearning.xcframework ./AmazonChimeSDK
# Execute unit tests
- name: Build and Run Unit Test
working-directory: ./AmazonChimeSDK
Expand Down
6 changes: 4 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
**/.DS_Store
xcuserdata
AmazonChimeSDKMedia.framework
AmazonChimeSDK.framework
AmazonChimeSDKMedia.xcframework
AmazonChimeSDKMachineLearning.framework
AmazonChimeSDKMedia.framework
AmazonChimeSDK.xcframework
AmazonChimeSDKMachineLearning.xcframework
AmazonChimeSDKMedia.xcframework
MockingbirdMocks
MockingbirdCache
build
Expand Down
142 changes: 135 additions & 7 deletions AmazonChimeSDK/AmazonChimeSDK.xcodeproj/project.pbxproj

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@
BlueprintName = "AmazonChimeSDKTests"
ReferencedContainer = "container:AmazonChimeSDK.xcodeproj">
</BuildableReference>
<SkippedTests>
<Test
Identifier = "BackgroundFilterTests">
</Test>
</SkippedTests>
</TestableReference>
</Testables>
</TestAction>
Expand Down
16 changes: 16 additions & 0 deletions AmazonChimeSDK/AmazonChimeSDK/AmazonChimeSDK.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
//
// AmazonChimeSDK.h
// AmazonChimeSDK
//
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
//

#ifndef AMAZON_CHIME_SDK_H
#define AMAZON_CHIME_SDK_H

// Umbrella header imports.
#import "TensorFlowSegmentationProcessor.h"
#import "CwtEnum.h"

#endif // AMAZON_CHIME_SDK_H
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
//
// BackgroundFilter.swift
// AmazonChimeSDK
//
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
//

import Foundation

/// Enum defining the different background filter options.
@objc public enum BackgroundFilter: Int {
case none
case blur
case replacement

public var description: String {
switch self {
case .none:
return "none"
case .blur:
return "blur"
case .replacement:
return "replacement"
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,247 @@
//
// BackgroundFilterProcessor.swift
// AmazonChimeSDK
//
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
//

import CoreImage
import CoreMedia
import Foundation
import UIKit

/// `BackgroundFilterProcessor` is a processor that uses `SegmentationProcessor` to process a frame by
/// creating the alpha mask of the foreground image and blending the mask with the input image which is then rendered on
/// top of a background image.
public class BackgroundFilterProcessor {
/// Context used for processing and rendering the final output image.
private let context = CIContext(options: [.cacheIntermediates: false])

/// `CVPixelBufferPool` used to store the final output image.
private var bufferPool: CVPixelBufferPool?

/// Used to track buffer pool width.
private var bufferPoolWidth: Int = 0

/// Used to track buffer pool height.
private var bufferPoolHeight: Int = 0

/// A segmentation processor used to predict foreground of an image.
/// See `SegmentationProcessor` for more details.
private let segmentationProcessor: SegmentationProcessor

/// Custom logger to log any errors or warnings.
let logger: Logger

/// Segmentation processor height.
private var segmentationProcessorHeight = 256

/// Segmentation processor width.
private var segmentationProcessorWidth = 144

/// Static method to check whether BackgroundFilterProcessor can be used. This verifies that the builder
/// has linked the necessary runtime framework (i.e. `AmazonChimeSDKMachineLearning`) to
/// use this class.
///
/// - Returns: true if the class can be used, otherwise false.
public static func isAvailable() -> Bool {
return TensorFlowSegmentationProcessor.isAvailable()
}

/// Public constructor to initialize the processor.
///
/// - Parameters:
/// - logger: Custom logger to log events.
public init(logger: Logger) {
self.logger = logger
if !BackgroundFilterProcessor.isAvailable() {
self.logger.error(msg: "Unable to load TensorFlowLiteSegmentationProcessor. " +
"See `Update Project File` section in README for more information " +
"on how to import `AmazonChimeSDKMachineLearning` framework " +
"and the `selfie_segmentation_landscape.tflite` as a bundle resource " +
"to your project.")
segmentationProcessor = NoopSegmentationProcessor()
} else {
segmentationProcessor = TensorFlowSegmentationProcessor()
}
}

/// Creates the alpha mask [0-255] of the foreground image using `SegmentationProcessor`.
///
/// - Parameters:
/// - inputFrameCG: Input CGImage frame to produce the foreground image.
/// - inputFrameCI: Input CIImage frame to produce the foreground image.
///
/// - Returns: Alpha mask CGImage of the foreground.
public func createForegroundAlphaMask(inputFrameCG: CGImage,
inputFrameCI: CIImage) -> CIImage? {
// Verify that the processor is available.
if !BackgroundFilterProcessor.isAvailable() {
return nil
}

// Number of the input image color space channels.
let imageChannels = inputFrameCG.bitsPerPixel / inputFrameCG.bitsPerComponent

// Update the buffer pool dimensions if the new frame does not match the previous frame dimensions.
if bufferPool == nil || inputFrameCG.width != bufferPoolWidth || inputFrameCG.height != bufferPoolHeight {

updateBufferPool(newWidth: inputFrameCG.width, newHeight: inputFrameCG.height)
// Initialize the segmentationProcessor if it has not been initialized.
let initializeResult: Bool = segmentationProcessor.initialize(height: segmentationProcessorHeight,
width: segmentationProcessorWidth,
channels: imageChannels)
if !initializeResult {
logger.error(msg: "Unable to initialize segmentation processor.")
return nil
}
}

// Check if segmentation model has loaded.
if segmentationProcessor.getModelState() != CwtModelState.LOADED.rawValue {
logger.error(msg: "Segmentation processor failed to start. Unable to perform segmentation.")
return nil
}

// Calculate the scale and aspect ratio factor to downscale.
let downSampleScale = Double(segmentationProcessorHeight) / Double(inputFrameCG.height)
let downSampleAspectRatio: Double = (Double(inputFrameCG.height) / Double(inputFrameCG.width)) /
(Double(segmentationProcessorHeight) / Double(segmentationProcessorWidth))

// Down sample the image.
guard let downSampleImage: CIImage = resizeImage(image: inputFrameCI,
scale: downSampleScale,
aspectRatio: downSampleAspectRatio)
else {
return nil
}

guard let downSampleImageCg = context.createCGImage(downSampleImage,
from: downSampleImage.extent)
else {
return nil
}

// Convert the input CGImage to a UInt8 byte array.
guard var byteArray: [UInt8] = ImageConversionUtils.cgImageToByteArray(cgImage: downSampleImageCg) else {
logger.error(msg: "Error converting CGImage to byte array when creating the foreground mask.")
return nil
}

// Copy the input buffer to the TensorFlow buffer which will be used during predict.
let inputBuffer: UnsafeMutablePointer<UInt8> = segmentationProcessor.getInputBuffer()
inputBuffer.initialize(from: &byteArray, count: byteArray.count)

// Predict the foreground mask.
let predictResult: Bool = segmentationProcessor.predict()
if !predictResult {
logger.error(msg: "Error predicting the foreground mask.")
return nil
}

// Retrieve the foreground mask.
let maskOutputBuffer = segmentationProcessor.getOutputBuffer()

guard let maskImage: CGImage = ImageConversionUtils.byteArrayToCGImage(raw: maskOutputBuffer,
frameWidth: segmentationProcessorWidth,
frameHeight: segmentationProcessorHeight,
bytesPerPixel: imageChannels,
bitsPerComponent: inputFrameCG.bitsPerComponent)
else {
logger.error(msg: "Error creating CGImage of the foreground mask.")
return nil
}

let maskImageCi = CIImage(cgImage: maskImage)

// Calculate the scale and aspect ratio factor to upscale.
let upSampleScale = Double(inputFrameCG.height) / Double(segmentationProcessorHeight)
let upSampleAspectRatio: Double = (Double(segmentationProcessorHeight) / Double(segmentationProcessorWidth)) /
(Double(inputFrameCG.height) / Double(inputFrameCG.width))

// Upsample image back to it size.
guard let upSampleImage = resizeImage(image: maskImageCi,
scale: upSampleScale,
aspectRatio: upSampleAspectRatio)
else {
return nil
}

return upSampleImage
}

/// Blends foreground alpha mask with input image to produce a foreground image which is rendered on top
/// of a background image using `CIBlendWithAlphaMask` CIFilter.
///
/// - Parameters:
/// - inputFrameCI: Input image which is used to blend the foreground alpha mask to produce the foreground image.
/// - maskImage: Foreground alpha mask.
/// - backgroundImage: Background image which can be a blurred or a custom background image.
public func blendWithWithAlphaMask(inputFrameCI: CIImage,
maskImage: CIImage,
backgroundImage: CIImage) -> CIImage? {
guard let blendFilter = CIFilter(name: "CIBlendWithAlphaMask") else {
logger.error(msg: "Error creating CIBlendWithAlphaMask CIFilter.")
return nil
}

blendFilter.setValue(backgroundImage, forKey: "inputBackgroundImage")
blendFilter.setValue(inputFrameCI, forKey: "inputImage")
blendFilter.setValue(maskImage, forKey: "inputMaskImage")

// Create the output image.
guard let outputImage = blendFilter.outputImage else {
logger.error(msg: "Error creating the blended output image.")
return nil
}

return outputImage
}

/// Updates the buffer pool if the previous and new frame dimensions don't match.
///
/// - Parameters:
/// - newWidth: New frame width.
/// - newHeight: New frame height.
private func updateBufferPool(newWidth: Int, newHeight: Int) {
var attributes: [NSString: NSObject] = [:]
attributes[kCVPixelBufferPixelFormatTypeKey] = NSNumber(value: Int(kCVPixelFormatType_32BGRA))
attributes[kCVPixelBufferWidthKey] = NSNumber(value: newWidth)
attributes[kCVPixelBufferHeightKey] = NSNumber(value: newHeight)
attributes[kCVPixelBufferIOSurfacePropertiesKey] = [:] as NSObject
CVPixelBufferPoolCreate(nil, nil, attributes as CFDictionary?, &bufferPool)

bufferPoolWidth = newWidth
bufferPoolHeight = newHeight
}

/// - Returns: Buffer pool used to store the final image data.
public func getBufferPool() -> CVPixelBufferPool? {
return bufferPool
}

/// Resize an image using `CILanczosScaleTransform` CIFilter.
///
/// - Parameters:
/// - image: Image to scale.
/// - scale: Scaling factor.
/// - aspectRatio: Aspect ratio factor.
func resizeImage(image: CIImage, scale: Double, aspectRatio: Double) -> CIImage? {
guard let resizeFilter = CIFilter(name: "CILanczosScaleTransform") else {
logger.error(msg: "Error creating CILanczosScaleTransform CIFilter.")
return nil
}

resizeFilter.setValue(CGFloat.init(scale), forKey: kCIInputScaleKey)
resizeFilter.setValue(image, forKey: kCIInputImageKey)
resizeFilter.setValue(CGFloat.init(aspectRatio), forKey: "inputAspectRatio")

guard let outputResizedImage = resizeFilter.outputImage
else {
logger.error(msg: "Error resizing image.")
return nil
}
return outputResizedImage
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
//
// CwtEnum.h
// cwt
//
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
//

#ifndef CWT_ENUM_H
#define CWT_ENUM_H

#import <Foundation/Foundation.h>

#ifdef __cplusplus
extern "C" {
#endif

// CwtInputModelConfig represents the input model configuration used to
// set up the CWT model. This is the same as TFLiteModel::InputModelConfig.
typedef struct {
int in_height;
int in_width;
int in_channels;

int model_range_min;
int model_range_max;
} CwtInputModelConfig;

#ifdef __cplusplus
}
#endif

// CwtModelState represents the state of the model. This is the same as
// TFLiteModel::ModelState.
typedef NS_ENUM(NSUInteger, CwtModelState) {
EMPTY = 0,
LOADING = 1,
LOADED = 2,

FAILED_TO_INIT_MODEL = 1000,
FAILED_TO_INIT_INTERPRETER,
FAILED_TO_ALLOC_MEMORY,
FAILED_TO_DOWNLOAD_MODEL,
FAILED_TO_PREDICT,
};

// CwtPredictResult represents the result of predict invoking a model.
typedef NS_ENUM(NSUInteger, CwtPredictResult) {
SUCCESS = 0,
ERROR,
};

#endif // CWT_ENUM_H
Loading

0 comments on commit 9096237

Please sign in to comment.