Skip to content

Add EmbeddingValue union type and Base64 support for embeddings #519

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import kotlin.jvm.optionals.getOrNull
class Embedding
private constructor(
private val embedding: JsonField<List<Float>>,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the new EmbeddingValue class supports List<Float>, having both private val embedding: JsonField<List<Float>> and private val embeddingValue: JsonField<EmbeddingValue>? is unnecessary, no?

We can change the underlying data model while keeping backwards compat by implementing the existing methods in terms of the new data. For example, fun embeding() can be implemented as embeddingValue.getRequired("embedding").asFloatList()?

private val embeddingValue: JsonField<EmbeddingValue>?,
private val index: JsonField<Long>,
private val object_: JsonValue,
private val additionalProperties: MutableMap<String, JsonValue>,
Expand All @@ -31,19 +32,52 @@ private constructor(
private constructor(
@JsonProperty("embedding")
@ExcludeMissing
embedding: JsonField<List<Float>> = JsonMissing.of(),
embedding: JsonField<EmbeddingValue> = JsonMissing.of(),
@JsonProperty("index") @ExcludeMissing index: JsonField<Long> = JsonMissing.of(),
@JsonProperty("object") @ExcludeMissing object_: JsonValue = JsonMissing.of(),
) : this(embedding, index, object_, mutableMapOf())
) : this(
JsonMissing.of(), // Legacy embedding field will be populated from embeddingValue
embedding,
index,
object_,
mutableMapOf(),
)

/**
* The embedding vector, which is a list of floats. The length of vector depends on the model as
* listed in the [embedding guide](https://platform.openai.com/docs/guides/embeddings).
*
* Important: When Base64 data is received, it is automatically decoded and returned as
* List<Float>
*
* @throws OpenAIInvalidDataException if the JSON field has an unexpected type or is
* unexpectedly missing or null (e.g. if the server responded with an unexpected value).
*/
fun embedding(): List<Float> =
when {
embeddingValue != null ->
embeddingValue
.getRequired("embedding")
.asFloatList() // Base64→Float auto conversion
!embedding.isMissing() ->
embedding.getRequired("embedding") // Original Float format data
else -> throw OpenAIInvalidDataException("Embedding data is missing")
}

/**
* The embedding data in its original format (either float list or base64 string). This method
* provides efficient access to the embedding data without unnecessary conversions.
*
* @return EmbeddingValue containing the embedding data in its original format
* @throws OpenAIInvalidDataException if the JSON field has an unexpected type or is
* unexpectedly missing or null (e.g. if the server responded with an unexpected value).
*/
fun embedding(): List<Float> = embedding.getRequired("embedding")
fun embeddingValue(): EmbeddingValue =
when {
embeddingValue != null -> embeddingValue.getRequired("embedding")
!embedding.isMissing() -> EmbeddingValue.ofFloatList(embedding.getRequired("embedding"))
else -> throw OpenAIInvalidDataException("Embedding data is missing")
}

/**
* The index of the embedding in the list of embeddings.
Expand Down Expand Up @@ -71,7 +105,15 @@ private constructor(
*
* Unlike [embedding], this method doesn't throw if the JSON field has an unexpected type.
*/
@JsonProperty("embedding") @ExcludeMissing fun _embedding(): JsonField<List<Float>> = embedding
@JsonProperty("embedding")
@ExcludeMissing
fun _embedding(): JsonField<EmbeddingValue> =
when {
embeddingValue != null -> embeddingValue
!embedding.isMissing() ->
JsonField.of(EmbeddingValue.ofFloatList(embedding.getRequired("embedding")))
else -> JsonMissing.of()
}

/**
* Returns the raw JSON value of [index].
Expand Down Expand Up @@ -116,7 +158,12 @@ private constructor(

@JvmSynthetic
internal fun from(embedding: Embedding) = apply {
this.embedding = embedding.embedding.map { it.toMutableList() }
try {
this.embedding = JsonField.of(embedding.embedding().toMutableList())
} catch (e: Exception) {
// Fallback to field-level copying if embedding() method fails
this.embedding = embedding.embedding.map { it.toMutableList() }
}
Comment on lines +161 to +166
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand what this is for. Is this because we have the nullable and non-nullable fields? If so, then I guess this will go away once we implement my other suggestion?

index = embedding.index
object_ = embedding.object_
additionalProperties = embedding.additionalProperties.toMutableMap()
Expand Down Expand Up @@ -212,6 +259,7 @@ private constructor(
fun build(): Embedding =
Embedding(
checkRequired("embedding", embedding).map { it.toImmutable() },
null, // embeddingValue - will be null for builder-created instances
checkRequired("index", index),
object_,
additionalProperties.toMutableMap(),
Expand All @@ -225,7 +273,7 @@ private constructor(
return@apply
}

embedding()
embedding() // This will call the method that returns List<Float>
index()
_object_().let {
if (it != JsonValue.from("embedding")) {
Expand All @@ -250,7 +298,11 @@ private constructor(
*/
@JvmSynthetic
internal fun validity(): Int =
(embedding.asKnown().getOrNull()?.size ?: 0) +
when {
embeddingValue != null -> embeddingValue.asKnown().getOrNull()?.validity() ?: 0
!embedding.isMissing() -> embedding.asKnown().getOrNull()?.size ?: 0
else -> 0
} +
(if (index.asKnown().isPresent) 1 else 0) +
object_.let { if (it == JsonValue.from("embedding")) 1 else 0 }

Expand All @@ -259,15 +311,43 @@ private constructor(
return true
}

return /* spotless:off */ other is Embedding && embedding == other.embedding && index == other.index && object_ == other.object_ && additionalProperties == other.additionalProperties /* spotless:on */
if (other !is Embedding) {
return false
}

return try {
embedding() == other.embedding() &&
index == other.index &&
object_ == other.object_ &&
additionalProperties == other.additionalProperties
} catch (e: Exception) {
// Fallback to field-level comparison if embedding() methods fail
embedding == other.embedding &&
embeddingValue == other.embeddingValue &&
index == other.index &&
object_ == other.object_ &&
additionalProperties == other.additionalProperties
}
}

/* spotless:off */
private val hashCode: Int by lazy { Objects.hash(embedding, index, object_, additionalProperties) }
private val hashCode: Int by lazy {
try {
Objects.hash(embedding(), index, object_, additionalProperties)
} catch (e: Exception) {
// Fallback to field-level hashing if embedding() method fails
Objects.hash(embedding, embeddingValue, index, object_, additionalProperties)
}
}
/* spotless:on */

override fun hashCode(): Int = hashCode

override fun toString() =
"Embedding{embedding=$embedding, index=$index, object_=$object_, additionalProperties=$additionalProperties}"
when {
embeddingValue != null ->
"Embedding{embedding=${try { embedding() } catch (e: Exception) { "[]" }}, index=$index, object_=$object_, additionalProperties=$additionalProperties}"
else ->
"Embedding{embedding=${embedding.asKnown().getOrNull() ?: emptyList<Float>()}, index=$index, object_=$object_, additionalProperties=$additionalProperties}"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ private constructor(
* The format to return the embeddings in. Can be either `float` or
* [`base64`](https://pypi.org/project/pybase64/).
*
* Returns the encoding format that was set (either explicitly or via default) when this
* EmbeddingCreateParams instance was built.
*
* @throws OpenAIInvalidDataException if the JSON field has an unexpected type (e.g. if the
* server responded with an unexpected value).
*/
Expand Down Expand Up @@ -418,12 +421,18 @@ private constructor(
*
* @throws IllegalStateException if any required field is unset.
*/
fun build(): EmbeddingCreateParams =
EmbeddingCreateParams(
fun build(): EmbeddingCreateParams {
// Apply default encoding format if not explicitly set
if (body._encodingFormat().isMissing()) {
body.encodingFormat(EmbeddingDefaults.defaultEncodingFormat)
}
Comment on lines +425 to +428
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't apply the default like this. We can just update the builder field to start out as:

private var encodingFormat: JsonField<EncodingFormat> = JsonField.of(EmbeddingDefaults.defaultEncodingFormat)

Then we also don't need that new internal method


return EmbeddingCreateParams(
body.build(),
additionalHeaders.build(),
additionalQueryParams.build(),
)
}
}

fun _body(): Body = body
Expand Down Expand Up @@ -724,6 +733,12 @@ private constructor(
keys.forEach(::removeAdditionalProperty)
}

/**
* Internal method to check if encodingFormat has been set. Used by the main Builder to
* determine if default should be applied.
*/
internal fun _encodingFormat(): JsonField<EncodingFormat> = encodingFormat

/**
* Returns an immutable instance of [Body].
*
Expand Down
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems pretty overkill to me. I think we can just swap the default to base64 and people can set the encoding explicitly on the params object if they want floats over the wire?

In that case we can just delete this class and inline EncodingFormat.BASE64 in the params builder default

Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
// File generated from our OpenAPI spec by Stainless.

package com.openai.models.embeddings

import com.openai.models.embeddings.EmbeddingCreateParams.EncodingFormat

/**
* Configuration object for default embedding behavior. This allows users to change the default
* encoding format globally.
*
* By default, Base64 encoding is used for optimal performance and reduced network bandwidth. Users
* can explicitly choose float encoding when direct float access is needed.
*/
object EmbeddingDefaults {

@JvmStatic
@get:JvmName("getDefaultEncodingFormat")
@set:JvmName("setDefaultEncodingFormat")
var defaultEncodingFormat: EncodingFormat = EncodingFormat.BASE64 // Default is Base64
private set

/**
* Set the default encoding format for embeddings. This will be applied when no explicit format
* is specified in EmbeddingCreateParams.
*
* @param format the encoding format to use as default
*/
@JvmStatic
fun setDefaultEncodingFormat(format: EncodingFormat) {
defaultEncodingFormat = format
}

/**
* Reset the default encoding format to Base64 (the recommended default). Base64 encoding
* provides better performance and reduced network bandwidth usage.
*/
@JvmStatic
fun resetToDefaults() {
defaultEncodingFormat = EncodingFormat.BASE64
}

/**
* Configure the system to use float encoding as default. This is primarily for backward
* compatibility scenarios. Note: Float encoding uses more network bandwidth and may impact
* performance. For most use cases, the default base64 encoding is recommended.
*/
@JvmStatic
fun enableLegacyFloatDefaults() {
defaultEncodingFormat = EncodingFormat.FLOAT
}

/** Returns true if the current default encoding format is BASE64. */
@JvmStatic fun isUsingBase64Defaults(): Boolean = defaultEncodingFormat == EncodingFormat.BASE64

/** Returns true if the current default encoding format is FLOAT. */
@JvmStatic fun isUsingFloatDefaults(): Boolean = defaultEncodingFormat == EncodingFormat.FLOAT
}
Loading