@@ -185,6 +185,7 @@ def __init__(
185185 client_credentials = None ,
186186 client_scopes = None ,
187187 ):
188+ logging .debug ("Initatializing auth.." )
188189 self .endpoint = remove_trailing_whitespace_and_slashes_in_url (endpoint )
189190 # note - `_refresh_token` is not actually a JWT refresh token - it's a
190191 # gen3 api key with a token as the "api_key" property
@@ -280,6 +281,9 @@ def __init__(
280281 )
281282 ):
282283 try :
284+ logging .debug (
285+ "Switch to using WTS and set external WTS host url.."
286+ )
283287 self ._use_wts = True
284288 self ._external_wts_host = (
285289 endpoint_from_token (self ._refresh_token ["api_key" ])
@@ -479,23 +483,72 @@ def get_access_token_from_wts(self, endpoint=None):
479483 in the given namespace. If idp is not set, then default to "local"
480484 """
481485 # attempt to get a token from the workspace-token-service
486+ logging .debug ("getting access token from wts.." )
482487 auth_url = get_wts_endpoint (self ._wts_namespace ) + "/token/"
483- if self ._wts_idp and self ._wts_idp != "local" :
484- auth_url += "?idp={}" .format (self ._wts_idp )
485488
486- try :
487- resp = requests .get (auth_url )
488- if (resp and resp .status_code == 200 ) or (not self ._external_wts_host ):
489- return _handle_access_token_response (resp , "token" )
490- except Exception as e :
491- if not self ._external_wts_host :
492- raise e
493- else :
494- # Try to obtain token from external wts
495- pass
489+ # If non "local" idp value exists, append to auth url
490+ # If user specified endpoint value, then first attempt to determine idp value.
491+ if self .endpoint or (self ._wts_idp and self ._wts_idp != "local" ):
492+ # If user supplied endpoint value and not idp, figure out the idp value
493+ if self .endpoint :
494+ logging .debug (
495+ "First try to use the local WTS to figure out idp name for the supplied endpoint.."
496+ )
497+ try :
498+ provider_List = get_wts_idps (self ._wts_namespace )
499+ matchProviders = list (
500+ filter (
501+ lambda provider : provider ["base_url" ] == endpoint ,
502+ provider_List ["providers" ],
503+ )
504+ )
505+ if len (matchProviders ) == 1 :
506+ logging .debug ("Found matching idp from local WTS." )
507+ self ._wts_idp = matchProviders [0 ]["idp" ]
508+ elif len (matchProviders ) > 1 :
509+ raise ValueError (
510+ "Multiple idps matched with endpoint value provided."
511+ )
512+ else :
513+ logging .debug ("Could not find matching idp from local WTS." )
514+ except Exception as e :
515+ logging .debug (
516+ "Exception occured when making network call to local WTS."
517+ )
518+ if not self ._external_wts_host :
519+ raise e
520+ else :
521+ logging .debug ("Since external WTS host exists, continuing on.." )
522+ pass
523+
524+ if self ._wts_idp and self ._wts_idp != "local" :
525+ auth_url += "?idp={}" .format (self ._wts_idp )
526+
527+ # If endpoint value exists, only get WTS token if idp value has been successfully determined
528+ # Otherwise skip to querying external WTS
529+ # This is to prevent local WTS from supplying an incorrect token to user
530+ if (
531+ not self ._external_wts_host
532+ or not self .endpoint
533+ or (self .endpoint and self ._wts_idp != "local" )
534+ ):
535+ try :
536+ logging .debug ("Try to get access token from local WTS.." )
537+ logging .debug (f"{ auth_url = } " )
538+ resp = requests .get (auth_url )
539+ if (resp and resp .status_code == 200 ) or (not self ._external_wts_host ):
540+ return _handle_access_token_response (resp , "token" )
541+ except Exception as e :
542+ if not self ._external_wts_host :
543+ raise e
544+ else :
545+ # Try to obtain token from external wts
546+ logging .debug ("Could get obtain token from Local WTS." )
547+ pass
496548
497549 # local workspace wts call failed, try using a network call
498550 # First get access token with WTS host
551+ logging .debug ("Trying to get access token from external WTS Host.." )
499552 wts_token = get_access_token_with_key (self ._refresh_token )
500553 auth_url = self ._external_wts_host + "token/"
501554
@@ -523,6 +576,7 @@ def get_access_token_from_wts(self, endpoint=None):
523576
524577 if len (matchProviders ) == 1 :
525578 self ._wts_idp = matchProviders [0 ]["idp" ]
579+ logging .debug ("Succesfully determined idp value: {}" .format (self ._wts_idp ))
526580 else :
527581 idp_list = "\n "
528582
@@ -559,7 +613,7 @@ def get_access_token_from_wts(self, endpoint=None):
559613 + idp_list
560614 + "Query /wts/external_oidc/ for more information."
561615 )
562-
616+ logging . debug ( "Finally getting access token.." )
563617 auth_url += "?idp={}" .format (self ._wts_idp )
564618 header = {"Authorization" : "Bearer " + wts_token }
565619 resp = requests .get (auth_url , headers = header )
0 commit comments