diff --git a/creation/lib/cvWParamDict.py b/creation/lib/cvWParamDict.py index a6e63b2cb..c0df15a13 100644 --- a/creation/lib/cvWParamDict.py +++ b/creation/lib/cvWParamDict.py @@ -1355,7 +1355,7 @@ def populate_group_security(client_security, params, sub_params, group_name): client_security["schedd_DNs"] = schedd_dns pilot_dns = [] - exclude_from_pilot_dns = ["SCITOKEN", "IDTOKEN"] + exclude_from_pilot_dns = ["SCITOKEN", "IDTOKEN", "GENERATOR"] for credentials in (params.security.credentials, sub_params.security.credentials): if is_true(params.groups[group_name].enabled): for pel in credentials: diff --git a/frontend/glideinFrontend.py b/frontend/glideinFrontend.py index b790a7cb7..031754a43 100755 --- a/frontend/glideinFrontend.py +++ b/frontend/glideinFrontend.py @@ -95,7 +95,7 @@ def poll_group_process(group_name, child): pass # ignore try: tempErr = child.stderr.read() - if tempOut: + if tempErr: logSupport.log.warning(f"[{group_name}]: {tempErr}") except OSError: pass # ignore diff --git a/frontend/glideinFrontendInterface.py b/frontend/glideinFrontendInterface.py index 7e06b5a04..fb9a32c63 100644 --- a/frontend/glideinFrontendInterface.py +++ b/frontend/glideinFrontendInterface.py @@ -857,7 +857,7 @@ def renew_and_load_credentials(self): cred_el = self.x509_proxies_data[i] cred_el.advertize = True cred_el.credential.renew() - cred_el.credential.save_to_file(overwrite=False) + cred_el.credential.save_to_file(overwrite=False, continue_if_no_path=True) return nr_credentials @@ -1164,8 +1164,7 @@ def createAdvertizeWorkFile(self, factory_pool, params_obj, key_obj=None, file_i if nr_credentials == 0: raise NoCredentialException - cred_types = set(t.credential.cred_type for t in credentials_with_requests) - auth_set = factory_auth.match(cred_types) + auth_set = factory_auth.match([t.credential for t in credentials_with_requests]) if not auth_set: logSupport.log.warning("No credentials match for factory pool %s, not advertising request" % factory_pool) raise NoCredentialException diff --git a/frontend/glideinFrontendPlugins.py b/frontend/glideinFrontendPlugins.py index 446571198..f61753fc3 100644 --- a/frontend/glideinFrontendPlugins.py +++ b/frontend/glideinFrontendPlugins.py @@ -667,13 +667,6 @@ def createRequestBundle(elementDescript): return request_bundle -def createRequestBundle(elementDescript): - """Creates a list of Credentials for a proxy plugin""" - request_bundle = RequestBundle() - request_bundle.load_from_element(elementDescript) - return request_bundle - - def fair_split(i, n, p): """ Split n requests amongst p proxies diff --git a/lib/credentials.py b/lib/credentials.py index 3bf3946f1..b4b18b093 100644 --- a/lib/credentials.py +++ b/lib/credentials.py @@ -35,6 +35,7 @@ from glideinwms.factory import glideFactoryInterface, glideFactoryLib from glideinwms.lib import condorMonitor, logSupport, pubCrypto, symCrypto +from glideinwms.lib.generators import Generator, load_generator from glideinwms.lib.util import hash_nc sys.path.append("/etc/gwms-frontend/plugin.d") @@ -63,6 +64,7 @@ class CredentialType(enum.Enum): TOKEN = "token" X509_CERT = "x509_cert" RSA_KEY = "rsa_key" + GENERATOR = "generator" def __repr__(self) -> str: return self.name @@ -208,12 +210,15 @@ def save_to_file( compress: bool = False, data_pattern: Optional[bytes] = None, overwrite: bool = True, + continue_if_no_path = False, ) -> None: if not self.string: raise CredentialError("Credential not initialized") path = path or self.path if not path: + if continue_if_no_path: + return raise CredentialError("No path specified") if os.path.isfile(path) and not overwrite: @@ -294,20 +299,6 @@ def __str__(self) -> str: return f"{self.name.value}={self.value}" -class Parameter: - def __init__(self, name: ParameterName, value): - if not isinstance(name, ParameterName): - raise TypeError("Name must be a ParameterName") - self.name = name - self.value = value - - def __repr__(self) -> str: - return f"{self.__class__.__name__}(name={self.name.value!r}, value={self.value!r})" - - def __str__(self) -> str: - return f"{self.name.value}={self.value}" - - class ParameterDict(dict): def __setitem__(self, __k, __v): if not isinstance(__k, ParameterName): @@ -320,6 +311,52 @@ def add(self, parameter: Parameter): self[parameter.name] = parameter.value +class CredentialGenerator(Credential[Credential]): + cred_type = CredentialType.GENERATOR + + def __init__(self, string: Optional[bytes] = None, path: Optional[str] = None) -> None: + if not string: + string = path.encode() if path else None + if not string: + raise CredentialError("No string or path specified") + self._string = string + self.path = None + self.load(string) + + def __renew__(self) -> None: + self.load(self._string) + + @property + def _payload(self) -> Optional[Credential]: + return self.decode(self._string) if self._string else None + + @property + def string(self) -> Optional[bytes]: + return self._payload.string if self._payload else None + + @staticmethod + def decode(string: bytes) -> Credential: + generator = load_generator(string.decode()) + return create_credential(generator.generate()) + + def valid(self) -> bool: + if self._payload: + return self._payload.valid() + return False + + def load_from_file(self, path: str) -> None: + raise CredentialError("Cannot load CredentialGenerator from file") + + def load(self, string: Optional[bytes] = None, path: Optional[str] = None) -> None: + if string: + self.load_from_string(string) + self.cred_type = self._payload.cred_type if self._payload else None + if path: + self.path = path + else: + raise CredentialError("No string specified") + + class Token(Credential[Mapping]): cred_type = CredentialType.TOKEN classad_attribute = "ScitokenId" # TODO: We might want to change this name to "TokenId" in the future @@ -424,16 +461,16 @@ def extract_sym_key(self, enc_sym_key) -> symCrypto.AutoSymKey: return symCrypto.AutoSymKey(self._payload.decrypt_hex(enc_sym_key)) -class TextCredential(Credential[bytes]): - cred_type = CredentialType.TOKEN - classad_attribute = "AuthFile" +# class TextCredential(Credential[bytes]): +# cred_type = CredentialType.TOKEN +# classad_attribute = "AuthFile" - @staticmethod - def decode(string: bytes) -> bytes: - return string +# @staticmethod +# def decode(string: bytes) -> bytes: +# return string - def valid(self) -> bool: - return True +# def valid(self) -> bool: +# return True class X509Pair(CredentialPair, X509Cert): @@ -451,25 +488,25 @@ def __init__( self.private_credential.classad_attribute = "PrivateCert" -class UsernamePassword(CredentialPair, TextCredential): - cred_type = CredentialPairType.USERNAME_PASSWORD +# class UsernamePassword(CredentialPair, TextCredential): +# cred_type = CredentialPairType.USERNAME_PASSWORD - def __init__( - self, - string: Optional[bytes] = None, - path: Optional[str] = None, - private_string: Optional[bytes] = None, - private_path: Optional[str] = None, - ) -> None: - super().__init__(string, path, private_string, private_path) - self.classad_attribute = "Username" - self.private_credential.classad_attribute = "Password" +# def __init__( +# self, +# string: Optional[bytes] = None, +# path: Optional[str] = None, +# private_string: Optional[bytes] = None, +# private_path: Optional[str] = None, +# ) -> None: +# super().__init__(string, path, private_string, private_path) +# self.classad_attribute = "Username" +# self.private_credential.classad_attribute = "Password" class RequestCredential: def __init__( self, - credential: Credential, + credential: Union[Credential, Generator], purpose: Optional[CredentialPurpose] = None, trust_domain: Optional[TrustDomain] = None, security_class: Optional[str] = None, @@ -490,10 +527,10 @@ def __str__(self) -> str: @property def id(self) -> str: - if not self.credential.string: + if not str(self.credential): raise CredentialError("Credential not initialized") - return hash_nc(f"{self.credential.string.decode()}{self.purpose}{self.trust_domain}{self.security_class}", 8) + return hash_nc(f"{str(self.credential)}{self.purpose}{self.trust_domain}{self.security_class}", 8) @property def private_id(self) -> str: @@ -589,7 +626,7 @@ def add_parameter(self, param_id: ParameterName, param_value): class AuthenticationSet: - _required_types: Set[CredentialType] = set() + _required_types: Set[Union[CredentialType, ParameterName]] = set() def __init__(self, cred_types: Iterable[CredentialType]): for cred_type in cred_types: @@ -613,31 +650,45 @@ def satisfied_by(self, cred_types: Iterable[CredentialType]) -> bool: class AuthenticationMethod: - _supported_sets: List[AuthenticationSet] = [] - def __init__(self, auth_method: str): + self._requirements: List[List[Union[CredentialType, ParameterName]]] = [] self.load(auth_method) def __repr__(self) -> str: - return f"{self.__class__.__name__}({self._supported_sets!r})" + return f"{self.__class__.__name__}({self._requirements!r})" def __str__(self) -> str: - return ";".join(str(auth_set) for auth_set in self._supported_sets) + return ";".join(str(auth_set) for auth_set in self._requirements) def load(self, auth_method: str): - for auth_set in auth_method.split(";"): - if auth_set.lower() == "any": - self._supported_sets.append(AuthenticationSet([])) + for group in auth_method.split(";"): + if group.lower() == "any": + self._requirements.append([]) # type: ignore else: - self._supported_sets.append( - AuthenticationSet([CredentialType.from_string(cred_type) for cred_type in auth_set.split(",")]) - ) - - def match(self, cred_types: Iterable[CredentialType]) -> Optional[AuthenticationSet]: - for auth_set in self._supported_sets: - if auth_set.satisfied_by(cred_types): - return auth_set - return None + options = [] + for option in group.split(","): + try: + options.append(CredentialType.from_string(option)) + except CredentialError: + try: + options.append(ParameterName.from_string(option)) + except CredentialError: + raise CredentialError(f"Unknown authentication requirement: {option}") + self._requirements.append(options) + + def match(self, credentials: Iterable[Credential]) -> Optional[AuthenticationSet]: + if not self._requirements: + return AuthenticationSet([]) + + auth_set = [] + cred_types = {credential.cred_type for credential in credentials if credential.valid()} + for group in self._requirements: + for option in group: + if option in cred_types: + auth_set.append(option) + break + return None + return AuthenticationSet(auth_set) def credential_of_type( diff --git a/lib/generators.py b/lib/generators.py new file mode 100644 index 000000000..4e8ad901b --- /dev/null +++ b/lib/generators.py @@ -0,0 +1,84 @@ +#!/usr/bin/env python3 + +# SPDX-FileCopyrightText: 2009 Fermi Research Alliance, LLC +# SPDX-License-Identifier: Apache-2.0 + +# +# Project: +# glideinWMS +# +# File Version: +# +# Description: +# Contains the Generator base class and built-in generators +# + +import sys +import inspect + +from abc import ABC, abstractmethod +from importlib import import_module +from typing import Generic, TypeVar, Optional, List, Mapping + +sys.path.append("/etc/gwms-frontend/plugin.d") +_loaded_generators = {} + +T = TypeVar("T") + + +class Generator(ABC, Generic[T]): + def __str__(self): + return f"{self.__class__.__name__}()" + + def __repr__(self): + return str(self) + + @abstractmethod + def generate(self, arguments: Optional[Mapping] = None) -> T: + pass + + +def load_generator(module: str) -> Generator: + """Load a generator from a module + + Args: + module (str): module that exports a generator + + Raises: + ImportError: when a `Generator` object cannot be imported from `module` + + Returns: + Generator: generator object + """ + + try: + if not module in _loaded_generators: + imported_module = import_module(module) + if module not in _loaded_generators: + del imported_module + raise ImportError(f"Module {module} does not export a generator") + except ImportError as e: + raise ImportError(f"Failed to import module {module}") from e + return _loaded_generators[module] + + +def export_generator(generator: Generator): + """Make a Generator object available to the genearators module""" + + if not isinstance(generator, Generator): + raise TypeError("generator must be a Generator object") + module = inspect.getmodule(inspect.stack()[1][0]) + if not module: + raise RuntimeError("Failed to get module name") + _loaded_generators[module.__name__] = generator + + +class RoundRobinGenerator(Generator[T]): + + def __init__(self, items: List[T]) -> None: + self._items = items + + def generate(self) -> T: + item = self._items.pop() + self._items.insert(0, item) + return item