diff --git a/pyproject.toml b/pyproject.toml index c910aea..5845760 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,3 +52,4 @@ ingest-flow = "datastation.ingest_flow:main" dv-dataverse-root-collect-storage-usage = "datastation.dv_dataverse_root_collect_storage_usage:main" dv-dataverse-root-collect-permission-overview = "datastation.dv_dataverse_root_collect_permission_overview:main" datastation-get-component-versions = "datastation.datastation_get_component_versions:main" +dv-dataverse-role-assignment = "datastation.dv_dataverse_role_assignment:main" diff --git a/src/datastation/common/common_batch_processing.py b/src/datastation/common/common_batch_processing.py new file mode 100644 index 0000000..7954c3f --- /dev/null +++ b/src/datastation/common/common_batch_processing.py @@ -0,0 +1,139 @@ +import logging +import os +import time + +from datastation.common.csv import CsvReport +from datastation.common.utils import plural + + +# Base class for batch processing of items +class CommonBatchProcessor: + def __init__(self, item_name="item", wait=0.1, fail_on_first_error=True): + self.item_name = item_name + self.wait = wait + self.fail_on_first_error = fail_on_first_error + + def process_items(self, items, callback): + if type(items) is list: + num_items = len(items) + logging.info(f"Start batch processing on {num_items} {plural(self.item_name)}") + else: + logging.info(f"Start batch processing on unknown number of {plural(self.item_name)}") + num_items = -1 + i = 0 + for item in items: + i += 1 + try: + if self.wait > 0 and i > 1: + logging.debug(f"Waiting {self.wait} seconds before processing next {self.item_name}") + time.sleep(self.wait) + logging.info(f"Processing {i} of {num_items}: {item}") + callback(item) + except Exception as e: + logging.exception("Exception occurred", exc_info=True) + if self.fail_on_first_error: + logging.error(f"Stop processing because of an exception: {e}") + break + logging.debug("fail_on_first_error is False, continuing...") + + +def get_provided_items_iterator(item_or_items_file, item_name="item"): + if item_or_items_file is None: + logging.debug(f"No {plural(item_name)} provided.") + return None + elif os.path.isfile(os.path.expanduser(item_or_items_file)): + items = [] + with open(os.path.expanduser(item_or_items_file)) as f: + for line in f: + items.append(line.strip()) + return items + else: + return [item_or_items_file] + + +def get_pids(pid_or_pids_file, search_api=None, query="*", subtree="root", object_type="dataset", dry_run=False): + """ + + Args: + pid_or_pids_file: The dataset pid, or a file with a list of pids. + search_api: must be provided if pid_or_pids_file is None + query: passed on to search_api().search + object_type: passed on to search_api().search + subtree (object): passed on to search_api().search + dry_run: Do not perform the action, but show what would be done. + Only applicable if pid_or_pids_file is None. + + Returns: an iterator with pids, + if pid_or_pids_file is not provided, it searches for all datasets + and extracts their pids, fetching the result pages lazy. + """ + if pid_or_pids_file is None: + result = search_api.search(query=query, subtree=subtree, object_type=object_type, dry_run=dry_run) + return map(lambda rec: rec['global_id'], result) + else: + return get_provided_items_iterator(pid_or_pids_file, "pid") + + +def get_aliases(alias_or_aliases_file, dry_run=False): + """ + + Args: + alias_or_aliases_file: The dataverse alias, or a file with a list of aliases. + dry_run: Do not perform the action, but show what would be done. + Only applicable if pid_or_pids_file is None. + + Returns: an iterator with aliases + """ + if alias_or_aliases_file is None: + # The tree of all (published) dataverses could be retrieved and aliases could recursively be extracted + # from the tree, but this is not implemented yet. + logging.warning(f"No aliases provided, nothing to do.") + return None + else: + return get_provided_items_iterator(alias_or_aliases_file, "alias") + + +class DatasetBatchProcessor(CommonBatchProcessor): + + def __init__(self, wait=0.1, fail_on_first_error=True): + super().__init__("pid", wait, fail_on_first_error) + + def process_pids(self, pids, callback): + super().process_items(pids, callback) + + +class DatasetBatchProcessorWithReport(DatasetBatchProcessor): + + def __init__(self, report_file=None, headers=None, wait=0.1, fail_on_first_error=True): + super().__init__(wait, fail_on_first_error) + if headers is None: + headers = ["DOI", "Modified", "Change"] + self.report_file = report_file + self.headers = headers + + def process_pids(self, pids, callback): + with CsvReport(os.path.expanduser(self.report_file), self.headers) as csv_report: + super().process_pids(pids, lambda pid: callback(pid, csv_report)) + + +class DataverseBatchProcessor(CommonBatchProcessor): + + def __init__(self, wait=0.1, fail_on_first_error=True): + super().__init__("alias", wait, fail_on_first_error) + + def process_aliases(self, aliases, callback): + super().process_items(aliases, callback) + + +class DataverseBatchProcessorWithReport(DataverseBatchProcessor): + + def __init__(self, report_file=None, headers=None, wait=0.1, fail_on_first_error=True): + super().__init__(wait, fail_on_first_error) + if headers is None: + headers = ["alias", "Modified", "Change"] + self.report_file = report_file + self.headers = headers + + def process_aliases(self, aliases, callback): + with CsvReport(os.path.expanduser(self.report_file), self.headers) as csv_report: + super().process_aliases(aliases, lambda alias: callback(alias, csv_report)) diff --git a/src/datastation/common/utils.py b/src/datastation/common/utils.py index b818e0f..d3c0b8c 100644 --- a/src/datastation/common/utils.py +++ b/src/datastation/common/utils.py @@ -110,3 +110,11 @@ def sizeof_fmt(num, suffix='B'): return "%3.1f%s%s" % (num, unit, suffix) num /= 1024.0 return "%.1f%s%s" % (num, 'Yi', suffix) + +def plural(word: str): + if word.endswith('s'): + return word + "es" + elif word.endswith('y'): + return word[:-1] + "ies" + else: + return word + "s" diff --git a/src/datastation/dataverse/dataverse_api.py b/src/datastation/dataverse/dataverse_api.py index 023873c..c5cf940 100644 --- a/src/datastation/dataverse/dataverse_api.py +++ b/src/datastation/dataverse/dataverse_api.py @@ -1,18 +1,23 @@ import requests +import json from datastation.common.utils import print_dry_run_message class DataverseApi: - def __init__(self, server_url, api_token): + def __init__(self, server_url, api_token, alias): self.server_url = server_url self.api_token = api_token + self.alias = alias # Methods should use this one if specified + + def get_alias(self): + return self.alias # get json data for a specific dataverses API endpoint using an API token - def get_resource_data(self, resource, alias="root", dry_run=False): + def get_resource_data(self, resource, dry_run=False): headers = {"X-Dataverse-key": self.api_token} - url = f"{self.server_url}/api/dataverses/{alias}/{resource}" + url = f"{self.server_url}/api/dataverses/{self.alias}/{resource}" if dry_run: print_dry_run_message(method="GET", url=url, headers=headers) @@ -24,21 +29,21 @@ def get_resource_data(self, resource, alias="root", dry_run=False): resp_data = dv_resp.json()["data"] return resp_data - def get_contents(self, alias="root", dry_run=False): - return self.get_resource_data("contents", alias, dry_run) + def get_contents(self, dry_run=False): + return self.get_resource_data("contents", dry_run) - def get_roles(self, alias="root", dry_run=False): - return self.get_resource_data("roles", alias, dry_run) + def get_roles(self, dry_run=False): + return self.get_resource_data("roles", dry_run) - def get_assignments(self, alias="root", dry_run=False): - return self.get_resource_data("assignments", alias, dry_run) + def get_role_assignments(self, dry_run=False): + return self.get_resource_data("assignments", dry_run) - def get_groups(self, alias="root", dry_run=False): - return self.get_resource_data("groups", alias, dry_run) + def get_groups(self, dry_run=False): + return self.get_resource_data("groups", dry_run) - def get_storage_size(self, alias="root", dry_run=False): + def get_storage_size(self, dry_run=False): """ Get dataverse storage size (bytes). """ - url = f'{self.server_url}/api/dataverses/{alias}/storagesize' + url = f'{self.server_url}/api/dataverses/{self.alias}/storagesize' headers = {'X-Dataverse-key': self.api_token} if dry_run: print_dry_run_message(method='GET', url=url, headers=headers) @@ -47,3 +52,27 @@ def get_storage_size(self, alias="root", dry_run=False): r = requests.get(url, headers=headers) r.raise_for_status() return r.json()['data']['message'] + + def add_role_assignment(self, assignee, role, dry_run=False): + url = f'{self.server_url}/api/dataverses/{self.alias}/assignments' + headers = {'X-Dataverse-key': self.api_token, 'Content-type': 'application/json'} + role_assignment = {"assignee": assignee, "role": role} + if dry_run: + print_dry_run_message(method='POST', url=url, headers=headers, + data=json.dumps(role_assignment)) + return None + else: + r = requests.post(url, headers=headers, json=role_assignment) + r.raise_for_status() + return r + + def remove_role_assignment(self, assignment_id, dry_run=False): + url = f'{self.server_url}/api/dataverses/{self.alias}/assignments/{assignment_id}' + headers = {'X-Dataverse-key': self.api_token, 'Content-type': 'application/json'} + if dry_run: + print_dry_run_message(method='DELETE', url=url, headers=headers) + return None + else: + r = requests.delete(url, headers=headers) + r.raise_for_status() + return r diff --git a/src/datastation/dataverse/dataverse_client.py b/src/datastation/dataverse/dataverse_client.py index f0355aa..39b5e67 100644 --- a/src/datastation/dataverse/dataverse_client.py +++ b/src/datastation/dataverse/dataverse_client.py @@ -27,8 +27,8 @@ def search_api(self): def dataset(self, pid): return DatasetApi(pid, self.server_url, self.api_token, self.unblock_key, self.safety_latch) - def dataverse(self): - return DataverseApi(self.server_url, self.api_token) + def dataverse(self, alias=None): + return DataverseApi(self.server_url, self.api_token, alias) def file(self, file_id): return FileApi(file_id, self.server_url, self.api_token, self.unblock_key, self.safety_latch) diff --git a/src/datastation/dataverse/metrics_collect.py b/src/datastation/dataverse/metrics_collect.py index 1fdb31b..30a9f28 100644 --- a/src/datastation/dataverse/metrics_collect.py +++ b/src/datastation/dataverse/metrics_collect.py @@ -46,7 +46,7 @@ def write_result_row(self, row): def get_result_row(self, parent_alias, child_alias, child_name, depth): logging.info(f'Retrieving size for dataverse: {parent_alias} / {child_alias} ...') - msg = self.dataverse_client.dataverse().get_storage_size(child_alias) + msg = self.dataverse_client.dataverse(child_alias).get_storage_size() storage_size = extract_size_str(msg) logging.info(f'size: {storage_size}') row = {'depth': depth, 'parentalias': parent_alias, 'alias': child_alias, 'name': child_name, diff --git a/src/datastation/dataverse/permissions_collect.py b/src/datastation/dataverse/permissions_collect.py index ed2e578..a1e7d73 100644 --- a/src/datastation/dataverse/permissions_collect.py +++ b/src/datastation/dataverse/permissions_collect.py @@ -43,7 +43,7 @@ def get_result_row(self, parent_alias, child_alias, child_name, id, vpath, depth return row def get_group_info(self, alias): - resp_data = self.dataverse_client.dataverse().get_groups(alias) + resp_data = self.dataverse_client.dataverse(alias).get_groups() # flatten and compact it... no list comprehension though result_list = [] for group in resp_data: @@ -52,7 +52,7 @@ def get_group_info(self, alias): return ', '.join(result_list) def get_role_info(self, alias): - resp_data = self.dataverse_client.dataverse().get_roles(alias) + resp_data = self.dataverse_client.dataverse(alias).get_roles() # flatten and compact it... no list comprehension though result_list = [] for role in resp_data: @@ -61,7 +61,7 @@ def get_role_info(self, alias): return ', '.join(result_list) def get_assignment_info(self, alias): - resp_data = self.dataverse_client.dataverse().get_assignments(alias) + resp_data = self.dataverse_client.dataverse(alias).get_role_assignments() # flatten and compact it... no list comprehension though result_list = [] for assignment in resp_data: diff --git a/src/datastation/dataverse/roles.py b/src/datastation/dataverse/roles.py new file mode 100644 index 0000000..d6c605a --- /dev/null +++ b/src/datastation/dataverse/roles.py @@ -0,0 +1,62 @@ +import rich +from datetime import datetime + +from datastation.common.common_batch_processing import DataverseBatchProcessorWithReport, get_aliases +from datastation.dataverse.dataverse_api import DataverseApi +from datastation.dataverse.dataverse_client import DataverseClient + + +class DataverseRole: + + def __init__(self, dataverse_client: DataverseClient, dry_run: bool = False): + self.dataverse_client = dataverse_client + self.dry_run = dry_run + + def list_role_assignments(self, alias): + r = self.dataverse_client.dataverse(alias).get_role_assignments() + if r is not None: + rich.print_json(data=r) + + def add_role_assignment(self, role_assignment, dataverse_api: DataverseApi, csv_report): + assignee = role_assignment.split('=')[0] + role = role_assignment.split('=')[1] + action = "None" + if self.in_current_assignments(assignee, role, dataverse_api): + print("{} is already {} for dataset {}".format(assignee, role, dataverse_api.get_alias())) + else: + print( + "Adding {} as {} for dataset {}".format(assignee, role, dataverse_api.get_alias())) + dataverse_api.add_role_assignment(assignee, role, dry_run=self.dry_run) + action = "Added" + csv_report.write( + {'alias': dataverse_api.get_alias(), 'Modified': datetime.now(), 'Assignee': assignee, 'Role': role, + 'Change': action}) + + def in_current_assignments(self, assignee, role, dataverse_api: DataverseApi): + current_assignments = dataverse_api.get_role_assignments() + found = False + for current_assignment in current_assignments: + if current_assignment.get('assignee') == assignee and current_assignment.get( + '_roleAlias') == role: + found = True + break + return found + + + def remove_role_assignment(self, role_assignment, dataverse_api: DataverseApi, csv_report): + assignee = role_assignment.split('=')[0] + role = role_assignment.split('=')[1] + action = "None" + if self.in_current_assignments(assignee, role, dataverse_api): + print("Removing {} as {} for dataverse {}".format(assignee, role, dataverse_api.get_alias())) + all_assignments = dataverse_api.get_role_assignments() + for assignment in all_assignments: + if assignment.get('assignee') == assignee and assignment.get('_roleAlias') == role: + dataverse_api.remove_role_assignment(assignment.get('id'), dry_run=self.dry_run) + action = "Removed" + break + else: + print("{} is not {} for dataverse {}".format(assignee, role, dataverse_api.get_alias())) + csv_report.write( + {'alias': dataverse_api.get_alias(), 'Modified': datetime.now(), 'Assignee': assignee, 'Role': role, + 'Change': action}) diff --git a/src/datastation/dv_dataset_destroy_migration_placeholder.py b/src/datastation/dv_dataset_destroy_migration_placeholder.py index 3a7a76e..1121e6b 100644 --- a/src/datastation/dv_dataset_destroy_migration_placeholder.py +++ b/src/datastation/dv_dataset_destroy_migration_placeholder.py @@ -28,8 +28,7 @@ def main(): batch_processor.process_pids(pids, callback=lambda pid, csv_report: destroy_placeholder_dataset(dataverse.dataset(pid), description_text_pattern, - csv_report, - dry_run=args.dry_run)) + csv_report)) if __name__ == '__main__': diff --git a/src/datastation/dv_dataverse_role_assignment.py b/src/datastation/dv_dataverse_role_assignment.py new file mode 100644 index 0000000..1a4ae96 --- /dev/null +++ b/src/datastation/dv_dataverse_role_assignment.py @@ -0,0 +1,84 @@ +import argparse + +from datastation.common.common_batch_processing import get_aliases, DataverseBatchProcessorWithReport +from datastation.common.config import init +from datastation.common.utils import add_batch_processor_args, add_dry_run_arg +from datastation.dataverse.dataverse_client import DataverseClient +from datastation.dataverse.roles import DataverseRole + + +def list_role_assignments(args, dataverse_client: DataverseClient): + role_assignment = DataverseRole(dataverse_client) + role_assignment.list_role_assignments(args.alias) + +def add_role_assignments(args, dataverse_client: DataverseClient): + role_assignment = DataverseRole(dataverse_client, args.dry_run) + aliases = get_aliases(args.alias_or_alias_file) + batch_processor = DataverseBatchProcessorWithReport(wait=args.wait, fail_on_first_error=args.fail_fast, + report_file=args.report_file, + headers=['alias', 'Modified', 'Assignee', 'Role', 'Change']) + batch_processor.process_aliases(aliases, + lambda alias, + csv_report: role_assignment.add_role_assignment(args.role_assignment, + dataverse_api= + dataverse_client.dataverse( + alias), + csv_report=csv_report)) + + +def remove_role_assignments(args, dataverse_client: DataverseClient): + role_assignment = DataverseRole(dataverse_client, args.dry_run) + aliases = get_aliases(args.alias_or_alias_file) + batch_processor = DataverseBatchProcessorWithReport(wait=args.wait, report_file=args.report_file, + headers=['alias', 'Modified', 'Assignee', 'Role', 'Change']) + batch_processor.process_aliases(aliases, + lambda alias, + csv_report: role_assignment.remove_role_assignment(args.role_assignment, + dataverse_api= + dataverse_client.dataverse( + alias), + csv_report=csv_report)) + + +def main(): + config = init() + dataverse_client = DataverseClient(config['dataverse']) + + # Create main parser and subparsers + parser = argparse.ArgumentParser(description='Manage role assignments on one or more datasets.') + subparsers = parser.add_subparsers(help='subcommands', dest='subcommand') + + # Add role assignment + parser_add = subparsers.add_parser('add', help='add role assignment to specified dataset(s)') + parser_add.add_argument('role_assignment', + help='role assignee and alias (example: @dataverseAdmin=contributor) to add') + parser_add.add_argument('alias_or_alias_file', + help='The dataverse alias or the input file with the dataverse aliases') + add_batch_processor_args(parser_add) + add_dry_run_arg(parser_add) + + parser_add.set_defaults(func=lambda _: add_role_assignments(_, dataverse_client)) + + # Remove role assignment + parser_remove = subparsers.add_parser('remove', help='remove role assignment from specified dataset(s)') + parser_remove.add_argument('role_assignment', + help='role assignee and alias (example: @dataverseAdmin=contributor)') + parser_remove.add_argument('alias_or_alias_file', + help='The dataverse alias or the input file with the dataverse aliases') + add_batch_processor_args(parser_remove) + add_dry_run_arg(parser_remove) + parser_remove.set_defaults(func=lambda _: remove_role_assignments(_, dataverse_client)) + + # List role assignments + parser_list = subparsers.add_parser('list', + help='list role assignments for specified dataverse (only one alias allowed)') + parser_list.add_argument('alias', help='the dataverse alias') + add_dry_run_arg(parser_list) + parser_list.set_defaults(func=lambda _: list_role_assignments(_, dataverse_client)) + + args = parser.parse_args() + args.func(args) + + +if __name__ == '__main__': + main() diff --git a/src/tests/test_utils.py b/src/tests/test_utils.py index 1ffdb39..ecc8b6c 100644 --- a/src/tests/test_utils.py +++ b/src/tests/test_utils.py @@ -2,7 +2,8 @@ import argparse import unittest -from datastation.common.utils import is_sub_path_of, has_dirtree_pred, set_permissions, positive_int_argument_converter +from datastation.common.utils import is_sub_path_of, has_dirtree_pred, set_permissions, positive_int_argument_converter, \ + plural class TestIsSubPathOf: @@ -104,3 +105,11 @@ def test_positive_int_argument_converter(self): positive_int_argument_converter("-5") with self.assertRaises(argparse.ArgumentTypeError): positive_int_argument_converter("abc") + + +class TestPlural(unittest.TestCase): + def test_plural(self): + self.assertEqual(plural("pid"), "pids") + self.assertEqual(plural("alias"), "aliases") + self.assertEqual(plural(":-)lolly"), ":-)lollies") +