Skip to content

Commit

Permalink
Merge branch 'shyamsai/GlobalParamObject2' of https://github.com/sss0…
Browse files Browse the repository at this point in the history
…4/SynapseML into shyamsai/GlobalParamObject2
  • Loading branch information
sss04 committed Nov 22, 2024
2 parents ca44c8f + f4f72cc commit 6e488e3
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,12 @@ trait HasOpenAIEmbeddingParams extends HasOpenAISharedParams with HasAPIVersion
}

trait HasOpenAITextParams extends HasOpenAISharedParams {

val maxTokens: ServiceParam[Int] = new ServiceParam[Int](
this, "maxTokens",
"The maximum number of tokens to generate. Has minimum of 0.",
isRequired = false)
isRequired = false){
override val payloadName: String = "max_tokens"
}

def getMaxTokens: Int = getScalarParam(maxTokens)

Expand Down Expand Up @@ -153,7 +154,9 @@ trait HasOpenAITextParams extends HasOpenAISharedParams {
" So 0.1 means only the tokens comprising the top 10 percent probability mass are considered." +
" We generally recommend using this or `temperature` but not both." +
" Minimum of 0 and maximum of 1 allowed.",
isRequired = false)
isRequired = false) {
override val payloadName: String = "top_p"
}

def getTopP: Double = getScalarParam(topP)

Expand Down Expand Up @@ -182,7 +185,9 @@ trait HasOpenAITextParams extends HasOpenAISharedParams {
" So for example, if `logprobs` is 10, the API will return a list of the 10 most likely tokens." +
" If `logprobs` is 0, only the chosen tokens will have logprobs returned." +
" Minimum of 0 and maximum of 100 allowed.",
isRequired = false)
isRequired = false) {
override val payloadName: String = "logprobs"
}

def getLogProbs: Int = getScalarParam(logProbs)

Expand All @@ -208,7 +213,9 @@ trait HasOpenAITextParams extends HasOpenAISharedParams {
val cacheLevel: ServiceParam[Int] = new ServiceParam[Int](
this, "cacheLevel",
"can be used to disable any server-side caching, 0=no cache, 1=prompt prefix enabled, 2=full cache",
isRequired = false)
isRequired = false){
override val payloadName: String = "cache_level"
}

def getCacheLevel: Int = getScalarParam(cacheLevel)

Expand All @@ -222,7 +229,9 @@ trait HasOpenAITextParams extends HasOpenAISharedParams {
this, "presencePenalty",
"How much to penalize new tokens based on their existing frequency in the text so far." +
" Decreases the likelihood of the model to repeat the same line verbatim. Has minimum of -2 and maximum of 2.",
isRequired = false)
isRequired = false){
override val payloadName: String = "presence_penalty"
}

def getPresencePenalty: Double = getScalarParam(presencePenalty)

Expand All @@ -236,7 +245,9 @@ trait HasOpenAITextParams extends HasOpenAISharedParams {
this, "frequencyPenalty",
"How much to penalize new tokens based on whether they appear in the text so far." +
" Increases the likelihood of the model to talk about new topics.",
isRequired = false)
isRequired = false){
override val payloadName: String = "frequency_penalty"
}

def getFrequencyPenalty: Double = getScalarParam(frequencyPenalty)

Expand All @@ -250,7 +261,9 @@ trait HasOpenAITextParams extends HasOpenAISharedParams {
this, "bestOf",
"How many generations to create server side, and display only the best." +
" Will not stream intermediate progress if best_of > 1. Has maximum value of 128.",
isRequired = false)
isRequired = false){
override val payloadName: String = "best_of"
}

def getBestOf: Int = getScalarParam(bestOf)

Expand All @@ -260,24 +273,27 @@ trait HasOpenAITextParams extends HasOpenAISharedParams {

def setBestOfCol(v: String): this.type = setVectorParam(bestOf, v)

// list of shared text parameters. In method getOptionalParams, we will iterate over these parameters
// to compute the optional parameters. Since this list never changes, we can create it once and reuse it.
private val sharedTextParams = Seq(
maxTokens,
temperature,
topP,
user,
n,
echo,
stop,
cacheLevel,
presencePenalty,
frequencyPenalty,
bestOf,
logProbs
)

private[ml] def getOptionalParams(r: Row): Map[String, Any] = {
Seq(
maxTokens,
temperature,
topP,
user,
n,
echo,
stop,
cacheLevel,
presencePenalty,
frequencyPenalty,
bestOf
).flatMap(param =>
getValueOpt(r, param).map(v => (GenerationUtils.camelToSnake(param.name), v))
).++(Seq(
getValueOpt(r, logProbs).map(v => ("logprobs", v))
).flatten).toMap
sharedTextParams.flatMap { param =>
getValueOpt(r, param).map { value => param.payloadName -> value }
}.toMap
}
}

Expand Down
10 changes: 6 additions & 4 deletions pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,10 @@ jobs:
TEST-CLASS: "com.microsoft.azure.synapse.ml.nbtest.DatabricksCPUTests"
databricks-gpu:
TEST-CLASS: "com.microsoft.azure.synapse.ml.nbtest.DatabricksGPUTests"
databricks-rapids:
TEST-CLASS: "com.microsoft.azure.synapse.ml.nbtest.DatabricksRapidsTests"
# databricks-rapids tests have been disabled because these tests are failing.
# This test will be re-enabled once the issue is fixed.
# databricks-rapids:
# TEST-CLASS: "com.microsoft.azure.synapse.ml.nbtest.DatabricksRapidsTests"
synapse:
TEST-CLASS: "com.microsoft.azure.synapse.ml.nbtest.SynapseTests"
# ${{ if eq(parameters.runSynapseExtensionE2ETests, true) }}:
Expand Down Expand Up @@ -264,7 +266,7 @@ jobs:
chmod +x git-chglog_linux_amd64
./git-chglog_linux_amd64 -o CHANGELOG.md $TAG
condition: and(eq(variables.isMaster, true), startsWith(variables['tag'], 'v'))
- task: GitHubRelease@0
- task: GitHubRelease@1
condition: and(eq(variables.isMaster, true), startsWith(variables['tag'], 'v'))
inputs:
gitHubConnection: 'MMLSpark Github'
Expand Down Expand Up @@ -294,7 +296,7 @@ jobs:
conda env create --force -f environment.yml -v
condition: and(eq(variables.isMaster, true), and(startsWith(variables['tag'], 'v'), eq(variables.CONDA_CACHE_RESTORED, 'false')))
displayName: Create Anaconda environment
- task: AzureKeyVault@1
- task: AzureKeyVault@2
condition: and(eq(variables.isMaster, true), startsWith(variables['tag'], 'v'))
inputs:
azureSubscription: 'SynapseML Build'
Expand Down
2 changes: 1 addition & 1 deletion templates/kv.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
steps:
- task: AzureKeyVault@1
- task: AzureKeyVault@2
retryCountOnTaskFailure: 3
inputs:
azureSubscription: 'SynapseML Build'
Expand Down

0 comments on commit 6e488e3

Please sign in to comment.