Skip to content

Commit

Permalink
Moved validation into chat_azure and removed AzureAuth dependency.
Browse files Browse the repository at this point in the history
  • Loading branch information
SokolovAnatoliy committed Dec 6, 2024
1 parent 0a072b8 commit 7b1aadb
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 42 deletions.
43 changes: 25 additions & 18 deletions R/provider-azure.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ NULL
#' @param api_key The API key to use for authentication. You generally should
#' not supply this directly, but instead set the `AZURE_OPENAI_API_KEY` environment
#' variable.
#' @param token Azure token object of class AzureToken for authentication. This is typically not required for
#' Azure OpenAI API calls, but can be used if your setup requires it. The token object is retrieved using the AzureAuth package.#' Using the token object ensures a refresh method is available for the token.
#' @param azure_token token object of class AzureToken for authentication. This is typically not required for
#' Azure OpenAI API calls, but can be used if your setup requires it. The azure_token object is retrieved using the AzureAuth package.#' Using the azure_token object ensures a refresh method is available for the token.
#' @inheritParams chat_openai
#' @inherit chat_openai return
#' @export
Expand All @@ -30,7 +30,7 @@ chat_azure <- function(endpoint = azure_endpoint(),
system_prompt = NULL,
turns = NULL,
api_key = azure_key(),
token = NULL,
azure_token = NULL,
api_args = list(),
echo = c("none", "text", "all")) {
check_string(endpoint)
Expand All @@ -41,12 +41,29 @@ chat_azure <- function(endpoint = azure_endpoint(),

base_url <- paste0(endpoint, "/openai/deployments/", deployment_id)

if(is.null(azure_token)){
access_token = azure_token
} else if(is_azure_token(azure_token)) {
# uses the token object method to validate the azure_token (for example if it expired)
valid = azure_token$validate()
if(!valid ) {
# uses the token object method to refresh the azure_token
azure_token = azure_token$refresh()
}
# retrieves the actual access token from the azure_token object for further use.
access_token = azure_token$credentials$access_token

} else {
cli::cli_abort("azure_token must be of class <AzureToken> or NULL. Please consider using the AzureAuth package to create a token object.")
return()
}

provider <- ProviderAzure(
base_url = base_url,
endpoint = endpoint,
model = deployment_id,
api_version = api_version,
token = token,
access_token = access_token,
extra_args = api_args,
api_key = api_key
)
Expand All @@ -58,7 +75,7 @@ ProviderAzure <- new_class(
parent = ProviderOpenAI,
properties = list(
api_key = prop_string(),
token = prop_azure_token(),
access_token = prop_string(allow_null = TRUE),
endpoint = prop_string(),
api_version = prop_string()
)
Expand All @@ -85,19 +102,9 @@ method(chat_request, ProviderAzure) <- function(provider,
req <- req_url_path_append(req, "/chat/completions")
req <- req_url_query(req, `api-version` = provider@api_version)
req <- req_headers(req, `api-key` = provider@api_key, .redact = "api-key")
if (!is.null(provider@token)) {
## uses the token object method to validate the token (for example if it expired)
valid = provider@token$validate()
if(!valid ) {
# uses the token object method to refresh the token
token = provider@token$refresh()
access_token = token$credentials$access_token
}
## retrieves the actual access token from the object for further use.
access_token = provider@token$credentials$access_token

req <- req_auth_bearer_token(req, access_token)
}
if (!is.null(provider@access_token)) {
req <- req_auth_bearer_token(req, provider@access_token)
}
req <- req_retry(req, max_tries = 2)
req <- req_error(req, body = function(resp) resp_body_json(resp)$message)

Expand Down
21 changes: 0 additions & 21 deletions R/utils-S7.R
Original file line number Diff line number Diff line change
Expand Up @@ -95,24 +95,3 @@ prop_number_whole <- function(default = NULL, min = NULL, max = NULL, allow_null
)
}

prop_azure_token <- function(allow_null = FALSE) {
force(allow_null)
## The call to AzureAuth ensures that the class name of the Azure Token does not change
class_name = AzureAuth::AzureToken$self$classname
## The returned token object is R6 - not currently supported by the S7::as_class()
## so I made a new S3 class to use validator
AzureToken_class <- new_S3_class(class_name)

new_property(
class = if (allow_null) NULL | AzureToken_class else AzureToken_class,
validator = function(value) {
if (allow_null && is.null(value)) {
return()
}

if (!inherits(value, class_name)) {
paste0("must be an object of class <", class_name, ">, not ", obj_type_friendly(value), ".")
}
}
)
}
8 changes: 8 additions & 0 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,11 @@ dots_named <- function(...) {
x[[length(x) + 1]] <- value
x
}

is_azure_token <- function (object)
{
R6::is.R6(object) && inherits(object, "AzureToken")
}



6 changes: 3 additions & 3 deletions man/chat_azure.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 7b1aadb

Please sign in to comment.