diff --git a/mig/server/createuser.py b/mig/server/createuser.py index 8087e4f76..fb2c701a4 100755 --- a/mig/server/createuser.py +++ b/mig/server/createuser.py @@ -91,8 +91,7 @@ def usage(name='createuser.py'): """ % {'name': name, 'cert_warn': cert_warn}) -if '__main__' == __name__: - (args, app_dir, db_path) = init_user_adm() +def main(args, cwd, db_path=keyword_auto): conf_path = None auth_type = 'custom' expire = None @@ -111,6 +110,7 @@ def usage(name='createuser.py'): user_dict = {} override_fields = {} opt_args = 'a:c:d:e:fhi:o:p:rR:s:u:v' + try: (opts, args) = getopt.getopt(args, opt_args) except getopt.GetoptError as err: @@ -138,13 +138,8 @@ def usage(name='createuser.py'): parsed = True break except ValueError: - pass - if parsed: - override_fields['expire'] = expire - override_fields['status'] = 'temporal' - else: - print('Failed to parse expire value: %s' % val) - sys.exit(1) + print('Failed to parse expire value: %s' % val) + sys.exit(1) elif opt == '-f': force = True elif opt == '-h': @@ -154,17 +149,13 @@ def usage(name='createuser.py'): user_id = val elif opt == '-o': short_id = val - override_fields['short_id'] = short_id elif opt == '-p': peer_pattern = val - override_fields['peer_pattern'] = peer_pattern - override_fields['status'] = 'temporal' elif opt == '-r': default_renew = True ask_renew = False elif opt == '-R': role = val - override_fields['role'] = role elif opt == '-s': # Translate slack days into seconds as slack_secs = int(float(val)*24*3600) @@ -190,8 +181,52 @@ def usage(name='createuser.py'): if verbose: print('using configuration from MIG_CONF (or default)') - configuration = get_configuration_object(config_file=conf_path) + _main(None, args, + conf_path=conf_path, + db_path=db_path, + expire=expire, + force=force, + verbose=verbose, + ask_renew=ask_renew, + default_renew=default_renew, + ask_change_pw=ask_change_pw, + user_file=user_file, + user_id=user_id, + short_id=short_id, + role=role, + peer_pattern=peer_pattern, + slack_secs=slack_secs, + hash_password=hash_password + ) + + +def _main(configuration, args, + conf_path=keyword_auto, + db_path=keyword_auto, + auth_type='custom', + expire=None, + force=False, + verbose=False, + ask_renew=True, + default_renew=False, + ask_change_pw=True, + user_file=None, + user_id=None, + short_id=None, + role=None, + peer_pattern=None, + slack_secs=0, + hash_password=True + ): + if configuration is None: + if conf_path == keyword_auto: + config_file = None + else: + config_file = conf_path + configuration = get_configuration_object(config_file=config_file) + logger = configuration.logger + # NOTE: we need explicit db_path lookup here for load_user_dict call if db_path == keyword_auto: db_path = default_db_path(configuration) @@ -211,9 +246,6 @@ def usage(name='createuser.py'): if auth_type == 'cert': hash_password = False - if expire is None: - expire = default_account_expire(configuration, auth_type) - raw_user = {} if args: try: @@ -291,9 +323,19 @@ def usage(name='createuser.py'): fill_user(user_dict) - # Make sure account expire is set with local certificate or OpenID login - + # assemble the fields to be explicitly overriden + override_fields = {} + if peer_pattern: + override_fields['peer_pattern'] = peer_pattern + override_fields['status'] = 'temporal' + if role: + override_fields['role'] = role + if short_id: + override_fields['short_id'] = short_id if 'expire' not in user_dict: + # Make sure account expire is set with local certificate or OpenID login + if not expire: + expire = default_account_expire(configuration, auth_type) override_fields['expire'] = expire # NOTE: let non-ID command line values override loaded values @@ -305,8 +347,10 @@ def usage(name='createuser.py'): if verbose: print('using user dict: %s' % user_dict) try: - create_user(user_dict, conf_path, db_path, force, verbose, ask_renew, - default_renew, verify_peer=peer_pattern, + conf_path = configuration.config_file + create_user(user_dict, conf_path, db_path, configuration, force, verbose, ask_renew, + default_renew, + verify_peer=peer_pattern, peer_expire_slack=slack_secs, ask_change_pw=ask_change_pw) if configuration.site_enable_gdp: (success_here, msg) = ensure_gdp_user(configuration, @@ -326,3 +370,8 @@ def usage(name='createuser.py'): if verbose: print('Cleaning up tmp file: %s' % user_file) os.remove(user_file) + + +if __name__ == '__main__': + (args, cwd, db_path) = init_user_adm() + main(args, cwd, db_path=db_path) diff --git a/mig/shared/accountstate.py b/mig/shared/accountstate.py index 21330ada7..ddf827795 100644 --- a/mig/shared/accountstate.py +++ b/mig/shared/accountstate.py @@ -33,6 +33,7 @@ from __future__ import absolute_import from past.builtins import basestring +from past.builtins import basestring import os import time diff --git a/mig/shared/base.py b/mig/shared/base.py index b21d4ae6f..006271c2a 100644 --- a/mig/shared/base.py +++ b/mig/shared/base.py @@ -295,7 +295,9 @@ def canonical_user(configuration, user_dict, limit_fields): if key == 'full_name': # IMPORTANT: we get utf8 coded bytes here and title() treats such # chars as word termination. Temporarily force to unicode. - val = force_utf8(force_unicode(val).title()) + val = force_unicode(val).title() + if PY2: + val = force_utf8(val) elif key == 'email': val = val.lower() elif key == 'country': diff --git a/mig/shared/compat.py b/mig/shared/compat.py index ac5ab0f75..09935f18c 100644 --- a/mig/shared/compat.py +++ b/mig/shared/compat.py @@ -55,6 +55,13 @@ def _is_unicode(val): return (type(val) == _TYPE_UNICODE) +def _unicode_string_to_escaped_unicode(unicode_string): + """Convert utf8 bytes to escaped unicode string.""" + + utf8_bytes = dn_utf8_bytes = codecs.encode(unicode_string, 'utf8') + return codecs.decode(utf8_bytes, 'unicode_escape') + + def ensure_native_string(string_or_bytes): """Given a supplied input which can be either a string or bytes return a representation providing string operations while ensuring that diff --git a/mig/shared/useradm.py b/mig/shared/useradm.py index 03603a2ef..ec8958d36 100644 --- a/mig/shared/useradm.py +++ b/mig/shared/useradm.py @@ -30,8 +30,11 @@ from __future__ import print_function from __future__ import absolute_import +from past.builtins import basestring from email.utils import parseaddr +import codecs import datetime +import errno import fnmatch import os import re @@ -44,6 +47,7 @@ from mig.shared.base import client_id_dir, client_dir_id, client_alias, \ get_client_id, extract_field, fill_user, fill_distinguished_name, \ is_gdp_user, mask_creds, sandbox_resource +from mig.shared.compat import _unicode_string_to_escaped_unicode from mig.shared.conf import get_configuration_object from mig.shared.configuration import Configuration from mig.shared.defaults import user_db_filename, keyword_auto, ssh_conf_dir, \ @@ -97,6 +101,10 @@ https_authdigests = user_db_filename +_USERADM_CONFIG_DIR_KEYS = ('user_db_home', 'user_home', 'user_settings', + 'user_cache', 'mrsl_files_dir', 'resource_pending') + + def init_user_adm(dynamic_db_path=True): """Shared init function for all user administration scripts. The optional dynamic_db_path argument toggles dynamic user db path lookup @@ -451,6 +459,21 @@ def verify_user_peers(configuration, db_path, client_id, user, now, verify_peer, return accepted_peer_list, effective_expire +def _check_directories_unprovisioned(configuration, db_path): + user_db_home = os.path.dirname(db_path) + return not os.path.exists(db_path) and not os.path.exists(user_db_home) + + +def _provision_directories(configuration): + for config_attr in _USERADM_CONFIG_DIR_KEYS: + try: + dir_to_create = getattr(configuration, config_attr) + os.mkdir(dir_to_create) + except OSError as oserr: + if oserr.errno != errno.ENOENT: # FileNotFoundError + raise + + def create_user_in_db(configuration, db_path, client_id, user, now, authorized, reset_token, reset_auth_type, accepted_peer_list, force, verbose, ask_renew, default_renew, do_lock, @@ -463,8 +486,25 @@ def create_user_in_db(configuration, db_path, client_id, user, now, authorized, flock = None user_db = {} renew = default_renew + + retry_lock = False if do_lock: + try: + flock = lock_user_db(db_path) + except (IOError, OSError) as oserr: + if oserr.errno != errno.ENOENT: # FileNotFoundError + raise + + if _check_directories_unprovisioned(configuration, db_path=db_path): + _provision_directories(configuration) + retry_lock = True + else: + raise Exception("Failed to lock user DB: '%s'" % db_path) + + if retry_lock: flock = lock_user_db(db_path) + if not flock: + raise Exception("Failed to lock user DB: '%s'" % db_path) if not os.path.exists(db_path): # Auto-create missing user DB if either auto_create_db or force is set @@ -859,7 +899,7 @@ def create_user_in_fs(configuration, client_id, user, now, renew, force, verbose # match in htaccess dn_plain = info['distinguished_name'] - dn_enc = dn_plain.encode('string_escape') + dn_enc = _unicode_string_to_escaped_unicode(dn_plain) def upper_repl(match): """Translate hex codes to upper case form""" @@ -1013,7 +1053,7 @@ def upper_repl(match): raise Exception('could not create custom css file: %s' % css_path) -def create_user(user, conf_path, db_path, force=False, verbose=False, +def create_user(user, conf_path, db_path, configuration=None, force=False, verbose=False, ask_renew=True, default_renew=False, do_lock=True, verify_peer=None, peer_expire_slack=0, from_edit_user=False, ask_change_pw=False, auto_create_db=True, create_backup=True): @@ -1021,7 +1061,10 @@ def create_user(user, conf_path, db_path, force=False, verbose=False, format as a first step. """ - if conf_path: + if configuration is not None: + # use it + pass + elif conf_path: if isinstance(conf_path, basestring): # has been checked for accessibility above... diff --git a/tests/test_mig_server_createuser.py b/tests/test_mig_server_createuser.py new file mode 100644 index 000000000..d2c0825a6 --- /dev/null +++ b/tests/test_mig_server_createuser.py @@ -0,0 +1,83 @@ +# -*- coding: utf-8 -*- +# +# --- BEGIN_HEADER --- +# +# test_mig_server-createuser - unit tests for the migrid createuser CLI +# Copyright (C) 2003-2024 The MiG Project by the Science HPC Center at UCPH +# +# This file is part of MiG. +# +# MiG is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation; either version 2 of the License, or +# (at your option) any later version. +# +# MiG is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program; if not, write to the Free Software +# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, +# USA. +# +# --- END_HEADER --- +# + +"""Unit tests for the migrid createuser CLI""" + +from __future__ import print_function +import os +import shutil +import sys + +from tests.support import MIG_BASE, TEST_OUTPUT_DIR, MigTestCase, testmain + +from mig.server.createuser import _main as createuser +from mig.shared.useradm import _USERADM_CONFIG_DIR_KEYS + + +class TestBooleans(MigTestCase): + def before_each(self): + configuration = self.configuration + test_state_path = configuration.state_path + + for config_key in _USERADM_CONFIG_DIR_KEYS: + dir_path = getattr(configuration, config_key)[0:-1] + try: + shutil.rmtree(dir_path) + except: + pass + + self.expected_user_db_home = configuration.user_db_home[0:-1] + + def _provide_configuration(self): + return 'testconfig' + + def test_user_db_is_created_and_user_is_added(self): + args = [ + "Test User", + "Test Org", + "NA", + "DK", + "dummy-user", + "This is the create comment", + "password" + ] + print("") # acount for output generated by the logic + createuser(self.configuration, args, default_renew=True) + + # presence of user home + path_kind = MigTestCase._absolute_path_kind(self.expected_user_db_home) + self.assertEqual(path_kind, 'dir') + + # presence of user db + expected_user_db_file = os.path.join( + self.expected_user_db_home, 'MiG-users.db') + path_kind = MigTestCase._absolute_path_kind(expected_user_db_file) + self.assertEqual(path_kind, 'file') + + +if __name__ == '__main__': + testmain()