forked from apache/spark
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[SPARK-21866][ML][PYSPARK] Adding spark image reader
## What changes were proposed in this pull request? Adding spark image reader, an implementation of schema for representing images in spark DataFrames The code is taken from the spark package located here: (https://github.com/Microsoft/spark-images) Please see the JIRA for more information (https://issues.apache.org/jira/browse/SPARK-21866) Please see mailing list for SPIP vote and approval information: (http://apache-spark-developers-list.1001551.n3.nabble.com/VOTE-SPIP-SPARK-21866-Image-support-in-Apache-Spark-td22510.html) # Background and motivation As Apache Spark is being used more and more in the industry, some new use cases are emerging for different data formats beyond the traditional SQL types or the numerical types (vectors and matrices). Deep Learning applications commonly deal with image processing. A number of projects add some Deep Learning capabilities to Spark (see list below), but they struggle to communicate with each other or with MLlib pipelines because there is no standard way to represent an image in Spark DataFrames. We propose to federate efforts for representing images in Spark by defining a representation that caters to the most common needs of users and library developers. This SPIP proposes a specification to represent images in Spark DataFrames and Datasets (based on existing industrial standards), and an interface for loading sources of images. It is not meant to be a full-fledged image processing library, but rather the core description that other libraries and users can rely on. Several packages already offer various processing facilities for transforming images or doing more complex operations, and each has various design tradeoffs that make them better as standalone solutions. This project is a joint collaboration between Microsoft and Databricks, which have been testing this design in two open source packages: MMLSpark and Deep Learning Pipelines. The proposed image format is an in-memory, decompressed representation that targets low-level applications. It is significantly more liberal in memory usage than compressed image representations such as JPEG, PNG, etc., but it allows easy communication with popular image processing libraries and has no decoding overhead. ## How was this patch tested? Unit tests in scala ImageSchemaSuite, unit tests in python Author: Ilya Matiach <[email protected]> Author: hyukjinkwon <[email protected]> Closes apache#19439 from imatiach-msft/ilmat/spark-images.
- Loading branch information
1 parent
0605ad7
commit 1edb317
Showing
16 changed files
with
721 additions
and
0 deletions.
There are no files selected for viewing
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
not an image |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
The images in the folder "kittens" are under the creative commons CC0 license, or no rights reserved: | ||
https://creativecommons.org/share-your-work/public-domain/cc0/ | ||
The images are taken from: | ||
https://ccsearch.creativecommons.org/image/detail/WZnbJSJ2-dzIDiuUUdto3Q== | ||
https://ccsearch.creativecommons.org/image/detail/_TlKu_rm_QrWlR0zthQTXA== | ||
https://ccsearch.creativecommons.org/image/detail/OPNnHJb6q37rSZ5o_L5JHQ== | ||
https://ccsearch.creativecommons.org/image/detail/B2CVP_j5KjwZm7UAVJ3Hvw== | ||
|
||
The chr30.4.184.jpg and grayscale.jpg images are also under the CC0 license, taken from: | ||
https://ccsearch.creativecommons.org/image/detail/8eO_qqotBfEm2UYxirLntw== | ||
|
||
The image under "multi-channel" directory is under the CC BY-SA 4.0 license cropped from: | ||
https://en.wikipedia.org/wiki/Alpha_compositing#/media/File:Hue_alpha_falloff.png |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
116 changes: 116 additions & 0 deletions
116
mllib/src/main/scala/org/apache/spark/ml/image/HadoopUtils.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.ml.image | ||
|
||
import scala.language.existentials | ||
import scala.util.Random | ||
|
||
import org.apache.commons.io.FilenameUtils | ||
import org.apache.hadoop.conf.{Configuration, Configured} | ||
import org.apache.hadoop.fs.{Path, PathFilter} | ||
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat | ||
|
||
import org.apache.spark.sql.SparkSession | ||
|
||
private object RecursiveFlag { | ||
/** | ||
* Sets the spark recursive flag and then restores it. | ||
* | ||
* @param value Value to set | ||
* @param spark Existing spark session | ||
* @param f The function to evaluate after setting the flag | ||
* @return Returns the evaluation result T of the function | ||
*/ | ||
def withRecursiveFlag[T](value: Boolean, spark: SparkSession)(f: => T): T = { | ||
val flagName = FileInputFormat.INPUT_DIR_RECURSIVE | ||
val hadoopConf = spark.sparkContext.hadoopConfiguration | ||
val old = Option(hadoopConf.get(flagName)) | ||
hadoopConf.set(flagName, value.toString) | ||
try f finally { | ||
old match { | ||
case Some(v) => hadoopConf.set(flagName, v) | ||
case None => hadoopConf.unset(flagName) | ||
} | ||
} | ||
} | ||
} | ||
|
||
/** | ||
* Filter that allows loading a fraction of HDFS files. | ||
*/ | ||
private class SamplePathFilter extends Configured with PathFilter { | ||
val random = new Random() | ||
|
||
// Ratio of files to be read from disk | ||
var sampleRatio: Double = 1 | ||
|
||
override def setConf(conf: Configuration): Unit = { | ||
if (conf != null) { | ||
sampleRatio = conf.getDouble(SamplePathFilter.ratioParam, 1) | ||
val seed = conf.getLong(SamplePathFilter.seedParam, 0) | ||
random.setSeed(seed) | ||
} | ||
} | ||
|
||
override def accept(path: Path): Boolean = { | ||
// Note: checking fileSystem.isDirectory is very slow here, so we use basic rules instead | ||
!SamplePathFilter.isFile(path) || random.nextDouble() < sampleRatio | ||
} | ||
} | ||
|
||
private object SamplePathFilter { | ||
val ratioParam = "sampleRatio" | ||
val seedParam = "seed" | ||
|
||
def isFile(path: Path): Boolean = FilenameUtils.getExtension(path.toString) != "" | ||
|
||
/** | ||
* Sets the HDFS PathFilter flag and then restores it. | ||
* Only applies the filter if sampleRatio is less than 1. | ||
* | ||
* @param sampleRatio Fraction of the files that the filter picks | ||
* @param spark Existing Spark session | ||
* @param seed Random number seed | ||
* @param f The function to evaluate after setting the flag | ||
* @return Returns the evaluation result T of the function | ||
*/ | ||
def withPathFilter[T]( | ||
sampleRatio: Double, | ||
spark: SparkSession, | ||
seed: Long)(f: => T): T = { | ||
val sampleImages = sampleRatio < 1 | ||
if (sampleImages) { | ||
val flagName = FileInputFormat.PATHFILTER_CLASS | ||
val hadoopConf = spark.sparkContext.hadoopConfiguration | ||
val old = Option(hadoopConf.getClass(flagName, null)) | ||
hadoopConf.setDouble(SamplePathFilter.ratioParam, sampleRatio) | ||
hadoopConf.setLong(SamplePathFilter.seedParam, seed) | ||
hadoopConf.setClass(flagName, classOf[SamplePathFilter], classOf[PathFilter]) | ||
try f finally { | ||
hadoopConf.unset(SamplePathFilter.ratioParam) | ||
hadoopConf.unset(SamplePathFilter.seedParam) | ||
old match { | ||
case Some(v) => hadoopConf.setClass(flagName, v, classOf[PathFilter]) | ||
case None => hadoopConf.unset(flagName) | ||
} | ||
} | ||
} else { | ||
f | ||
} | ||
} | ||
} |
257 changes: 257 additions & 0 deletions
257
mllib/src/main/scala/org/apache/spark/ml/image/ImageSchema.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,257 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.ml.image | ||
|
||
import java.awt.Color | ||
import java.awt.color.ColorSpace | ||
import java.io.ByteArrayInputStream | ||
import javax.imageio.ImageIO | ||
|
||
import scala.collection.JavaConverters._ | ||
|
||
import org.apache.spark.annotation.{Experimental, Since} | ||
import org.apache.spark.input.PortableDataStream | ||
import org.apache.spark.sql.{DataFrame, Row, SparkSession} | ||
import org.apache.spark.sql.types._ | ||
|
||
/** | ||
* :: Experimental :: | ||
* Defines the image schema and methods to read and manipulate images. | ||
*/ | ||
@Experimental | ||
@Since("2.3.0") | ||
object ImageSchema { | ||
|
||
val undefinedImageType = "Undefined" | ||
|
||
/** | ||
* (Scala-specific) OpenCV type mapping supported | ||
*/ | ||
val ocvTypes: Map[String, Int] = Map( | ||
undefinedImageType -> -1, | ||
"CV_8U" -> 0, "CV_8UC1" -> 0, "CV_8UC3" -> 16, "CV_8UC4" -> 24 | ||
) | ||
|
||
/** | ||
* (Java-specific) OpenCV type mapping supported | ||
*/ | ||
val javaOcvTypes: java.util.Map[String, Int] = ocvTypes.asJava | ||
|
||
/** | ||
* Schema for the image column: Row(String, Int, Int, Int, Int, Array[Byte]) | ||
*/ | ||
val columnSchema = StructType( | ||
StructField("origin", StringType, true) :: | ||
StructField("height", IntegerType, false) :: | ||
StructField("width", IntegerType, false) :: | ||
StructField("nChannels", IntegerType, false) :: | ||
// OpenCV-compatible type: CV_8UC3 in most cases | ||
StructField("mode", IntegerType, false) :: | ||
// Bytes in OpenCV-compatible order: row-wise BGR in most cases | ||
StructField("data", BinaryType, false) :: Nil) | ||
|
||
val imageFields: Array[String] = columnSchema.fieldNames | ||
|
||
/** | ||
* DataFrame with a single column of images named "image" (nullable) | ||
*/ | ||
val imageSchema = StructType(StructField("image", columnSchema, true) :: Nil) | ||
|
||
/** | ||
* Gets the origin of the image | ||
* | ||
* @return The origin of the image | ||
*/ | ||
def getOrigin(row: Row): String = row.getString(0) | ||
|
||
/** | ||
* Gets the height of the image | ||
* | ||
* @return The height of the image | ||
*/ | ||
def getHeight(row: Row): Int = row.getInt(1) | ||
|
||
/** | ||
* Gets the width of the image | ||
* | ||
* @return The width of the image | ||
*/ | ||
def getWidth(row: Row): Int = row.getInt(2) | ||
|
||
/** | ||
* Gets the number of channels in the image | ||
* | ||
* @return The number of channels in the image | ||
*/ | ||
def getNChannels(row: Row): Int = row.getInt(3) | ||
|
||
/** | ||
* Gets the OpenCV representation as an int | ||
* | ||
* @return The OpenCV representation as an int | ||
*/ | ||
def getMode(row: Row): Int = row.getInt(4) | ||
|
||
/** | ||
* Gets the image data | ||
* | ||
* @return The image data | ||
*/ | ||
def getData(row: Row): Array[Byte] = row.getAs[Array[Byte]](5) | ||
|
||
/** | ||
* Default values for the invalid image | ||
* | ||
* @param origin Origin of the invalid image | ||
* @return Row with the default values | ||
*/ | ||
private[spark] def invalidImageRow(origin: String): Row = | ||
Row(Row(origin, -1, -1, -1, ocvTypes(undefinedImageType), Array.ofDim[Byte](0))) | ||
|
||
/** | ||
* Convert the compressed image (jpeg, png, etc.) into OpenCV | ||
* representation and store it in DataFrame Row | ||
* | ||
* @param origin Arbitrary string that identifies the image | ||
* @param bytes Image bytes (for example, jpeg) | ||
* @return DataFrame Row or None (if the decompression fails) | ||
*/ | ||
private[spark] def decode(origin: String, bytes: Array[Byte]): Option[Row] = { | ||
|
||
val img = ImageIO.read(new ByteArrayInputStream(bytes)) | ||
|
||
if (img == null) { | ||
None | ||
} else { | ||
val isGray = img.getColorModel.getColorSpace.getType == ColorSpace.TYPE_GRAY | ||
val hasAlpha = img.getColorModel.hasAlpha | ||
|
||
val height = img.getHeight | ||
val width = img.getWidth | ||
val (nChannels, mode) = if (isGray) { | ||
(1, ocvTypes("CV_8UC1")) | ||
} else if (hasAlpha) { | ||
(4, ocvTypes("CV_8UC4")) | ||
} else { | ||
(3, ocvTypes("CV_8UC3")) | ||
} | ||
|
||
val imageSize = height * width * nChannels | ||
assert(imageSize < 1e9, "image is too large") | ||
val decoded = Array.ofDim[Byte](imageSize) | ||
|
||
// Grayscale images in Java require special handling to get the correct intensity | ||
if (isGray) { | ||
var offset = 0 | ||
val raster = img.getRaster | ||
for (h <- 0 until height) { | ||
for (w <- 0 until width) { | ||
decoded(offset) = raster.getSample(w, h, 0).toByte | ||
offset += 1 | ||
} | ||
} | ||
} else { | ||
var offset = 0 | ||
for (h <- 0 until height) { | ||
for (w <- 0 until width) { | ||
val color = new Color(img.getRGB(w, h)) | ||
|
||
decoded(offset) = color.getBlue.toByte | ||
decoded(offset + 1) = color.getGreen.toByte | ||
decoded(offset + 2) = color.getRed.toByte | ||
if (nChannels == 4) { | ||
decoded(offset + 3) = color.getAlpha.toByte | ||
} | ||
offset += nChannels | ||
} | ||
} | ||
} | ||
|
||
// the internal "Row" is needed, because the image is a single DataFrame column | ||
Some(Row(Row(origin, height, width, nChannels, mode, decoded))) | ||
} | ||
} | ||
|
||
/** | ||
* Read the directory of images from the local or remote source | ||
* | ||
* @note If multiple jobs are run in parallel with different sampleRatio or recursive flag, | ||
* there may be a race condition where one job overwrites the hadoop configs of another. | ||
* @note If sample ratio is less than 1, sampling uses a PathFilter that is efficient but | ||
* potentially non-deterministic. | ||
* | ||
* @param path Path to the image directory | ||
* @return DataFrame with a single column "image" of images; | ||
* see ImageSchema for the details | ||
*/ | ||
def readImages(path: String): DataFrame = readImages(path, null, false, -1, false, 1.0, 0) | ||
|
||
/** | ||
* Read the directory of images from the local or remote source | ||
* | ||
* @note If multiple jobs are run in parallel with different sampleRatio or recursive flag, | ||
* there may be a race condition where one job overwrites the hadoop configs of another. | ||
* @note If sample ratio is less than 1, sampling uses a PathFilter that is efficient but | ||
* potentially non-deterministic. | ||
* | ||
* @param path Path to the image directory | ||
* @param sparkSession Spark Session, if omitted gets or creates the session | ||
* @param recursive Recursive path search flag | ||
* @param numPartitions Number of the DataFrame partitions, | ||
* if omitted uses defaultParallelism instead | ||
* @param dropImageFailures Drop the files that are not valid images from the result | ||
* @param sampleRatio Fraction of the files loaded | ||
* @return DataFrame with a single column "image" of images; | ||
* see ImageSchema for the details | ||
*/ | ||
def readImages( | ||
path: String, | ||
sparkSession: SparkSession, | ||
recursive: Boolean, | ||
numPartitions: Int, | ||
dropImageFailures: Boolean, | ||
sampleRatio: Double, | ||
seed: Long): DataFrame = { | ||
require(sampleRatio <= 1.0 && sampleRatio >= 0, "sampleRatio should be between 0 and 1") | ||
|
||
val session = if (sparkSession != null) sparkSession else SparkSession.builder().getOrCreate | ||
val partitions = | ||
if (numPartitions > 0) { | ||
numPartitions | ||
} else { | ||
session.sparkContext.defaultParallelism | ||
} | ||
|
||
RecursiveFlag.withRecursiveFlag(recursive, session) { | ||
SamplePathFilter.withPathFilter(sampleRatio, session, seed) { | ||
val binResult = session.sparkContext.binaryFiles(path, partitions) | ||
val streams = if (numPartitions == -1) binResult else binResult.repartition(partitions) | ||
val convert = (origin: String, bytes: PortableDataStream) => | ||
decode(origin, bytes.toArray()) | ||
val images = if (dropImageFailures) { | ||
streams.flatMap { case (origin, bytes) => convert(origin, bytes) } | ||
} else { | ||
streams.map { case (origin, bytes) => | ||
convert(origin, bytes).getOrElse(invalidImageRow(origin)) | ||
} | ||
} | ||
session.createDataFrame(images, imageSchema) | ||
} | ||
} | ||
} | ||
} |
Oops, something went wrong.