Skip to content

Commit

Permalink
Add support for token object for ProviderAzure. Fixes #195
Browse files Browse the repository at this point in the history
  • Loading branch information
SokolovAnatoliy committed Dec 5, 2024
1 parent 2c1d675 commit 0a072b8
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 7 deletions.
20 changes: 15 additions & 5 deletions R/provider-azure.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ NULL
#' @param api_key The API key to use for authentication. You generally should
#' not supply this directly, but instead set the `AZURE_OPENAI_API_KEY` environment
#' variable.
#' @param token Azure token for authentication. This is typically not required for
#' Azure OpenAI API calls, but can be used if your setup requires it.
#' @param token Azure token object of class AzureToken for authentication. This is typically not required for
#' Azure OpenAI API calls, but can be used if your setup requires it. The token object is retrieved using the AzureAuth package.#' Using the token object ensures a refresh method is available for the token.
#' @inheritParams chat_openai
#' @inherit chat_openai return
#' @export
Expand Down Expand Up @@ -58,7 +58,7 @@ ProviderAzure <- new_class(
parent = ProviderOpenAI,
properties = list(
api_key = prop_string(),
token = prop_string(allow_null = TRUE),
token = prop_azure_token(),
endpoint = prop_string(),
api_version = prop_string()
)
Expand Down Expand Up @@ -86,8 +86,18 @@ method(chat_request, ProviderAzure) <- function(provider,
req <- req_url_query(req, `api-version` = provider@api_version)
req <- req_headers(req, `api-key` = provider@api_key, .redact = "api-key")
if (!is.null(provider@token)) {
req <- req_auth_bearer_token(req, provider@token)
}
## uses the token object method to validate the token (for example if it expired)
valid = provider@token$validate()
if(!valid ) {
# uses the token object method to refresh the token
token = provider@token$refresh()
access_token = token$credentials$access_token
}
## retrieves the actual access token from the object for further use.
access_token = provider@token$credentials$access_token

req <- req_auth_bearer_token(req, access_token)
}
req <- req_retry(req, max_tries = 2)
req <- req_error(req, body = function(resp) resp_body_json(resp)$message)

Expand Down
22 changes: 22 additions & 0 deletions R/utils-S7.R
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,25 @@ prop_number_whole <- function(default = NULL, min = NULL, max = NULL, allow_null
}
)
}

prop_azure_token <- function(allow_null = FALSE) {
force(allow_null)
## The call to AzureAuth ensures that the class name of the Azure Token does not change
class_name = AzureAuth::AzureToken$self$classname
## The returned token object is R6 - not currently supported by the S7::as_class()
## so I made a new S3 class to use validator
AzureToken_class <- new_S3_class(class_name)

new_property(
class = if (allow_null) NULL | AzureToken_class else AzureToken_class,
validator = function(value) {
if (allow_null && is.null(value)) {
return()
}

if (!inherits(value, class_name)) {
paste0("must be an object of class <", class_name, ">, not ", obj_type_friendly(value), ".")
}
}
)
}
4 changes: 2 additions & 2 deletions man/chat_azure.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 0a072b8

Please sign in to comment.