-
Notifications
You must be signed in to change notification settings - Fork 133
Add EmbeddingValue union type and Base64 support for embeddings #519
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,6 +22,7 @@ import kotlin.jvm.optionals.getOrNull | |
class Embedding | ||
private constructor( | ||
private val embedding: JsonField<List<Float>>, | ||
private val embeddingValue: JsonField<EmbeddingValue>?, | ||
private val index: JsonField<Long>, | ||
private val object_: JsonValue, | ||
private val additionalProperties: MutableMap<String, JsonValue>, | ||
|
@@ -31,19 +32,52 @@ private constructor( | |
private constructor( | ||
@JsonProperty("embedding") | ||
@ExcludeMissing | ||
embedding: JsonField<List<Float>> = JsonMissing.of(), | ||
embedding: JsonField<EmbeddingValue> = JsonMissing.of(), | ||
@JsonProperty("index") @ExcludeMissing index: JsonField<Long> = JsonMissing.of(), | ||
@JsonProperty("object") @ExcludeMissing object_: JsonValue = JsonMissing.of(), | ||
) : this(embedding, index, object_, mutableMapOf()) | ||
) : this( | ||
JsonMissing.of(), // Legacy embedding field will be populated from embeddingValue | ||
embedding, | ||
index, | ||
object_, | ||
mutableMapOf(), | ||
) | ||
|
||
/** | ||
* The embedding vector, which is a list of floats. The length of vector depends on the model as | ||
* listed in the [embedding guide](https://platform.openai.com/docs/guides/embeddings). | ||
* | ||
* Important: When Base64 data is received, it is automatically decoded and returned as | ||
* List<Float> | ||
* | ||
* @throws OpenAIInvalidDataException if the JSON field has an unexpected type or is | ||
* unexpectedly missing or null (e.g. if the server responded with an unexpected value). | ||
*/ | ||
fun embedding(): List<Float> = | ||
when { | ||
embeddingValue != null -> | ||
embeddingValue | ||
.getRequired("embedding") | ||
.asFloatList() // Base64→Float auto conversion | ||
!embedding.isMissing() -> | ||
embedding.getRequired("embedding") // Original Float format data | ||
else -> throw OpenAIInvalidDataException("Embedding data is missing") | ||
} | ||
|
||
/** | ||
* The embedding data in its original format (either float list or base64 string). This method | ||
* provides efficient access to the embedding data without unnecessary conversions. | ||
* | ||
* @return EmbeddingValue containing the embedding data in its original format | ||
* @throws OpenAIInvalidDataException if the JSON field has an unexpected type or is | ||
* unexpectedly missing or null (e.g. if the server responded with an unexpected value). | ||
*/ | ||
fun embedding(): List<Float> = embedding.getRequired("embedding") | ||
fun embeddingValue(): EmbeddingValue = | ||
when { | ||
embeddingValue != null -> embeddingValue.getRequired("embedding") | ||
!embedding.isMissing() -> EmbeddingValue.ofFloatList(embedding.getRequired("embedding")) | ||
else -> throw OpenAIInvalidDataException("Embedding data is missing") | ||
} | ||
|
||
/** | ||
* The index of the embedding in the list of embeddings. | ||
|
@@ -71,7 +105,15 @@ private constructor( | |
* | ||
* Unlike [embedding], this method doesn't throw if the JSON field has an unexpected type. | ||
*/ | ||
@JsonProperty("embedding") @ExcludeMissing fun _embedding(): JsonField<List<Float>> = embedding | ||
@JsonProperty("embedding") | ||
@ExcludeMissing | ||
fun _embedding(): JsonField<EmbeddingValue> = | ||
when { | ||
embeddingValue != null -> embeddingValue | ||
!embedding.isMissing() -> | ||
JsonField.of(EmbeddingValue.ofFloatList(embedding.getRequired("embedding"))) | ||
else -> JsonMissing.of() | ||
} | ||
|
||
/** | ||
* Returns the raw JSON value of [index]. | ||
|
@@ -116,7 +158,12 @@ private constructor( | |
|
||
@JvmSynthetic | ||
internal fun from(embedding: Embedding) = apply { | ||
this.embedding = embedding.embedding.map { it.toMutableList() } | ||
try { | ||
this.embedding = JsonField.of(embedding.embedding().toMutableList()) | ||
} catch (e: Exception) { | ||
// Fallback to field-level copying if embedding() method fails | ||
this.embedding = embedding.embedding.map { it.toMutableList() } | ||
} | ||
Comment on lines
+161
to
+166
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't understand what this is for. Is this because we have the nullable and non-nullable fields? If so, then I guess this will go away once we implement my other suggestion? |
||
index = embedding.index | ||
object_ = embedding.object_ | ||
additionalProperties = embedding.additionalProperties.toMutableMap() | ||
|
@@ -212,6 +259,7 @@ private constructor( | |
fun build(): Embedding = | ||
Embedding( | ||
checkRequired("embedding", embedding).map { it.toImmutable() }, | ||
null, // embeddingValue - will be null for builder-created instances | ||
checkRequired("index", index), | ||
object_, | ||
additionalProperties.toMutableMap(), | ||
|
@@ -225,7 +273,7 @@ private constructor( | |
return@apply | ||
} | ||
|
||
embedding() | ||
embedding() // This will call the method that returns List<Float> | ||
index() | ||
_object_().let { | ||
if (it != JsonValue.from("embedding")) { | ||
|
@@ -250,7 +298,11 @@ private constructor( | |
*/ | ||
@JvmSynthetic | ||
internal fun validity(): Int = | ||
(embedding.asKnown().getOrNull()?.size ?: 0) + | ||
when { | ||
embeddingValue != null -> embeddingValue.asKnown().getOrNull()?.validity() ?: 0 | ||
!embedding.isMissing() -> embedding.asKnown().getOrNull()?.size ?: 0 | ||
else -> 0 | ||
} + | ||
(if (index.asKnown().isPresent) 1 else 0) + | ||
object_.let { if (it == JsonValue.from("embedding")) 1 else 0 } | ||
|
||
|
@@ -259,15 +311,43 @@ private constructor( | |
return true | ||
} | ||
|
||
return /* spotless:off */ other is Embedding && embedding == other.embedding && index == other.index && object_ == other.object_ && additionalProperties == other.additionalProperties /* spotless:on */ | ||
if (other !is Embedding) { | ||
return false | ||
} | ||
|
||
return try { | ||
embedding() == other.embedding() && | ||
index == other.index && | ||
object_ == other.object_ && | ||
additionalProperties == other.additionalProperties | ||
} catch (e: Exception) { | ||
// Fallback to field-level comparison if embedding() methods fail | ||
embedding == other.embedding && | ||
embeddingValue == other.embeddingValue && | ||
index == other.index && | ||
object_ == other.object_ && | ||
additionalProperties == other.additionalProperties | ||
} | ||
} | ||
|
||
/* spotless:off */ | ||
private val hashCode: Int by lazy { Objects.hash(embedding, index, object_, additionalProperties) } | ||
private val hashCode: Int by lazy { | ||
try { | ||
Objects.hash(embedding(), index, object_, additionalProperties) | ||
} catch (e: Exception) { | ||
// Fallback to field-level hashing if embedding() method fails | ||
Objects.hash(embedding, embeddingValue, index, object_, additionalProperties) | ||
} | ||
} | ||
/* spotless:on */ | ||
|
||
override fun hashCode(): Int = hashCode | ||
|
||
override fun toString() = | ||
"Embedding{embedding=$embedding, index=$index, object_=$object_, additionalProperties=$additionalProperties}" | ||
when { | ||
embeddingValue != null -> | ||
"Embedding{embedding=${try { embedding() } catch (e: Exception) { "[]" }}, index=$index, object_=$object_, additionalProperties=$additionalProperties}" | ||
else -> | ||
"Embedding{embedding=${embedding.asKnown().getOrNull() ?: emptyList<Float>()}, index=$index, object_=$object_, additionalProperties=$additionalProperties}" | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -78,6 +78,9 @@ private constructor( | |
* The format to return the embeddings in. Can be either `float` or | ||
* [`base64`](https://pypi.org/project/pybase64/). | ||
* | ||
* Returns the encoding format that was set (either explicitly or via default) when this | ||
* EmbeddingCreateParams instance was built. | ||
* | ||
* @throws OpenAIInvalidDataException if the JSON field has an unexpected type (e.g. if the | ||
* server responded with an unexpected value). | ||
*/ | ||
|
@@ -418,12 +421,18 @@ private constructor( | |
* | ||
* @throws IllegalStateException if any required field is unset. | ||
*/ | ||
fun build(): EmbeddingCreateParams = | ||
EmbeddingCreateParams( | ||
fun build(): EmbeddingCreateParams { | ||
// Apply default encoding format if not explicitly set | ||
if (body._encodingFormat().isMissing()) { | ||
body.encodingFormat(EmbeddingDefaults.defaultEncodingFormat) | ||
} | ||
Comment on lines
+425
to
+428
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We shouldn't apply the default like this. We can just update the builder field to start out as:
Then we also don't need that new internal method |
||
|
||
return EmbeddingCreateParams( | ||
body.build(), | ||
additionalHeaders.build(), | ||
additionalQueryParams.build(), | ||
) | ||
} | ||
} | ||
|
||
fun _body(): Body = body | ||
|
@@ -724,6 +733,12 @@ private constructor( | |
keys.forEach(::removeAdditionalProperty) | ||
} | ||
|
||
/** | ||
* Internal method to check if encodingFormat has been set. Used by the main Builder to | ||
* determine if default should be applied. | ||
*/ | ||
internal fun _encodingFormat(): JsonField<EncodingFormat> = encodingFormat | ||
|
||
/** | ||
* Returns an immutable instance of [Body]. | ||
* | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems pretty overkill to me. I think we can just swap the default to base64 and people can set the encoding explicitly on the params object if they want floats over the wire? In that case we can just delete this class and inline |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
// File generated from our OpenAPI spec by Stainless. | ||
|
||
package com.openai.models.embeddings | ||
|
||
import com.openai.models.embeddings.EmbeddingCreateParams.EncodingFormat | ||
|
||
/** | ||
* Configuration object for default embedding behavior. This allows users to change the default | ||
* encoding format globally. | ||
* | ||
* By default, Base64 encoding is used for optimal performance and reduced network bandwidth. Users | ||
* can explicitly choose float encoding when direct float access is needed. | ||
*/ | ||
object EmbeddingDefaults { | ||
|
||
@JvmStatic | ||
@get:JvmName("getDefaultEncodingFormat") | ||
@set:JvmName("setDefaultEncodingFormat") | ||
var defaultEncodingFormat: EncodingFormat = EncodingFormat.BASE64 // Default is Base64 | ||
private set | ||
|
||
/** | ||
* Set the default encoding format for embeddings. This will be applied when no explicit format | ||
* is specified in EmbeddingCreateParams. | ||
* | ||
* @param format the encoding format to use as default | ||
*/ | ||
@JvmStatic | ||
fun setDefaultEncodingFormat(format: EncodingFormat) { | ||
defaultEncodingFormat = format | ||
} | ||
|
||
/** | ||
* Reset the default encoding format to Base64 (the recommended default). Base64 encoding | ||
* provides better performance and reduced network bandwidth usage. | ||
*/ | ||
@JvmStatic | ||
fun resetToDefaults() { | ||
defaultEncodingFormat = EncodingFormat.BASE64 | ||
} | ||
|
||
/** | ||
* Configure the system to use float encoding as default. This is primarily for backward | ||
* compatibility scenarios. Note: Float encoding uses more network bandwidth and may impact | ||
* performance. For most use cases, the default base64 encoding is recommended. | ||
*/ | ||
@JvmStatic | ||
fun enableLegacyFloatDefaults() { | ||
defaultEncodingFormat = EncodingFormat.FLOAT | ||
} | ||
|
||
/** Returns true if the current default encoding format is BASE64. */ | ||
@JvmStatic fun isUsingBase64Defaults(): Boolean = defaultEncodingFormat == EncodingFormat.BASE64 | ||
|
||
/** Returns true if the current default encoding format is FLOAT. */ | ||
@JvmStatic fun isUsingFloatDefaults(): Boolean = defaultEncodingFormat == EncodingFormat.FLOAT | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since the new
EmbeddingValue
class supportsList<Float>
, having bothprivate val embedding: JsonField<List<Float>>
andprivate val embeddingValue: JsonField<EmbeddingValue>?
is unnecessary, no?We can change the underlying data model while keeping backwards compat by implementing the existing methods in terms of the new data. For example,
fun embeding()
can be implemented asembeddingValue.getRequired("embedding").asFloatList()
?