diff --git a/.gitignore b/.gitignore index 81d3cfafe..b2f79a899 100644 --- a/.gitignore +++ b/.gitignore @@ -15,6 +15,7 @@ downloads/ eggs/ .eggs/ lib/ +!/mig/lib/ lib64/ parts/ sdist/ diff --git a/mig/lib/__init__.py b/mig/lib/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/mig/lib/coreapi/__init__.py b/mig/lib/coreapi/__init__.py new file mode 100644 index 000000000..995053549 --- /dev/null +++ b/mig/lib/coreapi/__init__.py @@ -0,0 +1,110 @@ +import codecs +import json +import werkzeug.exceptions as httpexceptions + +from tests.support._env import PY2 + +if PY2: + from urllib2 import HTTPError, Request, urlopen + from urllib import urlencode +else: + from urllib.error import HTTPError + from urllib.parse import urlencode + from urllib.request import urlopen, Request + +from mig.lib.coresvc.payloads import PAYLOAD_POST_USER + + +httpexceptions_by_code = { + exc.code: exc for exc in httpexceptions.__dict__.values() if hasattr(exc, 'code')} + + +def attempt_to_decode_response_data(data, response_encoding=None): + if data is None: + return None + elif response_encoding == 'textual': + data = codecs.decode(data, 'utf8') + + try: + return json.loads(data) + except Exception as e: + return data + elif response_encoding == 'binary': + return data + else: + raise AssertionError( + 'issue_POST: unknown response_encoding "%s"' % (response_encoding,)) + + +def http_error_from_status_code(http_status_code, description=None): + return httpexceptions_by_code[http_status_code](description) + + +class CoreApiClient: + def __init__(self, base_url): + self._base_url = base_url + + def _issue_GET(self, request_path, query_dict=None, response_encoding='textual'): + request_url = ''.join((self._base_url, request_path)) + + if query_dict is not None: + query_string = urlencode(query_dict) + request_url = ''.join((request_url, '?', query_string)) + + status = 0 + data = None + + try: + response = urlopen(request_url, None, timeout=2000) + + status = response.getcode() + data = response.read() + except HTTPError as httpexc: + status = httpexc.code + data = None + + content = attempt_to_decode_response_data(data, response_encoding) + return (status, content) + + def _issue_POST(self, request_path, request_data=None, request_json=None, response_encoding='textual'): + request_url = ''.join((self._base_url, request_path)) + + if request_data and request_json: + raise ValueError( + "only one of data or json request data may be specified") + + status = 0 + data = None + + try: + if request_json is not None: + request_data = codecs.encode(json.dumps(request_json), 'utf8') + request_headers = { + 'Content-Type': 'application/json' + } + request = Request(request_url, request_data, + headers=request_headers) + elif request_data is not None: + request = Request(request_url, request_data) + else: + request = Request(request_url) + + response = urlopen(request, timeout=2000) + + status = response.getcode() + data = response.read() + except HTTPError as httpexc: + status = httpexc.code + data = httpexc.fp.read() + + content = attempt_to_decode_response_data(data, response_encoding) + return (status, content) + + def createUser(self, user_dict): + payload = PAYLOAD_POST_USER.ensure(user_dict) + + status, output = self._issue_POST('/user', request_json=dict(payload)) + if status != 201: + description = output if isinstance(output, str) else None + raise http_error_from_status_code(status, description) + return output diff --git a/mig/lib/coresvc/__init__.py b/mig/lib/coresvc/__init__.py new file mode 100644 index 000000000..412a33c62 --- /dev/null +++ b/mig/lib/coresvc/__init__.py @@ -0,0 +1,2 @@ +from mig.lib.coresvc.server import ThreadedApiHttpServer, \ + _create_and_expose_server diff --git a/mig/lib/coresvc/__main__.py b/mig/lib/coresvc/__main__.py new file mode 100644 index 000000000..1a8155104 --- /dev/null +++ b/mig/lib/coresvc/__main__.py @@ -0,0 +1,30 @@ +from argparse import ArgumentError +from getopt import getopt +import sys + +from mig.shared.conf import get_configuration_object +from mig.services.coreapi.server import main as server_main + + +def _getopt_opts_to_options(opts): + options = {} + for k, v in opts: + options[k[1:]] = v + return options + + +def _required_argument_error(option, argument_name): + raise ArgumentError(None, 'Missing required argument: %s %s' % + (option, argument_name.upper())) + + +if __name__ == '__main__': + (opts, args) = getopt(sys.argv[1:], 'c:') + opts_dict = _getopt_opts_to_options(opts) + + if 'c' not in opts_dict: + raise _required_argument_error('-c', 'config_file') + + configuration = get_configuration_object(opts_dict['c'], + skip_log=True, disable_auth_log=True) + server_main(configuration) diff --git a/mig/lib/coresvc/payloads.py b/mig/lib/coresvc/payloads.py new file mode 100644 index 000000000..c7caf7028 --- /dev/null +++ b/mig/lib/coresvc/payloads.py @@ -0,0 +1,212 @@ +from collections import defaultdict, namedtuple, OrderedDict + +from mig.shared.safeinput import validate_helper + + +_EMPTY_LIST = {} +_REQUIRED_FIELD = object() + + +def _is_not_none(value): + """value is not None""" + assert value is not None, _is_not_none.__doc__ + + +def _is_string_and_non_empty(value): + """value is a non-empty string""" + assert isinstance(value, str) and len(value) > 0, _is_string_and_non_empty.__doc__ + + +class PayloadException(ValueError): + def __str__(self): + return self.serialize(output_format='text') + + def serialize(self, output_format='text'): + error_message = self.args[0] + + if output_format == 'json': + return dict(error=error_message) + else: + return error_message + + +class PayloadReport(PayloadException): + def __init__(self, errors_by_field): + self.errors_by_field = errors_by_field + + def serialize(self, output_format='text'): + if output_format == 'json': + return dict(errors=self.errors_by_field) + else: + lines = ["- %s: %s" % + (k, v) for k, v in self.errors_by_field.items()] + lines.insert(0, '') + return 'payload failed to validate:%s' % ('\n'.join(lines),) + + +class _MissingField: + def __init__(self, field, message=None): + assert message is not None + self._field = field + self._message = message + + def replace(self, _, __): + return self._field + + @classmethod + def assert_not_instance(cls, value): + assert not isinstance(value, cls), value._message + + +class Payload(OrderedDict): + def __init__(self, definition, dictionary): + super(Payload, self).__init__(dictionary) + self._definition = definition + + @property + def _fields(self): + return self._definition._fields + + @property + def name(self): + return self._definition._definition_name + + def __iter__(self): + return iter(self.values()) + + def items(self): + return zip(self._definition._item_names, self.values()) + + @staticmethod + def define(payload_name, payload_fields, validators_by_field): + positionals = list((field, validators_by_field[field]) for field in payload_fields) + return PayloadDefinition(payload_name, positionals) + + +class PayloadDefinition: + def __init__(self, name, positionals=_EMPTY_LIST): + self._definition_name = name + self._expected_positions = 0 + self._item_checks = [] + self._item_names = [] + + if positionals is not _EMPTY_LIST: + for positional in positionals: + self._define_positional(positional) + + @property + def _fields(self): + return self._item_names + + def __call__(self, *args): + return self._extract_and_bundle(args, extract_by='position') + + def _define_positional(self, positional): + assert len(positional) == 2 + + name, validator_fn = positional + + self._item_names.append(name) + self._item_checks.append(validator_fn) + + self._expected_positions += 1 + + def _extract_and_bundle(self, args, extract_by=None): + definition = self + + if extract_by == 'position': + actual_positions = len(args) + expected_positions = definition._expected_positions + if actual_positions < expected_positions: + raise PayloadException('Error: too few arguments given (expected %d got %d)' % ( + expected_positions, actual_positions)) + positions = list(range(actual_positions)) + dictionary = {definition._item_names[position]: args[position] for position in positions} + elif extract_by == 'name': + dictionary = {key: args.get(key, None) for key in definition._item_names} + else: + raise RuntimeError() + + return Payload(definition, dictionary) + + def ensure(self, bundle_or_args): + bundle_definition = self + + if isinstance(bundle_or_args, Payload): + assert bundle_or_args.name == bundle_definition._definition_name + return bundle_or_args + elif isinstance(bundle_or_args, dict): + bundle = self._extract_and_bundle(bundle_or_args, extract_by='name') + else: + bundle = bundle_definition(*bundle_or_args) + + return _validate_bundle(self, bundle) + + def ensure_bundle(self, bundle_or_args): + return self.ensure(bundle_or_args) + + def to_checks(self): + type_checks = {} + for key in self._fields: + type_checks[key] = _MissingField.assert_not_instance + + value_checks = dict(zip(self._item_names, self._item_checks)) + + return type_checks, value_checks + + +def _extract_field_error(bad_value): + try: + message = bad_value[0][1] + if not message: + raise IndexError + return message + except IndexError: + return 'required' + + +def _prepare_validate_helper_input(definition, payload): + def _covert_field_value(payload, field): + value = payload.get(field, _REQUIRED_FIELD) + if value is _REQUIRED_FIELD: + return _MissingField(field, 'required') + if value is None: + return _MissingField(field, 'missing') + return value + return {field: _covert_field_value(payload, field) + for field in definition._fields} + + +def _validate_bundle(definition, payload): + assert isinstance(payload, Payload) + + input_dict = _prepare_validate_helper_input(definition, payload) + type_checks, value_checks = definition.to_checks() + _, bad_values = validate_helper(input_dict, definition._fields, + type_checks, value_checks, list_wrap=True) + + if bad_values: + errors_by_field = {field_name: _extract_field_error(bad_value) + for field_name, bad_value in bad_values.items()} + raise PayloadReport(errors_by_field) + + return payload + + +PAYLOAD_POST_USER = Payload.define('PostUserArgs', [ + 'full_name', + 'organization', + 'state', + 'country', + 'email', + 'comment', + 'password', +], defaultdict(lambda: _is_not_none, dict( + full_name=_is_string_and_non_empty, + organization=_is_string_and_non_empty, + state=_is_string_and_non_empty, + country=_is_string_and_non_empty, + email=_is_string_and_non_empty, + comment=_is_string_and_non_empty, + password=_is_string_and_non_empty, +))) diff --git a/mig/lib/coresvc/server.py b/mig/lib/coresvc/server.py new file mode 100755 index 000000000..f96f4fa28 --- /dev/null +++ b/mig/lib/coresvc/server.py @@ -0,0 +1,251 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- +# +# --- BEGIN_HEADER --- +# +# mig/services/coreapi/server - coreapi service server internals +# Copyright (C) 2003-2025 The MiG Project by the Science HPC Center at UCPH +# +# This file is part of MiG. +# +# MiG is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation; either version 2 of the License, or +# (at your option) any later version. +# +# MiG is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program; if not, write to the Free Software +# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. +# +# -- END_HEADER --- +# + + +"""HTTP server parts of the coreapi service.""" + +from __future__ import print_function +from __future__ import absolute_import + +from http.server import HTTPServer, BaseHTTPRequestHandler +from socketserver import ThreadingMixIn + +import base64 +from collections import defaultdict, namedtuple +from flask import Flask, request, Response +import json +import os +import sys +import threading +import time +import werkzeug.exceptions as httpexceptions +from wsgiref.simple_server import WSGIRequestHandler + +from mig.lib.coresvc.payloads import PayloadException, \ + PAYLOAD_POST_USER as _REQUEST_ARGS_POST_USER +from mig.shared.base import canonical_user, keyword_auto, force_native_str_rec +from mig.shared.useradm import fill_user, \ + create_user as useradm_create_user, search_users as useradm_search_users +from mig.shared.userdb import default_db_path + + +httpexceptions_by_code = { + exc.code: exc for exc in httpexceptions.__dict__.values() if hasattr(exc, 'code')} + + +def http_error_from_status_code(http_status_code, http_url, description=None): + return httpexceptions_by_code[http_status_code](description) + + +def json_reponse_from_status_code(http_status_code, content): + json_content = json.dumps(content) + return Response(json_content, http_status_code, { 'Content-Type': 'application/json' }) + + +def _create_user(configuration, payload): + user_dict = canonical_user( + configuration, payload, _REQUEST_ARGS_POST_USER._fields) + fill_user(user_dict) + force_native_str_rec(user_dict) + + try: + useradm_create_user(user_dict, configuration, keyword_auto, default_renew=True) + except: + raise http_error_from_status_code(500, None) + user_email = user_dict['email'] + objects = search_users(configuration, { + 'email': user_email + }) + if len(objects) != 1: + raise http_error_from_status_code(400, None) + return objects[0] + + +def search_users(configuration, search_filter): + _, hits = useradm_search_users(search_filter, configuration, keyword_auto) + return list((obj for _, obj in hits)) + + +def _create_and_expose_server(server, configuration): + app = Flask('coreapi') + + @app.get('/user') + def GET_user(): + raise http_error_from_status_code(400, None) + + @app.get('/user/') + def GET_user_username(username): + return 'FOOBAR' + + @app.get('/user/find') + def GET_user_find(): + query_params = request.args + + objects = search_users(configuration, { + 'email': query_params['email'] + }) + + if len(objects) != 1: + raise http_error_from_status_code(404, None) + + return dict(objects=objects) + + @app.post('/user') + def POST_user(): + payload = request.get_json() + + try: + payload = _REQUEST_ARGS_POST_USER.ensure(payload) + except PayloadException as vr: + return http_error_from_status_code(400, None, vr.serialize()) + + user = _create_user(configuration, payload) + return json_reponse_from_status_code(201, user) + + return app + + +class ApiHttpServer(HTTPServer): + """ + http(s) server that contains a reference to an OpenID Server and + knows its base URL. + Extended to fork on requests to avoid one slow or broken login stalling + the rest. + """ + + def __init__(self, configuration, logger=None, host=None, port=None, **kwargs): + self.configuration = configuration + self.logger = logger if logger else configuration.logger + self.server_app = None + self._on_start = kwargs.pop('on_start', lambda _: None) + + addr = (host, port) + HTTPServer.__init__(self, addr, ApiHttpRequestHandler, **kwargs) + + @property + def base_environ(self): + return {} + + def get_app(self): + return self.server_app + + def server_activate(self): + HTTPServer.server_activate(self) + self._on_start(self) + + +class ThreadedApiHttpServer(ThreadingMixIn, ApiHttpServer): + """Multi-threaded version of the ApiHttpServer""" + + @property + def base_url(self): + proto = 'http' + return '%s://%s:%d/' % (proto, self.server_name, self.server_port) + + +class ApiHttpRequestHandler(WSGIRequestHandler): + """TODO: docstring""" + + def __init__(self, socket, addr, server, **kwargs): + self.server = server + + # NOTE: drop idle clients after N seconds to clean stale connections. + # Does NOT include clients that connect and do nothing at all :-( + self.timeout = 120 + + self._http_url = None + self.parsed_uri = None + self.path_parts = None + self.retry_url = '' + + WSGIRequestHandler.__init__(self, socket, addr, server, **kwargs) + + @property + def configuration(self): + return self.server.configuration + + @property + def daemon_conf(self): + return self.server.configuration.daemon_conf + + @property + def logger(self): + return self.server.logger + + +def start_service(configuration, host=None, port=None): + assert host is not None, "required kwarg: host" + assert port is not None, "required kwarg: port" + + logger = configuration.logger + + def _on_start(server, *args, **kwargs): + server.server_app = _create_and_expose_server( + None, server.configuration) + + httpserver = ThreadedApiHttpServer( + configuration, host=host, port=port, on_start=_on_start) + + serve_msg = 'Server running at: %s' % httpserver.base_url + logger.info(serve_msg) + print(serve_msg) + while True: + logger.debug('handle next request') + httpserver.handle_request() + logger.debug('done handling request') + httpserver.expire_volatile() + + +def main(configuration=None): + if not configuration: + from mig.shared.conf import get_configuration_object + # Force no log init since we use separate logger + configuration = get_configuration_object(skip_log=True) + + logger = configuration.logger + + # Allow e.g. logrotate to force log re-open after rotates + #register_hangup_handler(configuration) + + # FIXME: + host = 'localhost' # configuration.user_openid_address + port = 5555 # configuration.user_openid_port + server_address = (host, port) + + info_msg = "Starting coreapi..." + logger.info(info_msg) + print(info_msg) + + try: + start_service(configuration, host=host, port=port) + except KeyboardInterrupt: + info_msg = "Received user interrupt" + logger.info(info_msg) + print(info_msg) + info_msg = "Leaving with no more workers active" + logger.info(info_msg) + print(info_msg) diff --git a/mig/shared/useradm.py b/mig/shared/useradm.py index 2812dc5f5..a144f8c0a 100644 --- a/mig/shared/useradm.py +++ b/mig/shared/useradm.py @@ -1027,7 +1027,9 @@ def create_user(user, conf_path, db_path, force=False, verbose=False, format as a first step. """ - if conf_path: + if isinstance(conf_path, Configuration): + configuration = conf_path + elif conf_path: if isinstance(conf_path, basestring): # has been checked for accessibility above... @@ -2318,7 +2320,9 @@ def search_users(search_filter, conf_path, db_path, fnmatch for. """ - if conf_path: + if isinstance(conf_path, Configuration): + configuration = conf_path + elif conf_path: if isinstance(conf_path, basestring): configuration = Configuration(conf_path) else: diff --git a/requirements.txt b/requirements.txt index 5c2b1bc8f..8c398a2ec 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ # migrid core dependencies on a format suitable for pip install as described on # https://pip.pypa.io/en/stable/reference/requirement-specifiers/ +flask future # cgi was removed from the standard library in Python 3.13 diff --git a/tests/support/httpsupp.py b/tests/support/httpsupp.py new file mode 100644 index 000000000..4a115c489 --- /dev/null +++ b/tests/support/httpsupp.py @@ -0,0 +1,98 @@ +import codecs +import json + +from tests.support._env import PY2 + +if PY2: + from urllib2 import HTTPError, Request, urlopen + from urllib import urlencode +else: + from urllib.error import HTTPError + from urllib.parse import urlencode + from urllib.request import urlopen, Request + + +def attempt_to_decode_response_data(data, response_encoding=None): + if data is None: + return None + elif response_encoding == 'textual': + data = codecs.decode(data, 'utf8') + + try: + return json.loads(data) + except Exception as e: + return data + elif response_encoding == 'binary': + return data + else: + raise AssertionError( + 'issue_POST: unknown response_encoding "%s"' % (response_encoding,)) + + +class HttpAssertMixin: + + def _issue_GET(self, server_address, request_path, query_dict=None, response_encoding='textual'): + assert isinstance(server_address, tuple) and len( + server_address) == 2, "require server address tuple" + assert isinstance(request_path, str) and request_path.startswith( + '/'), "require http path starting with /" + request_url = ''.join( + ('http://', server_address[0], ':', str(server_address[1]), request_path)) + + if query_dict is not None: + query_string = urlencode(query_dict) + request_url = ''.join((request_url, '?', query_string)) + + status = 0 + data = None + + try: + response = urlopen(request_url, None, timeout=2000) + + status = response.getcode() + data = response.read() + except HTTPError as httpexc: + status = httpexc.code + data = None + + content = attempt_to_decode_response_data(data, response_encoding) + return (status, content) + + def _issue_POST(self, server_address, request_path, request_data=None, request_json=None, response_encoding='textual'): + assert isinstance(server_address, tuple) and len( + server_address) == 2, "require server address tuple" + assert isinstance(request_path, str) and request_path.startswith( + '/'), "require http path starting with /" + request_url = ''.join( + ('http://', server_address[0], ':', str(server_address[1]), request_path)) + + if request_data and request_json: + raise ValueError( + "only one of data or json request data may be specified") + + status = 0 + data = None + + try: + if request_json is not None: + request_data = codecs.encode(json.dumps(request_json), 'utf8') + request_headers = { + 'Content-Type': 'application/json' + } + request = Request(request_url, request_data, + headers=request_headers) + elif request_data is not None: + request = Request(request_url, request_data) + else: + request = Request(request_url) + + response = urlopen(request, timeout=2000) + + status = response.getcode() + data = response.read() + except HTTPError as httpexc: + status = httpexc.code + data = httpexc.file.read() + + content = attempt_to_decode_response_data(data, response_encoding) + return (status, content) diff --git a/tests/support/serversupp.py b/tests/support/serversupp.py index 0e0fd4b94..f7d78bffd 100644 --- a/tests/support/serversupp.py +++ b/tests/support/serversupp.py @@ -41,11 +41,16 @@ class ServerWithinThreadExecutor: def __init__(self, ServerClass, *args, **kwargs): self._serverclass = ServerClass + self._serverclass_on_instance = kwargs.pop('on_instance', None) self._arguments = (args, kwargs) self._started = ThreadEvent() self._thread = None self._wrapped = None + def __getattr__(self, attr): + assert self._wrapped, "wrapped instance was not created" + return getattr(self._wrapped, attr) + def run(self): """Mimic the same method from the standard thread API""" server_args, server_kwargs = self._arguments @@ -53,6 +58,8 @@ def run(self): server_kwargs['on_start'] = lambda _: self._started.set() self._wrapped = self._serverclass(*server_args, **server_kwargs) + if self._serverclass_on_instance: + self._serverclass_on_instance(self._wrapped) try: self._wrapped.serve_forever() @@ -73,14 +80,16 @@ def start_wait_until_ready(self): def stop(self): """Mimic the same method from the standard thread API""" self.stop_server() - self._wrapped = None - self._thread.join() - self._thread = None + if self._thread: + self._thread.join() + self._thread = None def stop_server(self): """Stop server thread""" - self._wrapped.shutdown() - self._wrapped.server_close() + if self._wrapped: + self._wrapped.shutdown() + self._wrapped.server_close() + self._wrapped = None def make_wrapped_server(ServerClass, *args, **kwargs): diff --git a/tests/test_mig_lib_coreapi.py b/tests/test_mig_lib_coreapi.py new file mode 100644 index 000000000..06a2e0a8c --- /dev/null +++ b/tests/test_mig_lib_coreapi.py @@ -0,0 +1,104 @@ +import codecs +import json +from http.server import HTTPServer, BaseHTTPRequestHandler + +from tests.support import MigTestCase, testmain +from tests.support.serversupp import make_wrapped_server + +from mig.lib.coreapi import CoreApiClient + + +class TestRequestHandler(BaseHTTPRequestHandler): + def do_POST(self): + test_server = self.server + + if test_server._programmed_response: + status, content = test_server._programmed_response + elif test_server._programmed_error: + status, content = test_server._programmed_error + + self.send_response(status) + self.end_headers() + self.wfile.write(content) + + +class TestHTTPServer(HTTPServer): + def __init__(self, addr, **kwargs): + self._programmed_error = None + self._programmed_response = None + self._on_start = kwargs.pop('on_start', lambda _: None) + + HTTPServer.__init__(self, addr, TestRequestHandler, **kwargs) + + def clear_programmed(self): + self._programmed_error = None + + def set_programmed_error(self, status, content): + assert self._programmed_response is None + assert isinstance(content, bytes) + self._programmed_error = (status, content) + + def set_programmed_response(self, status, content): + assert self._programmed_error is None + assert isinstance(content, bytes) + self._programmed_response = (status, content) + + def set_programmed_json_response(self, status, content): + self.set_programmed_response(status, codecs.encode(json.dumps(content), 'utf8')) + + def server_activate(self): + HTTPServer.server_activate(self) + self._on_start(self) + + +class TestMigLibCoreapi(MigTestCase): + def before_each(self): + self.server_addr = ('localhost', 4567) + self.server = make_wrapped_server(TestHTTPServer, self.server_addr) + + def after_each(self): + server = getattr(self, 'server', None) + setattr(self, 'server', None) + if server: + server.stop() + + def test_raises_in_the_absence_of_success(self): + self.server.start_wait_until_ready() + self.server.set_programmed_error(418, b'tea; earl grey; hot') + instance = CoreApiClient("http://%s:%s/" % self.server_addr) + + with self.assertRaises(Exception): + instance.createUser({ + 'full_name': "Test User", + 'organization': "Test Org", + 'state': "NA", + 'country': "DK", + 'email': "user@example.com", + 'comment': "This is the create comment", + 'password': "password", + }) + + def test_returs_a_user_object(self): + test_content = { + 'foo': 1, + 'bar': True + } + self.server.start_wait_until_ready() + self.server.set_programmed_json_response(201, test_content) + instance = CoreApiClient("http://%s:%s/" % self.server_addr) + + content = instance.createUser({ + 'full_name': "Test User", + 'organization': "Test Org", + 'state': "NA", + 'country': "DK", + 'email': "user@example.com", + 'comment': "This is the create comment", + 'password': "password", + }) + + self.assertIsInstance(content, dict) + self.assertEqual(content, test_content) + +if __name__ == '__main__': + testmain() diff --git a/tests/test_mig_lib_coresvc.py b/tests/test_mig_lib_coresvc.py new file mode 100644 index 000000000..75ca71fa2 --- /dev/null +++ b/tests/test_mig_lib_coresvc.py @@ -0,0 +1,245 @@ +from __future__ import print_function +import codecs +import errno +import json +import os +import shutil +import sys +import unittest +from threading import Thread +from unittest import skip + +from tests.support import PY2, MigTestCase, testmain, temppath, \ + make_wrapped_server +from tests.support.httpsupp import HttpAssertMixin + +from mig.shared.base import keyword_auto +from mig.shared.useradm import create_user +from mig.lib.coresvc import ThreadedApiHttpServer, \ + _create_and_expose_server + +_TAG_P_OPEN = '

' +_TAG_P_CLOSE = '

' +_USERADM_PATH_KEYS = ('user_cache', 'user_db_home', 'user_home', + 'user_settings', 'mrsl_files_dir', 'resource_pending') + + +def _extend_configuration(*args): + pass + + +def ensure_dirs_needed_by_create_user(configuration): + for config_key in _USERADM_PATH_KEYS: + dir_path = getattr(configuration, config_key)[0:-1] + try: + os.mkdir(dir_path) + except OSError as exc: + pass + + +def extract_error_description_from_html(content): + open_tag_index = content.find(_TAG_P_OPEN) + start_index = open_tag_index + len(_TAG_P_OPEN) + end_index = content.find(_TAG_P_CLOSE) + error_desription = content[start_index:end_index] + return error_desription + + +class MigServerGrid_openid(MigTestCase, HttpAssertMixin): + def before_each(self): + self.server_addr = None + self.server_thread = None + + ensure_dirs_needed_by_create_user(self.configuration) + + self.server_addr = ('localhost', 4567) + self.server_thread = self._make_server( + self.configuration, self.logger, self.server_addr) + + def _provide_configuration(self): + return 'testconfig' + + def after_each(self): + if self.server_thread: + self.server_thread.stop() + + def issue_GET(self, request_path): + return self._issue_GET(self.server_addr, request_path) + + def issue_POST(self, request_path, **kwargs): + return self._issue_POST(self.server_addr, request_path, **kwargs) + + @unittest.skipIf(PY2, "Python 3 only") + def test__GET_returns_not_found_for_missing_path(self): + self.server_thread.start_wait_until_ready() + + status, _ = self.issue_GET('/nonexistent') + + self.assertEqual(status, 404) + + @unittest.skipIf(PY2, "Python 3 only") + def test_GET_user__top_level_request(self): + self.server_thread.start_wait_until_ready() + + status, _ = self.issue_GET('/user') + + self.assertEqual(status, 400) + + @unittest.skipIf(PY2, "Python 3 only") + def test_GET__user_userid_request_succeeds_with_status_ok(self): + example_username = 'dummy-user' + example_username_home_dir = temppath( + 'state/user_home/%s' % example_username, self, ensure_dir=True) + test_user_home = os.path.dirname( + example_username_home_dir) # strip user from path + test_state_dir = os.path.dirname(test_user_home) + test_user_db_home = os.path.join(test_state_dir, "user_db_home") + self.server_thread.start_wait_until_ready() + + status, content = self.issue_GET('/user/dummy-user') + + self.assertEqual(status, 200) + self.assertEqual(content, 'FOOBAR') + + @unittest.skipIf(PY2, "Python 3 only") + def test_GET_openid_user_username(self): + self.server_thread.start_wait_until_ready() + + status, content = self.issue_GET('/user/dummy-user') + + self.assertEqual(status, 200) + self.assertEqual(content, 'FOOBAR') + + @unittest.skipIf(PY2, "Python 3 only") + def test_POST_user__bad_input_data(self): + self.server_thread.start_wait_until_ready() + + status, content = self.issue_POST('/user', request_json={ + 'greeting': 'provocation' + }) + + self.assertEqual(status, 400) + error_description = extract_error_description_from_html(content) + error_description_lines = error_description.split('
') + self.assertEqual( + error_description_lines[0], 'payload failed to validate:') + + @unittest.skipIf(PY2, "Python 3 only") + def test_POST_user(self): + self.server_thread.start_wait_until_ready() + + status, content = self.issue_POST('/user', response_encoding='textual', request_json=dict( + full_name="Test User", + organization="Test Org", + state="NA", + country="DK", + email="user@example.com", + comment="This is the create comment", + password="password", + )) + + self.assertEqual(status, 201) + self.assertIsInstance(content, dict) + self.assertIn('unique_id', content) + + def _make_configuration(self, test_logger, server_addr): + configuration = self.configuration + _extend_configuration( + configuration, + server_addr[0], + server_addr[1], + logger=test_logger, + expandusername=False, + host_rsa_key='', + nossl=True, + show_address=False, + show_port=False, + ) + return configuration + + @staticmethod + def _make_server(configuration, logger=None, server_address=None): + def _on_instance(server): + server.server_app = _create_and_expose_server( + server, server.configuration) + + (host, port) = server_address + server_thread = make_wrapped_server(ThreadedApiHttpServer, + configuration, logger, host, port, on_instance=_on_instance) + return server_thread + + +class MigServerGrid_openid__existing_user(MigTestCase, HttpAssertMixin): + def before_each(self): + ensure_dirs_needed_by_create_user(self.configuration) + + user_dict = { + 'full_name': "Test User", + 'organization': "Test Org", + 'state': "NA", + 'country': "DK", + 'email': "user@example.com", + 'comment': "This is the create comment", + 'password': "password", + } + create_user(user_dict, self.configuration, + keyword_auto, default_renew=True) + + self.server_addr = ('localhost', 4567) + self.server_thread = self._make_server( + self.configuration, self.logger, self.server_addr) + + def _provide_configuration(self): + return 'testconfig' + + def after_each(self): + if self.server_thread: + self.server_thread.stop() + + @unittest.skipIf(PY2, "Python 3 only") + def test_GET_openid_user_find(self): + self.server_thread.start_wait_until_ready() + + status, content = self._issue_GET(self.server_addr, '/user/find', { + 'email': 'user@example.com' + }) + + self.assertEqual(status, 200) + + self.assertIsInstance(content, dict) + self.assertIn('objects', content) + self.assertIsInstance(content['objects'], list) + + user = content['objects'][0] + # check we received the correct user + self.assertEqual(user['full_name'], 'Test User') + + def _make_configuration(self, test_logger, server_addr): + configuration = self.configuration + _extend_configuration( + configuration, + server_addr[0], + server_addr[1], + logger=test_logger, + expandusername=False, + host_rsa_key='', + nossl=True, + show_address=False, + show_port=False, + ) + return configuration + + @staticmethod + def _make_server(configuration, logger=None, server_address=None): + def _on_instance(server): + server.server_app = _create_and_expose_server( + server, server.configuration) + + (host, port) = server_address + server_thread = make_wrapped_server(ThreadedApiHttpServer, + configuration, logger, host, port, on_instance=_on_instance) + return server_thread + + +if __name__ == '__main__': + testmain() diff --git a/tests/test_mig_lib_coresvc_payloads.py b/tests/test_mig_lib_coresvc_payloads.py new file mode 100644 index 000000000..dae5616d6 --- /dev/null +++ b/tests/test_mig_lib_coresvc_payloads.py @@ -0,0 +1,79 @@ +from __future__ import print_function +import sys + +from tests.support import MigTestCase, testmain + +from mig.lib.coresvc.payloads import \ + Payload as ArgumentBundle, \ + PayloadDefinition as ArgumentBundleDefinition, \ + PayloadException + + +def _contains_a_thing(value): + assert 'thing' in value + + +def _upper_case_only(value): + """value must be upper case""" + assert value == value.upper(), _upper_case_only.__doc__ + + +class TestMigSharedArguments__bundles(MigTestCase): + ThingsBundle = ArgumentBundleDefinition('Things', [ + ('some_field', _contains_a_thing), + ('other_field', _contains_a_thing), + ]) + + def assertBundleOfKind(self, value, bundle_kind=None): + assert isinstance(bundle_kind, str) and bundle_kind + self.assertIsInstance(value, ArgumentBundle, "value is not an argument bundle") + self.assertEqual(value.name, bundle_kind, "expected %s bundle, got %s" % (bundle_kind, value.name)) + + def test_bundling_arguments_produces_a_bundle(self): + bundle = self.ThingsBundle('abcthing', 'thingdef') + + self.assertBundleOfKind(bundle, bundle_kind='Things') + + def test_raises_on_missing_positional_arguments(self): + with self.assertRaises(PayloadException) as raised: + self.ThingsBundle(['a']) + self.assertEqual(str(raised.exception), 'Error: too few arguments given (expected 2 got 1)') + + def test_ensuring_arguments_returns_a_bundle(self): + bundle = self.ThingsBundle.ensure_bundle(['abcthing', 'thingdef']) + + self.assertBundleOfKind(bundle, bundle_kind='Things') + + def test_ensuring_an_existing_bundle_returns_it_unchanged(self): + existing_bundle = self.ThingsBundle('abcthing', 'thingdef') + + bundle = self.ThingsBundle.ensure_bundle(existing_bundle) + + self.assertIs(bundle, existing_bundle) + + def test_ensuring_on_a_list_of_args_validates_them(self): + with self.assertRaises(Exception) as raised: + bundle = self.ThingsBundle.ensure_bundle(['abcthing', 'def']) + self.assertEqual(str(raised.exception), 'payload failed to validate:\n- other_field: required') + + def test_ensuring_on_invalid_args_produces_reports_with_errors(self): + UpperCaseValue = ArgumentBundle.define('UpperCaseValue', ['ustring'], { + 'ustring': _upper_case_only + }) + + with self.assertRaises(Exception) as raised: + bundle = UpperCaseValue.ensure_bundle(['lowerCHARS']) + self.assertEqual(str(raised.exception), 'payload failed to validate:\n- ustring: value must be upper case') + + def test_ensuring_on_invalid_args_containing_none_behaves_correctly(self): + UpperCaseValue = ArgumentBundle.define('UpperCaseValue', ['ustring'], { + 'ustring': _upper_case_only + }) + + with self.assertRaises(Exception) as raised: + bundle = UpperCaseValue.ensure_bundle([None]) + self.assertEqual(str(raised.exception), 'payload failed to validate:\n- ustring: missing') + + +if __name__ == '__main__': + testmain()