Skip to content

Commit

Permalink
Merge branch 'master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
mhamilton723 authored Jul 22, 2024
2 parents 3888968 + d453ba2 commit db5e950
Show file tree
Hide file tree
Showing 50 changed files with 1,087 additions and 841 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,25 @@ trait HasCustomAuthHeader extends HasServiceParams {
}
}

trait HasCustomHeaders extends HasServiceParams {

val customHeaders = new ServiceParam[Map[String, String]](
this, "customHeaders", "Map of Custom Header Key-Value Tuples."
)

def setCustomHeaders(v: Map[String, String]): this.type = {
setScalarParam(customHeaders, v)
}

// For Pyspark compatability accept Java HashMap as input to parameter
// py4J only natively supports conversions from Python Dict to Java HashMap
def setCustomHeaders(v: java.util.HashMap[String,String]): this.type = {
setCustomHeaders(v.asScala.toMap)
}

def getCustomHeaders: Map[String, String] = getScalarParam(customHeaders)
}

trait HasCustomCogServiceDomain extends Wrappable with HasURL with HasUrlPath {
def setCustomServiceName(v: String): this.type = {
setUrl(s"https://$v.cognitiveservices.azure.com/" + urlPath.stripPrefix("/"))
Expand Down Expand Up @@ -256,7 +275,15 @@ object URLEncodingUtils {
}

trait HasCognitiveServiceInput extends HasURL with HasSubscriptionKey with HasAADToken with HasCustomAuthHeader
with SynapseMLLogging {
with HasCustomHeaders with SynapseMLLogging {

val customUrlRoot: Param[String] = new Param[String](
this, "customUrlRoot", "The custom URL root for the service. " +
"This will not append OpenAI specific model path completions (i.e. /chat/completions) to the URL.")

def getCustomUrlRoot: String = $(customUrlRoot)

def setCustomUrlRoot(v: String): this.type = set(customUrlRoot, v)

protected def paramNameToPayloadName(p: Param[_]): String = p match {
case p: ServiceParam[_] => p.payloadName
Expand All @@ -281,7 +308,11 @@ trait HasCognitiveServiceInput extends HasURL with HasSubscriptionKey with HasAA
} else {
""
}
prepareUrlRoot(row) + appended
if (get(customUrlRoot).nonEmpty) {
$(customUrlRoot)
} else {
prepareUrlRoot(row) + appended
}
}
}

Expand All @@ -296,20 +327,25 @@ trait HasCognitiveServiceInput extends HasURL with HasSubscriptionKey with HasAA
protected def contentType: Row => String = { _ => "application/json" }

protected def getCustomAuthHeader(row: Row): Option[String] = {
val providedCustomHeader = getValueOpt(row, CustomAuthHeader)
if (providedCustomHeader .isEmpty && PlatformDetails.runningOnFabric()) {
val providedCustomAuthHeader = getValueOpt(row, CustomAuthHeader)
if (providedCustomAuthHeader .isEmpty && PlatformDetails.runningOnFabric()) {
logInfo("Using Default AAD Token On Fabric")
Option(TokenLibrary.getAuthHeader)
} else {
providedCustomHeader
providedCustomAuthHeader
}
}

protected def getCustomHeaders(row: Row): Option[Map[String, String]] = {
getValueOpt(row, customHeaders)
}

protected def addHeaders(req: HttpRequestBase,
subscriptionKey: Option[String],
aadToken: Option[String],
contentType: String = "",
customAuthHeader: Option[String] = None): Unit = {
customAuthHeader: Option[String] = None,
customHeaders: Option[Map[String, String]] = None): Unit = {

if (subscriptionKey.nonEmpty) {
req.setHeader(subscriptionKeyHeaderName, subscriptionKey.get)
Expand All @@ -326,6 +362,13 @@ trait HasCognitiveServiceInput extends HasURL with HasSubscriptionKey with HasAA
req.setHeader("x-ms-workload-resource-moniker", UUID.randomUUID().toString)
})
}
if (customHeaders.nonEmpty) {
customHeaders.foreach(m => {
m.foreach {
case (headerName, headerValue) => req.setHeader(headerName, headerValue)
}
})
}
if (contentType != "") req.setHeader("Content-Type", contentType)
}

Expand All @@ -342,7 +385,8 @@ trait HasCognitiveServiceInput extends HasURL with HasSubscriptionKey with HasAA
getValueOpt(row, subscriptionKey),
getValueOpt(row, AADToken),
contentType(row),
getCustomAuthHeader(row))
getCustomAuthHeader(row),
getCustomHeaders(row))

req match {
case er: HttpEntityEnclosingRequestBase =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ import org.apache.spark.sql.types.{DataType, StringType}
import spray.json.DefaultJsonProtocol._
import spray.json._

@deprecated("The Form Recognition v2.1 API is deprecated please use " +
"com.microsoft.azure.synapse.ml.services.form.AnalyzeDocument", "v1.0.4")
abstract class FormRecognizerBase(override val uid: String) extends CognitiveServicesBaseNoHandler(uid)
with HasCognitiveServiceInput with HasInternalJsonOutputParser with BasicAsyncReply
with HasImageInput with HasSetLocation with HasSetLinkedService {
Expand Down Expand Up @@ -99,6 +101,8 @@ trait HasLocale extends HasServiceParams {

}

@deprecated("The Form Recognition v2.1 API is deprecated please use " +
"com.microsoft.azure.synapse.ml.services.form.AnalyzeDocument", "v1.0.4")
object FormsFlatteners {

import FormsJsonProtocol._
Expand Down Expand Up @@ -183,8 +187,12 @@ object FormsFlatteners {
}
}

@deprecated("The Form Recognition v2.1 API is deprecated please use " +
"com.microsoft.azure.synapse.ml.services.form.AnalyzeDocument", "v1.0.4")
object AnalyzeLayout extends ComplexParamsReadable[AnalyzeLayout]

@deprecated("The Form Recognition v2.1 API is deprecated please use " +
"com.microsoft.azure.synapse.ml.services.form.AnalyzeDocument", "v1.0.4")
class AnalyzeLayout(override val uid: String) extends FormRecognizerBase(uid)
with SynapseMLLogging with HasPages {
logClass(FeatureNames.AiServices.Form)
Expand Down Expand Up @@ -216,8 +224,12 @@ class AnalyzeLayout(override val uid: String) extends FormRecognizerBase(uid)

}

@deprecated("The Form Recognition v2.1 API is deprecated please use " +
"com.microsoft.azure.synapse.ml.services.form.AnalyzeDocument", "v1.0.4")
object AnalyzeReceipts extends ComplexParamsReadable[AnalyzeReceipts]

@deprecated("The Form Recognition v2.1 API is deprecated please use " +
"com.microsoft.azure.synapse.ml.services.form.AnalyzeDocument", "v1.0.4")
class AnalyzeReceipts(override val uid: String) extends FormRecognizerBase(uid)
with SynapseMLLogging with HasPages with HasTextDetails with HasLocale {
logClass(FeatureNames.AiServices.Form)
Expand All @@ -230,8 +242,12 @@ class AnalyzeReceipts(override val uid: String) extends FormRecognizerBase(uid)

}

@deprecated("The Form Recognition v2.1 API is deprecated please use " +
"com.microsoft.azure.synapse.ml.services.form.AnalyzeDocument", "v1.0.4")
object AnalyzeBusinessCards extends ComplexParamsReadable[AnalyzeBusinessCards]

@deprecated("The Form Recognition v2.1 API is deprecated please use " +
"com.microsoft.azure.synapse.ml.services.form.AnalyzeDocument", "v1.0.4")
class AnalyzeBusinessCards(override val uid: String) extends FormRecognizerBase(uid)
with SynapseMLLogging with HasPages with HasTextDetails with HasLocale {
logClass(FeatureNames.AiServices.Form)
Expand All @@ -244,8 +260,12 @@ class AnalyzeBusinessCards(override val uid: String) extends FormRecognizerBase(

}

@deprecated("The Form Recognition v2.1 API is deprecated please use " +
"com.microsoft.azure.synapse.ml.services.form.AnalyzeDocument", "v1.0.4")
object AnalyzeInvoices extends ComplexParamsReadable[AnalyzeInvoices]

@deprecated("The Form Recognition v2.1 API is deprecated please use " +
"com.microsoft.azure.synapse.ml.services.form.AnalyzeDocument", "v1.0.4")
class AnalyzeInvoices(override val uid: String) extends FormRecognizerBase(uid)
with SynapseMLLogging with HasPages with HasTextDetails with HasLocale {
logClass(FeatureNames.AiServices.Form)
Expand All @@ -258,8 +278,12 @@ class AnalyzeInvoices(override val uid: String) extends FormRecognizerBase(uid)

}

@deprecated("The Form Recognition v2.1 API is deprecated please use " +
"com.microsoft.azure.synapse.ml.services.form.AnalyzeDocument", "v1.0.4")
object AnalyzeIDDocuments extends ComplexParamsReadable[AnalyzeIDDocuments]

@deprecated("The Form Recognition v2.1 API is deprecated please use " +
"com.microsoft.azure.synapse.ml.services.form.AnalyzeDocument", "v1.0.4")
class AnalyzeIDDocuments(override val uid: String) extends FormRecognizerBase(uid)
with SynapseMLLogging with HasPages with HasTextDetails {
logClass(FeatureNames.AiServices.Form)
Expand All @@ -272,8 +296,12 @@ class AnalyzeIDDocuments(override val uid: String) extends FormRecognizerBase(ui

}

@deprecated("The Form Recognition v2.1 API is deprecated please use " +
"com.microsoft.azure.synapse.ml.services.form.AnalyzeDocument", "v1.0.4")
object ListCustomModels extends ComplexParamsReadable[ListCustomModels]

@deprecated("The Form Recognition v2.1 API is deprecated please use " +
"com.microsoft.azure.synapse.ml.services.form.AnalyzeDocument", "v1.0.4")
class ListCustomModels(override val uid: String) extends CognitiveServicesBase(uid)
with HasCognitiveServiceInput with HasInternalJsonOutputParser
with HasSetLocation with HasSetLinkedService with SynapseMLLogging {
Expand All @@ -297,8 +325,12 @@ class ListCustomModels(override val uid: String) extends CognitiveServicesBase(u
override protected def responseDataType: DataType = ListCustomModelsResponse.schema
}

@deprecated("The Form Recognition v2.1 API is deprecated please use " +
"com.microsoft.azure.synapse.ml.services.form.AnalyzeDocument", "v1.0.4")
object GetCustomModel extends ComplexParamsReadable[GetCustomModel]

@deprecated("The Form Recognition v2.1 API is deprecated please use " +
"com.microsoft.azure.synapse.ml.services.form.AnalyzeDocument", "v1.0.4")
class GetCustomModel(override val uid: String) extends CognitiveServicesBase(uid)
with HasCognitiveServiceInput with HasInternalJsonOutputParser
with HasSetLocation with HasSetLinkedService with SynapseMLLogging with HasModelID {
Expand Down Expand Up @@ -326,8 +358,12 @@ class GetCustomModel(override val uid: String) extends CognitiveServicesBase(uid
override protected def responseDataType: DataType = GetCustomModelResponse.schema
}

@deprecated("The Form Recognition v2.1 API is deprecated please use " +
"com.microsoft.azure.synapse.ml.services.form.AnalyzeDocument", "v1.0.4")
object AnalyzeCustomModel extends ComplexParamsReadable[AnalyzeCustomModel]

@deprecated("The Form Recognition v2.1 API is deprecated please use " +
"com.microsoft.azure.synapse.ml.services.form.AnalyzeDocument", "v1.0.4")
class AnalyzeCustomModel(override val uid: String) extends FormRecognizerBase(uid)
with SynapseMLLogging with HasTextDetails with HasModelID {
logClass(FeatureNames.AiServices.Form)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class AnalyzeDocument(override val uid: String) extends CognitiveServicesBaseNoH
with HasImageInput with HasSetLocation with SynapseMLLogging with HasSetLinkedService {
logClass(FeatureNames.AiServices.Anomaly)

setDefault(apiVersion -> Left("2022-08-31"))
setDefault(apiVersion -> Left("2023-07-31"))

def this() = this(Identifiable.randomUID("AnalyzeDocument"))

Expand All @@ -60,6 +60,30 @@ class AnalyzeDocument(override val uid: String) extends CognitiveServicesBaseNoH

def getStringIndexTypeCol: String = getVectorParam(stringIndexType)


val features = new ServiceParam[Seq[String]](this, "features",
"List of optional analysis features. (barcodes,formulas,keyValuePairs,languages,ocrHighResolution,styleFont)",
{
case Left(s) => s.forall(entry => Set(
"barcodes",
"formulas",
"keyValuePairs",
"languages",
"ocrHighResolution",
"styleFont"
)(entry))
case Right(_) => true
}, isURLParam = true)

def setFeatures(v: Seq[String]): this.type = setScalarParam(features, v)

def setFeaturesCol(v: String): this.type = setVectorParam(features, v)

def getFeatures: Seq[String] = getScalarParam(features)

def getFeaturesCol: String = getVectorParam(features)


override protected def responseDataType: DataType = AnalyzeDocumentResponse.schema

override protected def prepareEntity: Row => Option[AbstractHttpEntity] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ case class PageResultV3(pageNumber: Int,
spans: Seq[FormSpan],
words: Option[Seq[FormWord]],
selectionMarks: Option[Seq[FormSelectionMark]],
lines: Option[Seq[FormLine]])
lines: Option[Seq[FormLine]],
barcodes: Option[Seq[FormBarcode]])

case class DocumentParagraph(role: Option[String],
content: String,
Expand All @@ -50,6 +51,12 @@ case class FormSelectionMark(state: String, polygon: Option[Seq[Double]], confid

case class FormLine(content: String, polygon: Option[Seq[Double]], spans: Option[Seq[FormSpan]])

case class FormBarcode(confidence: Option[Double],
kind: Option[String],
polygon: Option[Seq[Double]],
span: Option[FormSpan],
value: Option[String])

case class TableResultV3(rowCount: Int,
columnCount: Int,
boundingRegions: Option[Seq[BoundingRegion]],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import com.microsoft.azure.synapse.ml.logging.common.PlatformDetails
import com.microsoft.azure.synapse.ml.param.ServiceParam
import com.microsoft.azure.synapse.ml.services._
import org.apache.spark.ml.PipelineModel
import org.apache.spark.ml.param.{Param, Params}
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._
import spray.json.DefaultJsonProtocol._
Expand Down Expand Up @@ -40,10 +41,20 @@ trait HasPromptInputs extends HasServiceParams {

}

trait HasMessagesInput extends Params {
val messagesCol: Param[String] = new Param[String](
this, "messagesCol", "The column messages to generate chat completions for," +
" in the chat format. This column should have type Array(Struct(role: String, content: String)).")

def getMessagesCol: String = $(messagesCol)

def setMessagesCol(v: String): this.type = set(messagesCol, v)
}

trait HasOpenAISharedParams extends HasServiceParams with HasAPIVersion {

val deploymentName = new ServiceParam[String](
this, "deploymentName", "The name of the deployment", isRequired = true)
this, "deploymentName", "The name of the deployment", isRequired = false)

def getDeploymentName: String = getScalarParam(deploymentName)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,10 @@ import scala.language.existentials
object OpenAIChatCompletion extends ComplexParamsReadable[OpenAIChatCompletion]

class OpenAIChatCompletion(override val uid: String) extends OpenAIServicesBase(uid)
with HasOpenAITextParams with HasOpenAICognitiveServiceInput
with HasOpenAITextParams with HasMessagesInput with HasOpenAICognitiveServiceInput
with HasInternalJsonOutputParser with SynapseMLLogging {
logClass(FeatureNames.AiServices.OpenAI)

val messagesCol: Param[String] = new Param[String](
this, "messagesCol", "The column messages to generate chat completions for," +
" in the chat format. This column should have type Array(Struct(role: String, content: String)).")

def getMessagesCol: String = $(messagesCol)

def setMessagesCol(v: String): this.type = set(messagesCol, v)

def this() = this(Identifiable.randomUID("OpenAIChatCompletion"))

def urlPath: String = ""
Expand Down
Loading

0 comments on commit db5e950

Please sign in to comment.