Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Align proto primitives #187

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .changes/common/carpenter-beggar-creator-celery.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"type":"MAJOR","changes":["Better align protos in regards to primitive defaults."]}
1 change: 1 addition & 0 deletions .changes/generativeai/breath-brush-achiever-boat.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"type":"MAJOR","changes":["Better align protos in regards to primitive defaults."]}
Original file line number Diff line number Diff line change
Expand Up @@ -235,19 +235,19 @@ private suspend fun validateResponse(response: HttpResponse) {
if (message.contains("quota")) {
throw QuotaExceededException(message)
}
if (error.details?.any { "SERVICE_DISABLED" == it.reason } == true) {
if (error.details.any { "SERVICE_DISABLED" == it.reason }) {
throw ServiceDisabledException(message)
}
throw ServerException(message)
}

private fun GenerateContentResponse.validate() = apply {
if ((candidates?.isEmpty() != false) && promptFeedback == null) {
if (candidates.isEmpty() && promptFeedback == null) {
throw SerializationException("Error deserializing response, found no valid fields")
}
promptFeedback?.blockReason?.let { throw PromptBlockedException(this) }
candidates
?.mapNotNull { it.finishReason }
?.firstOrNull { it != FinishReason.STOP }
.map { it.finishReason }
.firstOrNull { it != FinishReason.STOP }
?.let { throw ResponseStoppedException(this) }
}
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ class InvalidStateException(message: String, cause: Throwable? = null) :
*/
class ResponseStoppedException(val response: GenerateContentResponse, cause: Throwable? = null) :
GoogleGenerativeAIException(
"Content generation stopped. Reason: ${response.candidates?.first()?.finishReason?.name}",
"Content generation stopped. Reason: ${response.candidates.first().finishReason?.name}",
cause,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
* limitations under the License.
*/

@file:OptIn(ExperimentalSerializationApi::class)

package com.google.ai.client.generativeai.common

import com.google.ai.client.generativeai.common.client.GenerationConfig
Expand All @@ -22,45 +24,41 @@ import com.google.ai.client.generativeai.common.client.ToolConfig
import com.google.ai.client.generativeai.common.shared.Content
import com.google.ai.client.generativeai.common.shared.SafetySetting
import com.google.ai.client.generativeai.common.util.fullModelName
import kotlinx.serialization.SerialName
import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.Serializable

sealed interface Request

@Serializable
data class GenerateContentRequest(
val model: String? = null,
val model: String,
val contents: List<Content>,
@SerialName("safety_settings") val safetySettings: List<SafetySetting>? = null,
@SerialName("generation_config") val generationConfig: GenerationConfig? = null,
val tools: List<Tool>? = null,
@SerialName("tool_config") var toolConfig: ToolConfig? = null,
@SerialName("system_instruction") val systemInstruction: Content? = null,
val safetySettings: List<SafetySetting> = emptyList(),
val generationConfig: GenerationConfig? = null,
val tools: List<Tool> = emptyList(),
val toolConfig: ToolConfig? = null,
val systemInstruction: Content? = null,
) : Request

@Serializable
data class CountTokensRequest(
val model: String,
val contents: List<Content> = emptyList(),
val tools: List<Tool> = emptyList(),
val generateContentRequest: GenerateContentRequest? = null,
val model: String? = null,
val contents: List<Content>? = null,
val tools: List<Tool>? = null,
@SerialName("system_instruction") val systemInstruction: Content? = null,
val systemInstruction: Content? = null,
) : Request {
companion object {
fun forGenAI(generateContentRequest: GenerateContentRequest) =
CountTokensRequest(
generateContentRequest =
generateContentRequest.model?.let {
generateContentRequest.copy(model = fullModelName(it))
} ?: generateContentRequest
)
fun forGenAI(request: GenerateContentRequest) =
CountTokensRequest(fullModelName(request.model), request.contents, emptyList(), request)

fun forVertexAI(generateContentRequest: GenerateContentRequest) =
fun forVertexAI(request: GenerateContentRequest) =
CountTokensRequest(
model = generateContentRequest.model?.let { fullModelName(it) },
contents = generateContentRequest.contents,
tools = generateContentRequest.tools,
systemInstruction = generateContentRequest.systemInstruction,
fullModelName(request.model),
request.contents,
request.tools,
null,
request.systemInstruction,
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,20 @@ sealed interface Response

@Serializable
data class GenerateContentResponse(
val candidates: List<Candidate>? = null,
val candidates: List<Candidate> = emptyList(),
val promptFeedback: PromptFeedback? = null,
val usageMetadata: UsageMetadata? = null,
) : Response

@Serializable
data class CountTokensResponse(val totalTokens: Int, val totalBillableCharacters: Int? = null) :
data class CountTokensResponse(val totalTokens: Int = 0, val totalBillableCharacters: Int = 0) :
Response

@Serializable data class GRpcErrorResponse(val error: GRpcError) : Response

@Serializable
data class UsageMetadata(
val promptTokenCount: Int? = null,
val candidatesTokenCount: Int? = null,
val totalTokenCount: Int? = null,
val promptTokenCount: Int = 0,
val candidatesTokenCount: Int = 0,
val totalTokenCount: Int = 0,
)
Original file line number Diff line number Diff line change
Expand Up @@ -22,32 +22,30 @@ import kotlinx.serialization.json.JsonObject

@Serializable
data class GenerationConfig(
val temperature: Float?,
@SerialName("top_p") val topP: Float?,
@SerialName("top_k") val topK: Int?,
@SerialName("candidate_count") val candidateCount: Int?,
@SerialName("max_output_tokens") val maxOutputTokens: Int?,
@SerialName("stop_sequences") val stopSequences: List<String>?,
@SerialName("response_mime_type") val responseMimeType: String? = null,
@SerialName("presence_penalty") val presencePenalty: Float? = null,
@SerialName("frequency_penalty") val frequencyPenalty: Float? = null,
@SerialName("response_schema") val responseSchema: Schema? = null,
val temperature: Float = 0f,
val topP: Float = 0f,
val topK: Int = 0,
val candidateCount: Int = 0,
val maxOutputTokens: Int = 0,
val stopSequences: List<String> = emptyList(),
val responseMimeType: String = "",
val presencePenalty: Float = 0f,
val frequencyPenalty: Float = 0f,
val responseSchema: Schema? = null,
)

@Serializable
data class Tool(
val functionDeclarations: List<FunctionDeclaration>? = null,
val functionDeclarations: List<FunctionDeclaration> = emptyList(),
// This is a json object because it is not possible to make a data class with no parameters.
val codeExecution: JsonObject? = null,
)

@Serializable
data class ToolConfig(
@SerialName("function_calling_config") val functionCallingConfig: FunctionCallingConfig
)
data class ToolConfig(val functionCallingConfig: FunctionCallingConfig = FunctionCallingConfig())

@Serializable
data class FunctionCallingConfig(val mode: Mode) {
data class FunctionCallingConfig(val mode: Mode? = null) {
@Serializable
enum class Mode {
@SerialName("MODE_UNSPECIFIED") UNSPECIFIED,
Expand All @@ -58,16 +56,20 @@ data class FunctionCallingConfig(val mode: Mode) {
}

@Serializable
data class FunctionDeclaration(val name: String, val description: String, val parameters: Schema)
data class FunctionDeclaration(
val name: String,
val description: String,
val parameters: Schema? = null,
)

@Serializable
data class Schema(
val type: String,
val description: String? = null,
val format: String? = null,
val nullable: Boolean? = false,
val enum: List<String>? = null,
val properties: Map<String, Schema>? = null,
val required: List<String>? = null,
val description: String = "",
val format: String = "",
val nullable: Boolean = false,
val enum: List<String> = emptyList(),
val properties: Map<String, Schema> = emptyMap(),
val required: List<String> = emptyList(),
val items: Schema? = null,
)
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
* limitations under the License.
*/

@file:OptIn(ExperimentalSerializationApi::class)

package com.google.ai.client.generativeai.common.server

import com.google.ai.client.generativeai.common.shared.Content
Expand All @@ -37,7 +39,7 @@ object FinishReasonSerializer :
@Serializable
data class PromptFeedback(
val blockReason: BlockReason? = null,
val safetyRatings: List<SafetyRating>? = null,
val safetyRatings: List<SafetyRating> = emptyList(),
)

@Serializable(BlockReasonSerializer::class)
Expand All @@ -52,59 +54,51 @@ enum class BlockReason {
data class Candidate(
val content: Content? = null,
val finishReason: FinishReason? = null,
val safetyRatings: List<SafetyRating>? = null,
val safetyRatings: List<SafetyRating> = emptyList(),
val citationMetadata: CitationMetadata? = null,
val groundingMetadata: GroundingMetadata? = null,
)

@Serializable
data class CitationMetadata
@OptIn(ExperimentalSerializationApi::class)
constructor(@JsonNames("citations") val citationSources: List<CitationSources>)
data class CitationMetadata(
@JsonNames("citations") val citationSources: List<CitationSources> = emptyList()
)

@Serializable
data class CitationSources(
val startIndex: Int = 0,
val endIndex: Int,
val uri: String,
val license: String? = null,
val endIndex: Int = 0,
val uri: String = "",
val license: String = "",
)

@Serializable
data class SafetyRating(
val category: HarmCategory,
val probability: HarmProbability,
val blocked: Boolean? = null, // TODO(): any reason not to default to false?
val probabilityScore: Float? = null,
val blocked: Boolean = false,
val probabilityScore: Float = 0f,
val severity: HarmSeverity? = null,
val severityScore: Float? = null,
val severityScore: Float = 0f,
)

@Serializable
data class GroundingMetadata(
@SerialName("web_search_queries") val webSearchQueries: List<String>?,
@SerialName("search_entry_point") val searchEntryPoint: SearchEntryPoint?,
@SerialName("retrieval_queries") val retrievalQueries: List<String>?,
@SerialName("grounding_attribution") val groundingAttribution: List<GroundingAttribution>?,
val webSearchQueries: List<String> = emptyList(),
val searchEntryPoint: SearchEntryPoint? = null,
val retrievalQueries: List<String> = emptyList(),
val groundingAttribution: List<GroundingAttribution> = emptyList(),
)

@Serializable
data class SearchEntryPoint(
@SerialName("rendered_content") val renderedContent: String?,
@SerialName("sdk_blob") val sdkBlob: String?,
)
data class SearchEntryPoint(val renderedContent: String = "", val sdkBlob: String = "")

// TODO() Has a different definition for labs vs vertex. May need to split into diff types in future
// (when labs supports it)
@Serializable
data class GroundingAttribution(
val segment: Segment,
@SerialName("confidence_score") val confidenceScore: Float?,
)
data class GroundingAttribution(val segment: Segment, val confidenceScore: Float = 0f)

@Serializable
data class Segment(
@SerialName("start_index") val startIndex: Int,
@SerialName("end_index") val endIndex: Int,
)
@Serializable data class Segment(val startIndex: Int = 0, val endIndex: Int = 0)

@Serializable(HarmProbabilitySerializer::class)
enum class HarmProbability {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,13 @@ typealias Base64 = String

@ExperimentalSerializationApi
@Serializable
data class Content(@EncodeDefault val role: String? = "user", val parts: List<Part>)
data class Content(@EncodeDefault val role: String = "", val parts: List<Part>)
daymxn marked this conversation as resolved.
Show resolved Hide resolved

@Serializable(PartSerializer::class) sealed interface Part

@Serializable data class TextPart(val text: String) : Part
@Serializable data class TextPart(val text: String = "") : Part

@Serializable data class BlobPart(@SerialName("inline_data") val inlineData: Blob) : Part
@Serializable data class BlobPart(val inlineData: Blob) : Part
rlazo marked this conversation as resolved.
Show resolved Hide resolved

@Serializable data class FunctionCallPart(val functionCall: FunctionCall) : Part

Expand All @@ -64,17 +64,14 @@ data class CodeExecutionResultPart(val codeExecutionResult: CodeExecutionResult)

@Serializable data class FunctionResponse(val name: String, val response: JsonObject)
rlazo marked this conversation as resolved.
Show resolved Hide resolved

@Serializable data class FunctionCall(val name: String, val args: Map<String, String?>)
@Serializable
data class FunctionCall(val name: String, val args: Map<String, String?> = emptyMap())
rlazo marked this conversation as resolved.
Show resolved Hide resolved

@Serializable data class FileDataPart(@SerialName("file_data") val fileData: FileData) : Part
@Serializable data class FileDataPart(val fileData: FileData) : Part

@Serializable
data class FileData(
@SerialName("mime_type") val mimeType: String,
@SerialName("file_uri") val fileUri: String,
)
@Serializable data class FileData(val mimeType: String, val fileUri: String)

@Serializable data class Blob(@SerialName("mime_type") val mimeType: String, val data: Base64)
@Serializable data class Blob(val mimeType: String, val data: Base64)

@Serializable data class ExecutableCode(val language: String, val code: String)

Expand Down
Loading
Loading