Skip to content

Commit

Permalink
chore: Optimizing the method getOptionalParams in HasOpenAITextParams (
Browse files Browse the repository at this point in the history
…#2315)

* Optimizing the method getOptionalParams

* Fixing failing unit test

* removing param cache

---------

Co-authored-by: Farrukh Masud <[email protected]>
  • Loading branch information
FarrukhMasud and FMasudMsft authored Nov 18, 2024
1 parent a360886 commit 08aab6a
Showing 1 changed file with 41 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,12 @@ trait HasOpenAIEmbeddingParams extends HasOpenAISharedParams with HasAPIVersion
}

trait HasOpenAITextParams extends HasOpenAISharedParams {

val maxTokens: ServiceParam[Int] = new ServiceParam[Int](
this, "maxTokens",
"The maximum number of tokens to generate. Has minimum of 0.",
isRequired = false)
isRequired = false){
override val payloadName: String = "max_tokens"
}

def getMaxTokens: Int = getScalarParam(maxTokens)

Expand Down Expand Up @@ -149,7 +150,9 @@ trait HasOpenAITextParams extends HasOpenAISharedParams {
" So 0.1 means only the tokens comprising the top 10 percent probability mass are considered." +
" We generally recommend using this or `temperature` but not both." +
" Minimum of 0 and maximum of 1 allowed.",
isRequired = false)
isRequired = false) {
override val payloadName: String = "top_p"
}

def getTopP: Double = getScalarParam(topP)

Expand Down Expand Up @@ -178,7 +181,9 @@ trait HasOpenAITextParams extends HasOpenAISharedParams {
" So for example, if `logprobs` is 10, the API will return a list of the 10 most likely tokens." +
" If `logprobs` is 0, only the chosen tokens will have logprobs returned." +
" Minimum of 0 and maximum of 100 allowed.",
isRequired = false)
isRequired = false) {
override val payloadName: String = "logprobs"
}

def getLogProbs: Int = getScalarParam(logProbs)

Expand All @@ -204,7 +209,9 @@ trait HasOpenAITextParams extends HasOpenAISharedParams {
val cacheLevel: ServiceParam[Int] = new ServiceParam[Int](
this, "cacheLevel",
"can be used to disable any server-side caching, 0=no cache, 1=prompt prefix enabled, 2=full cache",
isRequired = false)
isRequired = false){
override val payloadName: String = "cache_level"
}

def getCacheLevel: Int = getScalarParam(cacheLevel)

Expand All @@ -218,7 +225,9 @@ trait HasOpenAITextParams extends HasOpenAISharedParams {
this, "presencePenalty",
"How much to penalize new tokens based on their existing frequency in the text so far." +
" Decreases the likelihood of the model to repeat the same line verbatim. Has minimum of -2 and maximum of 2.",
isRequired = false)
isRequired = false){
override val payloadName: String = "presence_penalty"
}

def getPresencePenalty: Double = getScalarParam(presencePenalty)

Expand All @@ -232,7 +241,9 @@ trait HasOpenAITextParams extends HasOpenAISharedParams {
this, "frequencyPenalty",
"How much to penalize new tokens based on whether they appear in the text so far." +
" Increases the likelihood of the model to talk about new topics.",
isRequired = false)
isRequired = false){
override val payloadName: String = "frequency_penalty"
}

def getFrequencyPenalty: Double = getScalarParam(frequencyPenalty)

Expand All @@ -246,7 +257,9 @@ trait HasOpenAITextParams extends HasOpenAISharedParams {
this, "bestOf",
"How many generations to create server side, and display only the best." +
" Will not stream intermediate progress if best_of > 1. Has maximum value of 128.",
isRequired = false)
isRequired = false){
override val payloadName: String = "best_of"
}

def getBestOf: Int = getScalarParam(bestOf)

Expand All @@ -256,24 +269,27 @@ trait HasOpenAITextParams extends HasOpenAISharedParams {

def setBestOfCol(v: String): this.type = setVectorParam(bestOf, v)

// list of shared text parameters. In method getOptionalParams, we will iterate over these parameters
// to compute the optional parameters. Since this list never changes, we can create it once and reuse it.
private val sharedTextParams = Seq(
maxTokens,
temperature,
topP,
user,
n,
echo,
stop,
cacheLevel,
presencePenalty,
frequencyPenalty,
bestOf,
logProbs
)

private[ml] def getOptionalParams(r: Row): Map[String, Any] = {
Seq(
maxTokens,
temperature,
topP,
user,
n,
echo,
stop,
cacheLevel,
presencePenalty,
frequencyPenalty,
bestOf
).flatMap(param =>
getValueOpt(r, param).map(v => (GenerationUtils.camelToSnake(param.name), v))
).++(Seq(
getValueOpt(r, logProbs).map(v => ("logprobs", v))
).flatten).toMap
sharedTextParams.flatMap { param =>
getValueOpt(r, param).map { value => param.payloadName -> value }
}.toMap
}
}

Expand Down

0 comments on commit 08aab6a

Please sign in to comment.