Skip to content

Commit

Permalink
chore: check Fabric Tenant (#2175)
Browse files Browse the repository at this point in the history
* check Fabric Tenant

* update check fabric setting

* add missing import

* fix style

* clean up

* clean unused import

* fix typo

* add doc link in error message

* add doc link in error message
  • Loading branch information
JessicaXYWang authored Mar 11, 2024
1 parent b390fd4 commit ce5cc5c
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
package com.microsoft.azure.synapse.ml.services.openai

import com.microsoft.azure.synapse.ml.codegen.GenerationUtils
import com.microsoft.azure.synapse.ml.fabric.OpenAITokenLibrary
import com.microsoft.azure.synapse.ml.fabric.{FabricClient, OpenAIFabricSetting, OpenAITokenLibrary}
import com.microsoft.azure.synapse.ml.logging.common.PlatformDetails
import com.microsoft.azure.synapse.ml.services.{CognitiveServicesBase, HasAPIVersion,
HasCognitiveServiceInput, HasServiceParams}
import com.microsoft.azure.synapse.ml.param.ServiceParam
import com.microsoft.azure.synapse.ml.services._
import org.apache.spark.ml.PipelineModel
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._
import spray.json.DefaultJsonProtocol._

import scala.language.existentials
Expand Down Expand Up @@ -256,10 +257,21 @@ trait HasOpenAICognitiveServiceInput extends HasCognitiveServiceInput {
} else {
providedCustomHeader
}

}
}

abstract class OpenAIServicesBase(override val uid: String) extends CognitiveServicesBase(uid: String) {
abstract class OpenAIServicesBase(override val uid: String) extends CognitiveServicesBase(uid: String)
with HasOpenAISharedParams with OpenAIFabricSetting {
setDefault(timeout -> 360.0)

private def usingDefaultOpenAIEndpoint(): Boolean = {
getUrl == FabricClient.MLWorkloadEndpointML + "/cognitive/openai/"
}

override protected def getInternalTransformer(schema: StructType): PipelineModel = {
if (PlatformDetails.runningOnFabric() && usingDefaultOpenAIEndpoint) {
getModelStatus(getDeploymentName)
}
super.getInternalTransformer(schema)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.azure.synapse.ml.fabric

import spray.json.{JsValue, JsString}

trait OpenAIFabricSetting extends RESTUtils {

private def getHeaders: Map[String, String] = {
Map(
"Authorization" -> s"Bearer ${TokenLibrary.getAccessToken}",
"Content-Type" -> "application/json"
)
}

def usagePost(url: String, body: String): JsValue = {
usagePost(url, body, getHeaders);
}

def getModelStatus(modelName: String): Boolean = {

val payload =
s"""["${modelName}"]"""

val mlWorkloadEndpointML = FabricClient.MLWorkloadEndpointML
val url = mlWorkloadEndpointML + "cognitive/openai/tenantsetting"
val modelStatus = usagePost(url, payload).asJsObject.fields.get(modelName.toLowerCase).get

// Allowed, Disallowed, DisallowedForCrossGeo, ModelNotFound, InvalidResult
val resultString: String = modelStatus match {
case JsString(value) => value
case _ => throw new RuntimeException("Unexpected result from type conversion " +
"when checking the fabric tenant settings API.")
}

resultString match {
case "Disallowed" => throw new RuntimeException(s"Default OpenAI model ${modelName} is Disallowed, " +
s"please contact your admin if you want to use default fabric LLM model. " +
s"Or you can set your Azure OpenAI credentials.")
case "DisallowedForCrossGeo" => throw new RuntimeException(s"Default OpenAI model ${modelName} is Disallowed " +
s"for Cross Geo, please contact your admin if you want to use default fabric LLM model. " +
s"Or you can set your Azure OpenAI credentials." +
s"Refer to https://learn.microsoft.com/en-us/fabric/data-science/ai-services/ai-services-overview " +
s"for more detials")
case "ModelNotFound" => throw new RuntimeException(s"Default OpenAI model ${modelName} not found, " +
s"please check your deployment name. " +
s"Refer to https://learn.microsoft.com/en-us/fabric/data-science/ai-services/ai-services-overview " +
s"for the models available.")
case "InvalidResult" => throw new RuntimeException("Cannot get tenant admin setting status correctly")
case "Allowed" => true
case _ => throw new RuntimeException("Unexpected result from checking the Fabric tenant settings API.")
}
}

}

0 comments on commit ce5cc5c

Please sign in to comment.