Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add AAD auth to azure search writer #2285

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -226,23 +226,33 @@ object AzureSearchWriter extends IndexParser with IndexJsonGetter with SLogging
private def prepareDF(df: DataFrame, //scalastyle:ignore method.length
options: Map[String, String] = Map()): DataFrame = {
val applicableOptions = Set(
"subscriptionKey", "actionCol", "serviceName", "indexName", "indexJson",
"apiVersion", "batchSize", "fatalErrors", "filterNulls", "keyCol", "vectorCols"
"subscriptionKey", "AADToken", "actionCol", "serviceName", "indexName", "indexJson",
"apiVersion", "batchSize", "fatalErrors", "filterNulls", "keyCol", "vectorCols", "url"
)

options.keys.foreach(k =>
assert(applicableOptions(k), s"$k not an applicable option ${applicableOptions.toList}"))

val subscriptionKey = options("subscriptionKey")
val subscriptionKey = options.get("subscriptionKey")
val aadToken = options.get("AADToken")

val actionCol = options.getOrElse("actionCol", "@search.action")

val serviceName = options("serviceName")
val indexJsonOpt = options.get("indexJson")
val apiVersion = options.getOrElse("apiVersion", AzureSearchAPIConstants.DefaultAPIVersion)

val batchSize = options.getOrElse("batchSize", "100").toInt
val fatalErrors = options.getOrElse("fatalErrors", "true").toBoolean
val filterNulls = options.getOrElse("filterNulls", "false").toBoolean
val vectorColsInfo = options.get("vectorCols")


assert(!(subscriptionKey.isEmpty && aadToken.isEmpty),
"No auth found: Please set either subscriptionKey or AADToken")
assert(!(subscriptionKey.isDefined && aadToken.isDefined),
"Both subscriptionKey and AADToken is set. Please set either subscriptionKey or AADToken")

val keyCol = options.get("keyCol")
val indexName = options.getOrElse("indexName", parseIndexJson(indexJsonOpt.get).name.get)
if (indexJsonOpt.isDefined) {
Expand All @@ -260,12 +270,13 @@ object AzureSearchWriter extends IndexParser with IndexJsonGetter with SLogging
}
}

val (indexJson, preppedDF) = if (getExisting(subscriptionKey, serviceName, apiVersion).contains(indexName)) {
val existingIndices = getExisting(subscriptionKey, aadToken, serviceName, apiVersion)
val (indexJson, preppedDF) = if (existingIndices.contains(indexName)) {
if (indexJsonOpt.isDefined) {
println(f"indexJsonOpt is specified, however an index for $indexName already exists," +
f"we will use the index definition obtained from the existing index instead")
}
val existingIndexJson = getIndexJsonFromExistingIndex(subscriptionKey, serviceName, indexName)
val existingIndexJson = getIndexJsonFromExistingIndex(subscriptionKey, aadToken, serviceName, indexName)
val vectorColNameTypeTuple = getVectorColConf(existingIndexJson)
(existingIndexJson, makeColsCompatible(vectorColNameTypeTuple, df))
} else if (indexJsonOpt.isDefined) {
Expand All @@ -283,7 +294,7 @@ object AzureSearchWriter extends IndexParser with IndexJsonGetter with SLogging
// Throws an exception if any nested field is a vector in the schema
parseIndexJson(indexJson).fields.foreach(_.fields.foreach(assertNoNestedVectors))

SearchIndex.createIfNoneExists(subscriptionKey, serviceName, indexJson, apiVersion)
SearchIndex.createIfNoneExists(subscriptionKey, aadToken, serviceName, indexJson, apiVersion)

logInfo("checking schema parity")
checkSchemaParity(preppedDF.schema, indexJson, actionCol)
Expand All @@ -297,15 +308,17 @@ object AzureSearchWriter extends IndexParser with IndexJsonGetter with SLogging
preppedDF
}

new AddDocuments()
.setSubscriptionKey(subscriptionKey)
val ad = new AddDocuments()
.setServiceName(serviceName)
.setIndexName(indexName)
.setActionCol(actionCol)
.setBatchSize(batchSize)
.setOutputCol("out")
.setErrorCol("error")
.transform(df1)
val ad1 = subscriptionKey.map(key => ad.setSubscriptionKey(key)).getOrElse(ad)
val ad2 = aadToken.map(token => ad1.setAADToken(token)).getOrElse(ad1)

ad2.transform(df1)
.withColumn("error",
UDFUtils.oldUdf(checkForErrors(fatalErrors) _, ErrorUtils.ErrorSchema)(col("error"), col("input")))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@ import org.apache.http.entity.StringEntity
import org.apache.log4j.{LogManager, Logger}
import spray.json._

import java.util.UUID
import scala.util.{Failure, Success, Try}

object AzureSearchAPIConstants {
val DefaultAPIVersion = "2023-07-01-Preview"
val VectorConfigName = "vectorConfig"
val VectorSearchAlgorithm = "hnsw"
val AADHeaderName = "Authorization"
}
import com.microsoft.azure.synapse.ml.services.search.AzureSearchAPIConstants._

Expand All @@ -27,34 +29,44 @@ trait IndexParser {
}

trait IndexLister {
def getExisting(key: String,

def getExisting(key: Option[String],
AADToken: Option[String],
serviceName: String,
apiVersion: String = DefaultAPIVersion): Seq[String] = {
val indexListRequest = new HttpGet(
val req = new HttpGet(
s"https://$serviceName.search.windows.net/indexes?api-version=$apiVersion&$$select=name"
)
indexListRequest.setHeader("api-key", key)
val indexListResponse = safeSend(indexListRequest, close = false)
val indexList = IOUtils.toString(indexListResponse.getEntity.getContent, "utf-8").parseJson.convertTo[IndexList]
indexListResponse.close()
key.foreach(k => req.setHeader("api-key", k))
AADToken.foreach { token =>
req.setHeader(AADHeaderName, "Bearer " + token)
}

val response = safeSend(req, close = false)
val indexList = IOUtils.toString(response.getEntity.getContent, "utf-8").parseJson.convertTo[IndexList]
response.close()
for (i <- indexList.value.seq) yield i.name
}
}

trait IndexJsonGetter extends IndexLister {
def getIndexJsonFromExistingIndex(key: String,
def getIndexJsonFromExistingIndex(key: Option[String],
AADToken: Option[String],
serviceName: String,
indexName: String,
apiVersion: String = DefaultAPIVersion): String = {
val existingIndexNames = getExisting(key, serviceName, apiVersion)
val existingIndexNames = getExisting(key, AADToken, serviceName, apiVersion)
assert(existingIndexNames.contains(indexName), s"Cannot find an existing index name with $indexName")

val indexJsonRequest = new HttpGet(
val req = new HttpGet(
s"https://$serviceName.search.windows.net/indexes/$indexName?api-version=$apiVersion"
)
indexJsonRequest.setHeader("api-key", key)
indexJsonRequest.setHeader("Content-Type", "application/json")
val indexJsonResponse = safeSend(indexJsonRequest, close = false)
key.foreach(k => req.setHeader("api-key", k))
AADToken.foreach { token =>
req.setHeader(AADHeaderName, "Bearer " + token)
}
req.setHeader("Content-Type", "application/json")
val indexJsonResponse = safeSend(req, close = false)
val indexJson = IOUtils.toString(indexJsonResponse.getEntity.getContent, "utf-8")
indexJsonResponse.close()
indexJson
Expand All @@ -67,20 +79,24 @@ object SearchIndex extends IndexParser with IndexLister {

val Logger: Logger = LogManager.getRootLogger

def createIfNoneExists(key: String,
def createIfNoneExists(key: Option[String],
AADToken: Option[String],
serviceName: String,
indexJson: String,
apiVersion: String = DefaultAPIVersion): Unit = {
val indexName = parseIndexJson(indexJson).name.get

val existingIndexNames = getExisting(key, serviceName, apiVersion)
val existingIndexNames = getExisting(key, AADToken, serviceName, apiVersion)

if (!existingIndexNames.contains(indexName)) {
val createRequest = new HttpPost(s"https://$serviceName.search.windows.net/indexes?api-version=$apiVersion")
createRequest.setHeader("Content-Type", "application/json")
createRequest.setHeader("api-key", key)
createRequest.setEntity(prepareEntity(indexJson))
val response = safeSend(createRequest)
val req = new HttpPost(s"https://$serviceName.search.windows.net/indexes?api-version=$apiVersion")
req.setHeader("Content-Type", "application/json")
key.foreach(k => req.setHeader("api-key", k))
AADToken.foreach { token =>
req.setHeader(AADHeaderName, "Bearer " + token)
}
req.setEntity(prepareEntity(indexJson))
val response = safeSend(req)
val status = response.getStatusLine.getStatusCode
assert(status == 201)
()
Expand Down Expand Up @@ -133,7 +149,7 @@ object SearchIndex extends IndexParser with IndexLister {
}

private def validType(t: String, fields: Option[Seq[IndexField]]): Try[String] = {
val tdt = Try(AzureSearchWriter.edmTypeToSparkType(t,fields))
val tdt = Try(AzureSearchWriter.edmTypeToSparkType(t, fields))
tdt.map(_ => t)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package com.microsoft.azure.synapse.ml.services.search

import com.microsoft.azure.synapse.ml.Secrets
import com.microsoft.azure.synapse.ml.Secrets.getAccessToken
import com.microsoft.azure.synapse.ml.services._
import com.microsoft.azure.synapse.ml.services.openai.{OpenAIAPIKey, OpenAIEmbedding}
import com.microsoft.azure.synapse.ml.services.vision.AnalyzeImage
Expand Down Expand Up @@ -132,9 +133,13 @@ class SearchWriterSuite extends TestBase with AzureSearchKey with IndexJsonGette

override def beforeAll(): Unit = {
println("WARNING CREATING SEARCH ENGINE!")
SearchIndex.createIfNoneExists(azureSearchKey,
SearchIndex.createIfNoneExists(
Some(azureSearchKey),
None,
testServiceName,
createSimpleIndexJson(indexName))
val aadToken = getAccessToken("https://search.azure.com")
println(s"Triggering token creation early ${aadToken.length}")
}

def deleteIndex(indexName: String): Int = {
Expand All @@ -148,7 +153,7 @@ class SearchWriterSuite extends TestBase with AzureSearchKey with IndexJsonGette
override def afterAll(): Unit = {
//TODO make this existing search indices when multiple builds are allowed
println("Cleaning up services")
val successfulCleanup = getExisting(azureSearchKey, testServiceName)
val successfulCleanup = getExisting(Some(azureSearchKey), None, testServiceName)
.intersect(createdIndexes).map { n =>
deleteIndex(n)
}.forall(_ == 204)
Expand All @@ -163,7 +168,7 @@ class SearchWriterSuite extends TestBase with AzureSearchKey with IndexJsonGette

val twoDaysAgo = LocalDateTime.now().minusDays(2)
val endingDatePattern: Regex = "^.*-(\\d{17})$".r
val e = getExisting(azureSearchKey, testServiceName)
val e = getExisting(Some(azureSearchKey), None, testServiceName)
e.foreach { name =>
name match {
case endingDatePattern(dateString) =>
Expand Down Expand Up @@ -235,7 +240,7 @@ class SearchWriterSuite extends TestBase with AzureSearchKey with IndexJsonGette
}

ignore("clean up all search indexes") {
getExisting(azureSearchKey, testServiceName)
getExisting(Some(azureSearchKey), None, testServiceName)
.foreach { n =>
val deleteRequest = new HttpDelete(
s"https://$testServiceName.search.windows.net/indexes/$n?api-version=2017-11-11")
Expand Down Expand Up @@ -266,7 +271,7 @@ class SearchWriterSuite extends TestBase with AzureSearchKey with IndexJsonGette
dependsOn(2, writeHelper(dfA, in2, isVectorField=false))

dependsOn(2, retryWithBackoff({
if (getExisting(azureSearchKey, testServiceName).contains(in2)) {
if (getExisting(Some(azureSearchKey), None, testServiceName).contains(in2)) {
writeHelper(dfB, in2, isVectorField=false)
} else {
throw new RuntimeException("No existing service found")
Expand Down Expand Up @@ -315,7 +320,7 @@ class SearchWriterSuite extends TestBase with AzureSearchKey with IndexJsonGette
""".stripMargin

assertThrows[IllegalArgumentException] {
SearchIndex.createIfNoneExists(azureSearchKey, testServiceName, badJson)
SearchIndex.createIfNoneExists(Some(azureSearchKey), None, testServiceName, badJson)
}
}

Expand Down Expand Up @@ -370,7 +375,9 @@ class SearchWriterSuite extends TestBase with AzureSearchKey with IndexJsonGette
("upload", "1", "file1", Array("p4", null, "p6")))
.toDF("searchAction", "id", "fileName", "phrases")

SearchIndex.createIfNoneExists(azureSearchKey,
SearchIndex.createIfNoneExists(
Some(azureSearchKey),
None,
testServiceName,
phraseIndex)

Expand Down Expand Up @@ -404,6 +411,27 @@ class SearchWriterSuite extends TestBase with AzureSearchKey with IndexJsonGette
retryWithBackoff(assertSize(in, 2))
}

test("Use AAD") {
val in = generateIndexName()
val phraseDF = Seq(
("upload", "0", "file0", Array("p1", "p2", "p3")),
("upload", "1", "file1", Array("p4", null, "p6")))
.toDF("searchAction", "id", "fileName", "phrases")
val aadToken = getAccessToken("https://search.azure.com")

AzureSearchWriter.write(phraseDF,
Map(
"AADToken" -> aadToken,
"actionCol" -> "searchAction",
"serviceName" -> testServiceName,
"filterNulls" -> "true",
"indexName" -> in,
"keyCol" -> "id"
))

retryWithBackoff(assertSize(in, 2))
}

test("pipeline with analyze image") {
val in = generateIndexName()

Expand Down Expand Up @@ -449,7 +477,7 @@ class SearchWriterSuite extends TestBase with AzureSearchKey with IndexJsonGette
writeHelper(dfA, in2, isVectorField=true)

retryWithBackoff({
if (getExisting(azureSearchKey, testServiceName).contains(in2)) {
if (getExisting(Some(azureSearchKey), None, testServiceName).contains(in2)) {
writeHelper(dfB, in2, isVectorField=true)
} else {
throw new RuntimeException("No existing service found")
Expand All @@ -459,7 +487,7 @@ class SearchWriterSuite extends TestBase with AzureSearchKey with IndexJsonGette
retryWithBackoff(assertSize(in1, 4))
retryWithBackoff(assertSize(in2, 10))

val indexJson = retryWithBackoff(getIndexJsonFromExistingIndex(azureSearchKey, testServiceName, in1))
val indexJson = retryWithBackoff(getIndexJsonFromExistingIndex(Some(azureSearchKey), None, testServiceName, in1))
// assert if vectorCol is a vector field
assert(parseIndexJson(indexJson).fields.find(_.name == "vectorCol").get.vectorSearchConfiguration.nonEmpty)
}
Expand Down Expand Up @@ -496,7 +524,7 @@ class SearchWriterSuite extends TestBase with AzureSearchKey with IndexJsonGette
retryWithBackoff(assertSize(in, 2))

// assert if vectorCols are a vector field
val indexJson = retryWithBackoff(getIndexJsonFromExistingIndex(azureSearchKey, testServiceName, in))
val indexJson = retryWithBackoff(getIndexJsonFromExistingIndex(Some(azureSearchKey), None, testServiceName, in))
assert(parseIndexJson(indexJson).fields.find(_.name == "vectorCol1").get.vectorSearchConfiguration.nonEmpty)
assert(parseIndexJson(indexJson).fields.find(_.name == "vectorCol2").get.vectorSearchConfiguration.nonEmpty)
assert(parseIndexJson(indexJson).fields.find(_.name == "vectorCol3").get.vectorSearchConfiguration.nonEmpty)
Expand Down Expand Up @@ -578,7 +606,7 @@ class SearchWriterSuite extends TestBase with AzureSearchKey with IndexJsonGette
""".stripMargin

assertThrows[IllegalArgumentException] {
SearchIndex.createIfNoneExists(azureSearchKey, testServiceName, badJson)
SearchIndex.createIfNoneExists(Some(azureSearchKey), None, testServiceName, badJson)
}
}

Expand Down Expand Up @@ -661,7 +689,7 @@ class SearchWriterSuite extends TestBase with AzureSearchKey with IndexJsonGette
))

retryWithBackoff(assertSize(in, 2))
val indexJson = retryWithBackoff(getIndexJsonFromExistingIndex(azureSearchKey, testServiceName, in))
val indexJson = retryWithBackoff(getIndexJsonFromExistingIndex(Some(azureSearchKey), None, testServiceName, in))
assert(parseIndexJson(indexJson).fields.find(_.name == "vectorContent").get.vectorSearchConfiguration.nonEmpty)
}
}
Loading