Skip to content

Commit

Permalink
Implement Gemini-specific chunk merging logic
Browse files Browse the repository at this point in the history
  • Loading branch information
jcheng5 committed Dec 6, 2024
1 parent 2c1d675 commit 2504ae3
Showing 1 changed file with 156 additions and 1 deletion.
157 changes: 156 additions & 1 deletion R/provider-gemini.R
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,11 @@ method(stream_text, ProviderGemini) <- function(provider, event) {
event$candidates[[1]]$content$parts[[1]]$text
}
method(stream_merge_chunks, ProviderGemini) <- function(provider, result, chunk) {
info <<- list(provider, result, chunk)
if (is.null(result)) {
chunk
} else {
merge_dicts(result, chunk)
merge_gemini_chunks(result, chunk)
}
}
method(value_turn, ProviderGemini) <- function(provider, result, has_type = FALSE) {
Expand Down Expand Up @@ -249,3 +250,157 @@ method(as_json, list(ProviderGemini, TypeObject)) <- function(provider, x) {
required = as.list(names2(x@properties)[required])
))
}

# Gemini-specific merge logic --------------------------------------------------

merge_first <- function() {
function(left, right) {
left
}
}

merge_last <- function() {
function(left, right) {
right
}
}

merge_last_or_null <- merge_last

merge_identical <- function() {
function(left, right) {
if (!identical(left, right)) {
stop("Expected identical values, but got ", deparse(left), " and ", deparse(right))
}
left
}
}

merge_any_or_empty <- function() {
function(left, right) {
if (!is.null(left) && nzchar(left)) {
left
} else if (!is.null(right) && nzchar(right)) {
right
} else {
""
}
}
}

merge_concatenate <- function() {
function(left, right) {
# TODO: left and right should be NULL or single-element character vectors
paste0(left, right)
}
}

merge_safety_ratings <- function() {
function(left, right) {
# TODO: https://github.com/google-gemini/generative-ai-python/blob/b8772ed1424a080911151b354764d76a0e7af2af/google/generativeai/types/generation_types.py#L238
}
}

merge_optional <- function(merge_func) {
function(left, right) {
if (is.null(left) && is.null(right)) {
NULL
} else {
merge_func(left, right)
}
}
}

merge_by_spec <- function(...) {
spec <- list(...)
function(left, right) {
# TODO: left and right should be named lists
stopifnot(all(nzchar(names(spec))))
mapply(names(spec), spec, FUN = function(key, value) {
value(left[[key]], right[[key]])
}, USE.NAMES = TRUE, SIMPLIFY = FALSE)
}
}

merge_indexed_list <- function(...) {
function(left, right) {
ensure_indices <- function(lst) {
for (i in seq_len(length(lst))) {
if (is.null(lst[[i]][["index"]])) {
lst[[i]]$index <- i - 1L
}
}
lst
}
# TODO: We shouldn't need to do this--why don't we see .index??cry
left <- ensure_indices(left)
right <- ensure_indices(right)
# left and right are lists of objects with [["index"]]
# We need to find the elements that have matching indices and merge them
left_indices <- vapply(left, `[[`, integer(1), "index")
right_indices <- vapply(right, `[[`, integer(1), "index")
lapply(sort(unique(c(left_indices, right_indices))), function(index) {
left_item <- left[[which(left_indices == index)]]
right_item <- right[[which(right_indices == index)]]
if (is.null(left_item)) {
right_item
} else if (is.null(right_item)) {
left_item
} else {
merge_by_spec(...)(left_item, right_item)
}
})
}
}

merge_parts <- function(...) {
function(left, right) {
if (length(left) == 0) {
right
} else if (length(right) == 0) {
left
} else {
# Can we merge the last left and first right?
last_left <- tail(left, 1)[[1]]
first_right <- head(right, 1)[[1]]
if (!identical(names(last_left), names(first_right))) {
# Nothing to merge
c(left, right)
} else {
# Merge the last left and first right
result <- merge_by_spec(...)(last_left, first_right)
# Drop NULL properties
result <- result[!vapply(result, is.null, logical(1))]
# Put everything back together
c(head(left, -1), list(result), tail(right, -1))
}
}
}
}

merge_gemini_chunks <- merge_by_spec(
candidates = merge_indexed_list(
index = merge_identical(),
content = merge_by_spec(
role = merge_any_or_empty(),
parts = merge_parts(
text = merge_optional(merge_concatenate()),
executable_code = merge_optional(merge_by_spec(
language = merge_first(),
code = merge_concatenate()
)),
code_execution_result = merge_optional(merge_by_spec(
outcome = merge_last(),
output = merge_concatenate()
))
)
),
finish_reason = merge_last(),
safety_ratings = merge_safety_ratings(),
citation_metadata = merge_last(),
token_count = merge_last()
),
prompt_feedback = merge_last(),
usage_metadata = merge_last_or_null(),
model_version = merge_last_or_null()
)

0 comments on commit 2504ae3

Please sign in to comment.