Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: Optimizing the method getOptionalParams in HasOpenAITextParams #2315

Merged
merged 3 commits into from
Nov 18, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,12 @@ trait HasOpenAIEmbeddingParams extends HasOpenAISharedParams with HasAPIVersion
}

trait HasOpenAITextParams extends HasOpenAISharedParams {

val maxTokens: ServiceParam[Int] = new ServiceParam[Int](
this, "maxTokens",
"The maximum number of tokens to generate. Has minimum of 0.",
isRequired = false)
isRequired = false){
override val payloadName: String = "max_tokens"
}

def getMaxTokens: Int = getScalarParam(maxTokens)

Expand Down Expand Up @@ -149,7 +150,9 @@ trait HasOpenAITextParams extends HasOpenAISharedParams {
" So 0.1 means only the tokens comprising the top 10 percent probability mass are considered." +
" We generally recommend using this or `temperature` but not both." +
" Minimum of 0 and maximum of 1 allowed.",
isRequired = false)
isRequired = false) {
override val payloadName: String = "top_p"
}

def getTopP: Double = getScalarParam(topP)

Expand Down Expand Up @@ -178,7 +181,9 @@ trait HasOpenAITextParams extends HasOpenAISharedParams {
" So for example, if `logprobs` is 10, the API will return a list of the 10 most likely tokens." +
" If `logprobs` is 0, only the chosen tokens will have logprobs returned." +
" Minimum of 0 and maximum of 100 allowed.",
isRequired = false)
isRequired = false) {
override val payloadName: String = "logprobs"
}

def getLogProbs: Int = getScalarParam(logProbs)

Expand All @@ -204,7 +209,9 @@ trait HasOpenAITextParams extends HasOpenAISharedParams {
val cacheLevel: ServiceParam[Int] = new ServiceParam[Int](
this, "cacheLevel",
"can be used to disable any server-side caching, 0=no cache, 1=prompt prefix enabled, 2=full cache",
isRequired = false)
isRequired = false){
override val payloadName: String = "cache_level"
}

def getCacheLevel: Int = getScalarParam(cacheLevel)

Expand All @@ -218,7 +225,9 @@ trait HasOpenAITextParams extends HasOpenAISharedParams {
this, "presencePenalty",
"How much to penalize new tokens based on their existing frequency in the text so far." +
" Decreases the likelihood of the model to repeat the same line verbatim. Has minimum of -2 and maximum of 2.",
isRequired = false)
isRequired = false){
override val payloadName: String = "presence_penalty"
}

def getPresencePenalty: Double = getScalarParam(presencePenalty)

Expand All @@ -232,7 +241,9 @@ trait HasOpenAITextParams extends HasOpenAISharedParams {
this, "frequencyPenalty",
"How much to penalize new tokens based on whether they appear in the text so far." +
" Increases the likelihood of the model to talk about new topics.",
isRequired = false)
isRequired = false){
override val payloadName: String = "frequency_penalty"
}

def getFrequencyPenalty: Double = getScalarParam(frequencyPenalty)

Expand All @@ -246,7 +257,9 @@ trait HasOpenAITextParams extends HasOpenAISharedParams {
this, "bestOf",
"How many generations to create server side, and display only the best." +
" Will not stream intermediate progress if best_of > 1. Has maximum value of 128.",
isRequired = false)
isRequired = false){
override val payloadName: String = "best_of"
}

def getBestOf: Int = getScalarParam(bestOf)

Expand All @@ -256,24 +269,27 @@ trait HasOpenAITextParams extends HasOpenAISharedParams {

def setBestOfCol(v: String): this.type = setVectorParam(bestOf, v)

// list of shared text parameters. In method getOptionalParams, we will iterate over these parameters
// to compute the optional parameters. Since this list never changes, we can create it once and reuse it.
private val sharedTextParams = Seq(
maxTokens,
temperature,
topP,
user,
n,
echo,
stop,
cacheLevel,
presencePenalty,
frequencyPenalty,
bestOf,
logProbs
)

private[ml] def getOptionalParams(r: Row): Map[String, Any] = {
Seq(
maxTokens,
temperature,
topP,
user,
n,
echo,
stop,
cacheLevel,
presencePenalty,
frequencyPenalty,
bestOf
).flatMap(param =>
getValueOpt(r, param).map(v => (GenerationUtils.camelToSnake(param.name), v))
).++(Seq(
getValueOpt(r, logProbs).map(v => ("logprobs", v))
).flatten).toMap
sharedTextParams.flatMap { param =>
getValueOpt(r, param).map { value => param.payloadName -> value }
}.toMap
}
}

Expand Down