diff --git a/lib/credentials.py b/lib/credentials.py index 64d4015aa..24dfec4de 100644 --- a/lib/credentials.py +++ b/lib/credentials.py @@ -665,6 +665,11 @@ def __repr__(self) -> str: def __str__(self) -> str: return ";".join(str(auth_set) for auth_set in self._requirements) + + def __contains__(self, cred_type: Union[CredentialType, str]) -> bool: + if isinstance(cred_type, str): + cred_type = CredentialType.from_string(cred_type) + return any(cred_type in group for group in self._requirements) def load(self, auth_method: str): for group in auth_method.split(";"): @@ -1103,11 +1108,11 @@ def check_security_credentials(auth_method, params, client_int_name, entry_name, CredentialError: if the credentials in params don't match what is defined for the auth method """ - auth_method_list = auth_method.split("+") - if not set(auth_method_list) & set(SUPPORTED_AUTH_METHODS): + try: + auth_method = AuthenticationMethod(auth_method) + except CredentialError: logSupport.log.warning( - "None of the supported auth methods %s in provided auth methods: %s" - % (SUPPORTED_AUTH_METHODS, auth_method_list) + f"Invalid authentication method: \"{auth_method}\". Supported methods are: {SUPPORTED_AUTH_METHODS}" ) return @@ -1126,12 +1131,12 @@ def check_security_credentials(auth_method, params, client_int_name, entry_name, "AuthFile", } - if "scitoken" in auth_method_list or "frontend_scitoken" in params and scitoken_passthru: + if "scitoken" in auth_method or "frontend_scitoken" in params and scitoken_passthru: # TODO check validity # TODO Specifically, Add checks that no undesired credentials are # sent also when token is used return - if "grid_proxy" in auth_method_list: + if "grid_proxy" in auth_method: if not scitoken_passthru: if "SubmitProxy" in params: # v3+ protocol @@ -1155,7 +1160,7 @@ def check_security_credentials(auth_method, params, client_int_name, entry_name, if "GlideinProxy" not in params and not scitoken_passthru: raise CredentialError("Glidein proxy cannot be found for client %s, skipping request" % client_int_name) - if "x509_cert" in auth_method_list: + if "x509_cert" in auth_method: # Validate both the public and private certs were passed if not (("PublicCert" in params) and ("PrivateCert" in params)): # if not ('PublicCert' in params and 'PrivateCert' in params): @@ -1173,7 +1178,7 @@ def check_security_credentials(auth_method, params, client_int_name, entry_name, % (client_int_name, entry_name) ) - elif "rsa_key" in auth_method_list: + elif "rsa_key" in auth_method: # Validate both the public and private keys were passed if not (("PublicKey" in params) and ("PrivateKey" in params)): # key pair is required, cannot service request @@ -1190,7 +1195,7 @@ def check_security_credentials(auth_method, params, client_int_name, entry_name, % (client_int_name, entry_name) ) - elif "auth_file" in auth_method_list: + elif "auth_file" in auth_method: # Validate auth_file is passed if not ("AuthFile" in params): # auth_file is required, cannot service request @@ -1207,7 +1212,7 @@ def check_security_credentials(auth_method, params, client_int_name, entry_name, % (client_int_name, entry_name) ) - elif "username_password" in auth_method_list: + elif "username_password" in auth_method: # Validate username and password keys were passed if not (("Username" in params) and ("Password" in params)): # username and password is required, cannot service request