diff --git a/.gitignore b/.gitignore index 5d8ebbb3..f57a35f6 100755 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +.env *.iml Callback* .*_history diff --git a/build.gradle b/build.gradle index c7f59275..b84a2ccb 100755 --- a/build.gradle +++ b/build.gradle @@ -15,7 +15,7 @@ apply plugin: 'io.spring.dependency-management' jar { baseName='orion-data' - version= '2.8.1' + version= '2.9.1' } repositories { diff --git a/data/stateDisplay.json b/data/stateDisplay.json index 40e4f6ab..aa150250 100644 --- a/data/stateDisplay.json +++ b/data/stateDisplay.json @@ -7,6 +7,9 @@ "dlu": { "text": "Finishing upload", "alertType": "success" + }, + "dmd": { + "text": "Finishing Upload" } } },{ @@ -18,6 +21,9 @@ "dlu": { "text": "Waiting for files...", "alertType": "primary" + }, + "dmd": { + "text": "Waiting for files" } } },{ @@ -29,6 +35,9 @@ "dlu": { "text": "Upload failed", "alertType": "danger" + }, + "dmd": { + "text": "Upload failed" } } },{ @@ -50,80 +59,52 @@ "apps": { "dlu": { "text": "DVC QC Pending", - "alertType": "warning", - "showDownload": true + "alertType": "warning" + }, + "dmd": { + "text": "Ready for review" } } },{ "_id": { "$oid": "5ecd777c2e72d5a73904dce4" }, - "state": "PACKAGE_LEVEL_REVIEW_STARTED", + "state": "PACKAGE_REVIEW_STARTED", "apps": { "dlu": { "text": "DVC QC Pending", - "alertType": "warning", - "showDownload": true + "alertType": "warning" + }, + "dmd": { + "text": "Package Review Started" } } },{ "_id": { "$oid": "5ecd778f2e72d5a73904dce5" }, - "state": "PACKAGE_LEVEL_REVIEW_SUCCEEDED", + "state": "PACKAGE_REVIEW_SUCCEEDED", "apps": { "dlu": { "text": "DVC QC Pending", - "alertType": "warning", - "showDownload": true + "alertType": "warning" + }, + "dmd": { + "text": "Package Review Succeeded" } } },{ "_id": { "$oid": "5ecd77a42e72d5a73904dce6" }, - "state": "PACKAGE_LEVEL_REVIEW_FAILED", + "state": "PACKAGE_REVIEW_FAILED", "apps": { "dlu": { "text": "DVC QC Rejected", - "alertType": "danger", - "showDownload": true - } - } -},{ - "_id": { - "$oid": "5ecd78882e72d5a73904dce7" - }, - "state": "PROPERTY_LEVEL_REVIEW_STARTED", - "apps": { - "dlu": { - "text": "DVC QC Pending", - "alertType": "warning", - "showDownload": true - } - } -},{ - "_id": { - "$oid": "5ecd78ab2e72d5a73904dce8" - }, - "state": "PROPERTY_LEVEL_REVIEW_FAILED", - "apps": { - "dlu": { - "text": "DVC QC Rejected", - "alertType": "danger", - "showDownload": true - } - } -},{ - "_id": { - "$oid": "5ecd78c32e72d5a73904dce9" - }, - "state": "PROPERTY_LEVEL_REVIEW_SUCCEEDED", - "apps": { - "dlu": { - "text": "DVC QC Approved", - "alertType": "success", - "showDownload": true + "alertType": "danger" + }, + "dmd": { + "text": "Package Review Failed" } } }] \ No newline at end of file diff --git a/data/tokens-template.json b/data/tokens-template.json new file mode 100644 index 00000000..69c6c9b5 --- /dev/null +++ b/data/tokens-template.json @@ -0,0 +1,7 @@ +{ + "_id":"5f3ce8525471f40049fd339f", + "tokenString":"bdVFkm4OBPnv6g7fRElFWexOSAZWBnwp4LPLOFkQJMt2", + "shibId":"kpmp-devs@umich.edu", + "expiration":"2021-08-19T08:52:34.539Z", + "active":true +} \ No newline at end of file diff --git a/scripts/dataPromotion/.env.example b/scripts/dataPromotion/.env.example index fc03cb5d..72606e9e 100644 --- a/scripts/dataPromotion/.env.example +++ b/scripts/dataPromotion/.env.example @@ -1,6 +1,7 @@ minio_access_key = minio_secret_key = destination_bucket = "kpmp-knowledge-environment" +source_bucket = "orion-data-uploads" datalake_dir = minio_host = "localhost:9000" mysql_user = diff --git a/scripts/dataPromotion/README.md b/scripts/dataPromotion/README.md index ec2c34e0..731dab23 100644 --- a/scripts/dataPromotion/README.md +++ b/scripts/dataPromotion/README.md @@ -3,28 +3,14 @@ pip install -r requirements.txt # Moving files from DL to S3 1. Add packageIDs,filenames to files_to_s3.txt one per line -2. Execute 'python datalakeToS3.py' +2. Execute 'python filesToS3.py' -# Utility scripts for moving data into Knowledge Environment Database -1. Uncomment what you need in indexToKE.py and execute. - -# Creating / Updating index records for Atlas portal -## Option 1: -1. Execute 'python datalakeToAtlasIndex.py' -2. Follow prompts - -## Option 2: -1. Create a comma-delimited file with the same headers as package_to_atlas_index.csv -2. Execute 'python datalakeToAtlasIndex.py -f my_package_to_atlas_index.csv' -3. Redirect output to file or copy/paste into your favorite POST client. - -## Option 3: -### Generate using the Knowledge Environment database. -Requirements: Make sure the knowledge_environment MYSQL database is available on 3306. -### All records -1. Execute python keDatabaseToAtlasIndex.py' without arguments -### By release version -1. Execute python keDatabaseToAtlasIndex.py -v ' -### Per file -1. Execute python keDatabaseToAtlasIndex.py -f ' +# Move the files from “file_pending” table to the “file” table in the Staging DB +1. Requires a connection to the DLU Mongo and the Staging DB MySQL (e.g. through tunnels) +2. Expression Matrix files get a filesize of 0 and should be updated when they are created. +3. Execute 'python filesToKE.py' +# Adds clinical data and participants from a CSV file to the Staging Database +1. Requires a connection to the Staging DB MySQL (e.g. through a tunnel) +2. Edit script to point to clinical .csv file +3. Execute 'python clinicalToKE.py' diff --git a/scripts/dataPromotion/datalakeToAtlasIndex.py b/scripts/dataPromotion/datalakeToAtlasIndex.py deleted file mode 100644 index 39f59ff4..00000000 --- a/scripts/dataPromotion/datalakeToAtlasIndex.py +++ /dev/null @@ -1,187 +0,0 @@ -#!/usr/bin/env python2 -import pymongo -from collections import OrderedDict -import json -import copy -import csv -import sys -import uuid - -class MetadataType: - def __init__(self, experimental_strategy, data_type, data_category, data_format, platform, access, file_name_match_string, workflow_type): - self.experimental_strategy = experimental_strategy - self.data_type = data_type - self.data_category = data_category - self.data_format = data_format - self.platform = platform - self.access = access - self.file_name_match_string = file_name_match_string - self.workflow_type = workflow_type - -class IndexDoc: - def __init__(self, metadata_type, file_id, file_name, file_size, protocol, participant_id, package_id, cases): - self.access = metadata_type.access - self.platform = metadata_type.platform - self.experimental_strategy = metadata_type.experimental_strategy - self.data_category = metadata_type.data_category - self.workflow_type = metadata_type.workflow_type - self.file_id = file_id - self.file_name = file_name - self.data_format = metadata_type.data_format - self.file_size = file_size - self.data_type = metadata_type.data_type - self.protocol = protocol - self.participant_id = participant_id - self.package_id = package_id - self.cases = cases - -class CasesIndexDoc: - def __init__(self, tissue_source, samples, demographics): - self.tissue_source = tissue_source - self.samples = samples - self.demographics = demographics - - -def get_selector(text, array): - j = 1 - select_text = "" - for item in array: - select_text += str(j) + " : " + item + "\n" - j += 1 - array_index = int(raw_input(text + " \n" + select_text)) - 1 - return array[array_index] - -def print_index_update_json(id): - return '{"update":{"_index":"file_cases","_id":"' + id + '"}}' - -def print_index_doc_json(index_doc): - index_doc.cases = index_doc.cases.__dict__ - return '{"doc":' + json.dumps(index_doc.__dict__) + ',"doc_as_upsert":true}' - -input_file_name = './package_to_atlas_index.csv' - -m1_dt_wsi = MetadataType("", "Whole Slide Images", "Pathology", "svs", "", "open", ".svs", "") -m2_dt_single_nuc_rna = MetadataType("Single-nucleus RNA-Seq", "Transcriptomics", "Molecular", "bam", "10x Genomics", "controlled", ".bam", "") -m9_dt_sub_seg_trans = MetadataType("Sub-segmental Transcriptomics", "Transcriptomics", "Molecular", "bam", "LMD Transcriptomics", "controlled", ".bam", "") -m10_dt_clinical_data = MetadataType("", "Clinical Study Data", "Clinical", "csv", "", "open", ".csv", "") - -m3_dt_single_nuc_rna = copy.copy(m2_dt_single_nuc_rna) -m3_dt_single_nuc_rna.data_format = "fastq" -m3_dt_single_nuc_rna.file_name_match_string = ".fastq.gz" - -m4_dt_single_nuc_rna = copy.copy(m2_dt_single_nuc_rna) -m4_dt_single_nuc_rna.data_format = "tsv mtx" -m4_dt_single_nuc_rna.access = "open" -m4_dt_single_nuc_rna.workflow_type = "Expression Matrix" - -m5_dt_single_nuc_rna = copy.copy(m2_dt_single_nuc_rna) -m5_dt_single_nuc_rna.data_format = "tsv" -m5_dt_single_nuc_rna.access = "open" -m5_dt_single_nuc_rna.platform = "snDrop-seq" -m5_dt_single_nuc_rna.file_name_match_string = ".tsv" -m5_dt_single_nuc_rna.workflow_type = "Expression Matrix" - -m11_dt_snr_metadata = copy.copy(m2_dt_single_nuc_rna) -m11_dt_snr_metadata.data_format = "xlsx" -m11_dt_snr_metadata.platform = "10x Genomics" -m11_dt_snr_metadata.access = "open" -m11_dt_snr_metadata.file_name_match_string = ".xlsx" -m11_dt_snr_metadata.workflow_type = "Experimental Metadata" - -metadata_types = OrderedDict() -metadata_types["1"] = m1_dt_wsi -metadata_types["2"] = m2_dt_single_nuc_rna -metadata_types["3"] = m3_dt_single_nuc_rna -metadata_types["4"] = m4_dt_single_nuc_rna -metadata_types["5"] = m5_dt_single_nuc_rna -metadata_types["9"] = m9_dt_sub_seg_trans -metadata_types["10"] = m10_dt_clinical_data -metadata_types["11"] = m11_dt_snr_metadata - -data_type_select = "" -for metadata_num, metadata_type in metadata_types.items(): - metadata_type_name = metadata_type.data_type + ", " + metadata_type.experimental_strategy + ", " + metadata_type.data_format + ", " + metadata_type.access - data_type_select += metadata_num + " : " + metadata_type_name + "\n" - -if len(sys.argv) > 1 and sys.argv[1] == '-f': - using_file_answer = 'Y' - if sys.argv[2]: - input_file_name = sys.argv[2] -else: - using_file_answer = raw_input('Are you using the "package_to_atlas_index.csv" file?') - -def process_update_row(row): - selected_metadata_type = metadata_types[row['metadata_type_num']] - cases_doc = CasesIndexDoc([row['tissue_source']], {"participant_id":[row['participant_id']], "tissue_type":[row['tissue_type']], "sample_type":[row['sample_type']]},{"sex":[row['sex']], "age":[row['age']]}) - docs = [] - if selected_metadata_type.data_format == "tsv mtx": - file_name = row['package_id'] + "_" + "expression_matrix.zip" - index_doc = IndexDoc(selected_metadata_type, row['package_id'], file_name, row['file_size_exp_matrix_only'], row['protocol'], row['participant_id'], row['package_id'], cases_doc) - docs.append(print_index_update_json(row['package_id']) + "\n" + print_index_doc_json(index_doc)) - else: - mongo_client = pymongo.MongoClient("mongodb://localhost:27017/") - database = mongo_client["dataLake"] - packages = database["packages"] - - result = packages.find_one({ "_id": row['package_id'] }, {"files":1}) - if not result is None: - if result['files'] > 0: - found_a_file = False - for file in result['files']: - if file["fileName"].endswith(selected_metadata_type.file_name_match_string): - found_a_file = True - file_name = file["_id"] + "_" + file["fileName"] - index_doc = IndexDoc(selected_metadata_type, file["_id"], file_name, file["size"], row['protocol'], row['participant_id'], row['package_id'], cases_doc) - docs.append(print_index_update_json(file["_id"]) + "\n" + print_index_doc_json(index_doc)) - if not found_a_file: - print("No files found in package matching " + selected_metadata_type.file_name_match_string + " extension") - else: - print("No files found in package " + package_id) - else: - print("Could not find any packages for ID " + package_id) - return docs - -def process_clinical_row(row, index_doc, row_num): - if row_num == 2: - file_id = str(uuid.uuid1()) - file_name = file_id + "_" + row['file_name'] - cases_doc = CasesIndexDoc([row['tissue_source']], {"participant_id":[row['participant_id']], "tissue_type":[row['tissue_type']], "sample_type":[row['sample_type']]},{"sex":[row['sex']], "age":[row['age']]}) - index_doc = IndexDoc(selected_metadata_type, file_id, file_name, row['file_size_exp_matrix_only'], row['protocol'], row['participant_id'], file_id, cases_doc) - else: - index_doc.cases.tissue_source.append(row['tissue_source']) - index_doc.cases.samples["participant_id"].append(row['participant_id']) - index_doc.cases.samples["sample_type"].append(row['sample_type']) - index_doc.cases.samples["tissue_type"].append(row['tissue_type']) - index_doc.cases.demographics["age"].append(row['age']) - index_doc.cases.demographics["sex"].append(row['sex']) - - return index_doc - -output = [] -if using_file_answer not in ('Y', 'yes', 'Yes', 'y'): - protocol = get_selector("Select protocol: ", ["Pilot1", "KPMP Recruitment Site"]) - package_id = raw_input("Enter the package ID: ") - participant_id = raw_input("Enter the participant ID: ") - metadata_type_num = raw_input("Select a metadata scheme number: \n" + data_type_select) - sample_type = "" - tissue_type = "" - sex = "" - age = "" - output += process_update_row({"protocol":protocol,"package_id":package_id,"participant_id":participant_id,"tissue_type":tissue_type,"sample_type":sample_type,"sex":sex,"age":age,"metadata_type_num":metadata_type_num}) -else: - with open(input_file_name) as csv_file: - output = [] - csv_reader = csv.DictReader(csv_file) - no_rows = True - index_doc = {} - for row in csv_reader: - no_rows = False - selected_metadata_type = metadata_types[row['metadata_type_num']] - if selected_metadata_type.data_type == "Clinical Study Data": - index_doc = process_clinical_row(row, index_doc, csv_reader.line_num) - else: - output += process_update_row(row) - -if selected_metadata_type.data_type == "Clinical Study Data": - output.append(print_index_update_json(index_doc.file_id) + "\n" + print_index_doc_json(index_doc)) -print("\n".join(output)) diff --git a/scripts/dataPromotion/datalakeToS3.py b/scripts/dataPromotion/datalakeToS3.py deleted file mode 100644 index d816cbce..00000000 --- a/scripts/dataPromotion/datalakeToS3.py +++ /dev/null @@ -1,53 +0,0 @@ -#!/usr/bin/env python2 -import pymongo -from minio import Minio -from minio.error import (ResponseError) -from dotenv import load_dotenv -import os -import csv - -load_dotenv() - -minio_access_key = os.environ.get('minio_access_key') -minio_secret_key = os.environ.get('minio_secret_key') -destination_bucket = os.environ.get('destination_bucket') -datalake_dir = os.environ.get('datalake_dir') -minio_host = os.environ.get('minio_host') - -mongo_client = pymongo.MongoClient("mongodb://localhost:27017/") -database = mongo_client["dataLake"] -packages = database["packages"] - -minio_client = Minio(minio_host, access_key=minio_access_key, secret_key=minio_secret_key, secure=False) - -with open('./files_to_s3.txt') as csv_file: - csv_reader = csv.DictReader(csv_file) - no_rows = True - for row in csv_reader: - no_rows = False - datalake_package_dir = datalake_dir + "/package_" + row['package_id'] + "/" - file_path = datalake_package_dir + row['filename'] - filename = None - if row['filename'].endswith('expression_matrix.zip'): - matrix_file_answer = raw_input("Found an expression matrix file: " + row['package_id'] + "," + "row['filename']" + ". Was this created manually?") - if matrix_file_answer in ('Y', 'yes', 'Yes', 'y'): - filename = row['package_id'] + "_" + "expression_matrix.zip" - else: - result = packages.find_one({ "_id": row['package_id'], "files.fileName": row['filename']}, {"_id": 0, "files.$": 1}) - if not result is None: - filename = result['files'][0]['_id'] + "_" + row['filename'] - else: - print("No files found for " + row['package_id'] + "," + row['filename']) - if filename: - object_name = row['package_id'] + "/" + filename - print("Moving " + object_name) - try: - minio_client.fput_object(destination_bucket, object_name, file_path) - except ResponseError as err: - print(err) - else: - print("Skipping " + row['package_id'] + "," + row['filename']) - -if no_rows: - print('Please add some entries to "files_to_s3.txt"') - diff --git a/scripts/dataPromotion/fileSizeFix.py b/scripts/dataPromotion/fileSizeFix.py new file mode 100644 index 00000000..63e81178 --- /dev/null +++ b/scripts/dataPromotion/fileSizeFix.py @@ -0,0 +1,46 @@ +import mysql.connector +from dotenv import load_dotenv +import os +import csv +import json + +load_dotenv() + +mysql_user = os.environ.get('mysql_user') +mysql_pwd = os.environ.get('mysql_pwd') + +datalake_dir = os.environ.get('datalake_dir') + +try: + mydb = mysql.connector.connect( + host="localhost", + user=mysql_user, + password=mysql_pwd, + database="knowledge_environment", + autocommit=True + ) + mydb.get_warnings = True + cursor1 = mydb.cursor(buffered=True) + cursor2 = mydb.cursor(buffered=True) +except: + print("Can't connect to MySQL") + print("Make sure you have tunnel open to the KE database, e.g.") + print("ssh ubuntu@qa-atlas.kpmp.org -i ~/.ssh/um-kpmp.pem -L 3306:localhost:3306") + os.sys.exit() + +query = ("SELECT file_id, file_name, package_id FROM file WHERE file_size = 0") +cursor1.execute(query) +update_count = 0 + +for (file_id, file_name, package_id) in cursor1: + datalake_package_dir = datalake_dir + "/package_" + package_id + "/" + original_file_name = file_name[37:] + file_path = datalake_package_dir + "expression_matrix.zip" + file_size = os.path.getsize(file_path) + values = (file_size, file_id) + update_sql = "UPDATE file SET file_size = %s WHERE file_id = %s" + print(update_sql % values) + cursor2.execute(update_sql, values) + + + diff --git a/scripts/dataPromotion/filesToKE.py b/scripts/dataPromotion/filesToKE.py new file mode 100644 index 00000000..90e6346e --- /dev/null +++ b/scripts/dataPromotion/filesToKE.py @@ -0,0 +1,72 @@ +import pymongo +import mysql.connector +from dotenv import load_dotenv +import os +import csv +import json + +load_dotenv() + +mysql_user = os.environ.get('mysql_user') +mysql_pwd = os.environ.get('mysql_pwd') +EXPRESSION_MATRIX_METADATA_TYPES = [4,21] +cursor = None + +try: + mydb = mysql.connector.connect( + host="localhost", + user=mysql_user, + password=mysql_pwd, + database="knowledge_environment" + ) + mydb.get_warnings = True + cursor1 = mydb.cursor(buffered=True) + cursor2 = mydb.cursor(buffered=True) +except: + print("Can't connect to MySQL") + print("Make sure you have tunnel open to the KE database, e.g.") + print("ssh ubuntu@qa-atlas.kpmp.org -i ~/.ssh/um-kpmp.pem -L 3306:localhost:3306") + os.sys.exit() + +try: + mongo_client = pymongo.MongoClient("mongodb://localhost:27017/", serverSelectionTimeoutMS=5000) + database = mongo_client["dataLake"] + packages = database["packages"] +except: + print("Can't connect to Mongo") + os.sys.exit() + +query = ("SELECT * FROM file_pending") +cursor1.execute(query) +update_count = 0 +for (package_id, file_name, protocol, metadata_type_id, participant_id, release_ver) in cursor1: + insert_sql = "INSERT IGNORE INTO file (file_id, file_name, package_id, file_size, protocol, metadata_type_id, release_ver) VALUES (%s, %s, %s, %s, %s, %s, %s)" + if metadata_type_id in EXPRESSION_MATRIX_METADATA_TYPES: + new_file_name = package_id + "_expression_matrix.zip" + file_id = package_id + file_size = 0 + else: + result = packages.find_one({ "_id": package_id, "files.fileName": file_name}, {"files.$":1}) + new_file_name = result["files"][0]["_id"] + "_" + file_name + file_size = result["files"][0]["size"] + file_id = result["files"][0]["_id"] + + val = (file_id, new_file_name, package_id, file_size, protocol, metadata_type_id, release_ver) + update_count = update_count + 1 + print(insert_sql % val) + cursor2.execute(insert_sql, val) + print(cursor2.fetchwarnings()) + + sql2 = "INSERT IGNORE INTO file_participant (file_id, participant_id) VALUES (%s, %s)" + val2 = (result["files"][0]["_id"], participant_id) + print(sql2 % val2) + cursor2.execute(sql2, val2) + warning = cursor2.fetchwarnings() + if warning is not None: + print(warning) + mydb.commit() +print(str(update_count) + " rows inserted") + + + + diff --git a/scripts/dataPromotion/filesToS3.py b/scripts/dataPromotion/filesToS3.py new file mode 100644 index 00000000..eb2da560 --- /dev/null +++ b/scripts/dataPromotion/filesToS3.py @@ -0,0 +1,107 @@ +#!/usr/bin/env python2 +from minio import Minio +from minio.error import (ResponseError) +from dotenv import load_dotenv +import os +from collections import OrderedDict +import mysql.connector +from argparse import ArgumentParser + +load_dotenv() + +minio_access_key = os.environ.get('minio_access_key') +minio_secret_key = os.environ.get('minio_secret_key') +destination_bucket = os.environ.get('destination_bucket') +source_bucket = os.environ.get('source_bucket') +datalake_dir = os.environ.get('datalake_dir') +minio_host = os.environ.get('minio_host') + +minio_client = Minio(minio_host, access_key=minio_access_key, secret_key=minio_secret_key, secure=False) + +mysql_user = os.environ.get('mysql_user') +mysql_pwd = os.environ.get('mysql_pwd') + +parser = ArgumentParser(description="Move files to S3") +parser.add_argument("-v", "--release_ver", + dest="release_ver", + help="target release version", + required=True) + +args = parser.parse_args() + +try: + mydb = mysql.connector.connect( + host="localhost", + user=mysql_user, + password=mysql_pwd, + database="knowledge_environment" + ) + mydb.get_warnings = True + cursor = mydb.cursor(buffered=True) + cursor2 = mydb.cursor(buffered=True) +except: + print("Can't connect to MySQL") + print("Make sure you have tunnel open to the KE database, e.g.") + print("ssh ubuntu@qa-atlas.kpmp.org -i ~/.ssh/um-kpmp.pem -L 3306:localhost:3306") + os.sys.exit() + +query = ("SELECT file_id, package_id, file_name, metadata_type_id FROM file WHERE release_ver = " + args.release_ver + "AND file_name NOT IN (SELECT file_name FROM moved_files)") +cursor.execute(query) +update_count = 0 + +for (file_id, package_id, file_name, metadata_type_id) in cursor: + datalake_package_dir = datalake_dir + "/package_" + package_id + "/" + original_file_name = file_name[37:] + file_path = datalake_package_dir + original_file_name + expression_file_names = "barcodes.tsv.gz features.tsv.gz matrix.mtx.gz" + if file_name: + object_name = package_id + "/" + file_name + print("Looking for: " + file_path) + if file_name.endswith('expression_matrix.zip'): + if metadata_type_id == 21: + query2 = "SELECT file_name FROM file_pending WHERE package_id = %s AND metadata_type_id = %s" + cursor2.execute(query2, (package_id, metadata_type_id)) + expression_file_names = cursor2.fetchone()[0].replace(";", "") + print("Creating expression matrix zip file for: " + expression_file_names) + expression_file_names_arr = expression_file_names.split() + if not os.path.exists(datalake_package_dir + expression_file_names_arr[0]): + for expression_file_name in expression_file_names_arr: + source_object = source_bucket + "/package_" + package_id + "/" + expression_file_name + command_string = "aws s3 cp s3://" + source_object + " " + datalake_package_dir + expression_file_name + print(command_string) + os.system(command_string) + command_string = "cd " + datalake_package_dir + " && zip expression_matrix.zip " + expression_file_names + print(command_string) + os.system(command_string) + file_size = os.path.getsize(file_path) + values = (file_size, file_id) + update_sql = "UPDATE file SET file_size = %s WHERE file_id = %s" + print(update_sql % values) + cursor2.execute(update_sql, values) + if not os.path.exists(file_path): + source_object = source_bucket + "/package_" + package_id + "/" + original_file_name + print("File not found locally. Trying S3: " + source_object) + try: + command_string = "aws s3 cp s3://" + source_object + " s3://" + destination_bucket + "/" + object_name + print(command_string) + os.system(command_string) + update_count = update_count + 1 + except ResponseError as err: + print(err) + pass + else: + try: + print("Moving " + object_name) + minio_client.fput_object(destination_bucket, object_name, file_path) + update_count = update_count + 1 + insert_sql = "INSERT INTO moved_files (file_name) VALUES (%s)" + cursor2.execute(insert_sql, (file_name)) + except ResponseError as err: + print(err) + pass + + else: + print("No file name in record.") + print("\n") + +print(str(update_count) + " files moved") \ No newline at end of file diff --git a/scripts/dataPromotion/files_to_s3.txt b/scripts/dataPromotion/files_to_s3.txt deleted file mode 100644 index 94f31252..00000000 --- a/scripts/dataPromotion/files_to_s3.txt +++ /dev/null @@ -1,6 +0,0 @@ -package_id,filename -4e57665b-7f5d-4101-9ce1-2bc85e0f8084,4e57665b-7f5d-4101-9ce1-2bc85e0f8084_expression_matrix.zip -978a0320-8da5-4248-ae46-11cdbc8a49b9,978a0320-8da5-4248-ae46-11cdbc8a49b9_expression_matrix.zip -efed3c64-3086-4293-b2f8-07a6d6a82193,efed3c64-3086-4293-b2f8-07a6d6a82193_expression_matrix.zip -71316619-d3a7-4c18-bc11-e6e0eba9f0c7,71316619-d3a7-4c18-bc11-e6e0eba9f0c7_expression_matrix.zip -51d67192-8834-4bd7-9359-cbee202655f2,51d67192-8834-4bd7-9359-cbee202655f2_expression_matrix.zip diff --git a/scripts/dataPromotion/indexToKE.py b/scripts/dataPromotion/indexToKE.py deleted file mode 100644 index fbf999b9..00000000 --- a/scripts/dataPromotion/indexToKE.py +++ /dev/null @@ -1,102 +0,0 @@ -import pymongo -import mysql.connector -from dotenv import load_dotenv -import os -import csv -import json - -load_dotenv() - -mysql_user = os.environ.get('mysql_user') -mysql_pwd = os.environ.get('mysql_pwd') - -mydb = mysql.connector.connect( - host="localhost", - user=mysql_user, - password=mysql_pwd, - database="knowledge_environment" -) -mycursor = mydb.cursor(buffered=True) -mycursor2 = mydb.cursor(buffered=True) - -mongo_client = pymongo.MongoClient("mongodb://localhost:27017/") -database = mongo_client["dataLake"] -ke_files = database["keFiles"] -packages = database["packages"] - -# Move file and participant data from ES index to KE tables (file, file_participant) -# -# for file in ke_files.find(): -# sql = "INSERT INTO file (file_id, file_name, package_id, access, file_size, protocol) VALUES (%s, %s, %s, %s, %s, %s)" -# val = (file["file_id"], file["file_name"],file["package_id"], file["access"], file["file_size"],file["protocol"]) -# print(sql % val) -# mycursor.execute(sql, val) -# mydb.commit() -# for participant_id in file["cases"]["samples"]["participant_id"]: -# sql2 = "INSERT INTO file_participant (file_id, participant_id) VALUES (%s, %s)" -# val2 = (file["file_id"], participant_id) -# mycursor.execute(sql2, val2) -# mydb.commit() - -# Get package ID from dataLake.packages and update KE file table -# -# query = ("SELECT file_id FROM file WHERE package_id = ''") -# mycursor.execute(query) -# for (file_id,) in mycursor: -# result = packages.find_one({ "files._id": file_id}) -# print(file_id + " " + result["_id"]) -# sql = "UPDATE file SET package_id = %s WHERE file_id = %s" -# val = (result["_id"], file_id) -# mycursor2.execute(sql, val) -# mydb.commit() - -# Update KE file table with metadata type from spreadsheet -# -# with open('./atlas_files.csv') as csv_file: -# csv_reader = csv.DictReader(csv_file) -# for row in csv_reader: -# query = "SELECT file_id FROM file WHERE package_id = %s AND LOCATE(%s, file_name) > 0" -# if row["metadata_type_id"] in ("7", "4", "5"): -# file_suffix = "expression_matrix.zip" -# else: -# file_suffix = row["file_name"] -# val = (row["package_id"], file_suffix) -# mycursor.execute(query, val) -# for (file_id,) in mycursor: -# query2 = "UPDATE file SET metadata_type_id = %s WHERE file_id = %s" -# val2 = (row["metadata_type_id"], file_id) -# print(query2 % val2) -# mycursor2.execute(query2, val2) -# mydb.commit() - -# Adds the metadata types from a file -# -# with open('./metadata_types.csv') as csv_file: -# csv_reader = csv.DictReader(csv_file) -# for row in csv_reader: -# query = ("INSERT INTO metadata_type (metadata_type_id, experimental_strategy, data_type, data_category, data_format, platform, workflow_type, access, kpmp_data_type)" -# " VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)") -# val = (row["metadata_type_id"], row["experimental_strategy"], row["data_type"], row["data_category"], row["data_format"], row["platform"], row["workflow_type"], row["access"], row["kpmp_data_type"]) -# print(query % val) -# mycursor.execute(query, val) -# mydb.commit() - -# Adds the clinical data from file -# non_clinical_rows = ["Participant ID", "Age (Years) (Binned)", "Sex", "Tissue Source", "Protocol", "Sample Type", "Tissue Type"] -# with open('./d801f97e-d032-11ea-a504-a4c3f0f6c2ae_20200721_OpenAccessMainProtocolClinicalData.csv') as csv_file: -# csv_reader = csv.DictReader(csv_file) -# for row in csv_reader: -# clinical_json = {} -# for key in row.keys(): -# if key not in non_clinical_rows: -# clinical_json[key] = row[key] -# query = ("INSERT INTO participant (participant_id, age_binned, sex, tissue_source, protocol, sample_type, tissue_type, clinical_data)" -# "VALUES (%s, %s, %s, %s, %s, %s, %s, %s)") -# val = (row["Participant ID"], row["Age (Years) (Binned)"], row["Sex"], row["Tissue Source"], row["Protocol"], row["Sample Type"], row["Tissue Type"], json.dumps(clinical_json)) -# print(query % val) -# mycursor.execute(query, val) -# mydb.commit() - -mycursor.close() -mydb.close() - diff --git a/scripts/dataPromotion/keDatabaseToAtlasIndex.py b/scripts/dataPromotion/keDatabaseToAtlasIndex.py deleted file mode 100644 index 24c5dd1a..00000000 --- a/scripts/dataPromotion/keDatabaseToAtlasIndex.py +++ /dev/null @@ -1,99 +0,0 @@ -#!/usr/bin/env python2 -from collections import OrderedDict -import json -import copy -import csv -import sys -import mysql.connector -import pprint -from dotenv import load_dotenv -import os -from argparse import ArgumentParser - -load_dotenv() - -mysql_user = os.environ.get('mysql_user') -mysql_pwd = os.environ.get('mysql_pwd') - -mydb = mysql.connector.connect( - host="localhost", - user=mysql_user, - password=mysql_pwd, - database="knowledge_environment" -) -mycursor = mydb.cursor(buffered=True,dictionary=True) - -class IndexDoc: - def __init__(self, access, platform, experimental_strategy, data_category, workflow_type, data_format, data_type, file_id, file_name, file_size, protocol, package_id, cases): - self.access = access - self.platform = platform - self.experimental_strategy = experimental_strategy - self.data_category = data_category - self.workflow_type = workflow_type - self.file_id = file_id - self.file_name = file_name - self.data_format = data_format - self.file_size = file_size - self.data_type = data_type - self.protocol = protocol - self.package_id = package_id - self.cases = cases - -class CasesIndexDoc: - def __init__(self, tissue_source, samples, demographics): - self.tissue_source = tissue_source - self.samples = samples - self.demographics = demographics - -def get_index_update_json(id): - return '{"update":{"_index":"file_cases","_id":"' + id + '"}}' - -def get_index_doc_json(index_doc): - index_doc.cases = index_doc.cases.__dict__ - return '{"doc":' + json.dumps(index_doc.__dict__) + ',"doc_as_upsert":true}' - -input_file_id = "" -release_ver = "" -where_clause = "" - -parser = ArgumentParser(description="Generate ES index updates. No arguments will create updates for all records.") -parser.add_argument("-f", "--file_id", dest="file_id", - help="file ID") -parser.add_argument("-v", "--release_ver", - dest="release_ver", - help="target release version") - -args = parser.parse_args() - -if args.file_id: - where_clause = " WHERE f.file_id = '" + args.file_id + "' " -elif args.release_ver: - where_clause = " WHERE f.release_ver = " + args.release_ver + " " - -query = ("SELECT f.*, fp.*, p.*, m.* FROM file f " - "JOIN file_participant fp on f.file_id = fp.file_id " - "JOIN participant p on fp.participant_id = p.participant_id " - "JOIN metadata_type m on f.metadata_type_id = m.metadata_type_id" + where_clause) - -mycursor.execute(query) -row_num = 1 -last_file_id = -1 -for row in mycursor: - if row["file_id"] != last_file_id: - if row_num != 1: - print(get_index_update_json(index_doc.file_id) + "\n" + - get_index_doc_json(index_doc)) - cases_doc = CasesIndexDoc([row['tissue_source']], {"participant_id":[row['participant_id']], "tissue_type":[row['tissue_type']], "sample_type":[row['sample_type']]},{"sex":[row['sex']], "age":[row['age_binned']]}) - index_doc = IndexDoc(row["access"], row["platform"], row["experimental_strategy"], row["data_category"], row["workflow_type"], row["data_format"], row["data_type"], row["file_id"], row["file_name"], row["file_size"], row["protocol"], row["package_id"], cases_doc) - else: - index_doc.cases.tissue_source.append(row['tissue_source']) - index_doc.cases.samples["participant_id"].append(row['participant_id']) - index_doc.cases.samples["sample_type"].append(row['sample_type']) - index_doc.cases.samples["tissue_type"].append(row['tissue_type']) - index_doc.cases.demographics["age"].append(row['age_binned']) - index_doc.cases.demographics["sex"].append(row['sex']) - row_num += 1 - last_file_id = row["file_id"] - -print(get_index_update_json(index_doc.file_id) + "\n" + - get_index_doc_json(index_doc)) \ No newline at end of file diff --git a/scripts/dataPromotion/loadClinical/.env.example b/scripts/dataPromotion/loadClinical/.env.example new file mode 100644 index 00000000..ac94fd90 --- /dev/null +++ b/scripts/dataPromotion/loadClinical/.env.example @@ -0,0 +1,4 @@ +mysql_user= +mysql_pwd= +pilot_clinical_filename= +release_ver= \ No newline at end of file diff --git a/scripts/dataPromotion/loadClinical/addPilotClinicalDataFile.py b/scripts/dataPromotion/loadClinical/addPilotClinicalDataFile.py new file mode 100644 index 00000000..a60cd6ad --- /dev/null +++ b/scripts/dataPromotion/loadClinical/addPilotClinicalDataFile.py @@ -0,0 +1,55 @@ +import mysql.connector +import os +import uuid +import csv + +from dotenv import load_dotenv +load_dotenv() + +mysql_user = os.environ.get('mysql_user') +mysql_pwd = os.environ.get('mysql_pwd') +clinical_file_name = os.environ.get('pilot_clinical_filename') +relase_ver = os.environ.get('release_ver') + +mydb = mysql.connector.connect( + host="localhost", + user=mysql_user, + password=mysql_pwd, + database="knowledge_environment" +) +mycursor = mydb.cursor(buffered=True) + +file_uuid = str(uuid.uuid4()) +new_file_name= file_uuid + "_" + clinical_file_name +file_size = os.path.getsize('./' + clinical_file_name) + +query = "INSERT INTO file (dl_file_id, file_name, package_id, file_size, protocol, release_ver) " + \ + "VALUES (%s, %s, %s, %s, %s, %s)"; +values = (file_uuid, new_file_name, file_uuid, file_size, "KPMP Pilot 1 Protocol", relase_ver); + +mycursor.execute(query, values); +mydb.commit(); +mycursor.close(); + +file_id = mycursor.lastrowid + + +# This next part is dumb...I couldn't get the code to do this insert correctly, so instead I am +# just printing out the insert statements, and then I execute them manually +with open('./' + clinical_file_name) as csv_file: + csv_reader = csv.DictReader(csv_file) + for row in csv_reader: + mycursor = mydb.cursor(buffered=True, dictionary=True) + particpant_id = row['Participant ID']; + + query = "INSERT INTO file_participant (file_id, participant_id) " + \ + "VALUES (%s, (SELECT participant_id FROM participant WHERE redcap_id='%s'));"; + values = (file_id, particpant_id); + + print(query % values) + # mycursor.execute(query, values); + + mycursor.close(); + + +mydb.close(); diff --git a/scripts/dataPromotion/loadClinical/clinicalToKE.py b/scripts/dataPromotion/loadClinical/clinicalToKE.py new file mode 100644 index 00000000..87154594 --- /dev/null +++ b/scripts/dataPromotion/loadClinical/clinicalToKE.py @@ -0,0 +1,59 @@ +import mysql.connector +import os +import csv +import json + +from dotenv import load_dotenv +load_dotenv() + +mysql_user = os.environ.get('mysql_user') +mysql_pwd = os.environ.get('mysql_pwd') + +mydb = mysql.connector.connect( + host="localhost", + user=mysql_user, + password=mysql_pwd, + database="knowledge_environment" +) +mycursor = mydb.cursor(buffered=True) + +# Adds the clinical data from file +# Replace the filename as appropriate +non_clinical_rows = ["Participant ID", "secondary_id", "Tissue Source", "Protocol", "Sample Type", "Tissue Type", "Sex", "Age (Years) (Binned)"] +with open('./20201203_OpenAccessMainProtocolClinicalData.csv') as csv_file: + csv_reader = csv.DictReader(csv_file) + inserted = []; + updated = []; + for row in csv_reader: + clinical_json = {} + for key in row.keys(): + if key not in non_clinical_rows: + clinical_json[key] = row[key] + # Need to see if this participant is already in the db + # if it is, replace with new values, otherwise add new row + + query = "SELECT redcap_id FROM participant WHERE redcap_id = '" + row["Participant ID"] + "'"; + mycursor.execute(query); + + if not mycursor.rowcount: + inserted.append(row["Participant ID"]); + query = ("INSERT INTO participant (old_participant_id, redcap_id, age_binned, sex, tissue_source, protocol, sample_type, tissue_type, clinical_data)" + "VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)") + val = (row["secondary_id"], row["Participant ID"], row["Age (Years) (Binned)"], row["Sex"], row["Tissue Source"], row["Protocol"], row["Sample Type"], row["Tissue Type"], json.dumps(clinical_json)) + print(query % val) + mycursor.execute(query, val) + mydb.commit() + else: + updated.append(row["Participant ID"]); + query = "UPDATE participant SET old_participant_id = %s, redcap_id = %s, age_binned = %s, sex = %s, tissue_source = %s, protocol = %s, sample_type=%s, tissue_type=%s, clinical_data = %s WHERE old_participant_id = %s"; + values = (row["secondary_id"], row["Participant ID"], row["Age (Years) (Binned)"], row["Sex"], row["Tissue Source"], row["Protocol"], row["Sample Type"], row["Tissue Type"], json.dumps(clinical_json), row["Participant ID"]); + print(query % values) + mycursor.execute(query, values) + mydb.commit() + +print ("Inserted " + str(len(inserted)) + " records"); +print(" , ".join(inserted)) +print ("Updated " + str(len(updated)) + " records"); +print(" , ".join(updated)) +mycursor.close() +mydb.close() \ No newline at end of file diff --git a/scripts/dataPromotion/loadClinical/requirements.txt b/scripts/dataPromotion/loadClinical/requirements.txt new file mode 100644 index 00000000..31654a6b --- /dev/null +++ b/scripts/dataPromotion/loadClinical/requirements.txt @@ -0,0 +1,3 @@ +python-dotenv==0.15.0 +mysql_connector_repackaged==0.3.1 +mysql-connector-python diff --git a/scripts/dataPromotion/requirements.txt b/scripts/dataPromotion/requirements.txt index d046db36..1cf100dd 100644 --- a/scripts/dataPromotion/requirements.txt +++ b/scripts/dataPromotion/requirements.txt @@ -1,3 +1,3 @@ minio python-dotenv -mysql-connector-python \ No newline at end of file +mysql-connector-python diff --git a/scripts/dataPromotion/sunsetFiles.py b/scripts/dataPromotion/sunsetFiles.py new file mode 100644 index 00000000..a71b9535 --- /dev/null +++ b/scripts/dataPromotion/sunsetFiles.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python2 +from dotenv import load_dotenv +import os +from collections import OrderedDict +import mysql.connector +from argparse import ArgumentParser + +load_dotenv() + +destination_bucket = "kpmp-knowledge-environment-sunsetted" +source_bucket = "kpmp-knowledge-environment" + +mysql_user = os.environ.get('mysql_user') +mysql_pwd = os.environ.get('mysql_pwd') + +parser = ArgumentParser(description="Sunset files") +parser.add_argument("-v", "--release_sunset", + dest="release_sunset", + help="target release sunset version", + required=True) + +args = parser.parse_args() + +try: + mydb = mysql.connector.connect( + host="localhost", + user=mysql_user, + password=mysql_pwd, + database="knowledge_environment" + ) + mydb.get_warnings = True + cursor = mydb.cursor(buffered=True) + cursor2 = mydb.cursor(buffered=True) +except: + print("Can't connect to MySQL") + print("Make sure you have tunnel open to the KE database, e.g.") + print("ssh ubuntu@qa-atlas.kpmp.org -i ~/.ssh/um-kpmp.pem -L 3306:localhost:3306") + os.sys.exit() + +query = ("SELECT file_id, package_id, file_name FROM file WHERE release_sunset = " + args.release_sunset) +cursor.execute(query) +update_count = 0 + +for (file_id, package_id, file_name) in cursor: + if file_name: + object_name = package_id + "/" + file_name + source_object = source_bucket + "/" + object_name + try: + command_string = "aws s3 mv s3://" + source_object + " s3://" + destination_bucket + "/" + object_name + print(command_string) + os.system(command_string) + update_count = update_count + 1 + except ResponseError as err: + print(err) + pass + else: + print("No file name in record.") + print("\n") + +print(str(update_count) + " files moved") diff --git a/src/main/java/org/kpmp/apiTokens/Token.java b/src/main/java/org/kpmp/apiTokens/Token.java new file mode 100644 index 00000000..6033f2c3 --- /dev/null +++ b/src/main/java/org/kpmp/apiTokens/Token.java @@ -0,0 +1,49 @@ +package org.kpmp.apiTokens; + +import com.fasterxml.jackson.annotation.JsonPropertyOrder; +import org.springframework.data.mongodb.core.mapping.Document; + +import java.util.Date; + + +@JsonPropertyOrder({ "tokenString", "email", "expiration", "active" }) +@Document(collection = "tokens") +public class Token { + + private String tokenString; + private String shibId; + private Date expiration; + private Boolean active; + + public String getTokenString() { + return tokenString; + } + + public void setTokenString(String tokenString) { + this.tokenString = tokenString; + } + + public String getShibId() { + return shibId; + } + + public void setShibId(String shibId) { + this.shibId = shibId; + } + + public Date getExpiration() { + return expiration; + } + + public void setExpiration(Date expiration) { + this.expiration = expiration; + } + + public Boolean getActive() { + return active; + } + + public void setActive(Boolean active) { + this.active = active; + } +} diff --git a/src/main/java/org/kpmp/apiTokens/TokenController.java b/src/main/java/org/kpmp/apiTokens/TokenController.java new file mode 100644 index 00000000..725b42a0 --- /dev/null +++ b/src/main/java/org/kpmp/apiTokens/TokenController.java @@ -0,0 +1,38 @@ +package org.kpmp.apiTokens; + +import javax.servlet.http.HttpServletRequest; + +import org.kpmp.shibboleth.ShibbolethUserService; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Controller; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RequestMethod; +import org.springframework.web.bind.annotation.ResponseBody; + +@Controller +public class TokenController { + + private ShibbolethUserService userService; + private TokenService tokenService; + + @Autowired + public TokenController(ShibbolethUserService userService, TokenService tokenService) { + this.userService = userService; + this.tokenService = tokenService; + } + + @RequestMapping(value = "/v1/token", method = RequestMethod.GET) + public @ResponseBody TokenResponse getToken(HttpServletRequest request) { + String shibId = userService.getUser(request).getShibId(); + Token token = tokenService.getOrSetToken(shibId); + TokenResponse tokenResponse = new TokenResponse(); + if (!tokenService.checkToken(token)) { + tokenResponse.setMessage("Your token is inactive or expired. Please contact KPMP DLU support."); + } else { + tokenResponse.setMessage("Success!"); + } + tokenResponse.setToken(token); + return tokenResponse; + } + +} diff --git a/src/main/java/org/kpmp/apiTokens/TokenRepository.java b/src/main/java/org/kpmp/apiTokens/TokenRepository.java new file mode 100644 index 00000000..32f8bbb9 --- /dev/null +++ b/src/main/java/org/kpmp/apiTokens/TokenRepository.java @@ -0,0 +1,16 @@ +package org.kpmp.apiTokens; + +import org.springframework.data.mongodb.repository.MongoRepository; +import org.springframework.stereotype.Component; + +@Component +public interface TokenRepository extends MongoRepository { + + @SuppressWarnings("unchecked") + public Token save(Token token); + + public Token findByShibId(String shibId); + + public Token findByTokenString(String tokenString); + +} diff --git a/src/main/java/org/kpmp/apiTokens/TokenResponse.java b/src/main/java/org/kpmp/apiTokens/TokenResponse.java new file mode 100644 index 00000000..fd3d7e1b --- /dev/null +++ b/src/main/java/org/kpmp/apiTokens/TokenResponse.java @@ -0,0 +1,22 @@ +package org.kpmp.apiTokens; + +public class TokenResponse { + private String message; + private Token token; + + public String getMessage() { + return message; + } + + public void setMessage(String message) { + this.message = message; + } + + public Token getToken() { + return token; + } + + public void setToken(Token token) { + this.token = token; + } +} diff --git a/src/main/java/org/kpmp/apiTokens/TokenService.java b/src/main/java/org/kpmp/apiTokens/TokenService.java new file mode 100644 index 00000000..1d762de1 --- /dev/null +++ b/src/main/java/org/kpmp/apiTokens/TokenService.java @@ -0,0 +1,70 @@ +package org.kpmp.apiTokens; + +import java.util.Calendar; +import java.util.Date; + +import org.apache.commons.lang3.RandomStringUtils; +import org.kpmp.shibboleth.ShibbolethUserService; +import org.springframework.stereotype.Service; + +@Service +public class TokenService { + + private TokenRepository tokenRepository; + private ShibbolethUserService userService; + + public TokenService(TokenRepository tokenRepository, ShibbolethUserService userService) { + this.tokenRepository = tokenRepository; + this.userService = userService; + } + + public Token getOrSetToken(String shibId) { + Token resultToken = tokenRepository.findByShibId(shibId); + if (resultToken != null) { + return resultToken; + } else { + Token token = generateToken(shibId); + tokenRepository.save(token); + return token; + } + } + + public Token generateToken(String shibId) { + Token token = new Token(); + token.setShibId(shibId); + token.setActive(true); + Calendar cal = Calendar.getInstance(); + cal.add(Calendar.YEAR, 1); + Date nextYear = cal.getTime(); + token.setExpiration(nextYear); + int length = 44; + boolean useLetters = true; + boolean useNumbers = true; + String tokenString = RandomStringUtils.random(length, useLetters, useNumbers); + token.setTokenString(tokenString); + return token; + } + + public Boolean checkExpired(Token token) { + Calendar cal = Calendar.getInstance(); + Date today = cal.getTime(); + return today.compareTo(token.getExpiration()) > 0; + } + + public Boolean checkToken(Token token) { + return !checkExpired(token) && token.getActive(); + } + + public Token getTokenByTokenString(String tokenString) { + return tokenRepository.findByTokenString(tokenString); + } + + public Boolean checkAndValidate(String tokenString) { + Token token = tokenRepository.findByTokenString(tokenString); + if (token != null) { + return checkToken(token); + } else { + return false; + } + } +} diff --git a/src/main/java/org/kpmp/filters/AuthorizationFilter.java b/src/main/java/org/kpmp/filters/AuthorizationFilter.java index 1da15ba8..82681154 100644 --- a/src/main/java/org/kpmp/filters/AuthorizationFilter.java +++ b/src/main/java/org/kpmp/filters/AuthorizationFilter.java @@ -32,16 +32,17 @@ @Component public class AuthorizationFilter implements Filter { + private static final String FILE_PART_INDEX = "qqpartindex"; private static final String USER_NOT_PART_OF_KPMP = "User is not part of KPMP: "; private static final String USER_NO_DLU_ACCESS = "User does not have access to DLU: "; private static final String GROUPS_KEY = "groups"; private static final String USER_DOES_NOT_EXIST = "User does not exist in User Portal: "; private static final String CLIENT_ID_PROPERTY = "CLIENT_ID"; - private static final String COOKIE_NAME = "shibid"; private static final int SECONDS_IN_MINUTE = 60; private static final int MINUTES_IN_HOUR = 60; private static final int SESSION_TIMEOUT_HOURS = 8; private static final int SESSION_TIMEOUT_SECONDS = SECONDS_IN_MINUTE * MINUTES_IN_HOUR * SESSION_TIMEOUT_HOURS; + private static final String FILE_PART_UPLOAD_URI_MATCHER = "/v1/packages/(.*)/files"; private LoggingService logger; private ShibbolethUserService shibUserService; @@ -84,7 +85,9 @@ public void doFilter(ServletRequest incomingRequest, ServletResponse incomingRes Cookie[] cookies = request.getCookies(); User user = shibUserService.getUser(request); String shibId = user.getShibId(); - if (hasExistingSession(shibId, cookies, request) || allowedEndpoints.contains(request.getRequestURI())) { + if (hasExistingSession(user, shibId, cookies, request) || allowedEndpoints.contains(request.getRequestURI()) + || !isFirstFilePartUpload(request)) { + chain.doFilter(request, response); } else { String clientId = env.getProperty(CLIENT_ID_PROPERTY); @@ -99,9 +102,8 @@ public void doFilter(ServletRequest incomingRequest, ServletResponse incomingRes if (isAllowed(userGroups) && userJson.getBoolean("active")) { HttpSession session = request.getSession(true); session.setMaxInactiveInterval(SESSION_TIMEOUT_SECONDS); - Cookie message = new Cookie(COOKIE_NAME, shibId); session.setAttribute("roles", userGroups); - response.addCookie(message); + session.setAttribute("shibid", shibId); chain.doFilter(request, response); } else if (isKPMP(userGroups)) { handleError(USER_NO_DLU_ACCESS + userGroups, HttpStatus.FORBIDDEN, request, response); @@ -128,6 +130,18 @@ public void doFilter(ServletRequest incomingRequest, ServletResponse incomingRes } + private boolean isFirstFilePartUpload(HttpServletRequest request) { + String filePartIndex = request.getParameter(FILE_PART_INDEX); + if (filePartIndex != null && request.getRequestURI().matches(FILE_PART_UPLOAD_URI_MATCHER) + && Integer.parseInt(filePartIndex) > 0) { + logger.logInfoMessage(this.getClass(), null, null, + this.getClass().getSimpleName() + ".isFirstFilePartUpload", + "file upload: not first part, skipping user auth check"); + return false; + } + return true; + } + private boolean isAllowed(JSONArray userGroups) throws JSONException { for (int i = 0; i < userGroups.length(); i++) { String group = userGroups.getString(i); @@ -154,23 +168,22 @@ private void handleError(String errorMessage, HttpStatus status, HttpServletRequ response.setStatus(status.value()); } - private boolean hasExistingSession(String shibId, Cookie[] cookies, HttpServletRequest request) { + private boolean hasExistingSession(User user, String shibId, Cookie[] cookies, HttpServletRequest request) { HttpSession existingSession = request.getSession(false); if (existingSession != null) { - for (Cookie cookie : cookies) { - if (cookie.getName().equals("shibId")) { - if (cookie.getValue().equals(shibId)) { - return true; - } else { - logger.logInfoMessage(this.getClass(), null, - "MSG: Invalidating session. Cookie does not match shibId for user", request); - existingSession.invalidate(); - return false; - } - } + logger.logInfoMessage(this.getClass(), user, null, request.getRequestURI(), + "checking for existing session"); + if (existingSession.getAttribute("shibid") != null + && existingSession.getAttribute("shibid").equals(user.getShibId())) { + logger.logWarnMessage(this.getClass(), user, null, request.getRequestURI(), + "skipping filter, active session"); + return true; + } else { + return false; } } return false; + } @Override diff --git a/src/main/java/org/kpmp/ingest/redcap/REDCapIngestController.java b/src/main/java/org/kpmp/ingest/redcap/REDCapIngestController.java index 2bf0765e..469348b9 100644 --- a/src/main/java/org/kpmp/ingest/redcap/REDCapIngestController.java +++ b/src/main/java/org/kpmp/ingest/redcap/REDCapIngestController.java @@ -1,15 +1,21 @@ package org.kpmp.ingest.redcap; +import static org.springframework.http.HttpStatus.UNAUTHORIZED; + import javax.servlet.http.HttpServletRequest; import org.json.JSONException; +import org.kpmp.apiTokens.Token; +import org.kpmp.apiTokens.TokenService; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.http.ResponseEntity; import org.springframework.stereotype.Controller; import org.springframework.web.bind.annotation.RequestBody; import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RequestMethod; +import org.springframework.web.bind.annotation.RequestParam; import org.springframework.web.bind.annotation.ResponseBody; @Controller @@ -18,24 +24,37 @@ public class REDCapIngestController { private static Logger log = LoggerFactory.getLogger(REDCapIngestController.class); private static final String LOG_MESSAGE_FORMAT = "URI: {} | MSG: {} "; private REDCapIngestService service; + private TokenService tokenService; @Autowired - public REDCapIngestController(REDCapIngestService service) { + public REDCapIngestController(REDCapIngestService service, TokenService tokenService) { this.service = service; + this.tokenService = tokenService; } + @SuppressWarnings("rawtypes") @RequestMapping(value = "/v1/redcap", method = RequestMethod.POST) - public @ResponseBody boolean ingestREDCapData(@RequestBody String dataDump, HttpServletRequest request) - throws JSONException { + public @ResponseBody ResponseEntity ingestREDCapData(@RequestBody String dataDump, + @RequestParam("token") String tokenString, HttpServletRequest request) throws JSONException { log.info(LOG_MESSAGE_FORMAT, request.getRequestURI(), "Receiving new REDCap data dump"); - try { - service.saveDataDump(dataDump); - } catch (JSONException error) { - log.error(LOG_MESSAGE_FORMAT, request.getRequestURI(), error.getMessage()); - throw error; + ResponseEntity responseEntity; + if (tokenService.checkAndValidate(tokenString)) { + Token token = tokenService.getTokenByTokenString(tokenString); + try { + service.saveDataDump(dataDump); + log.info(LOG_MESSAGE_FORMAT, request.getRequestURI(), "Received new REDCap data dump from shibId " + + token.getShibId() + " using token " + token.getTokenString()); + responseEntity = ResponseEntity.ok().body("Successfully ingested REDCap data"); + } catch (JSONException error) { + log.error(LOG_MESSAGE_FORMAT, request.getRequestURI(), error.getMessage()); + throw error; + } + } else { + log.error(LOG_MESSAGE_FORMAT, request.getRequestURI(), "Invalid token provided: " + tokenString); + responseEntity = ResponseEntity.status(UNAUTHORIZED).body("Invalid token."); } - return true; + return responseEntity; } } diff --git a/src/main/java/org/kpmp/packages/PackageController.java b/src/main/java/org/kpmp/packages/PackageController.java index 048fd14a..f3e17cfc 100755 --- a/src/main/java/org/kpmp/packages/PackageController.java +++ b/src/main/java/org/kpmp/packages/PackageController.java @@ -1,5 +1,7 @@ package org.kpmp.packages; +import static org.springframework.http.HttpStatus.INTERNAL_SERVER_ERROR; + import java.io.IOException; import java.text.MessageFormat; import java.util.List; @@ -28,8 +30,6 @@ import org.springframework.web.bind.annotation.ResponseBody; import org.springframework.web.multipart.MultipartFile; -import static org.springframework.http.HttpStatus.INTERNAL_SERVER_ERROR; - @Controller public class PackageController { @@ -90,8 +90,8 @@ public PackageController(PackageService packageService, LoggingService logger, if ("true".equals(largeFilesChecked)) { packageResponse.setGlobusURL(globusService.createDirectory(packageId)); } - packageService.sendStateChangeEvent(packageId, metadataReceivedState, largeFilesChecked, packageResponse.getGlobusURL(), - cleanHostName); + packageService.sendStateChangeEvent(packageId, metadataReceivedState, largeFilesChecked, + packageResponse.getGlobusURL(), cleanHostName); } catch (Exception e) { logger.logErrorMessage(this.getClass(), packageId, e.getMessage(), request); packageService.sendStateChangeEvent(packageId, uploadFailedState, null, e.getMessage(), cleanHostName); @@ -142,20 +142,19 @@ public PackageController(PackageService packageService, LoggingService logger, .body(resource); } - @RequestMapping(value = "/v1/packages/{packageId}/files/move", method = RequestMethod.POST) - public @ResponseBody - ResponseEntity movePackageFiles(@PathVariable String packageId, - HttpServletRequest request) { - ResponseEntity responseEntity; - try { - packageService.movePackageFiles(packageId); - responseEntity = ResponseEntity.ok().body("Moving files for package " + packageId); - } catch (IOException | InterruptedException e) { + @SuppressWarnings("rawtypes") + @RequestMapping(value = "/v1/packages/{packageId}/files/move", method = RequestMethod.POST) + public @ResponseBody ResponseEntity movePackageFiles(@PathVariable String packageId, HttpServletRequest request) { + ResponseEntity responseEntity; + try { + packageService.movePackageFiles(packageId); + responseEntity = ResponseEntity.ok().body("Moving files for package " + packageId); + } catch (IOException | InterruptedException e) { logger.logErrorMessage(this.getClass(), packageId, e.getMessage(), request); - responseEntity = ResponseEntity.status(INTERNAL_SERVER_ERROR).body("There was a problem moving the files."); - } + responseEntity = ResponseEntity.status(INTERNAL_SERVER_ERROR).body("There was a problem moving the files."); + } return responseEntity; - } + } @RequestMapping(value = "/v1/packages/{packageId}/files/finish", method = RequestMethod.POST) public @ResponseBody FileUploadResponse finishUpload(@PathVariable("packageId") String packageId, diff --git a/src/test/java/org/kpmp/apiTokens/TokenControllerTest.java b/src/test/java/org/kpmp/apiTokens/TokenControllerTest.java new file mode 100644 index 00000000..d1265e63 --- /dev/null +++ b/src/test/java/org/kpmp/apiTokens/TokenControllerTest.java @@ -0,0 +1,50 @@ +package org.kpmp.apiTokens; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.kpmp.shibboleth.ShibbolethUserService; +import org.kpmp.users.User; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +import javax.servlet.http.HttpServletRequest; + +import static org.junit.Assert.assertEquals; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class TokenControllerTest { + + @Mock + private ShibbolethUserService userService; + @Mock + private TokenService tokenService; + private TokenController tokenController; + + @Before + public void setUp() throws Exception { + MockitoAnnotations.initMocks(this); + tokenController = new TokenController(userService, tokenService); + } + + @After + public void tearDown() throws Exception { + tokenController = null; + } + + @Test + public void testGetToken() { + TokenResponse tokenResponse = new TokenResponse(); + HttpServletRequest request = mock(HttpServletRequest.class); + Token token = new Token(); + User user = new User(); + user.setShibId("shibId"); + when(userService.getUser(request)).thenReturn(user); + tokenResponse.setToken(token); + tokenResponse.setMessage("This is the message"); + when(tokenService.getOrSetToken("shibId")).thenReturn(token); + assertEquals(token, tokenController.getToken(request).getToken()); + } + +} diff --git a/src/test/java/org/kpmp/apiTokens/TokenResponseTest.java b/src/test/java/org/kpmp/apiTokens/TokenResponseTest.java new file mode 100644 index 00000000..36a6757d --- /dev/null +++ b/src/test/java/org/kpmp/apiTokens/TokenResponseTest.java @@ -0,0 +1,36 @@ +package org.kpmp.apiTokens; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +public class TokenResponseTest { + + private TokenResponse tokenResponse; + + @Before + public void setUp() throws Exception { + tokenResponse = new TokenResponse(); + } + + @After + public void tearDown() throws Exception { + tokenResponse = null; + } + + @Test + public void testSetMessage() { + tokenResponse.setMessage("This is a token"); + assertEquals("This is a token", tokenResponse.getMessage()); + } + + @Test + public void testSetToken() { + Token token = new Token(); + tokenResponse.setToken(token); + assertEquals(token, tokenResponse.getToken()); + } + +} diff --git a/src/test/java/org/kpmp/apiTokens/TokenServiceTest.java b/src/test/java/org/kpmp/apiTokens/TokenServiceTest.java new file mode 100644 index 00000000..f633e7dc --- /dev/null +++ b/src/test/java/org/kpmp/apiTokens/TokenServiceTest.java @@ -0,0 +1,121 @@ +package org.kpmp.apiTokens; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.kpmp.shibboleth.ShibbolethUserService; +import org.kpmp.users.User; +import org.mockito.Mock; + +import static org.junit.Assert.assertEquals; +import static org.mockito.Mockito.when; +import org.mockito.MockitoAnnotations; + +import java.util.Calendar; +import java.util.Date; + +public class TokenServiceTest { + + private TokenService tokenService; + @Mock + private TokenRepository tokenRepository; + @Mock + private ShibbolethUserService userService; + + @Before + public void setUp() throws Exception { + MockitoAnnotations.initMocks(this); + tokenService = new TokenService(tokenRepository, userService); + } + + @After + public void tearDown() throws Exception { + tokenService = null; + } + + @Test + public void testCheckAndValidateGoodTokenString() { + Token token = new Token(); + token.setTokenString("ABCD"); + token.setActive(true); + Calendar cal = Calendar.getInstance(); + cal.add(Calendar.YEAR, 1); + Date nextYear = cal.getTime(); + token.setExpiration(nextYear); + when(tokenRepository.findByTokenString("ABCD")).thenReturn(token); + assertEquals(true, tokenService.checkAndValidate("ABCD")); + } + + @Test + public void testCheckAndValidateBadTokenString() { + Token token = new Token(); + token.setTokenString("ABCD"); + token.setActive(true); + Calendar cal = Calendar.getInstance(); + cal.add(Calendar.YEAR, 1); + Date nextYear = cal.getTime(); + token.setExpiration(nextYear); + when(tokenRepository.findByTokenString("ABCD")).thenReturn(null); + assertEquals(false, tokenService.checkAndValidate("ABCD")); + } + + @Test + public void testCheckTokenExpired() { + Token token = new Token(); + token.setTokenString("ABCD"); + token.setActive(true); + Calendar cal = Calendar.getInstance(); + cal.add(Calendar.YEAR, -1); + Date lastYear = cal.getTime(); + token.setExpiration(lastYear); + assertEquals(false, tokenService.checkToken(token)); + } + + @Test + public void testCheckTokenInactive() { + Token token = new Token(); + token.setTokenString("ABCD"); + token.setActive(false); + Calendar cal = Calendar.getInstance(); + cal.add(Calendar.YEAR, 1); + Date nextYear = cal.getTime(); + token.setExpiration(nextYear); + assertEquals(false, tokenService.checkToken(token)); + } + + @Test + public void testGenerateToken() { + User user = new User(); + user.setShibId("shibId"); + Token token = tokenService.generateToken("shibId"); + assertEquals(true, token.getActive()); + assertEquals("shibId", token.getShibId()); + assertEquals(44, token.getTokenString().length()); + } + + @Test + public void testGetOrSetTokenExists() { + Token token = new Token(); + token.setTokenString("ABCD"); + token.setShibId("shibId"); + token.setActive(true); + Calendar cal = Calendar.getInstance(); + cal.add(Calendar.YEAR, 1); + when(tokenRepository.findByShibId("shibId")).thenReturn(token); + assertEquals(token, tokenService.getOrSetToken("shibId")); + } + + @Test + public void testGetOrSetTokenDoesntExist() { + when(tokenRepository.findByShibId("shibId")).thenReturn(null); + assertEquals("shibId", tokenService.getOrSetToken("shibId").getShibId()); + } + + @Test + public void testGetTokenByTokenString() { + Token token = new Token(); + when(tokenRepository.findByTokenString("ABCD")).thenReturn(token); + assertEquals(token, tokenService.getTokenByTokenString("ABCD")); + } + +} diff --git a/src/test/java/org/kpmp/apiTokens/TokenTest.java b/src/test/java/org/kpmp/apiTokens/TokenTest.java new file mode 100644 index 00000000..fdcb3ef3 --- /dev/null +++ b/src/test/java/org/kpmp/apiTokens/TokenTest.java @@ -0,0 +1,50 @@ +package org.kpmp.apiTokens; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import java.util.Date; + +import static org.junit.Assert.assertEquals; + +public class TokenTest { + + private Token token; + + @Before + public void setUp() throws Exception { + token = new Token(); + } + + @After + public void tearDown() throws Exception { + token = null; + } + + @Test + public void testSetTokenString() { + token.setTokenString("Token string"); + assertEquals("Token string", token.getTokenString()); + } + + @Test + public void testSetShibId() { + token.setShibId("shibby"); + assertEquals("shibby", token.getShibId()); + } + + @Test + public void testSetExpiration() { + Date date = new Date(); + token.setExpiration(date); + assertEquals(date, token.getExpiration()); + } + + @Test + public void setActive() { + token.setActive(false); + assertEquals(false, token.getActive()); + } + +} diff --git a/src/test/java/org/kpmp/filters/AuthorizationFilterTest.java b/src/test/java/org/kpmp/filters/AuthorizationFilterTest.java index 99332036..842935f1 100644 --- a/src/test/java/org/kpmp/filters/AuthorizationFilterTest.java +++ b/src/test/java/org/kpmp/filters/AuthorizationFilterTest.java @@ -1,6 +1,5 @@ package org.kpmp.filters; -import static org.junit.Assert.assertEquals; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; @@ -12,7 +11,6 @@ import javax.servlet.FilterChain; import javax.servlet.FilterConfig; import javax.servlet.ServletException; -import javax.servlet.http.Cookie; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpSession; @@ -23,7 +21,6 @@ import org.kpmp.logging.LoggingService; import org.kpmp.shibboleth.ShibbolethUserService; import org.kpmp.users.User; -import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.springframework.core.env.Environment; @@ -72,8 +69,9 @@ public void testInit() throws ServletException { } @Test - public void testDoFilter_userHasValidCookie() throws Exception { // eslint-disable-line no-eval + public void testDoFilter_skippableURI() throws Exception { // eslint-disable-line no-eval HttpServletRequest incomingRequest = mock(HttpServletRequest.class); + when(incomingRequest.getRequestURI()).thenReturn("uri1"); HttpServletResponse incomingResponse = mock(HttpServletResponse.class); FilterChain chain = mock(FilterChain.class); User user = mock(User.class); @@ -81,72 +79,84 @@ public void testDoFilter_userHasValidCookie() throws Exception { // eslint-disab when(shibUserService.getUser(incomingRequest)).thenReturn(user); HttpSession session = mock(HttpSession.class); when(incomingRequest.getSession(false)).thenReturn(session); - Cookie goodCookie = mock(Cookie.class); - when(goodCookie.getName()).thenReturn("shibId"); - when(goodCookie.getValue()).thenReturn("shibboleth id"); - Cookie badCookie = mock(Cookie.class); - when(badCookie.getName()).thenReturn("Darth Vader"); - when(incomingRequest.getCookies()).thenReturn(new Cookie[] { badCookie, goodCookie }); filter.doFilter(incomingRequest, incomingResponse, chain); verify(chain).doFilter(incomingRequest, incomingResponse); } + @SuppressWarnings("unchecked") @Test - public void testDoFilter_skippableURI() throws Exception { // eslint-disable-line no-eval + public void testDoFilter_nonChunkedFileUpload() throws Exception { HttpServletRequest incomingRequest = mock(HttpServletRequest.class); - when(incomingRequest.getRequestURI()).thenReturn("uri1"); + when(incomingRequest.getRequestURI()).thenReturn("/v1/packages/123-3435-kljlkj/files"); HttpServletResponse incomingResponse = mock(HttpServletResponse.class); FilterChain chain = mock(FilterChain.class); User user = mock(User.class); when(user.getShibId()).thenReturn("shibboleth id"); when(shibUserService.getUser(incomingRequest)).thenReturn(user); HttpSession session = mock(HttpSession.class); - when(incomingRequest.getSession(false)).thenReturn(session); - when(incomingRequest.getCookies()).thenReturn(new Cookie[] {}); + when(incomingRequest.getSession(true)).thenReturn(session); + ResponseEntity response = mock(ResponseEntity.class); + when(response.getBody()).thenReturn("{groups: [ 'group1', 'another group'], active: true}"); + when(restTemplate.getForEntity(any(String.class), any(Class.class))).thenReturn(response); filter.doFilter(incomingRequest, incomingResponse, chain); - verify(chain).doFilter(incomingRequest, incomingResponse); + verify(incomingRequest, times(1)).getSession(true); } @SuppressWarnings("unchecked") @Test - public void testDoFilter_userHasSomeoneElsesCookieAndGotEmptyResponse() throws Exception { // eslint-disable-line - // no-eval + public void testDoFilter_firstChunkFileUpload() throws Exception { HttpServletRequest incomingRequest = mock(HttpServletRequest.class); + when(incomingRequest.getRequestURI()).thenReturn("/v1/packages/123-3435-kljlkj/files"); + when(incomingRequest.getParameter("qqpartindex")).thenReturn("0"); HttpServletResponse incomingResponse = mock(HttpServletResponse.class); FilterChain chain = mock(FilterChain.class); User user = mock(User.class); when(user.getShibId()).thenReturn("shibboleth id"); when(shibUserService.getUser(incomingRequest)).thenReturn(user); HttpSession session = mock(HttpSession.class); - when(incomingRequest.getSession(false)).thenReturn(session); - Cookie shibCookie = mock(Cookie.class); - when(shibCookie.getName()).thenReturn("shibId"); - when(shibCookie.getValue()).thenReturn("not your cookie"); - when(incomingRequest.getCookies()).thenReturn(new Cookie[] { shibCookie }); + when(incomingRequest.getSession(true)).thenReturn(session); ResponseEntity response = mock(ResponseEntity.class); - when(response.getBody()).thenReturn("{}"); + when(response.getBody()).thenReturn("{groups: [ 'group1', 'another group'], active: true}"); when(restTemplate.getForEntity(any(String.class), any(Class.class))).thenReturn(response); filter.doFilter(incomingRequest, incomingResponse, chain); - verify(chain, times(0)).doFilter(incomingRequest, incomingResponse); - verify(logger).logInfoMessage(AuthorizationFilter.class, null, - "MSG: Invalidating session. Cookie does not match shibId for user", incomingRequest); - verify(session).invalidate(); - verify(incomingResponse).setStatus(HttpStatus.FAILED_DEPENDENCY.value()); - verify(logger).logErrorMessage(AuthorizationFilter.class, null, - "Unable to parse response from User Portal, denying user shibboleth id access. Response: {}", - incomingRequest); + verify(incomingRequest, times(1)).getSession(true); + } + + @SuppressWarnings("unchecked") + @Test + public void testDoFilter_notFirstChunkFileUpload() throws Exception { + HttpServletRequest incomingRequest = mock(HttpServletRequest.class); + when(incomingRequest.getRequestURI()).thenReturn("/v1/packages/123-3435-kljlkj/files"); + when(incomingRequest.getParameter("qqpartindex")).thenReturn("3"); + HttpServletResponse incomingResponse = mock(HttpServletResponse.class); + FilterChain chain = mock(FilterChain.class); + User user = mock(User.class); + when(user.getShibId()).thenReturn("shibboleth id"); + when(shibUserService.getUser(incomingRequest)).thenReturn(user); + HttpSession session = mock(HttpSession.class); + when(incomingRequest.getSession(true)).thenReturn(session); + ResponseEntity response = mock(ResponseEntity.class); + when(response.getBody()).thenReturn("{groups: [ 'group1', 'another group'], active: true}"); + when(restTemplate.getForEntity(any(String.class), any(Class.class))).thenReturn(response); + + filter.doFilter(incomingRequest, incomingResponse, chain); + + verify(incomingRequest, times(0)).getSession(true); + verify(logger).logInfoMessage(AuthorizationFilter.class, null, null, + "AuthorizationFilter.isFirstFilePartUpload", "file upload: not first part, skipping user auth check"); } @SuppressWarnings("unchecked") @Test public void testDoFilter_whenNoSessionAndEmptyResponse() throws Exception { // eslint-disable-line no-eval HttpServletRequest incomingRequest = mock(HttpServletRequest.class); + when(incomingRequest.getRequestURI()).thenReturn("anything"); HttpServletResponse incomingResponse = mock(HttpServletResponse.class); FilterChain chain = mock(FilterChain.class); User user = mock(User.class); @@ -170,6 +180,7 @@ public void testDoFilter_whenNoSessionAndEmptyResponse() throws Exception { // e @Test public void testDoFilter_noSessionHasAllowedGroup() throws Exception { // eslint-disable-line no-eval HttpServletRequest incomingRequest = mock(HttpServletRequest.class); + when(incomingRequest.getRequestURI()).thenReturn("anything"); HttpServletResponse incomingResponse = mock(HttpServletResponse.class); HttpSession session = mock(HttpSession.class); when(incomingRequest.getSession(true)).thenReturn(session); @@ -186,10 +197,6 @@ public void testDoFilter_noSessionHasAllowedGroup() throws Exception { // eslint verify(chain).doFilter(incomingRequest, incomingResponse); verify(session).setMaxInactiveInterval(8 * 60 * 60); - ArgumentCaptor cookieJar = ArgumentCaptor.forClass(Cookie.class); - verify(incomingResponse).addCookie(cookieJar.capture()); - assertEquals(cookieJar.getValue().getName(), "shibid"); - assertEquals(cookieJar.getValue().getValue(), "shibboleth id"); } @SuppressWarnings("unchecked") @@ -197,6 +204,7 @@ public void testDoFilter_noSessionHasAllowedGroup() throws Exception { // eslint public void testDoFilter_noSessionDoesNotHaveAllowedGroupHasKpmpGroup() throws Exception { // eslint-disable-line // no-eval HttpServletRequest incomingRequest = mock(HttpServletRequest.class); + when(incomingRequest.getRequestURI()).thenReturn("anything"); HttpServletResponse incomingResponse = mock(HttpServletResponse.class); HttpSession session = mock(HttpSession.class); when(incomingRequest.getSession(true)).thenReturn(session); @@ -223,6 +231,7 @@ public void testDoFilter_noSessionDoesNotHaveAllowedGroupNoKpmpGroup() throws Ex // no-eval HttpServletRequest incomingRequest = mock(HttpServletRequest.class); HttpServletResponse incomingResponse = mock(HttpServletResponse.class); + when(incomingRequest.getRequestURI()).thenReturn("anything"); HttpSession session = mock(HttpSession.class); when(incomingRequest.getSession(true)).thenReturn(session); FilterChain chain = mock(FilterChain.class); @@ -247,6 +256,7 @@ public void testDoFilter_noSessionDoesNotHaveAllowedGroupNoKpmpGroup() throws Ex public void testDoFilter_userAuthReturned404() throws Exception { // eslint-disable-line no-eval HttpServletRequest incomingRequest = mock(HttpServletRequest.class); HttpServletResponse incomingResponse = mock(HttpServletResponse.class); + when(incomingRequest.getRequestURI()).thenReturn("anything"); HttpSession session = mock(HttpSession.class); when(incomingRequest.getSession(true)).thenReturn(session); FilterChain chain = mock(FilterChain.class); @@ -269,6 +279,7 @@ public void testDoFilter_userAuthReturned404() throws Exception { // eslint-disa @Test public void testDoFilter_userAuthReturnedAnotherErrorCode() throws Exception { // eslint-disable-line no-eval HttpServletRequest incomingRequest = mock(HttpServletRequest.class); + when(incomingRequest.getRequestURI()).thenReturn("anything"); HttpServletResponse incomingResponse = mock(HttpServletResponse.class); HttpSession session = mock(HttpSession.class); when(incomingRequest.getSession(true)).thenReturn(session); diff --git a/src/test/java/org/kpmp/ingest/redcap/REDCapIngestControllerTest.java b/src/test/java/org/kpmp/ingest/redcap/REDCapIngestControllerTest.java index 4e1793a4..b32a31d4 100644 --- a/src/test/java/org/kpmp/ingest/redcap/REDCapIngestControllerTest.java +++ b/src/test/java/org/kpmp/ingest/redcap/REDCapIngestControllerTest.java @@ -4,7 +4,9 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import javax.servlet.http.HttpServletRequest; @@ -12,20 +14,25 @@ import org.junit.After; import org.junit.Before; import org.junit.Test; +import org.kpmp.apiTokens.Token; +import org.kpmp.apiTokens.TokenService; import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import org.springframework.http.ResponseEntity; public class REDCapIngestControllerTest { @Mock private REDCapIngestService service; private REDCapIngestController controller; + @Mock + private TokenService tokenService; @Before public void setUp() throws Exception { MockitoAnnotations.initMocks(this); - controller = new REDCapIngestController(service); + controller = new REDCapIngestController(service, tokenService); } @After @@ -34,23 +41,40 @@ public void tearDown() throws Exception { } @Test - public void testIngestREDCapData() throws JSONException { + public void testIngestREDCapDataGoodToken() throws JSONException { HttpServletRequest request = mock(HttpServletRequest.class); - - controller.ingestREDCapData("json dump", request); - + String tokenString = "ABCD"; + Token token = new Token(); + token.setTokenString(tokenString); + token.setShibId("shibId"); + when(tokenService.checkAndValidate("ABCD")).thenReturn(true); + when(tokenService.getTokenByTokenString("ABCD")).thenReturn(token); + controller.ingestREDCapData("json dump", tokenString, request); ArgumentCaptor dumpCaptor = ArgumentCaptor.forClass(String.class); verify(service).saveDataDump(dumpCaptor.capture()); assertEquals("json dump", dumpCaptor.getValue()); } + @SuppressWarnings("rawtypes") + @Test + public void testIngestREDCapDataBadToken() throws JSONException { + HttpServletRequest request = mock(HttpServletRequest.class); + String token = "ABCD"; + when(tokenService.checkAndValidate("ABCD")).thenReturn(false); + ResponseEntity response = controller.ingestREDCapData("json dump", token, request); + ArgumentCaptor dumpCaptor = ArgumentCaptor.forClass(String.class); + verify(service, never()).saveDataDump(dumpCaptor.capture()); + assertEquals(org.springframework.http.HttpStatus.UNAUTHORIZED, response.getStatusCode()); + } + @Test public void testIngestREDCapData_throwsException() throws Exception { // eslint-disable-line no-eval HttpServletRequest request = mock(HttpServletRequest.class); + String token = "ABCD"; doThrow(new JSONException("oopsies")).when(service).saveDataDump(any(String.class)); try { - controller.ingestREDCapData("json dump", request); + controller.ingestREDCapData("json dump", token, request); } catch (JSONException expected) { assertEquals("oopsies", expected.getMessage()); } diff --git a/src/test/java/org/kpmp/packages/PackageControllerTest.java b/src/test/java/org/kpmp/packages/PackageControllerTest.java index 346fe78b..15b08f3a 100755 --- a/src/test/java/org/kpmp/packages/PackageControllerTest.java +++ b/src/test/java/org/kpmp/packages/PackageControllerTest.java @@ -137,7 +137,7 @@ public void testPostPackageInformation() throws Exception { verify(logger).logInfoMessage(PackageController.class, "universalId", "Posting package info: {\"packageType\":\"blah\"}", request); verify(packageService).sendStateChangeEvent("universalId", "UPLOAD_STARTED", null, "hostname"); - verify(packageService).sendStateChangeEvent("universalId", "METADATA_RECEIVED", "false",null, "hostname"); + verify(packageService).sendStateChangeEvent("universalId", "METADATA_RECEIVED", "false", null, "hostname"); } @Test @@ -167,7 +167,8 @@ public void testPostPackageInformationLargeFile() throws Exception { verify(logger).logInfoMessage(PackageController.class, "universalId", "Posting package info: {\"largeFilesChecked\":true,\"packageType\":\"blah\"}", request); verify(packageService).sendStateChangeEvent("universalId", "UPLOAD_STARTED", null, "hostname"); - verify(packageService).sendStateChangeEvent("universalId", "METADATA_RECEIVED", "true", "theWholeURL", "hostname"); + verify(packageService).sendStateChangeEvent("universalId", "METADATA_RECEIVED", "true", "theWholeURL", + "hostname"); } @Test @@ -246,8 +247,8 @@ public void testFinishUpload_whenCreateZipThrows() throws Exception { verify(logger).logErrorMessage(PackageController.class, "3545", "error getting metadata for package id: 3545", request); verify(packageService).sendStateChangeEvent("3545", "FILES_RECEIVED", null, "origin"); - verify(packageService).sendStateChangeEvent("3545", "UPLOAD_FAILED", - null, "error getting metadata for package id: 3545", "origin"); + verify(packageService).sendStateChangeEvent("3545", "UPLOAD_FAILED", null, + "error getting metadata for package id: 3545", "origin"); } @Test @@ -265,8 +266,8 @@ public void testFinishUpload_whenMismatchedFiles() throws Exception { verify(logger).logErrorMessage(PackageController.class, "3545", "Unable to zip package with package id: 3545", request); verify(packageService).sendStateChangeEvent("3545", "FILES_RECEIVED", null, "origin"); - verify(packageService).sendStateChangeEvent("3545", "UPLOAD_FAILED", - null, "Unable to zip package with package id: 3545", "origin"); + verify(packageService).sendStateChangeEvent("3545", "UPLOAD_FAILED", null, + "Unable to zip package with package id: 3545", "origin"); } @Test @@ -303,6 +304,7 @@ public void testDownloadPackage_serviceException() throws Exception { } } + @SuppressWarnings("rawtypes") @Test public void testMovePackageFiles() throws Exception { HttpServletRequest request = mock(HttpServletRequest.class); diff --git a/src/test/java/org/kpmp/packages/PackageNotificationInfoTest.java b/src/test/java/org/kpmp/packages/PackageNotificationInfoTest.java index 1d14cf62..3bd8285e 100644 --- a/src/test/java/org/kpmp/packages/PackageNotificationInfoTest.java +++ b/src/test/java/org/kpmp/packages/PackageNotificationInfoTest.java @@ -7,7 +7,6 @@ import org.junit.After; import org.junit.Before; import org.junit.Test; -import org.kpmp.packages.PackageNotificationInfo; public class PackageNotificationInfoTest { diff --git a/src/test/java/org/kpmp/packages/StateTest.java b/src/test/java/org/kpmp/packages/StateTest.java index 9c348093..a84d839e 100644 --- a/src/test/java/org/kpmp/packages/StateTest.java +++ b/src/test/java/org/kpmp/packages/StateTest.java @@ -7,7 +7,6 @@ import org.junit.After; import org.junit.Before; import org.junit.Test; -import org.kpmp.packages.State; public class StateTest { diff --git a/src/test/java/org/kpmp/packages/UniversalIdGeneratorTest.java b/src/test/java/org/kpmp/packages/UniversalIdGeneratorTest.java index 0277203d..3513b66e 100644 --- a/src/test/java/org/kpmp/packages/UniversalIdGeneratorTest.java +++ b/src/test/java/org/kpmp/packages/UniversalIdGeneratorTest.java @@ -8,7 +8,6 @@ import org.junit.After; import org.junit.Before; import org.junit.Test; -import org.kpmp.packages.UniversalIdGenerator; public class UniversalIdGeneratorTest { @@ -28,6 +27,7 @@ public void tearDown() throws Exception { public void testGenerateUniversalId() throws Exception { String uuid = generator.generateUniversalId(); + System.err.println(uuid); assertNotNull(uuid); assertEquals(4, UUID.fromString(uuid).version()); }