Skip to content

Commit

Permalink
Implement generators
Browse files Browse the repository at this point in the history
  • Loading branch information
BrunoCoimbra committed Nov 14, 2023
1 parent 7574d09 commit 0ef555e
Show file tree
Hide file tree
Showing 6 changed files with 193 additions and 66 deletions.
2 changes: 1 addition & 1 deletion creation/lib/cvWParamDict.py
Original file line number Diff line number Diff line change
Expand Up @@ -1355,7 +1355,7 @@ def populate_group_security(client_security, params, sub_params, group_name):
client_security["schedd_DNs"] = schedd_dns

pilot_dns = []
exclude_from_pilot_dns = ["SCITOKEN", "IDTOKEN"]
exclude_from_pilot_dns = ["SCITOKEN", "IDTOKEN", "GENERATOR"]
for credentials in (params.security.credentials, sub_params.security.credentials):
if is_true(params.groups[group_name].enabled):
for pel in credentials:
Expand Down
2 changes: 1 addition & 1 deletion frontend/glideinFrontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def poll_group_process(group_name, child):
pass # ignore
try:
tempErr = child.stderr.read()
if tempOut:
if tempErr:
logSupport.log.warning(f"[{group_name}]: {tempErr}")
except OSError:
pass # ignore
Expand Down
5 changes: 2 additions & 3 deletions frontend/glideinFrontendInterface.py
Original file line number Diff line number Diff line change
Expand Up @@ -857,7 +857,7 @@ def renew_and_load_credentials(self):
cred_el = self.x509_proxies_data[i]
cred_el.advertize = True
cred_el.credential.renew()
cred_el.credential.save_to_file(overwrite=False)
cred_el.credential.save_to_file(overwrite=False, continue_if_no_path=True)

return nr_credentials

Expand Down Expand Up @@ -1164,8 +1164,7 @@ def createAdvertizeWorkFile(self, factory_pool, params_obj, key_obj=None, file_i
if nr_credentials == 0:
raise NoCredentialException

cred_types = set(t.credential.cred_type for t in credentials_with_requests)
auth_set = factory_auth.match(cred_types)
auth_set = factory_auth.match([t.credential for t in credentials_with_requests])
if not auth_set:
logSupport.log.warning("No credentials match for factory pool %s, not advertising request" % factory_pool)
raise NoCredentialException
Expand Down
7 changes: 0 additions & 7 deletions frontend/glideinFrontendPlugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,13 +667,6 @@ def createRequestBundle(elementDescript):
return request_bundle


def createRequestBundle(elementDescript):
"""Creates a list of Credentials for a proxy plugin"""
request_bundle = RequestBundle()
request_bundle.load_from_element(elementDescript)
return request_bundle


def fair_split(i, n, p):
"""
Split n requests amongst p proxies
Expand Down
159 changes: 105 additions & 54 deletions lib/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@

from glideinwms.factory import glideFactoryInterface, glideFactoryLib
from glideinwms.lib import condorMonitor, logSupport, pubCrypto, symCrypto
from glideinwms.lib.generators import Generator, load_generator
from glideinwms.lib.util import hash_nc

sys.path.append("/etc/gwms-frontend/plugin.d")
Expand Down Expand Up @@ -63,6 +64,7 @@ class CredentialType(enum.Enum):
TOKEN = "token"
X509_CERT = "x509_cert"
RSA_KEY = "rsa_key"
GENERATOR = "generator"

def __repr__(self) -> str:
return self.name
Expand Down Expand Up @@ -208,12 +210,15 @@ def save_to_file(
compress: bool = False,
data_pattern: Optional[bytes] = None,
overwrite: bool = True,
continue_if_no_path = False,
) -> None:
if not self.string:
raise CredentialError("Credential not initialized")

path = path or self.path
if not path:
if continue_if_no_path:
return
raise CredentialError("No path specified")

if os.path.isfile(path) and not overwrite:
Expand Down Expand Up @@ -294,20 +299,6 @@ def __str__(self) -> str:
return f"{self.name.value}={self.value}"


class Parameter:
def __init__(self, name: ParameterName, value):
if not isinstance(name, ParameterName):
raise TypeError("Name must be a ParameterName")
self.name = name
self.value = value

def __repr__(self) -> str:
return f"{self.__class__.__name__}(name={self.name.value!r}, value={self.value!r})"

def __str__(self) -> str:
return f"{self.name.value}={self.value}"


class ParameterDict(dict):
def __setitem__(self, __k, __v):
if not isinstance(__k, ParameterName):
Expand All @@ -320,6 +311,52 @@ def add(self, parameter: Parameter):
self[parameter.name] = parameter.value


class CredentialGenerator(Credential[Credential]):
cred_type = CredentialType.GENERATOR

def __init__(self, string: Optional[bytes] = None, path: Optional[str] = None) -> None:
if not string:
string = path.encode() if path else None
if not string:
raise CredentialError("No string or path specified")
self._string = string
self.path = None
self.load(string)

def __renew__(self) -> None:
self.load(self._string)

@property
def _payload(self) -> Optional[Credential]:
return self.decode(self._string) if self._string else None

@property
def string(self) -> Optional[bytes]:
return self._payload.string if self._payload else None

@staticmethod
def decode(string: bytes) -> Credential:
generator = load_generator(string.decode())
return create_credential(generator.generate())

def valid(self) -> bool:
if self._payload:
return self._payload.valid()
return False

def load_from_file(self, path: str) -> None:
raise CredentialError("Cannot load CredentialGenerator from file")

def load(self, string: Optional[bytes] = None, path: Optional[str] = None) -> None:
if string:
self.load_from_string(string)
self.cred_type = self._payload.cred_type if self._payload else None
if path:
self.path = path
else:
raise CredentialError("No string specified")


class Token(Credential[Mapping]):
cred_type = CredentialType.TOKEN
classad_attribute = "ScitokenId" # TODO: We might want to change this name to "TokenId" in the future
Expand Down Expand Up @@ -424,16 +461,16 @@ def extract_sym_key(self, enc_sym_key) -> symCrypto.AutoSymKey:
return symCrypto.AutoSymKey(self._payload.decrypt_hex(enc_sym_key))


class TextCredential(Credential[bytes]):
cred_type = CredentialType.TOKEN
classad_attribute = "AuthFile"
# class TextCredential(Credential[bytes]):
# cred_type = CredentialType.TOKEN
# classad_attribute = "AuthFile"

@staticmethod
def decode(string: bytes) -> bytes:
return string
# @staticmethod
# def decode(string: bytes) -> bytes:
# return string

def valid(self) -> bool:
return True
# def valid(self) -> bool:
# return True


class X509Pair(CredentialPair, X509Cert):
Expand All @@ -451,25 +488,25 @@ def __init__(
self.private_credential.classad_attribute = "PrivateCert"


class UsernamePassword(CredentialPair, TextCredential):
cred_type = CredentialPairType.USERNAME_PASSWORD
# class UsernamePassword(CredentialPair, TextCredential):
# cred_type = CredentialPairType.USERNAME_PASSWORD

def __init__(
self,
string: Optional[bytes] = None,
path: Optional[str] = None,
private_string: Optional[bytes] = None,
private_path: Optional[str] = None,
) -> None:
super().__init__(string, path, private_string, private_path)
self.classad_attribute = "Username"
self.private_credential.classad_attribute = "Password"
# def __init__(
# self,
# string: Optional[bytes] = None,
# path: Optional[str] = None,
# private_string: Optional[bytes] = None,
# private_path: Optional[str] = None,
# ) -> None:
# super().__init__(string, path, private_string, private_path)
# self.classad_attribute = "Username"
# self.private_credential.classad_attribute = "Password"


class RequestCredential:
def __init__(
self,
credential: Credential,
credential: Union[Credential, Generator],
purpose: Optional[CredentialPurpose] = None,
trust_domain: Optional[TrustDomain] = None,
security_class: Optional[str] = None,
Expand All @@ -490,10 +527,10 @@ def __str__(self) -> str:

@property
def id(self) -> str:
if not self.credential.string:
if not str(self.credential):
raise CredentialError("Credential not initialized")

return hash_nc(f"{self.credential.string.decode()}{self.purpose}{self.trust_domain}{self.security_class}", 8)
return hash_nc(f"{str(self.credential)}{self.purpose}{self.trust_domain}{self.security_class}", 8)

@property
def private_id(self) -> str:
Expand Down Expand Up @@ -589,7 +626,7 @@ def add_parameter(self, param_id: ParameterName, param_value):


class AuthenticationSet:
_required_types: Set[CredentialType] = set()
_required_types: Set[Union[CredentialType, ParameterName]] = set()

def __init__(self, cred_types: Iterable[CredentialType]):
for cred_type in cred_types:
Expand All @@ -613,31 +650,45 @@ def satisfied_by(self, cred_types: Iterable[CredentialType]) -> bool:


class AuthenticationMethod:
_supported_sets: List[AuthenticationSet] = []

def __init__(self, auth_method: str):
self._requirements: List[List[Union[CredentialType, ParameterName]]] = []
self.load(auth_method)

def __repr__(self) -> str:
return f"{self.__class__.__name__}({self._supported_sets!r})"
return f"{self.__class__.__name__}({self._requirements!r})"

def __str__(self) -> str:
return ";".join(str(auth_set) for auth_set in self._supported_sets)
return ";".join(str(auth_set) for auth_set in self._requirements)

def load(self, auth_method: str):
for auth_set in auth_method.split(";"):
if auth_set.lower() == "any":
self._supported_sets.append(AuthenticationSet([]))
for group in auth_method.split(";"):
if group.lower() == "any":
self._requirements.append([]) # type: ignore
else:
self._supported_sets.append(
AuthenticationSet([CredentialType.from_string(cred_type) for cred_type in auth_set.split(",")])
)

def match(self, cred_types: Iterable[CredentialType]) -> Optional[AuthenticationSet]:
for auth_set in self._supported_sets:
if auth_set.satisfied_by(cred_types):
return auth_set
return None
options = []
for option in group.split(","):
try:
options.append(CredentialType.from_string(option))
except CredentialError:
try:
options.append(ParameterName.from_string(option))
except CredentialError:
raise CredentialError(f"Unknown authentication requirement: {option}")
self._requirements.append(options)

def match(self, credentials: Iterable[Credential]) -> Optional[AuthenticationSet]:
if not self._requirements:
return AuthenticationSet([])

auth_set = []
cred_types = {credential.cred_type for credential in credentials if credential.valid()}
for group in self._requirements:
for option in group:
if option in cred_types:
auth_set.append(option)
break
return None
return AuthenticationSet(auth_set)


def credential_of_type(
Expand Down
84 changes: 84 additions & 0 deletions lib/generators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
#!/usr/bin/env python3

# SPDX-FileCopyrightText: 2009 Fermi Research Alliance, LLC
# SPDX-License-Identifier: Apache-2.0

#
# Project:
# glideinWMS
#
# File Version:
#
# Description:
# Contains the Generator base class and built-in generators
#

import sys
import inspect

from abc import ABC, abstractmethod
from importlib import import_module
from typing import Generic, TypeVar, Optional, List, Mapping

sys.path.append("/etc/gwms-frontend/plugin.d")
_loaded_generators = {}

T = TypeVar("T")


class Generator(ABC, Generic[T]):
def __str__(self):
return f"{self.__class__.__name__}()"

def __repr__(self):
return str(self)

@abstractmethod
def generate(self, arguments: Optional[Mapping] = None) -> T:
pass


def load_generator(module: str) -> Generator:
"""Load a generator from a module
Args:
module (str): module that exports a generator
Raises:
ImportError: when a `Generator` object cannot be imported from `module`
Returns:
Generator: generator object
"""

try:
if not module in _loaded_generators:
imported_module = import_module(module)
if module not in _loaded_generators:
del imported_module
raise ImportError(f"Module {module} does not export a generator")
except ImportError as e:
raise ImportError(f"Failed to import module {module}") from e
return _loaded_generators[module]


def export_generator(generator: Generator):
"""Make a Generator object available to the genearators module"""

if not isinstance(generator, Generator):
raise TypeError("generator must be a Generator object")
module = inspect.getmodule(inspect.stack()[1][0])
if not module:
raise RuntimeError("Failed to get module name")
_loaded_generators[module.__name__] = generator


class RoundRobinGenerator(Generator[T]):

def __init__(self, items: List[T]) -> None:
self._items = items

def generate(self) -> T:
item = self._items.pop()
self._items.insert(0, item)
return item

0 comments on commit 0ef555e

Please sign in to comment.