diff --git a/python/ppc_common/db_models/computing_node.py b/python/ppc_common/db_models/computing_node.py new file mode 100644 index 00000000..bc0b988d --- /dev/null +++ b/python/ppc_common/db_models/computing_node.py @@ -0,0 +1,10 @@ + +from ppc_common.db_models import db + + +class ComputingNodeRecord(db.Model): + __tablename__ = 't_computing_node' + id = db.Column(db.String(255), primary_key=True) + url = db.Column(db.String(255)) + type = db.Column(db.String(255)) + loading = db.Column(db.Integer) diff --git a/python/ppc_common/db_models/config.sql b/python/ppc_common/db_models/config.sql new file mode 100644 index 00000000..295f5820 --- /dev/null +++ b/python/ppc_common/db_models/config.sql @@ -0,0 +1,26 @@ + +CREATE TABLE t_job_worker ( + worker_id VARCHAR(100) PRIMARY KEY, + job_id VARCHAR(255) INDEX, + type VARCHAR(255), + status VARCHAR(255), + upstreams TEXT, + inputs_statement TEXT, + outputs TEXT, + create_time BIGINT, + update_time BIGINT +)ENGINE=InnoDB default charset=utf8mb4 default collate=utf8mb4_unicode_ci; + +CREATE TABLE t_computing_node ( + id VARCHAR(255) PRIMARY KEY, + url VARCHAR(255), + type VARCHAR(255), + loading INT +)ENGINE=InnoDB default charset=utf8mb4 default collate=utf8mb4_unicode_ci; + + +INSERT INTO t_computing_node (id, url, type, loading) +VALUES + ("001", '127.0.0.1:10200', 'PSI', 0), + ("002", '127.0.0.1:10201', 'MPC', 0), + ("003", '127.0.0.1:10202', 'MODEL', 0); diff --git a/python/ppc_common/db_models/job_unit_record.py b/python/ppc_common/db_models/job_unit_record.py deleted file mode 100644 index f953ecbb..00000000 --- a/python/ppc_common/db_models/job_unit_record.py +++ /dev/null @@ -1,14 +0,0 @@ - -from ppc_common.db_models import db - - -class JobUnitRecord(db.Model): - __tablename__ = 't_job_unit' - unit_id = db.Column(db.String(100), primary_key=True) - job_id = db.Column(db.String(255), index=True) - type = db.Column(db.String(255)) - status = db.Column(db.String(255), index=True) - upstream_units = db.Column(db.Text) - inputs_statement = db.Column(db.Text) - outputs = db.Column(db.Text) - update_time = db.Column(db.BigInteger) diff --git a/python/ppc_common/db_models/job_worker_record.py b/python/ppc_common/db_models/job_worker_record.py new file mode 100644 index 00000000..29f94f61 --- /dev/null +++ b/python/ppc_common/db_models/job_worker_record.py @@ -0,0 +1,15 @@ + +from ppc_common.db_models import db + + +class JobWorkerRecord(db.Model): + __tablename__ = 't_job_worker' + worker_id = db.Column(db.String(100), primary_key=True) + job_id = db.Column(db.String(255), index=True) + type = db.Column(db.String(255)) + status = db.Column(db.String(255)) + upstreams = db.Column(db.Text) + inputs_statement = db.Column(db.Text) + outputs = db.Column(db.Text) + create_time = db.Column(db.BigInteger) + update_time = db.Column(db.BigInteger) diff --git a/python/ppc_common/ppc_async_executor/async_subprocess_executor.py b/python/ppc_common/ppc_async_executor/async_subprocess_executor.py index 87083ac8..afff9950 100644 --- a/python/ppc_common/ppc_async_executor/async_subprocess_executor.py +++ b/python/ppc_common/ppc_async_executor/async_subprocess_executor.py @@ -15,30 +15,30 @@ def __init__(self, logger): self._cleanup_thread.daemon = True self._cleanup_thread.start() - def execute(self, task_id: str, target: Callable, on_target_finish: Callable[[str, bool, Exception], None], + def execute(self, target_id: str, target: Callable, on_target_finish: Callable[[str, bool, Exception], None], args=()): process = multiprocessing.Process(target=target, args=args) process.start() with self.lock: - self.processes[task_id] = process + self.processes[target_id] = process - def kill(self, task_id: str): + def kill(self, target_id: str): with self.lock: - if task_id not in self.processes: + if target_id not in self.processes: return False else: - process = self.processes[task_id] + process = self.processes[target_id] process.terminate() - self.logger.info(f"Task {task_id} has been terminated!") + self.logger.info(f"Target {target_id} has been terminated!") return True def kill_all(self): with self.lock: keys = self.processes.keys() - for task_id in keys: - self.kill(task_id) + for target_id in keys: + self.kill(target_id) def _loop_cleanup(self): while True: @@ -48,13 +48,13 @@ def _loop_cleanup(self): def _cleanup_finished_processes(self): with self.lock: finished_processes = [ - (task_id, proc) for task_id, proc in self.processes.items() if not proc.is_alive()] + (target_id, proc) for target_id, proc in self.processes.items() if not proc.is_alive()] - for task_id, process in finished_processes: + for target_id, process in finished_processes: with self.lock: process.join() # 确保进程资源释放 - del self.processes[task_id] - self.logger.info(f"Cleanup finished task process {task_id}") + del self.processes[target_id] + self.logger.info(f"Cleanup finished process {target_id}") def __del__(self): self.kill_all() diff --git a/python/ppc_common/ppc_async_executor/async_thread_executor.py b/python/ppc_common/ppc_async_executor/async_thread_executor.py index 0bc65b91..15b1f5ca 100644 --- a/python/ppc_common/ppc_async_executor/async_thread_executor.py +++ b/python/ppc_common/ppc_async_executor/async_thread_executor.py @@ -17,44 +17,44 @@ def __init__(self, event_manager: ThreadEventManager, logger): self._cleanup_thread.daemon = True self._cleanup_thread.start() - def execute(self, task_id: str, target: Callable, on_target_finish: Callable[[str, bool, Exception], None], + def execute(self, target_id: str, target: Callable, on_target_finish: Callable[[str, bool, Exception], None], args=()): def thread_target(logger, on_finish, *args): try: target(*args) - on_finish(task_id, True) + on_finish(target_id, True) except Exception as e: logger.warn(traceback.format_exc()) - on_finish(task_id, False, e) + on_finish(target_id, False, e) thread = threading.Thread(target=thread_target, args=( self.logger, on_target_finish,) + args) thread.start() with self.lock: - self.threads[task_id] = thread + self.threads[target_id] = thread stop_event = threading.Event() - self.event_manager.add_event(task_id, stop_event) + self.event_manager.add_event(target_id, stop_event) - def kill(self, task_id: str): + def kill(self, target_id: str): with self.lock: - if task_id not in self.threads: + if target_id not in self.threads: return False else: - thread = self.threads[task_id] + thread = self.threads[target_id] - self.event_manager.set_event(task_id) + self.event_manager.set_event(target_id) thread.join() - self.logger.info(f"Task {task_id} has been stopped!") + self.logger.info(f"Target {target_id} has been stopped!") return True def kill_all(self): with self.lock: keys = self.threads.keys() - for task_id in keys: - self.kill(task_id) + for target_id in keys: + self.kill(target_id) def _loop_cleanup(self): while True: @@ -64,12 +64,12 @@ def _loop_cleanup(self): def _cleanup_finished_threads(self): with self.lock: finished_threads = [ - task_id for task_id, thread in self.threads.items() if not thread.is_alive()] + target_id for target_id, thread in self.threads.items() if not thread.is_alive()] - for task_id in finished_threads: + for target_id in finished_threads: with self.lock: - del self.threads[task_id] - self.logger.info(f"Cleanup finished task thread {task_id}") + del self.threads[target_id] + self.logger.info(f"Cleanup finished thread {target_id}") def __del__(self): self.kill_all() diff --git a/python/ppc_common/ppc_async_executor/thread_event_manager.py b/python/ppc_common/ppc_async_executor/thread_event_manager.py index f4311b69..b78a7883 100644 --- a/python/ppc_common/ppc_async_executor/thread_event_manager.py +++ b/python/ppc_common/ppc_async_executor/thread_event_manager.py @@ -10,25 +10,25 @@ def __init__(self): self.events: Dict[str, threading.Event] = {} self.rw_lock = rwlock.RWLockWrite() - def add_event(self, task_id: str, event: threading.Event) -> None: + def add_event(self, target_id: str, event: threading.Event) -> None: with self.rw_lock.gen_wlock(): - self.events[task_id] = event + self.events[target_id] = event - def remove_event(self, task_id: str): + def remove_event(self, target_id: str): with self.rw_lock.gen_wlock(): - if task_id in self.events: - del self.events[task_id] + if target_id in self.events: + del self.events[target_id] - def set_event(self, task_id: str): + def set_event(self, target_id: str): with self.rw_lock.gen_wlock(): - if task_id in self.events: - self.events[task_id].set() + if target_id in self.events: + self.events[target_id].set() else: - raise KeyError(f"Task ID {task_id} not found") + raise KeyError(f"Target id {target_id} not found") - def event_status(self, task_id: str) -> bool: + def event_status(self, target_id: str) -> bool: with self.rw_lock.gen_rlock(): - if task_id in self.events: - return self.events[task_id].is_set() + if target_id in self.events: + return self.events[target_id].is_set() else: return False diff --git a/python/ppc_common/ppc_initialize/dataset_handler_initialize.py b/python/ppc_common/ppc_initialize/dataset_handler_initialize.py index 4b2e64fe..3abe855e 100644 --- a/python/ppc_common/ppc_initialize/dataset_handler_initialize.py +++ b/python/ppc_common/ppc_initialize/dataset_handler_initialize.py @@ -10,7 +10,7 @@ class DataSetHandlerInitialize: def __init__(self, config, logger): self._config = config self._logger = logger - self._init_sql_storage() + # self._init_sql_storage() self._init_remote_storage() self._init_dataset_factory() diff --git a/python/ppc_common/ppc_protos/generated/ppc_pb2.py b/python/ppc_common/ppc_protos/generated/ppc_pb2.py index 5c8200ef..623c0434 100644 --- a/python/ppc_common/ppc_protos/generated/ppc_pb2.py +++ b/python/ppc_common/ppc_protos/generated/ppc_pb2.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: ppc.proto -# Protobuf Python Version: 4.25.1 +# Protobuf Python Version: 4.25.3 """Generated protocol buffer code.""" from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool @@ -12,53 +12,21 @@ _sym_db = _symbol_database.Default() -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\tppc.proto\x12\tppc.proto\"\xe5\x01\n\x0f\x41lgorithmDetail\x12\x1d\n\x15\x61lgorithm_description\x18\x01 \x01(\t\x12\x16\n\x0e\x61lgorithm_type\x18\x02 \x01(\t\x12\x13\n\x0b\x63reate_time\x18\x03 \x01(\x03\x12\x13\n\x0bupdate_time\x18\x04 \x01(\x03\x12\x17\n\x0fowner_user_name\x18\x05 \x01(\t\x12\x17\n\x0fowner_agency_id\x18\x06 \x01(\t\x12\x19\n\x11owner_agency_name\x18\x07 \x01(\t\x12$\n\x08ppc_flow\x18\x08 \x01(\x0b\x32\x12.ppc.proto.PpcFlow\"\xc1\x01\n\x07PpcFlow\x12\x13\n\x0binput_count\x18\x01 \x01(\t\x12\x12\n\nsql_module\x18\x02 \x01(\t\x12\x12\n\nmpc_module\x18\x03 \x01(\t\x12\x18\n\x10mpc_model_module\x18\x04 \x01(\t\x12\x14\n\x0cmatch_module\x18\x05 \x01(\t\x12\x19\n\x11\x61lgorithm_subtype\x18\x06 \x01(\t\x12\x1a\n\x12participant_agency\x18\x07 \x01(\t\x12\x12\n\nmodel_type\x18\x08 \x01(\t\"\xee\x01\n\rDatasetDetail\x12\x15\n\rdataset_title\x18\x01 \x01(\t\x12\x1b\n\x13\x64\x61taset_description\x18\x02 \x01(\t\x12\x14\n\x0c\x64\x61taset_hash\x18\x03 \x01(\t\x12\x14\n\x0c\x63olumn_count\x18\x04 \x01(\x03\x12\x11\n\trow_count\x18\x05 \x01(\x03\x12\x13\n\x0b\x63reate_time\x18\x06 \x01(\x03\x12\x11\n\tuser_name\x18\x07 \x01(\t\x12\x17\n\x0f\x64\x61ta_field_list\x18\x08 \x01(\t\x12\x14\n\x0c\x64\x61taset_size\x18\t \x01(\x03\x12\x13\n\x0bupdate_time\x18\n \x01(\x03\"K\n\x0f\x44\x61tasetAuthData\x12\x0f\n\x07is_auth\x18\x01 \x01(\x08\x12\x13\n\x0b\x63reate_time\x18\x02 \x01(\x03\x12\x12\n\nvalid_time\x18\x03 \x01(\x03\"\xde\x01\n\x12jobDatasetProvider\x12\x11\n\tagency_id\x18\x01 \x01(\t\x12\x13\n\x0b\x61gency_name\x18\x02 \x01(\t\x12\x12\n\ndataset_id\x18\x03 \x01(\t\x12\x15\n\rdataset_title\x18\x04 \x01(\t\x12\x15\n\rloading_index\x18\x05 \x01(\x03\x12\x1b\n\x13\x64\x61taset_description\x18\x06 \x01(\t\x12\x17\n\x0fowner_user_name\x18\x07 \x01(\t\x12\x14\n\x0c\x63olumns_size\x18\x08 \x01(\x03\x12\x12\n\nis_instant\x18\t \x01(\x03\"\x94\x01\n\x15\x64\x61tasourceInstantInfo\x12\r\n\x05\x64\x62_ip\x18\x01 \x01(\t\x12\x0f\n\x07\x64\x62_name\x18\x02 \x01(\t\x12\x13\n\x0b\x64\x62_password\x18\x03 \x01(\t\x12\x0f\n\x07\x64\x62_port\x18\x04 \x01(\x03\x12\x0e\n\x06\x64\x62_sql\x18\x05 \x01(\t\x12\x14\n\x0c\x64\x62_user_name\x18\x06 \x01(\t\x12\x0f\n\x07\x64\x62_type\x18\x07 \x01(\t\"I\n\x16jobDatasetProviderList\x12/\n\x08provider\x18\x01 \x03(\x0b\x32\x1d.ppc.proto.jobDatasetProvider\"m\n\x16jobComputationProvider\x12\x11\n\tagency_id\x18\x01 \x01(\t\x12\x13\n\x0b\x61gency_name\x18\x02 \x01(\t\x12\x12\n\nagency_url\x18\x03 \x01(\t\x12\x17\n\x0f\x63omputing_index\x18\x04 \x01(\x03\"Q\n\x1ajobComputationProviderList\x12\x33\n\x08provider\x18\x01 \x03(\x0b\x32!.ppc.proto.jobComputationProvider\")\n\x15jobResultReceiverList\x12\x10\n\x08receiver\x18\x01 \x03(\t\"\x8f\x05\n\x08JobEvent\x12\x0e\n\x06job_id\x18\x01 \x01(\t\x12\x11\n\tjob_title\x18\x02 \x01(\t\x12\x17\n\x0fjob_description\x18\x03 \x01(\t\x12\x14\n\x0cjob_priority\x18\x04 \x01(\x03\x12\x13\n\x0bjob_creator\x18\x05 \x01(\t\x12\x1b\n\x13initiator_agency_id\x18\x06 \x01(\t\x12\x1d\n\x15initiator_agency_name\x18\x07 \x01(\t\x12\x1b\n\x13initiator_signature\x18\x08 \x01(\t\x12\x18\n\x10job_algorithm_id\x18\t \x01(\t\x12\x1b\n\x13job_algorithm_title\x18\n \x01(\t\x12\x1d\n\x15job_algorithm_version\x18\x0b \x01(\t\x12\x1a\n\x12job_algorithm_type\x18\x0c \x01(\t\x12$\n\x1cjob_dataset_provider_list_pb\x18\r \x01(\t\x12#\n\x1bjob_result_receiver_list_pb\x18\x0e \x01(\t\x12\x13\n\x0b\x63reate_time\x18\x0f \x01(\x03\x12\x13\n\x0bupdate_time\x18\x10 \x01(\x03\x12\x12\n\npsi_fields\x18\x11 \x01(\t\x12\x1d\n\x15job_algorithm_subtype\x18\x12 \x01(\t\x12\x14\n\x0cmatch_fields\x18\x13 \x01(\t\x12\x16\n\x0eis_cem_encrypt\x18\x14 \x01(\x08\x12\x17\n\x0f\x64\x61taset_id_list\x18\x15 \x03(\t\x12\x30\n\x0e\x64\x61taset_detail\x18\x16 \x01(\x0b\x32\x18.ppc.proto.DatasetDetail\x12\x1e\n\x16tag_provider_agency_id\x18\x17 \x01(\t\x12\x10\n\x08job_type\x18\x18 \x01(\t\"/\n\tAuditItem\x12\x13\n\x0b\x64\x65scription\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t\"5\n\tAuditData\x12(\n\naudit_item\x18\x01 \x03(\x0b\x32\x14.ppc.proto.AuditItem\"E\n\x10JobOutputPreview\x12\x0e\n\x06header\x18\x01 \x03(\t\x12!\n\x04line\x18\x02 \x03(\x0b\x32\x13.ppc.proto.DataLine\"\x19\n\x08\x44\x61taLine\x12\r\n\x05value\x18\x01 \x03(\t\"=\n\x0eInputStatement\x12\x15\n\rupstream_unit\x18\x01 \x01(\t\x12\x14\n\x0coutput_index\x18\x02 \x01(\x03\"M\n\x16JobUnitInputsStatement\x12\x33\n\x10inputs_statement\x18\x01 \x03(\x0b\x32\x19.ppc.proto.InputStatement\"!\n\x0eJobUnitOutputs\x12\x0f\n\x07outputs\x18\x01 \x03(\t\")\n\x0fJobUnitUpstream\x12\x16\n\x0eupstream_units\x18\x01 \x03(\t\"<\n\tAlgorithm\x12\x14\n\x0c\x61lgorithm_id\x18\x01 \x01(\t\x12\x19\n\x11\x61lgorithm_version\x18\x02 \x01(\t\"6\n\nAlgorithms\x12(\n\nalgorithms\x18\x01 \x03(\x0b\x32\x14.ppc.proto.Algorithmb\x06proto3') + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\tppc.proto\x12\tppc.proto\"8\n\x0eInputStatement\x12\x10\n\x08upstream\x18\x01 \x01(\t\x12\x14\n\x0coutput_index\x18\x02 \x01(\x03\"O\n\x18JobWorkerInputsStatement\x12\x33\n\x10inputs_statement\x18\x01 \x03(\x0b\x32\x19.ppc.proto.InputStatement\"\'\n\x12JobWorkerUpstreams\x12\x11\n\tupstreams\x18\x01 \x03(\t\"#\n\x10JobWorkerOutputs\x12\x0f\n\x07outputs\x18\x01 \x03(\tb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'ppc_pb2', _globals) if _descriptor._USE_C_DESCRIPTORS == False: - DESCRIPTOR._options = None - _globals['_ALGORITHMDETAIL']._serialized_start = 25 - _globals['_ALGORITHMDETAIL']._serialized_end = 254 - _globals['_PPCFLOW']._serialized_start = 257 - _globals['_PPCFLOW']._serialized_end = 450 - _globals['_DATASETDETAIL']._serialized_start = 453 - _globals['_DATASETDETAIL']._serialized_end = 691 - _globals['_DATASETAUTHDATA']._serialized_start = 693 - _globals['_DATASETAUTHDATA']._serialized_end = 768 - _globals['_JOBDATASETPROVIDER']._serialized_start = 771 - _globals['_JOBDATASETPROVIDER']._serialized_end = 993 - _globals['_DATASOURCEINSTANTINFO']._serialized_start = 996 - _globals['_DATASOURCEINSTANTINFO']._serialized_end = 1144 - _globals['_JOBDATASETPROVIDERLIST']._serialized_start = 1146 - _globals['_JOBDATASETPROVIDERLIST']._serialized_end = 1219 - _globals['_JOBCOMPUTATIONPROVIDER']._serialized_start = 1221 - _globals['_JOBCOMPUTATIONPROVIDER']._serialized_end = 1330 - _globals['_JOBCOMPUTATIONPROVIDERLIST']._serialized_start = 1332 - _globals['_JOBCOMPUTATIONPROVIDERLIST']._serialized_end = 1413 - _globals['_JOBRESULTRECEIVERLIST']._serialized_start = 1415 - _globals['_JOBRESULTRECEIVERLIST']._serialized_end = 1456 - _globals['_JOBEVENT']._serialized_start = 1459 - _globals['_JOBEVENT']._serialized_end = 2114 - _globals['_AUDITITEM']._serialized_start = 2116 - _globals['_AUDITITEM']._serialized_end = 2163 - _globals['_AUDITDATA']._serialized_start = 2165 - _globals['_AUDITDATA']._serialized_end = 2218 - _globals['_JOBOUTPUTPREVIEW']._serialized_start = 2220 - _globals['_JOBOUTPUTPREVIEW']._serialized_end = 2289 - _globals['_DATALINE']._serialized_start = 2291 - _globals['_DATALINE']._serialized_end = 2316 - _globals['_INPUTSTATEMENT']._serialized_start = 2318 - _globals['_INPUTSTATEMENT']._serialized_end = 2379 - _globals['_JOBUNITINPUTSSTATEMENT']._serialized_start = 2381 - _globals['_JOBUNITINPUTSSTATEMENT']._serialized_end = 2458 - _globals['_JOBUNITOUTPUTS']._serialized_start = 2460 - _globals['_JOBUNITOUTPUTS']._serialized_end = 2493 - _globals['_JOBUNITUPSTREAM']._serialized_start = 2495 - _globals['_JOBUNITUPSTREAM']._serialized_end = 2536 - _globals['_ALGORITHM']._serialized_start = 2538 - _globals['_ALGORITHM']._serialized_end = 2598 - _globals['_ALGORITHMS']._serialized_start = 2600 - _globals['_ALGORITHMS']._serialized_end = 2654 + DESCRIPTOR._options = None + _globals['_INPUTSTATEMENT']._serialized_start=24 + _globals['_INPUTSTATEMENT']._serialized_end=80 + _globals['_JOBWORKERINPUTSSTATEMENT']._serialized_start=82 + _globals['_JOBWORKERINPUTSSTATEMENT']._serialized_end=161 + _globals['_JOBWORKERUPSTREAMS']._serialized_start=163 + _globals['_JOBWORKERUPSTREAMS']._serialized_end=202 + _globals['_JOBWORKEROUTPUTS']._serialized_start=204 + _globals['_JOBWORKEROUTPUTS']._serialized_end=239 # @@protoc_insertion_point(module_scope) diff --git a/python/ppc_common/ppc_protos/ppc.proto b/python/ppc_common/ppc_protos/ppc.proto index a9571c0f..ca9a3a5b 100644 --- a/python/ppc_common/ppc_protos/ppc.proto +++ b/python/ppc_common/ppc_protos/ppc.proto @@ -2,159 +2,20 @@ syntax = "proto3"; package ppc.proto; -message AlgorithmDetail { - string algorithm_description = 1; - string algorithm_type = 2; - int64 create_time = 3; - int64 update_time = 4; - string owner_user_name = 5; - string owner_agency_id = 6; - string owner_agency_name = 7; - PpcFlow ppc_flow = 8; -} - -message PpcFlow { - // input_count need present '3+' - string input_count = 1; - string sql_module = 2; - string mpc_module = 3; - string mpc_model_module = 4; - string match_module = 5; - string algorithm_subtype = 6; - string participant_agency = 7; - string model_type = 8; -} - -message DatasetDetail { - string dataset_title = 1; - string dataset_description = 2; - string dataset_hash = 3; - int64 column_count = 4; - int64 row_count = 5; - int64 create_time = 6; - string user_name = 7; - string data_field_list = 8; - int64 dataset_size = 9; - int64 update_time = 10; -} - -message DatasetAuthData { - bool is_auth = 1; - int64 create_time = 2; - int64 valid_time = 3; -} - - -message jobDatasetProvider { - string agency_id = 1; - string agency_name = 2; - string dataset_id = 3; - string dataset_title = 4; - int64 loading_index = 5; - string dataset_description = 6; - string owner_user_name = 7; - int64 columns_size = 8; - int64 is_instant = 9; -} - -message datasourceInstantInfo { - string db_ip = 1; - string db_name = 2; - string db_password = 3; - int64 db_port = 4; - string db_sql = 5; - string db_user_name = 6; - string db_type = 7; -} - -message jobDatasetProviderList { - repeated jobDatasetProvider provider = 1; -} - -message jobComputationProvider { - string agency_id = 1; - string agency_name = 2; - string agency_url = 3; - int64 computing_index = 4; -} - -message jobComputationProviderList { - repeated jobComputationProvider provider = 1; -} - -message jobResultReceiverList { - repeated string receiver = 1; -} - - -message JobEvent { - string job_id = 1; - string job_title = 2; - string job_description = 3; - int64 job_priority = 4; - string job_creator = 5; - string initiator_agency_id = 6; - string initiator_agency_name = 7; - string initiator_signature = 8; - string job_algorithm_id = 9; - string job_algorithm_title = 10; - string job_algorithm_version = 11; - string job_algorithm_type = 12; - string job_dataset_provider_list_pb = 13; // jobDatasetProviderList - string job_result_receiver_list_pb = 14; // jobResultReceiverList - int64 create_time = 15; - int64 update_time = 16; - string psi_fields = 17; - string job_algorithm_subtype = 18; - string match_fields = 19; - bool is_cem_encrypt = 20; - repeated string dataset_id_list = 21; // cem encrypt dataset id without username prefix - DatasetDetail dataset_detail = 22; // cem encrypt dataset transfer - string tag_provider_agency_id = 23; - string job_type = 24; -} - -message AuditItem { - string description = 1; - string value = 2; -} - -message AuditData { - repeated AuditItem audit_item = 1; -} - -message JobOutputPreview { - repeated string header = 1; - repeated DataLine line = 2; -} - -message DataLine { - repeated string value = 1; -} message InputStatement { - string upstream_unit = 1; + string upstream = 1; int64 output_index = 2; } -message JobUnitInputsStatement { +message JobWorkerInputsStatement { repeated InputStatement inputs_statement = 1; } -message JobUnitOutputs { - repeated string outputs = 1; -} - -message JobUnitUpstream { - repeated string upstream_units = 1; +message JobWorkerUpstreams { + repeated string upstreams = 1; } -message Algorithm { - string algorithm_id = 1; - string algorithm_version = 2; -} - -message Algorithms { - repeated Algorithm algorithms = 1; +message JobWorkerOutputs { + repeated string outputs = 1; } - diff --git a/python/ppc_common/ppc_utils/anonymous_search.py b/python/ppc_common/ppc_utils/anonymous_search.py deleted file mode 100644 index 80c3df61..00000000 --- a/python/ppc_common/ppc_utils/anonymous_search.py +++ /dev/null @@ -1,420 +0,0 @@ -import hashlib -import json -import logging -import os -import random -import string -import unittest -import uuid -from io import BytesIO - -import pandas as pd - -from ppc_common.ppc_utils import utils, http_utils -from ppc_common.ppc_crypto import crypto_utils -from ppc_common.ppc_utils.exception import PpcException, PpcErrorCode - -log = logging.getLogger(__name__) - - -def make_hash(data): - m = hashlib.sha3_256() - m.update(data) - return m.hexdigest() - - -def requester_gen_ot_cipher(id_list, obfuscation_order): - blinding_a = crypto_utils.get_random_int() - blinding_b = crypto_utils.get_random_int() - x_value = crypto_utils.ot_base_pown(blinding_a) - y_value = crypto_utils.ot_base_pown(blinding_b) - c_blinding = crypto_utils.ot_mul_fi(blinding_a, blinding_b) - id_index_list = [] - send_hash_vec = [] - z_value_list = [] - for id_hash in id_list: - obs_list = id_obfuscation(obfuscation_order, None) - id_index = random.randint(0, obfuscation_order) - obs_list[id_index] = id_hash - id_index_list.append(id_index) - send_hash_vec.append(obs_list) - z_value = crypto_utils.ot_base_pown(c_blinding - id_index) - z_value_list.append(str(z_value)) - return id_index_list, blinding_b, x_value, y_value, send_hash_vec, z_value_list - - -def provider_gen_ot_cipher(x_value, y_value, send_hash_vec, z_value_list, data_map, is_contain_result=True): - """ - data_map = hashmap[id1: "message1", - id2: "message2"] - """ - if isinstance(x_value, str): - x_value = int(x_value) - if isinstance(y_value, str): - y_value = int(y_value) - if len(send_hash_vec) != len(z_value_list): - raise PpcException(PpcErrorCode.AYS_LENGTH_ERROR.get_code(), - PpcErrorCode.AYS_LENGTH_ERROR.get_msg()) - message_cipher_vec = [] - for idx, z_value in enumerate(z_value_list): - if isinstance(z_value, str): - z_value = int(z_value) - message_cipher_list = [] - message_int_len = 0 - for send_hash in send_hash_vec[idx]: - blinding_r = crypto_utils.get_random_int() - blinding_s = crypto_utils.get_random_int() - w_value = crypto_utils.ot_mul_n(crypto_utils.ot_pown(x_value, blinding_s), - crypto_utils.ot_base_pown(blinding_r)) - key_value = crypto_utils.ot_mul_n(crypto_utils.ot_pown(z_value, blinding_s), - crypto_utils.ot_pown(y_value, blinding_r)) - aes_key_bytes = os.urandom(16) - aes_default = utils.AESCipher(aes_key_bytes) - aes_key_base64str = utils.bytes_to_base64str(aes_key_bytes) - message_int, message_int_len = crypto_utils.ot_str_to_int( - aes_key_base64str) - - if send_hash not in data_map: - letters = string.ascii_lowercase - # random string - ot_message_str = 'message not found' - cipher = aes_default.encrypt(ot_message_str) - cipher_str = utils.bytes_to_base64str(cipher) - else: - if is_contain_result: - ot_message_str = json.dumps(data_map[send_hash]) - cipher = aes_default.encrypt(ot_message_str) - cipher_str = utils.bytes_to_base64str(cipher) - else: - ot_message_str = 'True' - cipher = aes_default.encrypt(ot_message_str) - cipher_str = utils.bytes_to_base64str(cipher) - message_cipher = key_value ^ message_int - z_value = crypto_utils.ot_mul_n(z_value, crypto_utils.DEFAULT_G) - message_cipher_list.append({ - "w": str(w_value), - "e": str(message_cipher), - "len": message_int_len, - 'aesCipher': cipher_str, - }) - message_cipher_vec.append(message_cipher_list) - return message_cipher_vec - - -def requester_ot_recover_result(id_index_list, blinding_b, message_cipher_vec): - if len(id_index_list) != len(message_cipher_vec): - raise PpcException(PpcErrorCode.AYS_LENGTH_ERROR.get_code(), - PpcErrorCode.AYS_LENGTH_ERROR.get_msg()) - result_list = [] - for idx, id_index in enumerate(id_index_list): - w_value = message_cipher_vec[idx][id_index]['w'] - e_value = message_cipher_vec[idx][id_index]['e'] - message_len = message_cipher_vec[idx][id_index]['len'] - cipher_str = message_cipher_vec[idx][id_index]['aesCipher'] - - if isinstance(w_value, str): - w_value = int(w_value) - if isinstance(e_value, str): - e_value = int(e_value) - w_1 = crypto_utils.ot_pown(w_value, blinding_b) - message_recover = w_1 ^ e_value - try: - aes_key = crypto_utils.ot_int_to_str(message_recover, message_len) - key_recover = utils.base64str_to_bytes(aes_key) - aes_recover = utils.AESCipher(key_recover) - cipher_recover = utils.base64str_to_bytes(cipher_str) - message_result = aes_recover.decrypt(cipher_recover) - result_list.append(message_result) - except Exception as be: - result_list.append(None) - return result_list - - -def id_obfuscation(obfuscation_order, rule=None): - # use rule extend different id order, such as driver card or name: - if rule is not None: - print('obfuscation is work in progress') - obs_list = [] - for i in range(0, obfuscation_order+1): - obs_list.append(make_hash(bytes(str(uuid.uuid4()), 'utf8'))) - return obs_list - - -def prepare_dataset_with_matrix(search_id_matrix, tmp_file_path, prepare_dataset_tmp_file_path): - search_reg = 'ppc-normal-prefix' - for search_list in search_id_matrix: - for search_id in search_list: - search_reg = '{}|{}'.format(search_reg, search_id) - exec_command = 'head -n 1 {} >> {}'.format( - tmp_file_path, prepare_dataset_tmp_file_path) - (status, result) = utils.getstatusoutput(exec_command) - if status != 0: - log.error( - f'[OnError]prepare_dataset_with_matrix! status is {status}, output is {result}') - else: - log.info( - f'prepare_dataset_with_matrix success! status is {status}, output is {result}') - exec_command = 'grep -E \'{}\' {} >> {}'.format( - search_reg, tmp_file_path, prepare_dataset_tmp_file_path) - (status, result) = utils.getstatusoutput(exec_command) - if status != 0: - log.error( - f'[OnError]prepare_dataset_with_matrix! status is {status}, output is {result}') - else: - log.info( - f'prepare_dataset_with_matrix success! status is {status}, output is {result}') - - -class TestOtMethods(unittest.TestCase): - - def test_choice_all_flow(self): - file_path = '/Users/asher/Downloads/test_file.csv' - # data_pd = pd.read_csv(file_path, index_col=0, header=0) - data_pd = pd.read_csv(file_path) - data_map = data_pd.set_index(data_pd.columns[0]).T.to_dict('list') - - choice = ['bob', '小鸡', '美丽'] - obs_order = 10 - _id_index_list, _blinding_b, _x_value, _y_value, _send_hash_vec, _z_value_list = requester_gen_ot_cipher( - choice, obs_order) - _message_cipher_vec = provider_gen_ot_cipher( - _x_value, _y_value, _send_hash_vec, _z_value_list, data_map) - result = requester_ot_recover_result( - _id_index_list, _blinding_b, _message_cipher_vec) - for idx, id_num in enumerate(choice): - print(f"found {id_num} value is {result[idx]}") - - # def test_choice_ot(self): - # choice_list = [1, 4] - # blinding_a = crypto_utils.get_random_int() - # blinding_b = crypto_utils.get_random_int() - # x_value = crypto_utils.ot_base_pown(blinding_a) - # y_value = crypto_utils.ot_base_pown(blinding_b) - # c_blinding = crypto_utils.ot_mul_fi(blinding_a, blinding_b) - # # c_value = crypto_utils.ot_base_pown(c_blinding) - # z_value_list = [] - # for choice in choice_list: - # z_value = crypto_utils.ot_base_pown(c_blinding - choice) - # z_value_list.append(z_value) - # # send x_value, y_value, z_value - # message_str_list = ['hello', 'world', 'ot', 'cipher', 'test'] - # message_list = [] - # for message_str in message_str_list: - # message_int, message_int_len = crypto_utils.ot_str_to_int(message_str) - # message_list.append(message_int) - # # message_list = [111111, 222222, 333333, 444444, 555555] - # cipher_vec = [] - # for z_value in z_value_list: - # cipher_list = [] - # for message in message_list: - # blinding_r = crypto_utils.get_random_int() - # blinding_s = crypto_utils.get_random_int() - # w_value = crypto_utils.ot_mul_n(crypto_utils.ot_pown(x_value, blinding_s), - # crypto_utils.ot_base_pown(blinding_r)) - # key_value = crypto_utils.ot_mul_n(crypto_utils.ot_pown(z_value, blinding_s), - # crypto_utils.ot_pown(y_value, blinding_r)) - # z_value = crypto_utils.ot_mul_n(z_value, crypto_utils.DEFAULT_G) - # e_cipher = key_value ^ message - # cipher_list.append({ - # "w": w_value, - # "e": e_cipher - # }) - # cipher_vec.append(cipher_list) - # - # for cipher_each in cipher_vec: - # for cipher in cipher_each: - # w_1 = crypto_utils.ot_pown(cipher['w'], blinding_b) - # message_recover = w_1 ^ cipher['e'] - # print(message_recover) - # for idx, cipher_list in enumerate(cipher_vec): - # w_1 = crypto_utils.ot_pown(cipher_list[choice_list[idx]]['w'], blinding_b) - # message_recover = w_1 ^ cipher_list[choice_list[idx]]['e'] - # s = crypto_utils.ot_int_to_str(message_recover) - # print(s) - - # def test_id_ot(self): - # print("test_id_ot") - # choice_id_list = [crypto_utils.ot_str_to_int('小明'), crypto_utils.ot_str_to_int('张三')] - # blinding_a = crypto_utils.get_random_int() - # blinding_b = crypto_utils.get_random_int() - # x_value = crypto_utils.ot_base_pown(blinding_a) - # y_value = crypto_utils.ot_base_pown(blinding_b) - # c_blinding = crypto_utils.ot_mul_fi(blinding_a, blinding_b) - # # c_value = crypto_utils.ot_base_pown(c_blinding) - # z_value_list = [] - # for choice in choice_id_list: - # z_value = crypto_utils.ot_base_pown(c_blinding - choice) - # z_value_list.append(z_value) - # # z_value = crypto_utils.ot_base_pown(c_blinding - choice_id) - # # send x_value, y_value, z_value - # id_str_list = ['小往', '小明', 'asher', 'set', '张三'] - # id_list = [] - # for id_str in id_str_list: - # id_list.append(crypto_utils.ot_str_to_int(id_str)) - # message_str_list = ['hello', 'world', 'ot', 'cipher', 'test'] - # message_list = [] - # for message_str in message_str_list: - # message_list.append(crypto_utils.ot_str_to_int(message_str)) - # # message_list = [111111, 222222, 333333, 444444, 555555] - # cipher_vec = [] - # for z_value in z_value_list: - # cipher_list = [] - # for idx, message in enumerate(message_list): - # blinding_r = crypto_utils.get_random_int() - # blinding_s = crypto_utils.get_random_int() - # z_value_use = crypto_utils.ot_mul_n(z_value, crypto_utils.ot_base_pown(id_list[idx])) - # w_value = crypto_utils.ot_mul_n(crypto_utils.ot_pown(x_value, blinding_s), - # crypto_utils.ot_base_pown(blinding_r)) - # key_value = crypto_utils.ot_mul_n(crypto_utils.ot_pown(z_value_use, blinding_s), - # crypto_utils.ot_pown(y_value, blinding_r)) - # e_cipher = key_value ^ message - # cipher_list.append({ - # "w": w_value, - # "e": e_cipher - # }) - # cipher_vec.append(cipher_list) - # - # for idx, cipher_list in enumerate(cipher_vec): - # for now_idx, cipher in enumerate(cipher_list): - # w_1 = crypto_utils.ot_pown(cipher['w'], blinding_b) - # message_recover = w_1 ^ cipher['e'] - # # print(message_recover) - # # print(idx) - # # print(now_idx) - # if (idx == 0 and now_idx == 1) or (idx == 1 and now_idx == 4): - # s = crypto_utils.ot_int_to_str(message_recover) - # print(s) - - def test_pd_with_multi_index(self): - print(True) - df = pd.DataFrame( - [[21, 'Amol', 72, 67], - [23, 'Lini', 78, 69], - [32, 'Kiku', 74, 56], - [52, 'Ajit', 54, 76], - [53, 'Ajit', 55, 78] - ], - columns=['rollno', 'name', 'physics', 'botony']) - - print('DataFrame with default index\n', df) - # set multiple columns as index - # df_map = df.set_index('name').T.to_dict('list') - # df_map = df.set_index('name').groupby(level=0).apply(lambda x: x.to_dict('r')).to_dict() - df_map = df.set_index('name').groupby(level=0).apply( - lambda x: x.to_dict('r')).to_dict() - print(df_map) - print(json.dumps(df_map['Ajit'])) - print(type(json.dumps(df_map['Kiku']))) - - def test_prepare_dataset_with_matrix(self): - tmp_file_path = "/Users/asher/Desktop/数据集2021/8_1_100w.csv" - prepare_dataset_tmp_file_path = "/Users/asher/Desktop/数据集2021/pre-test_100.csv" - search_id_matrix = [['645515750175253924', '779808531920530393'], [ - '399352968694137676', '399352968694137676222', '399352968694137']] - prepare_dataset_with_matrix( - search_id_matrix, tmp_file_path, prepare_dataset_tmp_file_path) - - # - # def test_choice_ot_multi(self): - # choice_list = [1, 2, 4] - # blinding_a = crypto_utils.get_random_int() - # blinding_b = crypto_utils.get_random_int() - # x_value = crypto_utils.ot_base_pown(blinding_a) - # y_value = crypto_utils.ot_base_pown(blinding_b) - # c_blinding = crypto_utils.ot_mul_fi(blinding_a, blinding_b) - # # c_value = crypto_utils.ot_base_pown(c_blinding) - # choice_final = 0 - # for choice in choice_list: - # choice_final = choice_final + choice - # z_value = crypto_utils.ot_base_pown(c_blinding - choice_final) - # # send x_value, y_value, z_value - # message_str_list = ['hello', 'world', 'ot', 'cipher', 'test'] - # message_list = [] - # for message_str in message_str_list: - # message_list.append(crypto_utils.ot_str_to_int(message_str)) - # # message_list = [111111, 222222, 333333, 444444, 555555] - # cipher_list = [] - # for message in message_list: - # blinding_r = crypto_utils.get_random_int() - # blinding_s = crypto_utils.get_random_int() - # w_value = crypto_utils.ot_mul_n(crypto_utils.ot_pown(x_value, blinding_s), - # crypto_utils.ot_base_pown(blinding_r)) - # key_value = crypto_utils.ot_mul_n(crypto_utils.ot_pown(z_value, blinding_s), - # crypto_utils.ot_pown(y_value, blinding_r)) - # z_value = crypto_utils.ot_mul_n(z_value, crypto_utils.DEFAULT_G) - # e_cipher = key_value ^ message - # cipher_list.append({ - # "w": w_value, - # "e": e_cipher - # }) - # - # # for cipher in cipher_list: - # # w_1 = crypto_utils.ot_pown(cipher['w'], blinding_b) - # # message_recover = w_1 ^ cipher['e'] - # # print(message_recover) - # - # for choice in choice_list: - # w_1 = crypto_utils.ot_pown(cipher_list[choice]['w'], blinding_b) - # for item in choice_list: - # if choice == item: - # continue - # else: - # base_g = crypto_utils.ot_base_pown(-item) - # w_1 = crypto_utils.ot_mul_n(w_1, base_g) - # message_recover = w_1 ^ cipher_list[choice]['e'] - # print(message_recover) - # # s = crypto_utils.ot_int_to_str(message_recover) - # # print(s) - # - - -# if __name__ == '__main__': - - # print(True) - # df = pd.DataFrame( - # [[21, 'Amol', 72, 67], - # [23, 'Lini', 78, 69], - # [32, 'Kiku', 74, 56], - # [52, 'Ajit', 54, 76]], - # columns=['rollno', 'name', 'physics', 'botony']) - # - # print('DataFrame with default index\n', df) - # # set multiple columns as index - # df = df.set_index(['rollno', 'name']) - # - # print('\nDataFrame with MultiIndex\n', df) - # point1 = point.base() - # print(point.base(scalar1).hex()) - - # json_response = "{\"id\": {\"0\": \"67b176705b46206614219f47a05aee7ae6a3edbe850bbbe214c536b989aea4d2\", \"1\": \"b1b1bd1ed240b1496c81ccf19ceccf2af6fd24fac10ae42023628abbe2687310\"}, \"x0\": {\"0\": 10, \"1\": 20}, \"x1\": {\"0\": 11, \"1\": 22}}" - # json_dict = json.loads(json_response) - # # json_pd = pd.json_normalize(json_dict) - # print(json_dict) - - # hex_str = utils.make_hash(bytes(str(796443), 'utf8'), CryptoType.ECDSA, HashType.HEXSTR) - # print(hex_str) - # csv_path1 = '/Users/asher/Downloads/UseCase120/usecase120_party1.csv' - # csv_path2 = '/Users/asher/Downloads/UseCase120/usecase120_party2.csv' - # data = pd.read_csv(csv_path1) - # if 'id' in data.columns.values: - # duplicated_list = data.duplicated('id', False).tolist() - # if True in duplicated_list: - # log.error(f"[OnError]id duplicated, check csv file") - # raise PpcException(PpcErrorCode.DATASET_CSV_ERROR.get_code(), - # PpcErrorCode.DATASET_CSV_ERROR.get_msg()) - # start = time.time() - # get_pd_file_with_hash_requester(csv_path1, f'{csv_path1}-pre', 2) - # start2 = time.time() - # print(f"requester prepare time is {start2 - start}s") - # - # get_pd_file_with_hash_data_provider(csv_path2, f'{csv_path2}-pre') - # start3 = time.time() - # print(f"provider prepare time is {start3 - start2}s") - # provider_output_path = '/Users/asher/Downloads/UseCase120/output.csv' - # get_anonymous_data(f'{csv_path1}-pre-requester', f'{csv_path2}-pre', provider_output_path) - # start4 = time.time() - # print(f"provider compute time is {start4 - start3}s") - # result_path = '/Users/asher/Downloads/UseCase120/result.csv' - # recover_result_data(f'{csv_path1}-pre', '/Users/asher/Downloads/UseCase120/output.csv', result_path) - # end = time.time() - # print(f"requester get result time is {end - start4}s") diff --git a/python/ppc_common/ppc_utils/audit_utils.py b/python/ppc_common/ppc_utils/audit_utils.py deleted file mode 100644 index 913a3934..00000000 --- a/python/ppc_common/ppc_utils/audit_utils.py +++ /dev/null @@ -1,56 +0,0 @@ -# -*- coding: utf-8 -*- -import getopt -import sys - -from ppc_common.ppc_utils import utils -from ppc_common.ppc_utils.utils import CryptoType - -AUDIT_KEYS = ["agency_name", "input_dataset_hash", - "psi_input_hash", "psi_output_hash", - "mpc_result_hash"] - - -def parse_parameter(argv): - file_path = 0 - data_hash_value = 0 - crypto_type = None - try: - opts, args = getopt.getopt( - argv, "hf:v:c:", ["file_path=", "data_hash_value="]) - except getopt.GetoptError: - usage() - sys.exit(2) - if len(opts) == 0: - usage() - sys.exit(2) - for opt, arg in opts: - if opt == '-h': - usage() - sys.exit(0) - elif opt in ("-f", "--file_path"): - file_path = arg - elif opt in ("-v", "--data_hash_value"): - data_hash_value = arg - elif opt in ("-c", "--crypto_type"): - crypto_type = arg - else: - usage() - sys.exit(2) - return file_path, data_hash_value, crypto_type - - -def usage(): - print('audit.py -f -v -c crypto_type') - print('usage:') - print(' -f Notice:The file will be hashed to audit.') - print(' -v Notice:The dataset hash value will be compared to audit.') - print( - ' -c Notice:The crypto type[ECDSA or GM] will be used.') - - -if __name__ == '__main__': - file_path, data_hash_value, crypto_type = parse_parameter(sys.argv[1:]) - file_data_hash = utils.make_hash_from_file_path( - file_path, CryptoType[crypto_type]) - print(f'The file hash is:{file_data_hash}') - print(f'Audit result is:{file_data_hash == data_hash_value}') diff --git a/python/ppc_common/ppc_utils/cem_utils.py b/python/ppc_common/ppc_utils/cem_utils.py deleted file mode 100644 index 6e9d9d5f..00000000 --- a/python/ppc_common/ppc_utils/cem_utils.py +++ /dev/null @@ -1,160 +0,0 @@ -import json -import unittest - -LEGAL_OPERATOR_LIST = ['>', '<', '>=', '<=', '==', '!='] -LEGAL_QUOTA_LIST = ['and', 'or'] - - -def get_rule_detail(match_module_dict): - ruleset = match_module_dict['ruleset'] - operator_set = set() - count_set = set() - quota_set = set() - for ruleset_item in ruleset: - sub_rule_set = ruleset_item['set'] - for sub_rule_item in sub_rule_set: - rule = sub_rule_item['rule'] - operator_set.add(rule['operator']) - count_set.add(rule['count']) - quota_set.add(rule['quota']) - return operator_set, count_set, quota_set - - -def check_dataset_id_has_duplicated(match_module_dict): - ruleset = match_module_dict['ruleset'] - for ruleset_item in ruleset: - sub_rule_set = ruleset_item['set'] - for sub_rule_item in sub_rule_set: - dataset_id_set = set() - dataset_id_list = sub_rule_item['dataset'] - for dataset_id in dataset_id_list: - dataset_id_set.add(dataset_id) - if len(dataset_id_set) != len(dataset_id_list): - return True - return False - - -def get_dataset_id_set(match_module_dict): - dataset_id_set = set() - ruleset = match_module_dict['ruleset'] - for ruleset_item in ruleset: - sub_rule_set = ruleset_item['set'] - for sub_rule_item in sub_rule_set: - dataset_id_list = sub_rule_item['dataset'] - for dataset_id in dataset_id_list: - dataset_id_set.add(dataset_id) - return dataset_id_set - - -# get field_dataset_map: {'x1':['d1', 'd2', 'd3'], 'x2':['d1', 'd2', 'd3']} -def parse_field_dataset_map(match_module_dict): - field_dataset_map = {} - ruleset = match_module_dict['ruleset'] - for ruleset_item in ruleset: - field = ruleset_item['field'] - if field in field_dataset_map: - field_dataset_id_set = field_dataset_map[field] - else: - field_dataset_id_set = set() - field_dataset_map[field] = field_dataset_id_set - sub_rule_set = ruleset_item['set'] - for sub_rule_item in sub_rule_set: - dataset_id_list = sub_rule_item['dataset'] - for dataset_id in dataset_id_list: - field_dataset_id_set.add(dataset_id) - return field_dataset_map - - -# get dataset_field_map: {'d1':['x1', 'x2'], 'd2':['x1', 'x2'], 'd3':['x1', 'x2']} -def parse_dataset_field_map(dataset_id_set, field_dataset_map): - dataset_field_map = {} - for dataset_id in dataset_id_set: - for field, field_dataset_id_set in field_dataset_map.items(): - if dataset_id in field_dataset_id_set: - if dataset_id in dataset_field_map: - dataset_field_set = dataset_field_map[dataset_id] - else: - dataset_field_set = set() - dataset_field_map[dataset_id] = dataset_field_set - dataset_field_set.add(field) - return dataset_field_map - - -def parse_match_param(match_fields, match_module_dict): - # step1 get field_dataset_map: {'x1':['d1', 'd2', 'd3'], 'x2':['d1', 'd2', 'd3']} - field_dataset_map = parse_field_dataset_map(match_module_dict) - # step2 get all dataset_id {'d1', 'd2', 'd3'} - dataset_id_set = set() - for field, field_dataset_id_set in field_dataset_map.items(): - dataset_id_set.update(field_dataset_id_set) - # step3 get dataset_field_map: {'d1':['x1', 'x2'], 'd2':['x1', 'x2'], 'd3':['x1', 'x2']} - dataset_field_map = parse_dataset_field_map( - dataset_id_set, field_dataset_map) - # step4 get match_param_list from dataset_field_map and match_field: - # [ - # {'dataset_id':'d1', 'match_field':{'x1':'xxx, 'x2':'xxx}, - # {'dataset_id':'d2', 'match_field':{'x1':'xxx, 'x2':'xxx}, - # {'dataset_id':'d3', 'match_field':{'x1':'xxx, 'x2':'xxx}, - # ] - match_param_list = parse_match_param_list(dataset_field_map, match_fields) - return dataset_id_set, match_param_list - - -# get match_param_list from dataset_field_map and match_field: -# [ -# {'dataset_id':'d1', 'match_field':{'x1':'xxx, 'x2':'xxx}, -# {'dataset_id':'d2', 'match_field':{'x1':'xxx, 'x2':'xxx}, -# {'dataset_id':'d3', 'match_field':{'x1':'xxx, 'x2':'xxx}, -# ] -def parse_match_param_list(dataset_field_map, match_fields): - match_param_list = [] - match_fields = match_fields.replace("'", '"') - match_fields_object = json.loads(match_fields) - for dataset_id, field_set in dataset_field_map.items(): - match_param = {'dataset_id': dataset_id} - field_value_map = {} - for field in field_set: - # allow some part field match - if field in match_fields_object.keys(): - field_value_map[field] = match_fields_object[field] - match_param['match_field'] = field_value_map - match_param_list.append(match_param) - return match_param_list - - -class TestCemUtils(unittest.TestCase): - def test_cem_match_algorithm_load(self): - match_module = '{"ruleset":[' \ - '{"field":"x1",' \ - '"set":[' \ - '{"rule":{"operator":"<","count":50,"quota":"and"},"dataset":["d1-encrypted","d2-encrypted"]},' \ - '{"rule":{"operator":">","count":3,"quota":"or"},"dataset":["d3-encrypted"]}]},' \ - '{"field":"x2","set":[' \ - '{"rule":{"operator":"<","count":2,"quota":"or"},"dataset":["d1-encrypted","d2-encrypted","d3-encrypted"]}]}]}' - match_module_dict = json.loads(match_module) - # match_module_dict = utils.json_loads(match_module) - print(match_module_dict) - - def test_check_dataset_id_has_duplicated(self): - match_module = '{"ruleset":[' \ - '{"field":"x1",' \ - '"set":[' \ - '{"rule":{"operator":"<","count":50,"quota":"and"},"dataset":["d1-encrypted","d2-encrypted"]},' \ - '{"rule":{"operator":">","count":3,"quota":"or"},"dataset":["d3-encrypted"]}]},' \ - '{"field":"x2","set":[' \ - '{"rule":{"operator":"<","count":2,"quota":"or"},"dataset":["d1-encrypted","d2-encrypted","d3-encrypted"]}]}]}' - match_module_dict = json.loads(match_module) - has_duplicated = check_dataset_id_has_duplicated(match_module_dict) - assert has_duplicated == False - - def test_check_dataset_id_has_duplicated(self): - match_module = '{"ruleset":[' \ - '{"field":"x1",' \ - '"set":[' \ - '{"rule":{"operator":"<","count":50,"quota":"and"},"dataset":["d1-encrypted","d2-encrypted"]},' \ - '{"rule":{"operator":">","count":3,"quota":"or"},"dataset":["d3-encrypted"]}]},' \ - '{"field":"x2","set":[' \ - '{"rule":{"operator":"<","count":2,"quota":"or"},"dataset":["d1-encrypted","d1-encrypted","d3-encrypted"]}]}]}' - match_module_dict = json.loads(match_module) - has_duplicated = check_dataset_id_has_duplicated(match_module_dict) - assert has_duplicated == True diff --git a/python/ppc_common/ppc_utils/permission.py b/python/ppc_common/ppc_utils/permission.py deleted file mode 100644 index 7eb0c619..00000000 --- a/python/ppc_common/ppc_utils/permission.py +++ /dev/null @@ -1,79 +0,0 @@ -from enum import Enum - -from ppc_common.ppc_utils.exception import PpcException, PpcErrorCode -from ppc_common.ppc_utils.utils import JobRole, JobStatus - - -def check_job_status(job_status): - if job_status in JobStatus.__members__: - return True - return False - - -def check_job_role(job_role): - if job_role in JobRole.__members__: - return True - return False - - -ADMIN_PERMISSIONS = 'ADMIN_PERMISSIONS' - - -class UserRole(Enum): - ADMIN = 1 - DATA_PROVIDER = 2 - ALGO_PROVIDER = 3 - DATA_CONSUMER = 4 - - -class PermissionGroup(Enum): - AGENCY_GROUP = 1 - DATA_GROUP = 2 - ALGO_GROUP = 3 - JOB_GROUP = 4 - AUDIT_GROUP = 5 - - -class AgencyGroup(Enum): - LIST_AGENCY = 1 - WRITE_AGENCY = 2 - - -class DataGroup(Enum): - LIST_DATA = 1 - READ_DATA_PUBLIC_INFO = 2 - READ_DATA_PRIVATE_INFO = 3 - WRITE_DATA = 4 - - -class AlgoGroup(Enum): - LIST_ALGO = 1 - READ_ALGO_PUBLIC_INFO = 2 - READ_ALGO_PRIVATE_INFO = 3 - WRITE_ALGO = 4 - - -class JobGroup(Enum): - LIST_JOB = 1 - READ_JOB_PUBLIC_INFO = 2 - READ_JOB_PRIVATE_INFO = 3 - WRITE_JOB = 4 - - -class AuditGroup(Enum): - READ_AUDIT = 1 - WRITE_AUDIT = 2 - - -# permissions formed as permission_a|permission_b|permission_a_group|... -def check_permission(permissions, needed_permission_group, *needed_permissions): - permission_list = permissions.split('|') - if ADMIN_PERMISSIONS in permission_list: - return 0 - if needed_permission_group in permission_list: - return 1 - for needed_permission in needed_permissions: - if needed_permission in permission_list: - return 1 - raise PpcException(PpcErrorCode.INSUFFICIENT_AUTHORITY.get_code( - ), PpcErrorCode.INSUFFICIENT_AUTHORITY.get_msg()) diff --git a/python/ppc_common/ppc_utils/utils.py b/python/ppc_common/ppc_utils/utils.py index d2564cce..51605969 100644 --- a/python/ppc_common/ppc_utils/utils.py +++ b/python/ppc_common/ppc_utils/utils.py @@ -1,29 +1,22 @@ import base64 -import datetime import hashlib import io import json import logging import os -import random import re import shutil -import string import subprocess import time -import uuid -from concurrent.futures.thread import ThreadPoolExecutor -from enum import Enum, IntEnum, unique +from enum import Enum, unique import jwt -import pandas as pd from ecdsa import SigningKey, SECP256k1, VerifyingKey from gmssl import func, sm2, sm3 from google.protobuf.descriptor import FieldDescriptor from jsoncomment import JsonComment from pysmx.SM3 import SM3 -from ppc_common.ppc_protos.generated.ppc_pb2 import DatasetDetail from ppc_common.ppc_utils.exception import PpcException, PpcErrorCode log = logging.getLogger(__name__) @@ -118,80 +111,17 @@ class CryptoType(Enum): ECDSA = 1 GM = 2 - @unique class HashType(Enum): BYTES = 1 HEXSTR = 2 -class DeployMode(IntEnum): - PRIVATE_MODE = 1 - SAAS_MODE = 2 - PROXY_MODE = 3 - - -class JobRole(Enum): - ALGORITHM_PROVIDER = 1 - COMPUTATION_PROVIDER = 2 - DATASET_PROVIDER = 3 - DATA_CONSUMER = 4 - - -class JobStatus(Enum): - SUCCEED = 1 - FAILED = 2 - RUNNING = 3 - WAITING = 4 - NONE = 5 - RETRY = 6 - - class AlgorithmType(Enum): Train = "Train", Predict = "Predict" -class AlgorithmSubtype(Enum): - HeteroLR = 1 - HomoLR = 2 - HeteroNN = 3 - HomoNN = 4 - HeteroXGB = 5 - - -class DataAlgorithmType(Enum): - PIR = 'pir' - CEM = 'cem' - ALL = 'all' - - -class IdPrefixEnum(Enum): - DATASET = "d-" - ALGORITHM = "a-" - JOB = "j-" - - -class JobCemResultSuffixEnum(Enum): - JOB_ALGORITHM_AGENCY = "-1" - INITIAL_JOB_AGENCY = "-2" - - -class OriginAlgorithm(Enum): - PPC_AYS = ["a-1001", "匿踪查询", "1", "1.0"] - PPC_PSI = ["a-1002", "隐私求交", "2+", "1.0"] - - -executor = ThreadPoolExecutor(max_workers=10) - - -def async_fn(func): - def wrapper(*args, **kwargs): - executor.submit(func, *args, **kwargs) - - return wrapper - - def json_loads(json_config): try: json_comment = JsonComment(json) @@ -498,392 +428,6 @@ def load_credential_from_file(filepath): return f.read() -def getstatusoutput(cmd): - """replace commands.getstatusoutput - - Arguments: - cmd {[string]} - """ - - get_cmd = subprocess.Popen( - cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) - ret = get_cmd.communicate() - out = ret[0] - err = ret[1] - output = '' - if not out is None: - output = output + out.decode('utf-8') - if not err is None: - output = output + err.decode('utf-8') - - log.debug(' cmd is %s, status is %s, output is %s', - cmd, str(get_cmd.returncode), output) - - return (get_cmd.returncode, output) - - -def make_dataset_metadata(data_field_list, dataset_description, dataset_size, row_count, column_count, - version_hash, user_name): - dataset_metadata = DatasetDetail() - dataset_metadata.data_field_list = data_field_list - dataset_metadata.dataset_description = dataset_description - dataset_metadata.dataset_size = dataset_size - dataset_metadata.row_count = row_count - dataset_metadata.column_count = column_count - dataset_metadata.dataset_hash = version_hash - dataset_metadata.user_name = user_name - - date_time = make_timestamp() - dataset_metadata.create_time = date_time - dataset_metadata.update_time = date_time - return pb_to_str(dataset_metadata) - - -def default_model_module(subtype): - defaultModelModule = [ - { - 'label': 'use_psi', - 'type': 'bool', - 'value': 0, - 'min_value': 0, - 'max_value': 1, - 'description': '是否进行隐私求交,只能为0或1' - }, - { - 'label': 'fillna', - 'type': 'bool', - 'value': 0, - 'min_value': 0, - 'max_value': 1, - 'description': '是否进行缺失值填充,只能为0或1' - }, - { - 'label': 'na_select', - 'type': 'float', - 'value': 1, - 'min_value': 0, - 'max_value': 1, - 'description': '缺失值筛选阈值,取值范围为0~1之间(0表示只要有缺失值就移除,1表示移除全为缺失值的列)' - }, - { - 'label': 'filloutlier', - 'type': 'bool', - 'value': 0, - 'min_value': 0, - 'max_value': 1, - 'description': '是否进行异常值处理,只能为0或1' - }, - { - 'label': 'normalized', - 'type': 'bool', - 'value': 0, - 'min_value': 0, - 'max_value': 1, - 'description': '是否归一化,只能为0或1' - }, - { - 'label': 'standardized', - 'type': 'bool', - 'value': 0, - 'min_value': 0, - 'max_value': 1, - 'description': '是否标准化,只能为0或1' - }, - { - 'label': 'categorical', - 'type': 'string', - 'value': '', - 'description': '标记所有分类特征字段,格式:"x1,x12" (空代表无分类特征)' - }, - { - 'label': 'psi_select_col', - 'type': 'string', - 'value': '', - 'description': 'PSI稳定性筛选时间列名(空代表不进行PSI筛选)' - }, - { - 'label': 'psi_select_base', - 'type': 'string', - 'value': '', - 'description': 'PSI稳定性筛选的基期(空代表不进行PSI筛选)' - }, - { - 'label': 'psi_select_thresh', - 'type': 'float', - 'value': 0.3, - 'min_value': 0, - 'max_value': 1, - 'description': 'PSI筛选阈值,取值范围为0~1之间' - }, - { - 'label': 'psi_select_bins', - 'type': 'int', - 'value': 4, - 'min_value': 3, - 'max_value': 100, - 'description': '计算PSI时分箱数, 取值范围为3~100之间' - }, - { - 'label': 'corr_select', - 'type': 'float', - 'value': 0, - 'min_value': 0, - 'max_value': 1, - 'description': '特征相关性筛选阈值,取值范围为0~1之间(值为0时不进行相关性筛选)' - }, - { - 'label': 'use_iv', - 'type': 'bool', - 'value': 0, - 'min_value': 0, - 'max_value': 1, - 'description': '是否使用iv进行特征筛选,只能为0或1' - }, - { - 'label': 'group_num', - 'type': 'int', - 'value': 4, - 'min_value': 3, - 'max_value': 100, - 'description': 'woe计算分箱数,取值范围为3~100之间的整数' - }, - { - 'label': 'iv_thresh', - 'type': 'float', - 'value': 0.1, - 'min_value': 0.01, - 'max_value': 1, - 'description': 'iv特征筛选的阈值,取值范围为0.01~1之间' - }] - - if subtype == 'HeteroXGB': - defaultModelModule.extend([ - { - 'label': 'use_goss', - 'type': 'bool', - 'value': 0, - 'min_value': 0, - 'max_value': 1, - 'description': '是否采用goss抽样加速训练, 只能为0或1' - }, - { - 'label': 'test_dataset_percentage', - 'type': 'float', - 'value': 0.3, - 'min_value': 0.1, - 'max_value': 0.5, - 'description': '测试集比例, 取值范围为0.1~0.5之间' - }, - { - 'label': 'learning_rate', - 'type': 'float', - 'value': 0.1, - 'min_value': 0.01, - 'max_value': 1, - 'description': '学习率, 取值范围为0.01~1之间' - }, - { - 'label': 'num_trees', - 'type': 'int', - 'value': 6, - 'min_value': 1, - 'max_value': 300, - 'description': 'XGBoost迭代树棵树, 取值范围为1~300之间的整数' - }, - { - 'label': 'max_depth', - 'type': 'int', - 'value': 3, - 'min_value': 1, - 'max_value': 6, - 'description': 'XGBoost树深度, 取值范围为1~6之间的整数' - }, - { - 'label': 'max_bin', - 'type': 'int', - 'value': 4, - 'min_value': 3, - 'max_value': 100, - 'description': '特征分箱数, 取值范围为3~100之间的整数' - }, - { - 'label': 'silent', - 'type': 'bool', - 'value': 0, - 'min_value': 0, - 'max_value': 1, - 'description': '值为1时不打印运行信息,只能为0或1' - }, - { - 'label': 'subsample', - 'type': 'float', - 'value': 1, - 'min_value': 0.1, - 'max_value': 1, - 'description': '训练每棵树使用的样本比例,取值范围为0.1~1之间' - }, - { - 'label': 'colsample_bytree', - 'type': 'float', - 'value': 1, - 'min_value': 0.1, - 'max_value': 1, - 'description': '训练每棵树使用的特征比例,取值范围为0.1~1之间' - }, - { - 'label': 'colsample_bylevel', - 'type': 'float', - 'value': 1, - 'min_value': 0.1, - 'max_value': 1, - 'description': '训练每一层使用的特征比例,取值范围为0.1~1之间' - }, - { - 'label': 'reg_alpha', - 'type': 'float', - 'value': 0, - 'min_value': 0, - 'description': 'L1正则化项,用于控制模型复杂度,取值范围为大于等于0的数值' - }, - { - 'label': 'reg_lambda', - 'type': 'float', - 'value': 1, - 'min_value': 0, - 'description': 'L2正则化项,用于控制模型复杂度,取值范围为大于等于0的数值' - }, - { - 'label': 'gamma', - 'type': 'float', - 'value': 0, - 'min_value': 0, - 'description': '最优分割点所需的最小损失函数下降值,取值范围为大于等于0的数值' - }, - { - 'label': 'min_child_weight', - 'type': 'float', - 'value': 0, - 'min_value': 0, - 'description': '最优分割点所需的最小叶子节点权重,取值范围为大于等于0的数值' - }, - { - 'label': 'min_child_samples', - 'type': 'int', - 'value': 10, - 'min_value': 1, - 'max_value': 1000, - 'description': '最优分割点所需的最小叶子节点样本数量,取值范围为1~1000之间的整数' - }, - { - 'label': 'seed', - 'type': 'int', - 'value': 2024, - 'min_value': 0, - 'max_value': 10000, - 'description': '分割训练集/测试集时随机数种子,取值范围为0~10000之间的整数' - }, - { - 'label': 'early_stopping_rounds', - 'type': 'int', - 'value': 5, - 'min_value': 0, - 'max_value': 100, - 'description': '指定迭代多少次没有提升则停止训练, 值为0时不执行, 取值范围为0~100之间的整数' - }, - { - 'label': 'eval_metric', - 'type': 'string', - 'value': 'auc', - 'description': '早停的评估指标,支持:auc, acc, recall, precision' - }, - { - 'label': 'verbose_eval', - 'type': 'int', - 'value': 1, - 'min_value': 0, - 'max_value': 100, - 'description': '按传入的间隔输出训练过程中的评估信息,0表示不打印' - }, - { - 'label': 'eval_set_column', - 'type': 'string', - 'value': '', - 'description': '指定训练集测试集标记字段名称' - }, - { - 'label': 'train_set_value', - 'type': 'string', - 'value': '', - 'description': '指定训练集标记值' - }, - { - 'label': 'eval_set_value', - 'type': 'string', - 'value': '', - 'description': '指定测试集标记值' - }, - { - 'label': 'train_features', - 'type': 'string', - 'value': '', - 'description': '指定入模特征' - } # , - # { - # 'label': 'threads', - # 'type': 'int', - # 'value': 8, - # 'min_value': 1, - # 'max_value': 8, - # 'description': '取值范围为1~8之间的整数' - # } - ]) - else: - defaultModelModule.extend([ - { - 'label': 'test_dataset_percentage', - 'type': 'float', - 'value': 0.3, - 'min_value': 0.1, - 'max_value': 0.5, - 'description': '取值范围为0.1~0.5之间' - }, - { - 'label': 'learning_rate', - 'type': 'float', - 'value': 0.3, - 'min_value': 0.001, - 'max_value': 0.999, - 'description': '取值范围为0.001~0.999之间' - }, - { - 'label': 'epochs', - 'type': 'int', - 'value': 3, - 'min_value': 1, - 'max_value': 5, - 'description': '取值范围为1~5之间的整数' - }, - { - 'label': 'batch_size', - 'type': 'int', - 'value': 16, - 'min_value': 1, - 'max_value': 100, - 'description': '取值范围为1~100之间的整数' - }, - { - 'label': 'threads', - 'type': 'int', - 'value': 8, - 'min_value': 1, - 'max_value': 8, - 'description': '取值范围为1~8之间的整数' - } - ]) - - return defaultModelModule - - def merge_files(file_list, output_file): try: with open(output_file, 'wb') as outfile: diff --git a/python/ppc_model/common/initializer.py b/python/ppc_model/common/initializer.py index 829e0182..17b79082 100644 --- a/python/ppc_model/common/initializer.py +++ b/python/ppc_model/common/initializer.py @@ -6,7 +6,6 @@ import yaml from ppc_common.deps_services import storage_loader -from ppc_common.deps_services.storage_api import StorageType from ppc_common.ppc_async_executor.thread_event_manager import ThreadEventManager from ppc_common.ppc_utils import common_func from ppc_model.network.grpc.grpc_client import GrpcClient diff --git a/python/ppc_model/preprocessing/local_processing/psi_select.py b/python/ppc_model/preprocessing/local_processing/psi_select.py index 5bd1223e..3a3db8c7 100644 --- a/python/ppc_model/preprocessing/local_processing/psi_select.py +++ b/python/ppc_model/preprocessing/local_processing/psi_select.py @@ -34,7 +34,7 @@ def calculate_psi(expected, actual, buckettype='bins', buckets=10, axis=0): axis: axis by which variables are defined, 0 for vertical, 1 for horizontal Returns: - psi_values: ndarray of psi values for each variable + psi_values: ndarray of engine values for each variable Author: Matthew Burke diff --git a/python/ppc_scheduler/__init__.py b/python/ppc_scheduler/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/ppc_scheduler/common/__init__.py b/python/ppc_scheduler/common/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/ppc_scheduler/common/global_context.py b/python/ppc_scheduler/common/global_context.py new file mode 100644 index 00000000..7a090c83 --- /dev/null +++ b/python/ppc_scheduler/common/global_context.py @@ -0,0 +1,8 @@ +import os + +from ppc_scheduler.common.initializer import Initializer + +dirName, _ = os.path.split(os.path.abspath(__file__)) +config_path = "application.yml" + +components = Initializer(log_config_path='logging.conf', config_path=config_path) diff --git a/python/ppc_scheduler/common/initializer.py b/python/ppc_scheduler/common/initializer.py new file mode 100644 index 00000000..6ea92305 --- /dev/null +++ b/python/ppc_scheduler/common/initializer.py @@ -0,0 +1,91 @@ +import logging +import logging.config +import os +from contextlib import contextmanager + +import yaml +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker + +from ppc_common.deps_services import storage_loader +from ppc_common.ppc_async_executor.thread_event_manager import ThreadEventManager +from ppc_common.ppc_initialize.dataset_handler_initialize import DataSetHandlerInitialize +from ppc_common.ppc_utils import common_func +from ppc_scheduler.job.job_manager import JobManager + + +class Initializer: + def __init__(self, log_config_path, config_path): + self.job_cache_dir = None + self.log_config_path = log_config_path + self.config_path = config_path + self.config_data = None + self.job_manager = None + self.thread_event_manager = None + self.sql_session = None + self.sql_engine = None + self.storage_client = None + # 只用于测试 + self.mock_logger = None + self.dataset_handler_initializer = None + + def init_all(self): + self.init_log() + self.init_cache() + self.init_config() + self.init_job_manager() + self.init_sql_client() + self.init_storage_client() + self.init_others() + + def init_log(self): + logging.config.fileConfig(self.log_config_path) + + def init_cache(self): + self.job_cache_dir = common_func.get_config_value( + "JOB_TEMP_DIR", "/tmp", self.config_data, False) + if not os.path.exists(self.job_cache_dir): + os.makedirs(self.job_cache_dir) + + def init_config(self): + with open(self.config_path, 'rb') as f: + self.config_data = yaml.safe_load(f.read()) + + def init_job_manager(self): + self.thread_event_manager = ThreadEventManager() + self.job_manager = JobManager( + logger=self.logger(), + thread_event_manager=self.thread_event_manager, + workspace=self.config_data['WORKSPACE'], + job_timeout_h=self.config_data['JOB_TIMEOUT_H'] + ) + + def init_sql_client(self): + self.sql_engine = create_engine(self.config_data['SQLALCHEMY_DATABASE_URI'], pool_pre_ping=True) + self.sql_session = sessionmaker(bind=self.sql_engine) + + @contextmanager + def create_sql_session(self): + session = self.sql_session() + try: + yield session + session.commit() + except: + session.rollback() + raise + finally: + session.close() + + def init_storage_client(self): + self.storage_client = storage_loader.load( + self.config_data, self.logger()) + + def init_others(self): + self.dataset_handler_initializer = DataSetHandlerInitialize( + self.config_data, self.logger()) + + def logger(self, name=None): + if self.mock_logger is None: + return logging.getLogger(name) + else: + return self.mock_logger diff --git a/python/ppc_scheduler/common/log_utils.py b/python/ppc_scheduler/common/log_utils.py new file mode 100644 index 00000000..5e675d2f --- /dev/null +++ b/python/ppc_scheduler/common/log_utils.py @@ -0,0 +1,71 @@ +# -*- coding: utf-8 -*- + +import os + +from ppc_common.ppc_utils import utils, path + + +def job_start_log_info(job_id): + return f"=====================start_{job_id}=====================" + + +def job_end_log_info(job_id): + return f"======================end_{job_id}======================" + + +def upload_job_log(storage_client, job_id, extra=None): + job_log_path = None + try: + file_path = path.get_path() + job_log_path = utils.get_log_temp_file_path(file_path, job_id) + origin_log_path = utils.get_log_file_path(file_path) + filter_job_log(job_id, origin_log_path, job_log_path) + if extra is not None: + job_log = open(job_log_path, 'a+') + job_log.write('\n' * 3) + job_log.write(extra) + job_log.close() + storage_client.upload_file(job_log_path, job_id + os.sep + utils.LOG_NAME) + finally: + os.remove(job_log_path) + + +def read_line_inverse(file): + file.seek(0, 2) + current_position = file.tell() + position = 0 + while position + current_position >= 0: + while True: + position -= 1 + try: + file.seek(position, 2) + if file.read(1) == b'\n': + break + except: + # point at file header + file.seek(0, 0) + break + line = file.readline() + yield line + + +def filter_job_log(job_id, origin_log_path, job_log_path): + origin_log_file = open(origin_log_path, 'rb') + line_list = [] + need_record = False + + # search job log + for line in read_line_inverse(origin_log_file): + if need_record: + line_list.append(line) + if not need_record and str(line).__contains__(job_end_log_info(job_id)): + need_record = True + line_list.append(line) + elif str(line).__contains__(job_start_log_info(job_id)): + break + origin_log_file.close() + + # save log lines into temp file + job_log_file = open(job_log_path, 'wb+') + job_log_file.writelines(list(reversed(line_list))) + job_log_file.close() diff --git a/python/ppc_scheduler/conf/application-sample.yml b/python/ppc_scheduler/conf/application-sample.yml new file mode 100644 index 00000000..0be7408c --- /dev/null +++ b/python/ppc_scheduler/conf/application-sample.yml @@ -0,0 +1,20 @@ +HOST: "0.0.0.0" +HTTP_PORT: 43471 + +JOB_TIMEOUT_H: 1800 + +AGENCY_ID: "1001" +PPCS_RPC_TOKEN: "ppcs_psi_apikey" + +WORKSPACE: "/data/app/files/job" +HDFS_ENDPOINT: "http://127.0.0.1:50070" + +# mysql or dm +DB_TYPE: "mysql" +SQLALCHEMY_DATABASE_URI: "mysql://[*user_ppcsmodeladm]:[*pass_ppcsmodeladm]@[@4346-TDSQL_VIP]:[@4346-TDSQL_PORT]/ppcsmodeladm?autocommit=true&charset=utf8mb4" +# SQLALCHEMY_DATABASE_URI: "dm+dmPython://ppcv16:ppc12345678@127.0.0.1:5236" + +SQLALCHEMY_TRACK_MODIFICATIONS: False + +MPC_NODE_DIRECT_PORT: 5899 +IS_MALICIOUS: False \ No newline at end of file diff --git a/python/ppc_scheduler/conf/logging.conf b/python/ppc_scheduler/conf/logging.conf new file mode 100644 index 00000000..bff66b16 --- /dev/null +++ b/python/ppc_scheduler/conf/logging.conf @@ -0,0 +1,40 @@ +[loggers] +keys=root,wsgi + +[logger_root] +level=INFO +handlers=consoleHandler,fileHandler + +[logger_wsgi] +level = INFO +handlers = accessHandler +qualname = wsgi +propagate = 0 + +[handlers] +keys=fileHandler,consoleHandler,accessHandler + +[handler_accessHandler] +class=handlers.TimedRotatingFileHandler +args=('/data/app/wedpr/appmonitor.log', 'D', 1, 30, 'utf-8') +level=INFO +formatter=simpleFormatter + +[handler_fileHandler] +class=handlers.TimedRotatingFileHandler +args=('/data/app/wedpr/scheduler.log', 'D', 1, 30, 'utf-8') +level=INFO +formatter=simpleFormatter + +[handler_consoleHandler] +class=StreamHandler +args=(sys.stdout,) +level=ERROR +formatter=simpleFormatter + +[formatters] +keys=simpleFormatter + +[formatter_simpleFormatter] +format=[%(levelname)s][%(asctime)s %(msecs)03d][%(process)d][%(filename)s:%(lineno)d] %(message)s +datefmt=%Y-%m-%d %H:%M:%S diff --git a/python/ppc_scheduler/database/__init__.py b/python/ppc_scheduler/database/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/ppc_scheduler/database/computing_node_mapper.py b/python/ppc_scheduler/database/computing_node_mapper.py new file mode 100644 index 00000000..0a95ad79 --- /dev/null +++ b/python/ppc_scheduler/database/computing_node_mapper.py @@ -0,0 +1,69 @@ +from sqlalchemy import update, and_, select, delete + +from ppc_common.db_models.computing_node import ComputingNodeRecord + + +def insert_computing_node(session, node_id: str, url: str, node_type: str, loading: int): + new_node = ComputingNodeRecord( + id=node_id, + url=url, + type=node_type, + loading=loading + ) + + session.add(new_node) + + +def delete_computing_node(session, url: str, node_type: str): + stmt = ( + delete(ComputingNodeRecord).where( + and_( + ComputingNodeRecord.url == url, + ComputingNodeRecord.type == node_type + ) + ) + ) + + result = session.execute(stmt) + + return result.rowcount > 0 + + +def get_and_update_min_loading_url(session, node_type): + subquery = ( + select([ComputingNodeRecord.id]).where( + and_( + ComputingNodeRecord.type == node_type + ) + ).order_by(ComputingNodeRecord.loading.asc()).limit(1) + ).scalar_subquery() + + stmt = ( + update(ComputingNodeRecord).where( + and_( + ComputingNodeRecord.id == subquery + ) + ).values( + loading=ComputingNodeRecord.loading + 1 + ).returning(ComputingNodeRecord.url) + ) + + result = session.execute(stmt) + return result.scalar() + + +def release_loading(session, url: str, node_type: str): + stmt = ( + update(ComputingNodeRecord).where( + and_( + ComputingNodeRecord.url == url, + ComputingNodeRecord.type == node_type, + ComputingNodeRecord.loading > 0 + ) + ).values( + loading=ComputingNodeRecord.loading - 1 + ) + ) + result = session.execute(stmt) + + return result.rowcount > 0 diff --git a/python/ppc_scheduler/database/job_worker_mapper.py b/python/ppc_scheduler/database/job_worker_mapper.py new file mode 100644 index 00000000..a4ee3699 --- /dev/null +++ b/python/ppc_scheduler/database/job_worker_mapper.py @@ -0,0 +1,32 @@ +from sqlalchemy import and_, update +from sqlalchemy.exc import NoResultFound + +from ppc_common.db_models.job_worker_record import JobWorkerRecord +from ppc_common.ppc_utils import utils +from ppc_scheduler.workflow.common import codec + + +def query_job_worker(session, job_id, worker_id): + try: + return session.query(JobWorkerRecord).filter( + and_(JobWorkerRecord.worker_id == worker_id, + JobWorkerRecord.job_id == job_id)).one() + except NoResultFound: + return None + + +def update_job_worker(session, job_id, worker_id, status, outputs): + stmt = ( + update(JobWorkerRecord).where( + and_( + JobWorkerRecord.job_id == job_id, + JobWorkerRecord.worker_id == worker_id + ) + ).values( + status=status, + outputs=codec.serialize_worker_outputs_for_db(outputs), + update_time=utils.make_timestamp() + ) + ) + result = session.execute(stmt) + return result.rowcount > 0 diff --git a/python/ppc_scheduler/endpoints/__init__.py b/python/ppc_scheduler/endpoints/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/ppc_scheduler/endpoints/body_schema.py b/python/ppc_scheduler/endpoints/body_schema.py new file mode 100644 index 00000000..01ad24b8 --- /dev/null +++ b/python/ppc_scheduler/endpoints/body_schema.py @@ -0,0 +1,18 @@ +# -*- coding: utf-8 -*- +import json + +from flask_restx import fields + +from ppc_model.network.http.restx import api + +response_base = api.model('Response base info', { + 'errorCode': fields.Integer(description='return code'), + 'message': fields.String(description='return message') +}) + +response_job_status = api.inherit('Task status', response_base, { + 'data': fields.Raw(description='Task status data as key-value dictionary', example={ + 'status': 'RUNNING', + 'time_costs': '30s' + }) +}) diff --git a/python/ppc_scheduler/endpoints/job_controller.py b/python/ppc_scheduler/endpoints/job_controller.py new file mode 100644 index 00000000..883fbca8 --- /dev/null +++ b/python/ppc_scheduler/endpoints/job_controller.py @@ -0,0 +1,48 @@ +# -*- coding: utf-8 -*- + +from flask import request +from flask_restx import Resource + +from ppc_common.ppc_utils import utils +from ppc_scheduler.common.global_context import components +from ppc_scheduler.endpoints.body_schema import response_job_status, response_base +from ppc_scheduler.endpoints.restx import api + +ns = api.namespace('ppc-scheduler/job', + description='Operations related to run job') + + +@ns.route('/') +class JobCollection(Resource): + + @api.response(201, 'Job started successfully.', response_base) + def post(self, job_id): + """ + Run a specific job by job_id. + """ + args = request.get_json() + components.logger().info(f"run job request, job_id: {job_id}, args: {args}") + components.job_manager.run_task(job_id, (args,)) + return utils.BASE_RESPONSE + + @api.response(200, 'Job status retrieved successfully.', response_job_status) + def get(self, job_id): + """ + Get the status of a specific job by job_id. + """ + response = utils.BASE_RESPONSE + status, time_costs = components.job_manager.status(job_id) + response['data'] = { + 'status': status, + 'time_costs': time_costs + } + return response + + @api.response(200, 'Job killed successfully.', response_base) + def delete(self, job_id): + """ + Kill a specific job by job_id. + """ + components.logger().info(f"kill request, job_id: {job_id}") + components.job_manager.kill_job(job_id) + return utils.BASE_RESPONSE diff --git a/python/ppc_scheduler/endpoints/restx.py b/python/ppc_scheduler/endpoints/restx.py new file mode 100644 index 00000000..28f9829f --- /dev/null +++ b/python/ppc_scheduler/endpoints/restx.py @@ -0,0 +1,35 @@ +# -*- coding: utf-8 -*- + +from flask_restx import Api + +from ppc_common.ppc_utils.exception import PpcException, PpcErrorCode +from ppc_model.common.global_context import components + +authorizations = { + 'apikey': { + 'type': 'apiKey', + 'in': 'header', + 'name': 'Authorization' + } +} + +api = Api(version='1.0', title='Ppc Scheduler Service', + authorizations=authorizations, security='apikey') + + +@api.errorhandler(PpcException) +def default_error_handler(e): + components.logger().exception(e) + info = e.to_dict() + response = {'errorCode': info['code'], 'message': info['message']} + components.logger().error(f"OnError: code: {info['code']}, message: {info['message']}") + return response, 500 + + +@api.errorhandler(BaseException) +def default_error_handler(e): + components.logger().exception(e) + message = 'unknown error.' + response = {'errorCode': PpcErrorCode.INTERNAL_ERROR, 'message': message} + components.logger().error(f"OnError: unknown error") + return response, 500 diff --git a/python/ppc_scheduler/job/__init__.py b/python/ppc_scheduler/job/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/ppc_scheduler/job/job_manager.py b/python/ppc_scheduler/job/job_manager.py new file mode 100644 index 00000000..95afa15a --- /dev/null +++ b/python/ppc_scheduler/job/job_manager.py @@ -0,0 +1,123 @@ +import datetime +import threading +import time +from typing import Union + +from readerwriterlock import rwlock + +from ppc_common.ppc_async_executor.async_thread_executor import AsyncThreadExecutor +from ppc_common.ppc_async_executor.thread_event_manager import ThreadEventManager +from ppc_common.ppc_utils.exception import PpcException, PpcErrorCode +from ppc_scheduler.common import log_utils +from ppc_scheduler.job.job_status import JobStatus +from ppc_scheduler.workflow.scheduler import Scheduler + + +class JobManager: + def __init__(self, logger, + thread_event_manager: ThreadEventManager, + workspace, + job_timeout_h: Union[int, float]): + self.logger = logger + self._thread_event_manager = thread_event_manager + self._workspace = workspace + self._job_timeout_s = job_timeout_h * 3600 + self._rw_lock = rwlock.RWLockWrite() + self._jobs: dict[str, list] = {} + self._async_executor = AsyncThreadExecutor( + event_manager=self._thread_event_manager, logger=logger) + self._cleanup_thread = threading.Thread(target=self._loop_cleanup) + self._cleanup_thread.daemon = True + self._cleanup_thread.start() + self.scheduler = Scheduler(self._workspace) + + def run_task(self, job_id, args=()): + """ + 发起任务 + param args: 任务参数 + """ + with self._rw_lock.gen_wlock(): + if job_id in self._jobs: + self.logger.info(f"Task already exists, job_id: {job_id}, status: {self._jobs[job_id][0]}") + return + self._jobs[job_id] = [JobStatus.RUNNING, datetime.datetime.now(), 0] + self.logger.info(log_utils.job_start_log_info(job_id)) + self._async_executor.execute(job_id, self._run_job_flow, self._on_task_finish, args) + + def _run_job_flow(self, args): + """ + 运行任务流 + + """ + self.scheduler.schedule_job_flow(args) + + def kill_job(self, job_id: str): + with self._rw_lock.gen_rlock(): + if job_id not in self._jobs or self._jobs[job_id][0] != JobStatus.RUNNING: + return + + self.logger.info(f"Kill job, job_id: {job_id}") + self._async_executor.kill(job_id) + + with self._rw_lock.gen_wlock(): + self._jobs[job_id][0] = JobStatus.FAILURE + + def status(self, job_id: str) -> [str, float]: + """ + 返回: 任务状态, 执行耗时(s) + """ + with self._rw_lock.gen_rlock(): + if job_id not in self._jobs: + raise PpcException( + PpcErrorCode.JOB_NOT_FOUND.get_code(), + PpcErrorCode.JOB_NOT_FOUND.get_msg()) + status = self._jobs[job_id][0] + time_costs = self._jobs[job_id][2] + return status, time_costs + + def _on_task_finish(self, job_id: str, is_succeeded: bool, e: Exception = None): + with self._rw_lock.gen_wlock(): + time_costs = (datetime.datetime.now() - + self._jobs[job_id][1]).total_seconds() + self._jobs[job_id][2] = time_costs + if is_succeeded: + self._jobs[job_id][0] = JobStatus.SUCCESS + self.logger.info(f"Job {job_id} completed, time_costs: {time_costs}s") + else: + self._jobs[job_id][0] = JobStatus.FAILURE + self.logger.warn(f"Job {job_id} failed, time_costs: {time_costs}s, error: {e}") + self.logger.info(log_utils.job_end_log_info(job_id)) + + def _loop_cleanup(self): + while True: + self._terminate_timeout_jobs() + self._cleanup_finished_jobs() + time.sleep(5) + + def _terminate_timeout_jobs(self): + jobs_to_kill = [] + with self._rw_lock.gen_rlock(): + for job_id, value in self._jobs.items(): + alive_time = (datetime.datetime.now() - + value[1]).total_seconds() + if alive_time >= self._job_timeout_s and value[0] == JobStatus.RUNNING: + jobs_to_kill.append(job_id) + + for job_id in jobs_to_kill: + self.logger.warn(f"Job is timeout, job_id: {job_id}") + self.kill_job(job_id) + + def _cleanup_finished_jobs(self): + jobs_to_cleanup = [] + with self._rw_lock.gen_rlock(): + for job_id, value in self._jobs.items(): + alive_time = (datetime.datetime.now() - + value[1]).total_seconds() + if alive_time >= self._job_timeout_s + 3600: + jobs_to_cleanup.append((job_id, value[3])) + with self._rw_lock.gen_wlock(): + for job_id, job_id in jobs_to_cleanup: + if job_id in self._jobs: + del self._jobs[job_id] + self._thread_event_manager.remove_event(job_id) + self.logger.info(f"Cleanup job cache, job_id: {job_id}") diff --git a/python/ppc_scheduler/job/job_status.py b/python/ppc_scheduler/job/job_status.py new file mode 100644 index 00000000..a0a530ad --- /dev/null +++ b/python/ppc_scheduler/job/job_status.py @@ -0,0 +1,6 @@ +class JobStatus: + RUNNING = 'RUNNING' + SUCCESS = 'SUCCESS' + FAILURE = 'FAILURE' + TIMEOUT = 'TIMEOUT' + KILLED = 'KILLED' diff --git a/python/ppc_scheduler/job/job_type.py b/python/ppc_scheduler/job/job_type.py new file mode 100644 index 00000000..d734639f --- /dev/null +++ b/python/ppc_scheduler/job/job_type.py @@ -0,0 +1,8 @@ + +class JobType: + PSI = "PSI" + MPC = "MPC" + PREPROCESSING = 'PREPROCESSING' + FEATURE_ENGINEERING = 'FEATURE_ENGINEERING' + TRAINING = 'TRAINING' + PREDICTION = 'PREDICTION' diff --git a/python/ppc_scheduler/mpc_generator/__init__.py b/python/ppc_scheduler/mpc_generator/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/ppc_scheduler/mpc_generator/generator.py b/python/ppc_scheduler/mpc_generator/generator.py new file mode 100644 index 00000000..86466192 --- /dev/null +++ b/python/ppc_scheduler/mpc_generator/generator.py @@ -0,0 +1,468 @@ +from enum import Enum + +import sqlparse +import sqlvalidator +from sqlparse.exceptions import SQLParseError +from sqlparse.sql import Comparison, Identifier, Function +from sqlparse.tokens import Punctuation, Operator, Name, Token + +from ppc_scheduler.mpc_generator import mpc_func_str +from ppc_common.ppc_utils.exception import PpcException, PpcErrorCode +from ppc_common.ppc_utils.utils import PPC_RESULT_FIELDS_FLAG, PPC_RESULT_VALUES_FLAG + + +class SqlPattern(Enum): + BASIC_ARITH_OPE = 1 + AGGR_FUNC_ONLY = 2 + AGGR_FUNC_WITH_GROUP_BY = 3 + + +INDENT = " " + +SUPPORTED_KEYWORDS = [ + 'SELECT', + 'FROM', + 'WHERE', + 'JOIN', + 'INNER JOIN', + 'ON', + 'AS', + 'GROUP BY', + 'COUNT', + 'SUM', + 'AVG', + 'MAX', + 'MIN', +] + +VALUE_TYPE = 'pfix' + +GROUP_BY_COLUMN_CODE = 'group_indexes_key[i]' + +DISPLAY_FIELDS_FUNC = 'set_display_field_names' + +DISPLAY_RESULT_FUNC = 'display_data' + + +class CodeGenerator: + + def __init__(self, sql_str): + self.sql_str = sql_str + + # three patterns supported + self.sql_pattern = SqlPattern.BASIC_ARITH_OPE + + # based on ID only + self.need_run_psi = False + + # record dataset sources + self.table_set = set() + + # 0: table number, 1: column index + self.group_by_column = [] + + # filter conditions + self.condition_list = [] + + self.result_fields = [] + + self.max_column_index = [] + + def sql_to_mpc_code(self): + operators = self.pre_parsing() + format_sql_str = sqlparse.format(self.sql_str, reindent=True, keyword_case='upper') + mpc_str = self.generate_common_code(format_sql_str) + mpc_str = self.generate_function_code(mpc_str) + mpc_str = self.generate_result_calculation_code(operators, mpc_str) + mpc_str = self.generate_result_print_code(mpc_str) + mpc_str = self.generate_mpc_execution_code(mpc_str) + mpc_str = self.replace_max_filed(mpc_str) + return mpc_str + + def pre_parsing(self): + try: + if not sqlvalidator.parse(self.sql_str).is_valid(): + raise PpcException(PpcErrorCode.ALGORITHM_BAD_SQL.get_code(), "bad sql statement") + + # format sql str + format_sql_str = sqlparse.format(self.sql_str, reindent=True, keyword_case='upper') + + tokens = sqlparse.parse(format_sql_str)[0].tokens + + # warning unsupported keywords + self.check_sql_tokens(tokens) + + # parse table aliases + aliases = self.find_table_alias(tokens, {}, False) + + # recover table aliases + new_sql_str = self.recover_table_name(tokens, aliases, '') + format_new_sql_str = sqlparse.format(new_sql_str, reindent=True, keyword_case='upper') + tokens = sqlparse.parse(format_new_sql_str)[0].tokens + + # parse filters (only 'id' based column alignment is supported currently) + self.find_table_and_condition(tokens, False) + + # check table suffix and number of participants + self.check_table() + + # ensure that all tables participate in alignment + self.check_table_alignment(self.need_run_psi, len(self.table_set)) + + self.max_column_index = [0 for _ in range(len(self.table_set))] + + self.check_sql_pattern(tokens) + + operators = self.extract_operators(format_new_sql_str) + + return operators + except SQLParseError: + raise PpcException(PpcErrorCode.ALGORITHM_BAD_SQL.get_code(), "bad sql statement") + + def check_sql_tokens(self, tokens): + for token in tokens: + if token.is_keyword and token.value not in SUPPORTED_KEYWORDS: + raise PpcException(PpcErrorCode.ALGORITHM_BAD_SQL.get_code(), f"keyword '{token.value}' not supported") + if hasattr(token, 'tokens'): + self.check_sql_tokens(token.tokens) + + def find_table_alias(self, tokens, aliases, after_from): + end_current_round = False + for i in range(len(tokens)): + if after_from and tokens[i].value == 'AS': + # find a alias + end_current_round = True + aliases[tokens[i + 2].value] = tokens[i - 2].value + for i in range(len(tokens)): + if tokens[i].value == 'FROM': + after_from = True + if after_from and not end_current_round and hasattr(tokens[i], 'tokens'): + self.find_table_alias(tokens[i].tokens, aliases, after_from) + return aliases + + def recover_table_name(self, tokens, aliases, format_sql_str): + for i in range(len(tokens)): + if tokens[i].value == 'AS' and tokens[i + 2].value in aliases: + break + elif not tokens[i].is_group: + if tokens[i].value in aliases: + format_sql_str += aliases[tokens[i].value] + else: + format_sql_str += tokens[i].value + elif hasattr(tokens[i], 'tokens'): + format_sql_str = self.recover_table_name(tokens[i].tokens, aliases, format_sql_str) + return format_sql_str + + def find_table_and_condition(self, tokens, after_from): + for token in tokens: + if token.value == 'FROM': + after_from = True + if after_from: + if type(token) == Comparison: + self.check_equal_comparison(token.tokens) + self.condition_list.append(token.value) + if type(token) == Identifier and '.' not in token.value: + self.table_set.add(token.value) + elif hasattr(token, 'tokens'): + self.find_table_and_condition(token.tokens, after_from) + + def check_equal_comparison(self, tokens): + for i in range(len(tokens)): + if tokens[i].value == '=': + self.need_run_psi = True + elif tokens[i].value == '.' and tokens[i + 1].value != 'id': + raise PpcException(PpcErrorCode.ALGORITHM_BAD_SQL.get_code(), + f"only 'id' based column alignment is supported currently") + elif hasattr(tokens[i], 'tokens'): + self.check_equal_comparison(tokens[i].tokens) + + def check_table_alignment(self, has_equal_comparison, table_count): + if has_equal_comparison: + column = self.condition_list[0].split('=')[0].strip() + table = column[0:column.find('.')] + + # all tables must be aligned + self.equal_comparison_dfs(table, [0 for _ in range(len(self.condition_list))]) + if len(self.table_set) != table_count: + raise PpcException(PpcErrorCode.ALGORITHM_BAD_SQL.get_code(), "all tables must be aligned") + + def equal_comparison_dfs(self, table, flag_array): + for i in range(len(self.condition_list)): + if flag_array[i] == 0 and table in self.condition_list[i]: + flag_array[i] = 1 + columns = self.condition_list[i].split('=') + for column in columns: + table = column[0:column.find('.')].strip() + self.table_set.add(table) + self.equal_comparison_dfs(table, flag_array) + + def check_table(self): + table_count = len(self.table_set) + if table_count > 5 or table_count < 2: + raise PpcException(PpcErrorCode.ALGORITHM_BAD_SQL.get_code(), "source must be 2 to 5 parties") + suffixes = set() + for table in self.table_set: + suffixes.add(table[-1]) + for i in range(table_count): + if str(i) not in suffixes: + raise PpcException(PpcErrorCode.ALGORITHM_BAD_SQL.get_code(), "invalid suffix of table name") + + def update_max_field(self, table_number, field_number): + if field_number > self.max_column_index[table_number]: + self.max_column_index[table_number] = field_number + + def replace_max_filed(self, mpc_str): + for i in range(len(self.max_column_index)): + mpc_str = mpc_str.replace(f'$(source{i}_column_count)', str(self.max_column_index[i] + 1)) + return mpc_str + + def check_sql_pattern(self, tokens): + for i in range(len(tokens)): + if tokens[i].value == 'GROUP BY': + self.sql_pattern = SqlPattern.AGGR_FUNC_WITH_GROUP_BY + items = tokens[i + 2].value.split('.') + self.group_by_column = [int(items[0][-1]), get_column_number(items[1])] + elif type(tokens[i]) == Function: + self.sql_pattern = SqlPattern.AGGR_FUNC_ONLY + elif hasattr(tokens[i], 'tokens'): + self.check_sql_pattern(tokens[i].tokens) + + def extract_operators(self, format_sql_str): + start = format_sql_str.find("SELECT") + end = format_sql_str.find("FROM") + operators = format_sql_str[start + 6:end].split(',') + for i in range(len(operators)): + if ' AS ' in operators[i]: + index = operators[i].find(' AS ') + self.result_fields.append(operators[i][index + 4:].strip().strip('\n').strip('\'').strip('\"').strip()) + operators[i] = operators[i][:index].strip() + else: + self.result_fields.append(f"result{i}") + operators[i] = operators[i].strip() + if ' ' in self.result_fields[-1]: + raise PpcException(PpcErrorCode.ALGORITHM_BAD_SQL.get_code(), "result field cannot contain space") + return operators + + def generate_common_code(self, format_sql_str): + table_count = len(self.table_set) + result_column_count = len(self.result_fields) + + if self.need_run_psi: + mpc_str = '# PSI_OPTION=True\n\n' + else: + mpc_str = '' + mpc_str += '# BIT_LENGTH = 128\n\n' + mpc_str += '# This file is generated automatically by ams\n' + mpc_str += f"'''\n{format_sql_str}\n'''\n\n" + mpc_str += "from ppc import *\n\n" + # mpc_str += "program.use_trunc_pr = True\n" + # mpc_str += "program.use_split(3)\n" + mpc_str += "n_threads = 8\n" + mpc_str += f"value_type = {VALUE_TYPE}\n\n" + if VALUE_TYPE == 'pfix': + mpc_str += f"pfix.set_precision(16, 47)\n\n" + + for i in range(table_count): + mpc_str += f"SOURCE{i} = {i}\n" + mpc_str += f"source{i}_record_count = $(source{i}_record_count)\n" + mpc_str += f"source{i}_column_count = $(source{i}_column_count)\n" + mpc_str += f"source{i}_record = Matrix(source{i}_record_count, source{i}_column_count, value_type)\n\n" + + if self.sql_pattern == SqlPattern.BASIC_ARITH_OPE: + mpc_str += "# basic arithmetic operation means that all parties have same number of record\n" + mpc_str += "result_record = $(source0_record_count)\n" + mpc_str += f"results = Matrix(result_record, {result_column_count}, value_type)\n\n\n" + elif self.sql_pattern == SqlPattern.AGGR_FUNC_ONLY: + mpc_str += f"results = Matrix(1, {result_column_count}, value_type)\n\n\n" + elif self.sql_pattern == SqlPattern.AGGR_FUNC_WITH_GROUP_BY: + mpc_str += "# group by means all parties have same number of record\n" + mpc_str += "source_record_count = $(source0_record_count)\n" + mpc_str += "result_record = cint(source_record_count)\n" + mpc_str += f"results = Matrix(source_record_count, {result_column_count}, value_type)\n\n\n" + + mpc_str += "def read_data_collection(data_collection, party_id):\n" + mpc_str += f"{INDENT}if data_collection.sizes[0] > 0:\n" + mpc_str += f"{INDENT}{INDENT}data_collection.input_from(party_id)\n\n\n" + + return mpc_str + + def generate_function_code(self, mpc_str): + if self.sql_pattern == SqlPattern.AGGR_FUNC_ONLY: + mpc_str += mpc_func_str.FUNC_COMPUTE_SUM + mpc_str += mpc_func_str.FUNC_COMPUTE_COUNT + mpc_str += mpc_func_str.FUNC_COMPUTE_AVG + mpc_str += mpc_func_str.FUNC_COMPUTE_MAX + mpc_str += mpc_func_str.FUNC_COMPUTE_MIN + elif self.sql_pattern == SqlPattern.AGGR_FUNC_WITH_GROUP_BY: + mpc_str += mpc_func_str.GROUP_BY_GLOBAL_VARIABLE + mpc_str += mpc_func_str.FUNC_COMPUTE_GROUP_BY_INDEXES + mpc_str += mpc_func_str.FUNC_COMPUTE_SUM_WITH_GROUP_BY + mpc_str += mpc_func_str.FUNC_COMPUTE_COUNT_WITH_GROUP_BY + mpc_str += mpc_func_str.FUNC_COMPUTE_AVG_WITH_GROUP_BY + mpc_str += mpc_func_str.FUNC_COMPUTE_MAX_WITH_GROUP_BY + mpc_str += mpc_func_str.FUNC_COMPUTE_MIN_WITH_GROUP_BY + return mpc_str + + def generate_result_calculation_code(self, operators, mpc_str): + for i in range(len(operators)): + tokens = sqlparse.parse(operators[i])[0].tokens + participants_set = set() + formula_str = self.generate_formula(tokens, '', participants_set) + if len(participants_set) == 1: + raise PpcException(PpcErrorCode.ALGORITHM_BAD_SQL.get_code(), "disabled query pattern") + + if self.sql_pattern == SqlPattern.BASIC_ARITH_OPE or self.sql_pattern == SqlPattern.AGGR_FUNC_WITH_GROUP_BY: + mpc_str += f"def calculate_result_{i}():\n" \ + f"{INDENT}@for_range_opt_multithread(n_threads, result_record)\n" \ + f"{INDENT}def _(i):\n" \ + f"{INDENT}{INDENT}results[i][{i}] = {formula_str}\n\n\n" + elif self.sql_pattern == SqlPattern.AGGR_FUNC_ONLY: + mpc_str += f"def calculate_result_{i}():\n" \ + f"{INDENT}results[0][{i}] = {formula_str}\n\n\n" + return mpc_str + + def generate_result_print_code(self, mpc_str): + field_print_str = f"{PPC_RESULT_FIELDS_FLAG} = ['{self.result_fields[0]}'" + for i in range(1, len(self.result_fields)): + field_print_str += f", '{self.result_fields[i]}'" + field_print_str += ']' + + if self.sql_pattern == SqlPattern.BASIC_ARITH_OPE or self.sql_pattern == SqlPattern.AGGR_FUNC_WITH_GROUP_BY: + result_print_str = f"{PPC_RESULT_VALUES_FLAG} = [results[i][0].reveal()" + for i in range(1, len(self.result_fields)): + result_print_str += f", results[i][{i}].reveal()" + result_print_str += ']' + mpc_str += f"def print_results():\n" \ + f"{INDENT}{field_print_str}\n" \ + f"{INDENT}{DISPLAY_FIELDS_FUNC}({PPC_RESULT_FIELDS_FLAG})\n\n" \ + f"{INDENT}@for_range_opt(result_record)\n" \ + f"{INDENT}def _(i):\n" \ + f"{INDENT}{INDENT}{result_print_str}\n" \ + f"{INDENT}{INDENT}{DISPLAY_RESULT_FUNC}({PPC_RESULT_VALUES_FLAG})\n\n\n" + elif self.sql_pattern == SqlPattern.AGGR_FUNC_ONLY: + result_print_str = f"{PPC_RESULT_VALUES_FLAG} = [results[0][0].reveal()" + for i in range(1, len(self.result_fields)): + result_print_str += f", results[0][{i}].reveal()" + result_print_str += ']' + mpc_str += f"def print_results():\n" \ + f"{INDENT}{field_print_str}\n" \ + f"{INDENT}{DISPLAY_FIELDS_FUNC}({PPC_RESULT_FIELDS_FLAG})\n\n" \ + f"{INDENT}{result_print_str}\n" \ + f"{INDENT}{DISPLAY_RESULT_FUNC}({PPC_RESULT_VALUES_FLAG})\n\n\n" + return mpc_str + + def generate_mpc_execution_code(self, mpc_str): + mpc_str += 'def ppc_main():\n' + for i in range(len(self.table_set)): + mpc_str += f"{INDENT}read_data_collection(source{i}_record, SOURCE{i})\n" + + if self.sql_pattern == SqlPattern.AGGR_FUNC_WITH_GROUP_BY: + mpc_str += f"\n{INDENT}compute_group_by_indexes(source{self.group_by_column[0]}_record, " \ + f"{self.group_by_column[1]})\n\n" + + for i in range(len(self.result_fields)): + mpc_str += f"{INDENT}calculate_result_{i}()\n" + + mpc_str += f"\n{INDENT}print_results()\n\n\n" + + mpc_str += "ppc_main()\n" + return mpc_str + + def generate_formula(self, tokens, formula_str, participants_set): + for token in tokens: + if token.ttype == Punctuation \ + or token.ttype == Operator \ + or token.ttype == Token.Literal.Number.Integer \ + or token.ttype == Token.Operator.Comparison: + formula_str += token.value + elif type(token) == Function: + formula_str += self.handle_function(token) + elif type(token) == Identifier and token.tokens[0].ttype == Name and len(token.tokens) >= 3: + (table_number, field_number) = self.handle_basic_identifier(token) + if self.sql_pattern == SqlPattern.AGGR_FUNC_WITH_GROUP_BY: + if table_number != self.group_by_column[0] or field_number != self.group_by_column[1]: + raise PpcException(PpcErrorCode.ALGORITHM_BAD_SQL.get_code(), "bad sql statement") + self.update_max_field(table_number, field_number) + formula_str += GROUP_BY_COLUMN_CODE + elif self.sql_pattern == SqlPattern.AGGR_FUNC_ONLY: + raise PpcException(PpcErrorCode.ALGORITHM_BAD_SQL.get_code(), "disabled query pattern") + elif self.sql_pattern == SqlPattern.BASIC_ARITH_OPE: + if token.value == token.parent.value: + raise PpcException(PpcErrorCode.ALGORITHM_BAD_SQL.get_code(), + "disabled query pattern") + self.update_max_field(table_number, field_number) + formula_str += f"source{table_number}_record[i][{field_number}]" + participants_set.add(table_number) + elif hasattr(token, 'tokens'): + formula_str = self.generate_formula(token.tokens, formula_str, participants_set) + return formula_str + + def handle_function(self, token): + tokens = token.tokens + func_name = tokens[0].value + (table_number, field_number) = self.handle_parenthesis(tokens[1]) + self.update_max_field(table_number, field_number) + return self.func_to_formula(func_name, table_number, field_number) + + def func_to_formula(self, func, table_number, field_number): + if self.sql_pattern == SqlPattern.AGGR_FUNC_ONLY: + formula = { + 'COUNT': f"{mpc_func_str.FUNC_COMPUTE_COUNT_NAME}(source{table_number}_record_count)", + 'SUM': f"{mpc_func_str.FUNC_COMPUTE_SUM_NAME}(source{table_number}_record, " + f"source{table_number}_record_count, {field_number})", + 'AVG': f"{mpc_func_str.FUNC_COMPUTE_AVG_NAME}(source{table_number}_record, " + f"source{table_number}_record_count, {field_number})", + 'MAX': f"{mpc_func_str.FUNC_COMPUTE_MAX_NAME}(source{table_number}_record, " + f"source{table_number}_record_count, {field_number})", + 'MIN': f"{mpc_func_str.FUNC_COMPUTE_MIN_NAME}(source{table_number}_record, " + f"source{table_number}_record_count, {field_number})" + } + elif self.sql_pattern == SqlPattern.AGGR_FUNC_WITH_GROUP_BY: + formula = { + 'COUNT': f"{mpc_func_str.FUNC_COMPUTE_COUNT_WITH_GROUP_BY_NAME}(i)", + 'SUM': f"{mpc_func_str.FUNC_COMPUTE_SUM_WITH_GROUP_BY_NAME}(source{table_number}_record," + f" {field_number}, i)", + 'AVG': f"{mpc_func_str.FUNC_COMPUTE_AVG_WITH_GROUP_BY_NAME}(source{table_number}_record," + f" {field_number}, i)", + 'MAX': f"{mpc_func_str.FUNC_COMPUTE_MAX_WITH_GROUP_BY_NAME}(source{table_number}_record," + f" {field_number}, i)", + 'MIN': f"{mpc_func_str.FUNC_COMPUTE_MIN_WITH_GROUP_BY_NAME}(source{table_number}_record," + f" {field_number}, i)" + } + else: + formula = {} + + return formula.get(func, '') + + def handle_parenthesis(self, token): + for token in token.tokens: + if type(token) == Identifier: + (table_number, field_number) = self.handle_basic_identifier(token) + return table_number, field_number + + def handle_basic_identifier(self, token): + tokens = token.tokens + + if tokens[0].value not in self.table_set: + raise PpcException(PpcErrorCode.ALGORITHM_BAD_SQL.get_code(), "table name not matched") + if tokens[1].value != '.': + raise PpcException(PpcErrorCode.ALGORITHM_BAD_SQL.get_code(), "invalid identifier") + + field_num = get_column_number(tokens[2].value) + + return int(tokens[0].value[-1]), field_num + + +def get_column_number(field_name): + field_len = len(field_name) + field_num = 0 + for i in range(field_len, 0, -1): + try: + int(field_name[i - 1:field_len]) + except ValueError: + if i == field_len: + raise PpcException(PpcErrorCode.ALGORITHM_BAD_SQL.get_code(), + f"invalid field suffix of table column '{field_name}'") + field_num = int(field_name[i:field_len]) + break + return field_num diff --git a/python/ppc_scheduler/mpc_generator/mpc_func_str.py b/python/ppc_scheduler/mpc_generator/mpc_func_str.py new file mode 100644 index 00000000..87501124 --- /dev/null +++ b/python/ppc_scheduler/mpc_generator/mpc_func_str.py @@ -0,0 +1,64 @@ +import os + +FILE_PATH = os.path.abspath(__file__) + +CURRENT_PATH = os.path.abspath(os.path.dirname(FILE_PATH) + os.path.sep + ".") + +AGGR_FUNC_SAMPLE_PATH = f"{CURRENT_PATH}{os.sep}mpc_sample{os.sep}aggr_func_only.mpc" +GROUP_BY_SAMPLE_PATH = f"{CURRENT_PATH}{os.sep}mpc_sample{os.sep}aggr_func_with_group_by.mpc" + +FUNC_COMPUTE_SUM_NAME = 'compute_sum' +FUNC_COMPUTE_COUNT_NAME = 'compute_count' +FUNC_COMPUTE_AVG_NAME = 'compute_avg' +FUNC_COMPUTE_MAX_NAME = 'compute_max' +FUNC_COMPUTE_MIN_NAME = 'compute_min' + +FUNC_COMPUTE_GROUP_BY_INDEXES_NAME = 'compute_group_by_indexes' +FUNC_COMPUTE_SUM_WITH_GROUP_BY_NAME = 'compute_sum_with_group_by' +FUNC_COMPUTE_COUNT_WITH_GROUP_BY_NAME = 'compute_count_with_group_by' +FUNC_COMPUTE_AVG_WITH_GROUP_BY_NAME = 'compute_avg_with_group_by' +FUNC_COMPUTE_MAX_WITH_GROUP_BY_NAME = 'compute_max_with_group_by' +FUNC_COMPUTE_MIN_WITH_GROUP_BY_NAME = 'compute_min_with_group_by' + +with open(AGGR_FUNC_SAMPLE_PATH, "r") as file: + AGGR_FUNC_SAMPLE_STR = file.read() + +with open(GROUP_BY_SAMPLE_PATH, "r") as file: + GROUP_BY_SAMPLE_STR = file.read() + + +def get_body_str_by_name(start_str, end_str, sql_pattern): + if sql_pattern == 1: + source_str = AGGR_FUNC_SAMPLE_STR + elif sql_pattern == 2: + source_str = GROUP_BY_SAMPLE_STR + else: + return '' + + start_index = source_str.find(start_str) + source_str = source_str[start_index:] + + end_index = source_str.find(end_str) + len(end_str) + return source_str[:end_index] + + +def get_func_str_by_name(func_name, sql_pattern): + start_str = f"def {func_name}" + end_str = "\n\n\n" + return get_body_str_by_name(start_str, end_str, sql_pattern) + + +FUNC_COMPUTE_SUM = get_func_str_by_name(FUNC_COMPUTE_SUM_NAME, 1) +FUNC_COMPUTE_COUNT = get_func_str_by_name(FUNC_COMPUTE_COUNT_NAME, 1) +FUNC_COMPUTE_AVG = get_func_str_by_name(FUNC_COMPUTE_AVG_NAME, 1) +FUNC_COMPUTE_MAX = get_func_str_by_name(FUNC_COMPUTE_MAX_NAME, 1) +FUNC_COMPUTE_MIN = get_func_str_by_name(FUNC_COMPUTE_MIN_NAME, 1) + +GROUP_BY_GLOBAL_VARIABLE = get_body_str_by_name("# matrix of indexes", "\n\n\n", 2) + +FUNC_COMPUTE_GROUP_BY_INDEXES = get_func_str_by_name(FUNC_COMPUTE_GROUP_BY_INDEXES_NAME, 2) +FUNC_COMPUTE_SUM_WITH_GROUP_BY = get_func_str_by_name(FUNC_COMPUTE_SUM_WITH_GROUP_BY_NAME, 2) +FUNC_COMPUTE_COUNT_WITH_GROUP_BY = get_func_str_by_name(FUNC_COMPUTE_COUNT_WITH_GROUP_BY_NAME, 2) +FUNC_COMPUTE_AVG_WITH_GROUP_BY = get_func_str_by_name(FUNC_COMPUTE_AVG_WITH_GROUP_BY_NAME, 2) +FUNC_COMPUTE_MAX_WITH_GROUP_BY = get_func_str_by_name(FUNC_COMPUTE_MAX_WITH_GROUP_BY_NAME, 2) +FUNC_COMPUTE_MIN_WITH_GROUP_BY = get_func_str_by_name(FUNC_COMPUTE_MIN_WITH_GROUP_BY_NAME, 2) diff --git a/python/ppc_scheduler/mpc_generator/mpc_sample/aggr_func_only.mpc b/python/ppc_scheduler/mpc_generator/mpc_sample/aggr_func_only.mpc new file mode 100644 index 00000000..9ddb3c34 --- /dev/null +++ b/python/ppc_scheduler/mpc_generator/mpc_sample/aggr_func_only.mpc @@ -0,0 +1,136 @@ +# PSI_OPTION=True + +# BIT_LENGTH = 128 + +# This file is generated automatically by ams +''' +SELECT COUNT(s1.field3) + COUNT(s2.field3) AS r0, + SUM(s1.field3) + COUNT(s0.field0) AS 'count', + (MAX(s0.field1) + MAX(s2.field1)) / 2 AS r1, + (AVG(s1.field2) + AVG(s2.field2)) / 2 AS r2, + MIN(s1.field0) - MIN(s0.field0) AS r3 +FROM (source0 AS s0 + INNER JOIN source1 AS s1 ON s0.id = s1.id) +INNER JOIN source2 AS s2 ON s0.id = s2.id; +''' + +from ppc import * + +n_threads = 8 +value_type = pfix + +pfix.set_precision(16, 47) + +SOURCE0 = 0 +source0_record_count = $(source0_record_count) +source0_column_count = 2 +source0_record = Matrix(source0_record_count, source0_column_count, value_type) + +SOURCE1 = 1 +source1_record_count = $(source1_record_count) +source1_column_count = 4 +source1_record = Matrix(source1_record_count, source1_column_count, value_type) + +SOURCE2 = 2 +source2_record_count = $(source2_record_count) +source2_column_count = 4 +source2_record = Matrix(source2_record_count, source2_column_count, value_type) + +results = Matrix(1, 5, value_type) + + +def read_data_collection(data_collection, party_id): + if data_collection.sizes[0] > 0: + data_collection.input_from(party_id) + + +def compute_sum(source, record_count, col_index): + records_sum = Array(1, value_type) + records_sum[0] = source[0][col_index] + + @for_range(1, record_count) + def _(i): + records_sum[0] = records_sum[0] + source[i][col_index] + + return records_sum[0] + + +def compute_count(record_count): + return record_count + + +def compute_avg(source, record_count, col_index): + records_sum = Array(1, value_type) + records_sum[0] = source[0][col_index] + + @for_range(1, record_count) + def _(i): + records_sum[0] = records_sum[0] + source[i][col_index] + + return records_sum[0] / record_count + + +def compute_max(source, record_count, col_index): + max_record = Array(1, value_type) + max_record[0] = source[0][col_index] + + @for_range(1, record_count) + def _(i): + max_record[0] = condition(max_record[0] < source[i][col_index], source[i][col_index], max_record[0]) + + return max_record[0] + + +def compute_min(source, record_count, col_index): + min_record = Array(1, value_type) + min_record[0] = source[0][col_index] + + @for_range(1, record_count) + def _(i): + min_record[0] = condition(min_record[0] > source[i][col_index], source[i][col_index], min_record[0]) + + return min_record[0] + + +def calculate_result_0(): + results[0][0] = compute_count(source1_record_count)+compute_count(source2_record_count) + + +def calculate_result_1(): + results[0][1] = compute_sum(source1_record, source1_record_count, 3)+compute_count(source0_record_count) + + +def calculate_result_2(): + results[0][2] = (compute_max(source0_record, source0_record_count, 1)+compute_max(source2_record, source2_record_count, 1))/2 + + +def calculate_result_3(): + results[0][3] = (compute_avg(source1_record, source1_record_count, 2)+compute_avg(source2_record, source2_record_count, 2))/2 + + +def calculate_result_4(): + results[0][4] = compute_min(source1_record, source1_record_count, 0)-compute_min(source0_record, source0_record_count, 0) + + +def print_results(): + result_fields = ['r0', 'count', 'r1', 'r2', 'r3'] + set_display_field_names(result_fields) + + result_values = [results[0][0].reveal(), results[0][1].reveal(), results[0][2].reveal(), results[0][3].reveal(), results[0][4].reveal()] + display_data(result_values) + + +def ppc_main(): + read_data_collection(source0_record, SOURCE0) + read_data_collection(source1_record, SOURCE1) + read_data_collection(source2_record, SOURCE2) + calculate_result_0() + calculate_result_1() + calculate_result_2() + calculate_result_3() + calculate_result_4() + + print_results() + + +ppc_main() diff --git a/python/ppc_scheduler/mpc_generator/mpc_sample/aggr_func_with_group_by.mpc b/python/ppc_scheduler/mpc_generator/mpc_sample/aggr_func_with_group_by.mpc new file mode 100644 index 00000000..f1a01339 --- /dev/null +++ b/python/ppc_scheduler/mpc_generator/mpc_sample/aggr_func_with_group_by.mpc @@ -0,0 +1,234 @@ +# PSI_OPTION=True + +# BIT_LENGTH = 128 + +# This file is generated automatically by ams +''' +SELECT 3*s1.field4 AS r0, + COUNT(s1.field4) AS 'count', + AVG(s0.field1) * 2 + s1.field4 AS r1, + (SUM(s0.field2) + SUM(s1.field2))/(COUNT(s1.field3) + 100/(MIN(s0.field1)+MIN(s1.field1))) + 10, + MAX(s1.field1), + MIN(s2.field2) +FROM (source0 AS s0 + INNER JOIN source1 AS s1 ON s0.id = s1.id) +INNER JOIN source2 AS s2 ON s0.id = s2.id +GROUP BY s1.field4; +''' + +from ppc import * + +n_threads = 8 +value_type = pfix + +pfix.set_precision(16, 47) + +SOURCE0 = 0 +source0_record_count = $(source0_record_count) +source0_column_count = 3 +source0_record = Matrix(source0_record_count, source0_column_count, value_type) + +SOURCE1 = 1 +source1_record_count = $(source1_record_count) +source1_column_count = 5 +source1_record = Matrix(source1_record_count, source1_column_count, value_type) + +SOURCE2 = 2 +source2_record_count = $(source2_record_count) +source2_column_count = 3 +source2_record = Matrix(source2_record_count, source2_column_count, value_type) + +# group by means all parties have same number of record +source_record_count = $(source0_record_count) +result_record = cint(source_record_count) +results = Matrix(source_record_count, 6, value_type) + + +def read_data_collection(data_collection, party_id): + if data_collection.sizes[0] > 0: + data_collection.input_from(party_id) + + +# matrix of indexes after group by: +# 0 1 2 +# 0 count1 start_index end_index +# 1 count2 start_index end_index +# 2 count3 start_index end_index +# ... +# source_record_count - 1 ... +group_indexes_key = Array(source_record_count, value_type) +group_indexes_matrix = Matrix(source_record_count, 3, pint) +group_column = Array(source_record_count, value_type) + + +def compute_group_by_indexes(source, col_index): + # group_count group_index group_flag + group_states = Array(3, cint) + + @for_range_opt(source_record_count) + def _(i): + group_column[i] = source[i][col_index] + group_states[1] = 0 + group_states[2] = 0 + + @for_range(group_states[0]) + def _(j): + @if_(pint(group_indexes_key[j] == source[i][col_index]).reveal()) + def _(): + group_states[1] = j + group_states[2] = 1 + + @if_e(group_states[2] == 0) + def _(): + # new item + group_indexes_key[group_states[0]] = source[i][col_index] + group_indexes_matrix[group_states[0]][0] = 1 + group_indexes_matrix[group_states[0]][1] = i + group_indexes_matrix[group_states[0]][2] = i + group_states[0] = group_states[0] + 1 + + @else_ + def _(): + group_indexes_matrix[group_states[1]][0] = group_indexes_matrix[group_states[1]][0] + 1 + group_indexes_matrix[group_states[1]][2] = i + + global result_record + result_record = group_states[0] + + +def compute_sum_with_group_by(source, col_index, group_row_index): + records_sum = Array(1, value_type) + + start_index = group_indexes_matrix[group_row_index][1].reveal() + end_index = group_indexes_matrix[group_row_index][2].reveal() + + records_sum[0] = source[start_index][col_index] + + @for_range(start_index + 1, end_index + 1) + def _(i): + @if_(pint(group_indexes_key[group_row_index] == group_column[i]).reveal()) + def _(): + records_sum[0] = records_sum[0] + source[i][col_index] + + return records_sum[0] + + +def compute_count_with_group_by(group_row_index): + return group_indexes_matrix[group_row_index][0] + + +def compute_avg_with_group_by(source, col_index, group_row_index): + records_sum = Array(1, value_type) + + start_index = group_indexes_matrix[group_row_index][1].reveal() + end_index = group_indexes_matrix[group_row_index][2].reveal() + + records_sum[0] = source[start_index][col_index] + + @for_range(start_index + 1, end_index + 1) + def _(i): + @if_(pint(group_indexes_key[group_row_index] == group_column[i]).reveal()) + def _(): + records_sum[0] = records_sum[0] + source[i][col_index] + + return value_type(records_sum[0] / group_indexes_matrix[group_row_index][0]) + + +def compute_max_with_group_by(source, col_index, group_row_index): + max_records = Array(1, value_type) + + start_index = group_indexes_matrix[group_row_index][1].reveal() + end_index = group_indexes_matrix[group_row_index][2].reveal() + + max_records[0] = source[start_index][col_index] + + @for_range(start_index + 1, end_index + 1) + def _(i): + @if_(pint(group_indexes_key[group_row_index] == group_column[i]).reveal()) + def _(): + max_records[0] = condition(max_records[0] < source[i][col_index], source[i][col_index], max_records[0]) + + return max_records[0] + + +def compute_min_with_group_by(source, col_index, group_row_index): + min_records = Array(1, value_type) + + start_index = group_indexes_matrix[group_row_index][1].reveal() + end_index = group_indexes_matrix[group_row_index][2].reveal() + + min_records[0] = source[start_index][col_index] + + @for_range(start_index + 1, end_index + 1) + def _(i): + @if_(pint(group_indexes_key[group_row_index] == group_column[i]).reveal()) + def _(): + min_records[0] = condition(min_records[0] > source[i][col_index], source[i][col_index], min_records[0]) + + return min_records[0] + + +def calculate_result_0(): + @for_range_opt_multithread(n_threads, result_record) + def _(i): + results[i][0] = 3*group_indexes_key[i] + + +def calculate_result_1(): + @for_range_opt_multithread(n_threads, result_record) + def _(i): + results[i][1] = compute_count_with_group_by(i) + + +def calculate_result_2(): + @for_range_opt_multithread(n_threads, result_record) + def _(i): + results[i][2] = compute_avg_with_group_by(source0_record, 1, i)*2+group_indexes_key[i] + + +def calculate_result_3(): + @for_range_opt_multithread(n_threads, result_record) + def _(i): + results[i][3] = (compute_sum_with_group_by(source0_record, 2, i)+compute_sum_with_group_by(source1_record, 2, i))/(compute_count_with_group_by(i)+100/(compute_min_with_group_by(source0_record, 1, i)+compute_min_with_group_by(source1_record, 1, i)))+10 + + +def calculate_result_4(): + @for_range_opt_multithread(n_threads, result_record) + def _(i): + results[i][4] = compute_max_with_group_by(source1_record, 1, i) + + +def calculate_result_5(): + @for_range_opt_multithread(n_threads, result_record) + def _(i): + results[i][5] = compute_min_with_group_by(source2_record, 2, i) + + +def print_results(): + result_fields = ['r0', 'count', 'r1', 'result3', 'result4', 'result5'] + set_display_field_names(result_fields) + + @for_range_opt(result_record) + def _(i): + result_values = [results[i][0].reveal(), results[i][1].reveal(), results[i][2].reveal(), results[i][3].reveal(), results[i][4].reveal(), results[i][5].reveal()] + display_data(result_values) + + +def ppc_main(): + read_data_collection(source0_record, SOURCE0) + read_data_collection(source1_record, SOURCE1) + read_data_collection(source2_record, SOURCE2) + + compute_group_by_indexes(source1_record, 4) + + calculate_result_0() + calculate_result_1() + calculate_result_2() + calculate_result_3() + calculate_result_4() + calculate_result_5() + + print_results() + + +ppc_main() diff --git a/python/ppc_scheduler/mpc_generator/mpc_sample/basic_arith_ope.mpc b/python/ppc_scheduler/mpc_generator/mpc_sample/basic_arith_ope.mpc new file mode 100644 index 00000000..fa2fc97d --- /dev/null +++ b/python/ppc_scheduler/mpc_generator/mpc_sample/basic_arith_ope.mpc @@ -0,0 +1,78 @@ +# PSI_OPTION=True + +# BIT_LENGTH = 128 + +# This file is generated automatically by ams +''' +SELECT 3*(s1.field3 + s2.field3) - s0.field3 AS r0, + (s0.field1 + s2.field1) / 2 * s1.field1 AS r1 +FROM (source0 AS s0 + INNER JOIN source1 AS s1 ON s0.id = s1.id) +INNER JOIN source2 AS s2 ON s0.id = s2.id; +''' + +from ppc import * + +n_threads = 8 +value_type = pfix + +pfix.set_precision(16, 47) + +SOURCE0 = 0 +source0_record_count = $(source0_record_count) +source0_column_count = 4 +source0_record = Matrix(source0_record_count, source0_column_count, value_type) + +SOURCE1 = 1 +source1_record_count = $(source1_record_count) +source1_column_count = 4 +source1_record = Matrix(source1_record_count, source1_column_count, value_type) + +SOURCE2 = 2 +source2_record_count = $(source2_record_count) +source2_column_count = 4 +source2_record = Matrix(source2_record_count, source2_column_count, value_type) + +# basic arithmetic operation means that all parties have same number of record +result_record = $(source0_record_count) +results = Matrix(result_record, 2, value_type) + + +def read_data_collection(data_collection, party_id): + if data_collection.sizes[0] > 0: + data_collection.input_from(party_id) + + +def calculate_result_0(): + @for_range_opt_multithread(n_threads, result_record) + def _(i): + results[i][0] = 3*(source1_record[i][3]+source2_record[i][3])-source0_record[i][3] + + +def calculate_result_1(): + @for_range_opt_multithread(n_threads, result_record) + def _(i): + results[i][1] = (source0_record[i][1]+source2_record[i][1])/2*source1_record[i][1] + + +def print_results(): + result_fields = ['r0', 'r1'] + set_display_field_names(result_fields) + + @for_range_opt(result_record) + def _(i): + result_values = [results[i][0].reveal(), results[i][1].reveal()] + display_data(result_values) + + +def ppc_main(): + read_data_collection(source0_record, SOURCE0) + read_data_collection(source1_record, SOURCE1) + read_data_collection(source2_record, SOURCE2) + calculate_result_0() + calculate_result_1() + + print_results() + + +ppc_main() diff --git a/python/ppc_scheduler/mpc_generator/test_generator.py b/python/ppc_scheduler/mpc_generator/test_generator.py new file mode 100644 index 00000000..9a37a2e1 --- /dev/null +++ b/python/ppc_scheduler/mpc_generator/test_generator.py @@ -0,0 +1,101 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- +import os +import unittest + +from ppc_scheduler.mpc_generator.generator import CodeGenerator +from ppc_common.ppc_utils.exception import PpcException, PpcErrorCode + +FILE_PATH = os.path.abspath(__file__) + +CURRENT_PATH = os.path.abspath(os.path.dirname(FILE_PATH) + os.path.sep + ".") + +BASIC_ARITH_OPE_PATH = f"{CURRENT_PATH}{os.sep}mpc_sample{os.sep}basic_arith_ope.mpc" +AGGR_FUNC_SAMPLE_PATH = f"{CURRENT_PATH}{os.sep}mpc_sample{os.sep}aggr_func_only.mpc" +GROUP_BY_SAMPLE_PATH = f"{CURRENT_PATH}{os.sep}mpc_sample{os.sep}aggr_func_with_group_by.mpc" + +with open(BASIC_ARITH_OPE_PATH, "r") as file: + BASIC_ARITH_OPE_STR = file.read() + +with open(AGGR_FUNC_SAMPLE_PATH, "r") as file: + AGGR_FUNC_SAMPLE_STR = file.read() + +with open(GROUP_BY_SAMPLE_PATH, "r") as file: + GROUP_BY_SAMPLE_STR = file.read() + + +class TestGenerator(unittest.TestCase): + + def test_bad_sql(self): + try: + sql = "select from a from b where c = d" + code_generator = CodeGenerator(sql) + code_generator.sql_to_mpc_code() + except PpcException as e: + self.assertEqual(PpcErrorCode.ALGORITHM_BAD_SQL.get_code(), e.code) + + sql = "select a0.f1 + b1.f1 from a0, b1 where a0.id=b1.id" + code_generator = CodeGenerator(sql) + self.assertIsNotNone(code_generator.sql_to_mpc_code()) + + def test_unsupported_keyword(self): + try: + sql = "select s0.f1 + s1.f1 from s0, s1 where s0.f1 > 1 and s0.f1 < 10" + code_generator = CodeGenerator(sql) + code_generator.sql_to_mpc_code() + except PpcException as e: + self.assertEqual("keyword 'AND' not supported", e.message) + + def test_disabled_query_pattern(self): + try: + sql = "select s0.f1, 3 + s1.f1 from s0, s1" + code_generator = CodeGenerator(sql) + code_generator.sql_to_mpc_code() + except PpcException as e: + self.assertEqual("disabled query pattern", e.message) + + try: + sql = "select s0.f1, s1.f1 + s1.f1 from s0, s1" + code_generator = CodeGenerator(sql) + code_generator.sql_to_mpc_code() + except PpcException as e: + self.assertEqual("disabled query pattern", e.message) + + def test_basic_pattern(self): + sql = "SELECT 3*(s1.field3 + s2.field3) - s0.field3 AS r0, \ + (s0.field1 + s2.field1) / 2 * s1.field1 AS r1\ + FROM (source0 AS s0\ + INNER JOIN source1 AS s1 ON s0.id = s1.id)\ + INNER JOIN source2 AS s2 ON s0.id = s2.id;" + code_generator = CodeGenerator(sql) + self.assertEqual(BASIC_ARITH_OPE_STR, code_generator.sql_to_mpc_code()) + + def test_single_aggre_pattern(self): + sql = "SELECT COUNT(s1.field3) + COUNT(s2.field3) AS r0,\ + SUM(s1.field3) + COUNT(s0.field0) AS 'count',\ + (MAX(s0.field1) + MAX(s2.field1)) / 2 AS r1,\ + (AVG(s1.field2) + AVG(s2.field2)) / 2 AS r2,\ + MIN(s1.field0) - MIN(s0.field0) AS r3\ + FROM (source0 AS s0\ + INNER JOIN source1 AS s1 ON s0.id = s1.id)\ + INNER JOIN source2 AS s2 ON s0.id = s2.id;" + code_generator = CodeGenerator(sql) + self.assertEqual(AGGR_FUNC_SAMPLE_STR, code_generator.sql_to_mpc_code()) + + def test_group_by_pattern(self): + sql = "SELECT 3*s1.field4 AS r0,\ + COUNT(s1.field4) AS 'count', \ + AVG(s0.field1) * 2 + s1.field4 AS r1,\ + (SUM(s0.field2) + SUM(s1.field2))/(COUNT(s1.field3) + 100/(MIN(s0.field1)+MIN(s1.field1))) + 10,\ + MAX(s1.field1),\ + MIN(s2.field2)\ + FROM (source0 AS s0\ + INNER JOIN source1 AS s1 ON s0.id = s1.id)\ + INNER JOIN source2 AS s2 ON s0.id = s2.id\ + GROUP BY s1.field4;" + code_generator = CodeGenerator(sql) + self.assertEqual(GROUP_BY_SAMPLE_STR, code_generator.sql_to_mpc_code()) + + +if __name__ == '__main__': + unittest.main(verbosity=1) diff --git a/python/ppc_scheduler/node/__init__.py b/python/ppc_scheduler/node/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/ppc_scheduler/node/computing_node_client/__init__.py b/python/ppc_scheduler/node/computing_node_client/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/ppc_scheduler/node/computing_node_client/model_node_client.py b/python/ppc_scheduler/node/computing_node_client/model_node_client.py new file mode 100644 index 00000000..03456c9b --- /dev/null +++ b/python/ppc_scheduler/node/computing_node_client/model_node_client.py @@ -0,0 +1,87 @@ +import time + +from ppc_common.ppc_utils import http_utils +from ppc_common.ppc_utils.exception import PpcException, PpcErrorCode + +RUN_MODEL_API_PREFIX = "/api/ppc-model/pml/run-model-task/" +GET_MODEL_LOG_API_PREFIX = "/api/ppc-model/pml/record-model-log/" + + +class ModelClient: + def __init__(self, log, endpoint, polling_interval_s: int = 5, max_retries: int = 5, retry_delay_s: int = 5): + self.log = log + self.endpoint = endpoint + self.polling_interval_s = polling_interval_s + self.max_retries = max_retries + self.retry_delay_s = retry_delay_s + self._completed_status = 'COMPLETED' + self._failed_status = 'FAILED' + + def run(self, args): + task_id = args['task_id'] + try: + self.log.info(f"ModelApi: begin to run model task {task_id}") + response = self._send_request_with_retry(http_utils.send_post_request, + endpoint=self.endpoint, + uri=RUN_MODEL_API_PREFIX + task_id, + params=args) + check_response(response) + return self._poll_task_status(task_id) + except Exception as e: + self.log.error(f"ModelApi: run model task error, task: {task_id}, error: {e}") + raise e + + def kill(self, job_id): + try: + self.log.info(f"ModelApi: begin to kill model task {job_id}") + response = self._send_request_with_retry(http_utils.send_delete_request, + endpoint=self.endpoint, + uri=RUN_MODEL_API_PREFIX + job_id) + check_response(response) + self.log.info(f"ModelApi: model task {job_id} was killed") + return response + except Exception as e: + self.log.warn(f"ModelApi: kill model task {job_id} failed, error: {e}") + raise e + + def _poll_task_status(self, task_id): + while True: + response = self._send_request_with_retry(http_utils.send_get_request, + endpoint=self.endpoint, + uri=RUN_MODEL_API_PREFIX + task_id) + check_response(response) + if response['data']['status'] == self._completed_status: + self.log.info(f"task {task_id} completed, response: {response['data']}") + return response + elif response['data']['status'] == self._failed_status: + self.log.warn(f"task {task_id} failed, response: {response['data']}") + raise PpcException(PpcErrorCode.CALL_SCS_ERROR.get_code(), response['data']) + else: + time.sleep(self.polling_interval_s) + + def get_remote_log(self, job_id): + response = self._send_request_with_retry(http_utils.send_get_request, + endpoint=self.endpoint, + uri=GET_MODEL_LOG_API_PREFIX + job_id) + check_response(response) + return response['data'] + + def _send_request_with_retry(self, request_func, *args, **kwargs): + attempt = 0 + while attempt < self.max_retries: + try: + response = request_func(*args, **kwargs) + return response + except Exception as e: + self.log.warn(f"Request failed: {e}, attempt {attempt + 1}/{self.max_retries}") + attempt += 1 + if attempt < self.max_retries: + time.sleep(self.retry_delay_s) + else: + self.log.warn(f"Request failed after {self.max_retries} attempts") + raise e + + +def check_response(response): + if response['errorCode'] != 0: + raise PpcException(PpcErrorCode.CALL_SCS_ERROR.get_code(), response['message']) diff --git a/python/ppc_scheduler/node/computing_node_client/mpc_node_client.py b/python/ppc_scheduler/node/computing_node_client/mpc_node_client.py new file mode 100644 index 00000000..f64ece20 --- /dev/null +++ b/python/ppc_scheduler/node/computing_node_client/mpc_node_client.py @@ -0,0 +1,32 @@ +import random + +from ppc_common.ppc_utils import http_utils, utils +from ppc_scheduler.node.computing_node_client.utils import check_privacy_service_response + + +class MpcClient: + def __init__(self, endpoint): + self.endpoint = endpoint + + def run(self, job_info, token): + params = { + 'jsonrpc': '2', + 'method': 'run', + 'token': token, + 'id': random.randint(1, 65535), + 'params': job_info + } + response = http_utils.send_post_request(self.endpoint, None, params) + check_privacy_service_response(response) + return response['result'] + + def kill(self, job_id, token): + params = { + 'jsonrpc': '2', + 'method': 'kill', + 'token': token, + 'id': random.randint(1, 65535), + 'params': {'jobId': job_id} + } + http_utils.send_post_request(self.endpoint, None, params) + return utils.make_response(0, "success", None) diff --git a/python/ppc_scheduler/node/computing_node_client/psi_node_client.py b/python/ppc_scheduler/node/computing_node_client/psi_node_client.py new file mode 100644 index 00000000..a4506a4c --- /dev/null +++ b/python/ppc_scheduler/node/computing_node_client/psi_node_client.py @@ -0,0 +1,66 @@ +import random +import time + +from ppc_common.ppc_utils import http_utils +from ppc_common.ppc_utils.exception import PpcException, PpcErrorCode +from ppc_scheduler.node.computing_node_client.utils import check_privacy_service_response + + +class PsiClient: + def __init__(self, log, endpoint, polling_interval_s: int = 5, max_retries: int = 5, retry_delay_s: int = 5): + self.log = log + self.endpoint = endpoint + self.polling_interval_s = polling_interval_s + self.max_retries = max_retries + self.retry_delay_s = retry_delay_s + self._async_run_task_method = 'asyncRunTask' + self._get_task_status_method = 'getTaskStatus' + self._completed_status = 'COMPLETED' + self._failed_status = 'FAILED' + + def run(self, job_info, token): + params = { + 'jsonrpc': '1', + 'method': self._async_run_task_method, + 'token': token, + 'id': random.randint(1, 65535), + 'params': job_info + } + response = self._send_request_with_retry(http_utils.send_post_request, self.endpoint, None, params) + check_privacy_service_response(response) + return self._poll_task_status(job_info['taskID'], token) + + def _poll_task_status(self, task_id, token): + while True: + params = { + 'jsonrpc': '1', + 'method': self._get_task_status_method, + 'token': token, + 'id': random.randint(1, 65535), + 'params': { + 'taskID': task_id, + } + } + response = self._send_request_with_retry(http_utils.send_post_request, self.endpoint, None, params) + check_privacy_service_response(response) + if response['result']['status'] == self._completed_status: + return response['result'] + elif response['result']['status'] == self._failed_status: + self.log.warn(f"task {task_id} failed, response: {response['data']}") + raise PpcException(PpcErrorCode.CALL_SCS_ERROR.get_code(), response['data']) + time.sleep(self.polling_interval_s) + + def _send_request_with_retry(self, request_func, *args, **kwargs): + attempt = 0 + while attempt < self.max_retries: + try: + response = request_func(*args, **kwargs) + return response + except Exception as e: + self.log.warn(f"Request failed: {e}, attempt {attempt + 1}/{self.max_retries}") + attempt += 1 + if attempt < self.max_retries: + time.sleep(self.retry_delay_s) + else: + self.log.warn(f"Request failed after {self.max_retries} attempts") + raise e diff --git a/python/ppc_scheduler/node/computing_node_client/utils.py b/python/ppc_scheduler/node/computing_node_client/utils.py new file mode 100644 index 00000000..3d3527c8 --- /dev/null +++ b/python/ppc_scheduler/node/computing_node_client/utils.py @@ -0,0 +1,8 @@ +from ppc_common.ppc_utils.exception import PpcErrorCode, PpcException + + +def check_privacy_service_response(response): + if 'result' not in response.keys(): + raise PpcException(PpcErrorCode.CALL_SCS_ERROR.get_code(), "http request error") + elif 0 != response['result']['code'] or response['result']['status'] == 'FAILED': + raise PpcException(PpcErrorCode.CALL_SCS_ERROR.get_code(), response['result']['message']) diff --git a/python/ppc_scheduler/node/node_manager.py b/python/ppc_scheduler/node/node_manager.py new file mode 100644 index 00000000..fa2b2247 --- /dev/null +++ b/python/ppc_scheduler/node/node_manager.py @@ -0,0 +1,32 @@ +from ppc_scheduler.database import computing_node_mapper +from ppc_scheduler.workflow.common.worker_type import WorkerType + + +class ComputingNodeManager: + type_map = { + WorkerType.T_PSI: 'PSI', + WorkerType.T_MPC: 'MPC', + WorkerType.T_PREPROCESSING: 'MODEL', + WorkerType.T_FEATURE_ENGINEERING: 'MODEL', + WorkerType.T_TRAINING: 'MODEL', + WorkerType.T_PREDICTION: 'MODEL', + } + + def __init__(self, components): + self.components = components + + def add_node(self, node_id: str, url: str, worker_type: str): + with self.components.create_sql_session() as session: + computing_node_mapper.insert_computing_node(session, node_id, url, self.type_map[worker_type], 0) + + def remove_node(self, url: str, worker_type: str): + with self.components.create_sql_session() as session: + computing_node_mapper.delete_computing_node(session, url, self.type_map[worker_type]) + + def get_node(self, worker_type: str): + with self.components.create_sql_session() as session: + return computing_node_mapper.get_and_update_min_loading_url(session, self.type_map[worker_type]) + + def release_node(self, url: str, worker_type: str): + with self.components.create_sql_session() as session: + return computing_node_mapper.release_loading(session, url, self.type_map[worker_type]) diff --git a/python/ppc_scheduler/ppc_scheduler_app.py b/python/ppc_scheduler/ppc_scheduler_app.py new file mode 100644 index 00000000..fb44f888 --- /dev/null +++ b/python/ppc_scheduler/ppc_scheduler_app.py @@ -0,0 +1,48 @@ +# Note: here can't be refactored by autopep + +from ppc_scheduler.endpoints.restx import api +from ppc_scheduler.endpoints.job_controller import ns as job_namespace +from ppc_scheduler.common.global_context import components +from paste.translogger import TransLogger +from flask import Flask, Blueprint +from cheroot.wsgi import Server as WSGIServer +from cheroot.ssl.builtin import BuiltinSSLAdapter +import os +import sys +sys.path.append("../") + + +app = Flask(__name__) + + +def initialize_app(app): + # 初始化应用功能组件 + components.init_all() + + app.config.update(components.config_data) + blueprint = Blueprint('api', __name__, url_prefix='/api') + api.init_app(blueprint) + api.add_namespace(job_namespace) + app.register_blueprint(blueprint) + + +if __name__ == '__main__': + initialize_app(app) + + app.config['SECRET_KEY'] = os.urandom(24) + server = WSGIServer((app.config['HOST'], app.config['HTTP_PORT']), + TransLogger(app, setup_console_handler=False), numthreads=2) + + ssl_switch = app.config['SSL_SWITCH'] + protocol = 'http' + if ssl_switch == 1: + protocol = 'https' + server.ssl_adapter = BuiltinSSLAdapter( + certificate=app.config['SSL_CRT'], + private_key=app.config['SSL_KEY'], + certificate_chain=app.config['CA_CRT']) + + message = f"Starting ppc scheduler server at {protocol}://{app.config['HOST']}:{app.config['HTTP_PORT']}" + print(message) + components.logger().info(message) + server.start() diff --git a/python/ppc_scheduler/workflow/__init__.py b/python/ppc_scheduler/workflow/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/ppc_scheduler/workflow/common/__init__.py b/python/ppc_scheduler/workflow/common/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/ppc_scheduler/workflow/common/codec.py b/python/ppc_scheduler/workflow/common/codec.py new file mode 100644 index 00000000..3fde280a --- /dev/null +++ b/python/ppc_scheduler/workflow/common/codec.py @@ -0,0 +1,57 @@ +from ppc_common.ppc_protos.generated.ppc_pb2 import JobWorkerOutputs, JobWorkerUpstreams, \ + InputStatement, JobWorkerInputsStatement +from ppc_common.ppc_utils import utils + + +def deserialize_worker_outputs(outputs_str): + outputs = [] + outputs_pb = JobWorkerOutputs() + utils.str_to_pb(outputs_pb, outputs_str) + for output in outputs_pb.outputs: + outputs.append(output) + return outputs + + +def serialize_worker_outputs_for_db(outputs): + outputs_pb = JobWorkerOutputs() + for output in outputs: + outputs_pb.outputs.append(output) + return utils.pb_to_str(outputs_pb) + + +def deserialize_upstreams(upstreams_str): + upstreams = [] + upstream_pb = JobWorkerUpstreams() + utils.str_to_pb(upstream_pb, upstreams_str) + for upstream in upstream_pb.upstreams: + upstreams.append(upstream) + return upstreams + + +def serialize_upstreams_for_db(upstreams): + upstreams_pb = JobWorkerUpstreams() + for upstream in upstreams: + upstreams_pb.upstreams.append(upstream) + return utils.pb_to_str(upstreams_pb) + + +def deserialize_inputs_statement(inputs_statement_str): + inputs_statement = [] + inputs_statement_pb = JobWorkerInputsStatement() + utils.str_to_pb(inputs_statement_pb, inputs_statement_str) + for input_statement_pb in inputs_statement_pb.inputs_statement: + inputs_statement.append({ + 'upstream': input_statement_pb.upstream, + 'output_index': input_statement_pb.output_index + }) + return inputs_statement + + +def serialize_inputs_statement_for_db(inputs_statement): + inputs_statement_pb = JobWorkerInputsStatement() + for input_statement in inputs_statement: + input_statement_pb = InputStatement() + input_statement_pb.upstream = input_statement['upstream'] + input_statement_pb.output_index = input_statement['output_index'] + inputs_statement_pb.inputs_statement.append(input_statement_pb) + return utils.pb_to_str(inputs_statement_pb) diff --git a/python/ppc_scheduler/workflow/common/default_flow_config.py b/python/ppc_scheduler/workflow/common/default_flow_config.py new file mode 100644 index 00000000..6797bbc4 --- /dev/null +++ b/python/ppc_scheduler/workflow/common/default_flow_config.py @@ -0,0 +1,225 @@ +flow_dict = { + "PSI": [ + { + "index": 1, + "type": "T_PSI", + "isParamsProvided": False, + "params": { + "type": 0, + "algorithm": 0, + "syncResult": True, + "parties": [ + { + "index": "", + "partyIndex": 1 + }, + { + "index": "", + "partyIndex": 0, + "data": { + "index": "", + "input": { + "type": 2, + "path": "" + }, + "output": { + "type": 2, + "path": "" + } + } + } + ] + } + } + ], + + "MPC": [ + { + "index": 1, + "type": "T_MPC" + } + ], + + "PSI_MPC": [ + { + "index": 1, + "type": "T_PSI" + }, + { + "index": 2, + "type": "T_MPC", + "upstreams": [ + { + "index": 1, + "output_input_map": [ + "0:0" + ] + } + ] + } + ], + + "PREPROCESSING": [ + { + "index": 1, + "type": "T_PREPROCESSING" + } + ], + + "FEATURE_ENGINEERING": [ + { + "index": 1, + "type": "T_PREPROCESSING" + }, + { + "index": 2, + "type": "T_FEATURE_ENGINEERING", + "upstreams": [ + { + "index": 1 + } + ] + } + ], + + "TRAINING": [ + { + "index": 1, + "type": "T_PREPROCESSING" + }, + { + "index": 2, + "type": "T_TRAINING", + "upstreams": [ + { + "index": 1 + } + ] + } + ], + + "PREDICTION": [ + { + "index": 1, + "type": "T_PREPROCESSING" + }, + { + "index": 2, + "type": "T_PREDICTION", + "upstreams": [ + { + "index": 1 + } + ] + } + ], + + "FEATURE_ENGINEERING_TRAINING": [ + { + "index": 1, + "type": "T_PREPROCESSING" + }, + { + "index": 2, + "type": "T_FEATURE_ENGINEERING", + "upstreams": [ + { + "index": 1 + } + ] + }, + { + "index": 3, + "type": "T_TRAINING", + "upstreams": [ + { + "index": 2 + } + ] + } + ], + + "PSI_FEATURE_ENGINEERING": [ + { + "index": 1, + "type": "T_PSI" + }, + { + "index": 2, + "type": "T_PREPROCESSING", + "upstreams": [ + { + "index": 1 + } + ] + }, + { + "index": 3, + "type": "T_FEATURE_ENGINEERING", + "upstreams": [ + { + "index": 2 + } + ] + } + ], + + "PSI_TRAINING": [ + { + "index": 1, + "type": "T_PSI" + }, + { + "index": 2, + "type": "T_PREPROCESSING", + "upstreams": [ + { + "index": 1 + } + ] + }, + { + "index": 3, + "type": "T_TRAINING", + "upstreams": [ + { + "index": 2 + } + ] + } + ], + + "PSI_FEATURE_ENGINEERING_TRAINING": [ + { + "index": 1, + "type": "T_PSI" + }, + { + "index": 2, + "type": "T_PREPROCESSING", + "upstreams": [ + { + "index": 1 + } + ] + }, + { + "index": 3, + "type": "T_FEATURE_ENGINEERING", + "upstreams": [ + { + "index": 2 + } + ] + }, + { + "index": 4, + "type": "T_TRAINING", + "upstreams": [ + { + "index": 3 + } + ] + } + ] +} diff --git a/python/ppc_scheduler/workflow/common/flow_utils.py b/python/ppc_scheduler/workflow/common/flow_utils.py new file mode 100644 index 00000000..3d05e11e --- /dev/null +++ b/python/ppc_scheduler/workflow/common/flow_utils.py @@ -0,0 +1,34 @@ +from ppc_scheduler.workflow.common.worker_type import WorkerType + + +def cat_worker_id(job_id, index, worker_type): + return f"{job_id}_{index}_{worker_type}" + + +def success_id(job_id): + return cat_worker_id(job_id, 0, WorkerType.T_ON_SUCCESS) + + +def failure_id(job_id): + return cat_worker_id(job_id, 0, WorkerType.T_ON_FAILURE) + + +def to_origin_inputs(worker_inputs): + inputs = [] + for each in worker_inputs: + output_index = each['output_index'] + upstream_outputs = each['upstream_outputs'] + inputs.append(upstream_outputs[output_index]) + return inputs + + +def to_worker_inputs(job_workers, inputs_statement): + worker_inputs = [] + for each in inputs_statement: + output_index = each['output_index'] + upstream_unit = each['upstream_unit'] + worker_inputs.append({ + 'output_index': output_index, + 'upstream_outputs': job_workers[upstream_unit] + }) + return worker_inputs diff --git a/python/ppc_scheduler/workflow/common/job_context.py b/python/ppc_scheduler/workflow/common/job_context.py new file mode 100644 index 00000000..0a6c6feb --- /dev/null +++ b/python/ppc_scheduler/workflow/common/job_context.py @@ -0,0 +1,115 @@ +import os + +from ppc_scheduler.job.job_type import JobType +from ppc_scheduler.mpc_generator.generator import CodeGenerator +from ppc_scheduler.workflow.common.default_flow_config import flow_dict +from ppc_scheduler.common.global_context import components + + +class JobContext: + PSI_PREPARE_FILE = "psi_inputs" + PSI_RESULT_INDEX_FILE = "psi_result_index" + PSI_RESULT_FILE = "psi_result" + MPC_PREPARE_FILE = "mpc_prepare.csv" + MPC_RESULT_FILE = "mpc_result.csv" + MPC_OUTPUT_FILE = "mpc_output.txt" + HDFS_STORAGE_PATH = "/user/ppc/" + + def __init__(self, args, workspace): + self.args = args + #todo: 确保java服务给过来的任务信息包含如下字段,如果是建模相关的任务,还需要job_context.model_config_dict = args['model_config'] + # 如果是mpc任务,还需要有args['sql'],或者args['mpc_content'] + self.job_id: str = args['job_id'] + self.user_name: str = args['user_name'] + self.dataset_id: str = args['dataset_id'] + + self.psi_fields: str = args['psi_fields'] + + self.result_receiver_list: list = args['result_receiver_list'] + self.participant_id_list: list = args['participant_id_list'] + self.job_type = args['job_type'] + self.dataset_record_count = args['dataset_record_count'] + + self.my_index = None + self.need_run_psi = False + self.need_run_fe = False + self.mpc_content = None + + self.model_config_dict: dict = {} + self.tag_provider_agency_id = None + self.job_subtype = None + self.predict_algorithm = None + + self.worker_configs: list = [] + self.workflow_view_path = 'workflow_view' + + self.workspace = workspace + self.job_cache_dir = "{}{}{}".format(self.workspace, os.sep, self.job_id) + self.dataset_file_path = "{}{}{}".format(self.job_cache_dir, os.sep, self.dataset_id) + self.psi_prepare_path = "{}{}{}".format(self.job_cache_dir, os.sep, JobContext.PSI_PREPARE_FILE) + self.psi_result_index_path = "{}{}{}".format(self.job_cache_dir, os.sep, JobContext.PSI_RESULT_INDEX_FILE) + self.psi_result_path = "{}{}{}".format(self.job_cache_dir, os.sep, JobContext.PSI_RESULT_FILE) + self.mpc_file_name = "{}.mpc".format(self.job_id) + self.mpc_model_module_name = "{}.json".format(self.job_id) + self.mpc_file_path = "{}{}{}".format(self.job_cache_dir, os.sep, self.mpc_file_name) + self.mpc_prepare_path = "{}{}{}".format(self.job_cache_dir, os.sep, JobContext.MPC_PREPARE_FILE) + self.mpc_result_path = "{}{}{}".format(self.job_cache_dir, os.sep, JobContext.MPC_RESULT_FILE) + self.mpc_output_path = "{}{}{}".format(self.job_cache_dir, os.sep, JobContext.MPC_OUTPUT_FILE) + + @staticmethod + def load_from_args(args, workspace): + job_context = JobContext(args, workspace) + job_context.my_index = job_context.participant_id_list.index(components.config_data['AGENCY_ID']) + + if job_context.job_type == JobType.PREPROCESSING or \ + job_context.job_type == JobType.TRAINING or \ + job_context.job_type == JobType.PREDICTION or \ + job_context.job_type == JobType.FEATURE_ENGINEERING: + + job_context.model_config_dict = args['model_config'] + job_context.tag_provider_agency_id = job_context.participant_id_list[0] + if 'job_subtype' in job_context.model_config_dict: + job_context.job_subtype = job_context.model_config_dict['job_subtype'] + if 'predict_algorithm' in job_context.model_config_dict: + job_context.predict_algorithm = job_context.model_config_dict['predict_algorithm'] + if 'use_psi' in job_context.model_config_dict: + job_context.need_run_psi = job_context.model_config_dict['use_psi'] == 1 + if 'use_iv' in job_context.model_config_dict: + job_context.need_run_fe = job_context.model_config_dict['use_iv'] == 1 + + if job_context.job_type == JobType.PSI: + job_context.worker_configs = flow_dict['PSI'] + elif job_context.job_type == JobType.MPC: + if 'sql' in job_context.args: + job_context.sql = args['sql'] + job_context.mpc_content = CodeGenerator(job_context.sql) + else: + job_context.mpc_content = job_context.args['mpc_content'] + if "PSI_OPTION=True" in job_context.mpc_content: + job_context.need_run_psi = True + job_context.worker_configs = flow_dict['PSI_MPC'] + else: + job_context.worker_configs = flow_dict['MPC'] + elif job_context.job_type == JobType.PREPROCESSING: + job_context.worker_configs = flow_dict['PREPROCESSING'] + elif job_context.job_type == JobType.FEATURE_ENGINEERING: + if job_context.need_run_psi: + job_context.worker_configs = flow_dict['PSI_FEATURE_ENGINEERING'] + else: + job_context.worker_configs = flow_dict['FEATURE_ENGINEERING'] + elif job_context.job_type == JobType.TRAINING: + if job_context.need_run_psi: + if job_context.need_run_fe: + job_context.worker_configs = flow_dict['PSI_FEATURE_ENGINEERING_TRAINING'] + else: + job_context.worker_configs = flow_dict['PSI_TRAINING'] + else: + if job_context.need_run_fe: + job_context.worker_configs = flow_dict['FEATURE_ENGINEERING_TRAINING'] + else: + job_context.worker_configs = flow_dict['TRAINING'] + elif job_context.job_type == JobType.PREDICTION: + job_context.worker_configs = flow_dict['PREDICTION'] + else: + raise Exception("Unsupported job type {}".format(job_context.job_type)) + return job_context diff --git a/python/ppc_scheduler/workflow/common/worker_status.py b/python/ppc_scheduler/workflow/common/worker_status.py new file mode 100644 index 00000000..f915f473 --- /dev/null +++ b/python/ppc_scheduler/workflow/common/worker_status.py @@ -0,0 +1,7 @@ +class WorkerStatus: + PENDING = 'PENDING' + RUNNING = 'RUNNING' + SUCCESS = 'SUCCESS' + FAILURE = 'FAILURE' + TIMEOUT = 'TIMEOUT' + KILLED = 'KILLED' diff --git a/python/ppc_scheduler/workflow/common/worker_type.py b/python/ppc_scheduler/workflow/common/worker_type.py new file mode 100644 index 00000000..d10f1f8f --- /dev/null +++ b/python/ppc_scheduler/workflow/common/worker_type.py @@ -0,0 +1,17 @@ +class WorkerType: + # generic job worker + T_API = 'T_API' + T_PYTHON = 'T_PYTHON' + T_SHELL = 'T_SHELL' + + # specific job worker + T_PSI = 'T_PSI' + T_MPC = 'T_MPC' + T_PREPROCESSING = 'T_PREPROCESSING' + T_FEATURE_ENGINEERING = 'T_FEATURE_ENGINEERING' + T_TRAINING = 'T_TRAINING' + T_PREDICTION = 'T_PREDICTION' + + # finish job + T_ON_SUCCESS = 'T_ON_SUCCESS' + T_ON_FAILURE = 'T_ON_FAILURE' diff --git a/python/ppc_scheduler/workflow/constructor.py b/python/ppc_scheduler/workflow/constructor.py new file mode 100644 index 00000000..ca6cf22b --- /dev/null +++ b/python/ppc_scheduler/workflow/constructor.py @@ -0,0 +1,82 @@ +# -*- coding: utf-8 -*- +from ppc_common.db_models.job_worker_record import JobWorkerRecord +from ppc_common.ppc_utils import utils +from ppc_scheduler.common.global_context import components +from ppc_scheduler.database import job_worker_mapper +from ppc_scheduler.workflow.common import codec, flow_utils +from ppc_scheduler.workflow.common.job_context import JobContext +from ppc_scheduler.workflow.common.worker_status import WorkerStatus + + +class Constructor: + def __init__(self): + self.log = components.logger() + + def build_flow_context(self, job_context: JobContext): + self.log.info(f"start build_flow_context, job_id: {job_context.job_id}") + job_id = job_context.job_id + flow_context = {} + index_type_map = {} + for worker_config in job_context.worker_configs: + index_type_map[worker_config['index']] = worker_config['type'] + + for worker_config in job_context.worker_configs: + worker_type = worker_config['type'] + worker_id = flow_utils.cat_worker_id(job_id, worker_config['index'], worker_type) + upstreams = [] + inputs_statement = [] + inputs_statement_tuple = [] + if 'upstream' in worker_config: + for upstream_config in worker_config["upstreams"]: + index = upstream_config['index'] + upstream_id = flow_utils.cat_worker_id(job_id, index, index_type_map[index]) + upstreams.append(upstream_id), + if 'output_input_map' in upstream_config: + for mapping in upstream_config.get("output_input_map", []): + output_index, input_index = mapping.split(":") + inputs_statement_tuple.append((upstream_id, int(output_index), int(input_index))) + + inputs_statement_tuple.sort(key=lambda x: x[2]) + for upstream_id, output_index, _ in inputs_statement_tuple: + inputs_statement.append( + { + 'output_index': output_index, + 'upstream': upstream_id + } + ) + worker_context = self._construct_context(job_context, worker_id, worker_type, + upstreams, inputs_statement) + flow_context[worker_id] = worker_context + self.log.info(f"end build_flow_context, flow_context:\n{flow_context}") + + def _construct_context(self, job_context, worker_id, worker_type, upstreams, inputs_statement): + context = { + 'type': worker_type, + 'status': WorkerStatus.PENDING, + 'upstreams': upstreams, + 'inputs_statement': inputs_statement + } + + with components.create_sql_session() as session: + worker_record = job_worker_mapper.query_job_worker(components.create_sql_session, + job_context.job_id, worker_id) + if worker_record is None: + worker_record = JobWorkerRecord( + worker_id=worker_id, + job_id=job_context.job_id, + type=worker_type, + status=WorkerStatus.PENDING, + upstreams=codec.serialize_upstreams_for_db(upstreams), + inputs_statement=codec.serialize_inputs_statement_for_db(inputs_statement), + create_time=utils.make_timestamp(), + update_time=utils.make_timestamp() + ) + session.add(worker_record) + session.commit() + else: + context['status'] = worker_record.status + context['upstreams'] = codec.deserialize_upstreams(worker_record.upstreams) + context['inputs_statement'] = codec.deserialize_inputs_statement(worker_record.inputs_statement) + + self.log.debug(f"Load worker_context successfully, worker_id: {worker_id}, context:\n{context}") + return context diff --git a/python/ppc_scheduler/workflow/scheduler.py b/python/ppc_scheduler/workflow/scheduler.py new file mode 100644 index 00000000..ed2e1fc4 --- /dev/null +++ b/python/ppc_scheduler/workflow/scheduler.py @@ -0,0 +1,87 @@ +from prefect import Flow +from prefect.executors import LocalDaskExecutor +from prefect.triggers import all_successful, any_failed + +from ppc_scheduler.workflow.common import flow_utils +from ppc_scheduler.workflow.common.job_context import JobContext +from ppc_scheduler.workflow.common.worker_status import WorkerStatus +from ppc_scheduler.workflow.common.worker_type import WorkerType +from ppc_scheduler.workflow.constructor import Constructor +from ppc_scheduler.workflow.worker.worker_factory import WorkerFactory + + +class Scheduler: + def __init__(self, workspace): + self.workspace = workspace + self.constructor = Constructor() + + def schedule_job_flow(self, args): + job_context = JobContext.load_from_args(args, self.workspace) + flow_context = self.constructor.build_flow_context(job_context) + self._run(job_context, flow_context) + + @staticmethod + def _run(job_context, flow_context): + job_workers = {} + job_id = job_context.job_id + job_flow = Flow(f"job_flow_{job_id}") + + # create a final job worker to handle success + finish_job_on_success = WorkerFactory.build_worker( + job_context, + flow_utils.success_id(job_id), + WorkerType.T_ON_SUCCESS) + + finish_job_on_success.trigger = all_successful + finish_job_on_success.bind(worker_status=WorkerStatus.PENDING, worker_inputs=[], flow=job_flow) + job_flow.add_task(finish_job_on_success) + + # set reference task to bind job flow status + job_flow.set_reference_tasks([finish_job_on_success]) + + # create a final job worker to handle failure + finish_job_on_failure = WorkerFactory.build_worker( + job_context, + flow_utils.failure_id(job_id), + WorkerType.T_ON_FAILURE) + + # do finish_job_on_failure while any job worker failed + finish_job_on_failure.trigger = any_failed + finish_job_on_failure.bind(worker_status=WorkerStatus.PENDING, worker_inputs=[], flow=job_flow) + job_flow.add_task(finish_job_on_failure) + + # create main job workers + for worker_id in flow_context: + worker_type = flow_context[worker_id]['type'] + job_worker = WorkerFactory.build_worker(job_context, worker_id, worker_type) + job_flow.add_task(job_worker) + job_workers[worker_id] = job_worker + + # set upstream for final job + finish_job_on_success.set_upstream(job_worker, flow=job_flow) + finish_job_on_failure.set_upstream(job_worker, flow=job_flow) + + # customize main job workers + for worker_id in flow_context: + # set upstream + upstreams = flow_context[worker_id]['upstreams'] + for upstream in upstreams: + if upstream not in job_workers: + raise Exception(-1, f"upstream job worker not found: {upstream}, " + f"job_id: {job_context.job_id}") + job_workers[worker_id].set_upstream(job_workers[upstream], flow=job_flow) + + # bind worker inputs + inputs_statement = flow_context[worker_id]['inputs_statement'] + worker_inputs = flow_utils.to_worker_inputs(job_workers, inputs_statement) + job_workers[worker_id].bind(worker_status=flow_context[worker_id]['status'], + worker_inputs=worker_inputs, flow=job_flow) + + # enable parallel execution + job_flow.executor = LocalDaskExecutor() + + # run dag workflow + job_flow_state = job_flow.run() + + # save workflow view as file + job_flow.visualize(job_flow_state, job_context.workflow_view_path, 'svg') diff --git a/python/ppc_scheduler/workflow/worker/__init__.py b/python/ppc_scheduler/workflow/worker/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/ppc_scheduler/workflow/worker/api_worker.py b/python/ppc_scheduler/workflow/worker/api_worker.py new file mode 100644 index 00000000..2698133b --- /dev/null +++ b/python/ppc_scheduler/workflow/worker/api_worker.py @@ -0,0 +1,10 @@ +from ppc_scheduler.workflow.worker.worker import Worker + + +class ApiWorker(Worker): + + def __init__(self, components, job_context, worker_id, worker_type, *args, **kwargs): + super().__init__(components, job_context, worker_id, worker_type, *args, **kwargs) + + def engine_run(self, worker_inputs): + ... diff --git a/python/ppc_scheduler/workflow/worker/engine/__init__.py b/python/ppc_scheduler/workflow/worker/engine/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/ppc_scheduler/workflow/worker/engine/model_engine.py b/python/ppc_scheduler/workflow/worker/engine/model_engine.py new file mode 100644 index 00000000..2542c611 --- /dev/null +++ b/python/ppc_scheduler/workflow/worker/engine/model_engine.py @@ -0,0 +1,101 @@ +import os +import time + +from ppc_scheduler.workflow.common.job_context import JobContext +from ppc_scheduler.workflow.common.worker_type import WorkerType + + +class ModelWorkerEngine: + def __init__(self, model_client, worker_type, components, job_context: JobContext): + self.model_client = model_client + self.worker_type = worker_type + self.components = components + self.job_context = job_context + self.log = self.components.logger() + + def run(self) -> list: + if self.worker_type == WorkerType.T_PREPROCESSING: + self._run_preprocessing() + elif self.worker_type == WorkerType.T_FEATURE_ENGINEERING: + self._run_feature_engineering() + elif self.worker_type == WorkerType.T_TRAINING: + self._run_training() + elif self.worker_type == WorkerType.T_PREDICTION: + self._run_prediction() + else: + raise ValueError(f"Unsupported worker type: {self.worker_type}") + return [] + + def _run_preprocessing(self): + start = time.time() + job_id = self.job_context.job_id + task_id = job_id + '_d' + user_name = self.job_context.user_name + dataset_storage_path = os.path.join(user_name, self.job_context.dataset_id) + args = { + 'job_id': job_id, + 'task_id': task_id, + 'task_type': 'PREPROCESSING', + 'dataset_id': self.job_context.dataset_id, + 'dataset_storage_path': dataset_storage_path, + 'job_algorithm_type': self.job_context.job_type, + 'need_run_psi': self.job_context.need_run_psi, + 'model_dict': self.job_context.model_config_dict + } + self.log.info(f"start prepare_xgb, job_id: {job_id}, task_id: {task_id}, args: {args}") + self.model_client.run(args) + self.log.info( + f"call compute_xgb_job service success, job: {job_id}, " + f"task_id: {task_id}, timecost: {time.time() - start}") + + def _run_feature_engineering(self): + start = time.time() + job_id = self.job_context.job_id + task_id = job_id + '_f' + args = { + 'job_id': job_id, + 'task_id': task_id, + 'task_type': 'FEATURE_ENGINEERING', + 'is_label_holder': self.job_context.tag_provider_agency_id == self.components.config_data['AGENCY_ID'], + 'result_receiver_id_list': self.job_context.result_receiver_list, + 'participant_id_list': self.job_context.participant_id_list, + 'model_dict': self.job_context.model_config_dict + } + self.log.info(f"start feature_engineering, job_id: {job_id}, task_id: {task_id}, args: {args}") + self.model_client.run(args) + self.log.info( + f"call compute_xgb_job service success, job: {job_id}, " + f"task_id: {task_id}, timecost: {time.time() - start}") + + def _run_training(self): + # todo 支持LR + task_id = self.job_context.job_id + '_t' + task_type = 'XGB_TRAINING' + xgb_predict_algorithm = '' + self._run_model(task_id, task_type, xgb_predict_algorithm) + + def _run_prediction(self): + # todo 支持LR + task_id = self.job_context.job_id + '_p' + task_type = 'XGB_PREDICTING' + xgb_predict_algorithm = self.job_context.predict_algorithm + self._run_model(task_id, task_type, xgb_predict_algorithm) + + def _run_model(self, task_id, task_type, model_algorithm): + job_id = self.job_context.job_id + args = { + "job_id": job_id, + 'task_id': task_id, + 'task_type': task_type, + 'is_label_holder': self.job_context.tag_provider_agency_id == self.components.config_data['AGENCY_ID'], + 'result_receiver_id_list': self.job_context.result_receiver_list, + 'participant_id_list': self.job_context.participant_id_list, + 'model_predict_algorithm': model_algorithm, + "algorithm_type": self.job_context.job_type, + "algorithm_subtype": self.job_context.job_subtype, + "model_dict": self.job_context.model_config_dict + } + self.log.info(f"start run xgb task, job_id, job: {job_id}, " + f"task_id: {task_id}, task_type: {task_type}, args: {args}") + self.model_client.run(args) + self.log.info(f"call compute_xgb_job service success, job: {job_id}") diff --git a/python/ppc_scheduler/workflow/worker/engine/mpc_engine.py b/python/ppc_scheduler/workflow/worker/engine/mpc_engine.py new file mode 100644 index 00000000..b159c7b2 --- /dev/null +++ b/python/ppc_scheduler/workflow/worker/engine/mpc_engine.py @@ -0,0 +1,284 @@ +# -*- coding: utf-8 -*- +import os + +import pandas as pd + +from ppc_common.ppc_dataset import dataset_helper_factory +from ppc_common.ppc_utils import utils +from ppc_scheduler.mpc_generator.generator import CodeGenerator +from ppc_scheduler.workflow.common.job_context import JobContext + + +class MpcWorkerEngine: + def __init__(self, mpc_client, worker_type, components, job_context: JobContext): + self.mpc_client = mpc_client + self.worker_type = worker_type + self.components = components + self.job_context = job_context + self.log = self.components.logger() + + def run(self) -> list: + bit_length = self._prepare_mpc_file() + if self.job_context.need_run_psi: + self._prepare_mpc_after_psi() + else: + self._prepare_mpc_without_psi() + self._run_mpc_job(bit_length) + self._finish_mpc_job() + return [self.job_context.mpc_output_path] + + def _prepare_mpc_file(self): + # if self.job_context.sql is not None: + # # compile sql to mpc content + # mpc_content = CodeGenerator(self.job_context.sql) + # else: + # mpc_content = self.job_context.mpc_content + utils.write_content_to_file(self.job_context.mpc_content, self.job_context.mpc_file_path) + + self.components.storage_client.upload_file(self.job_context.mpc_file_path, + self.job_context.job_id + os.sep + self.job_context.mpc_file_name) + return self._get_share_bytes_length(self.job_context.mpc_content) + + def _run_mpc_job(self, bit_length): + job_id = self.job_context.job_id + mpc_record = 0 + self.log.info(f"start compute_mpc_job, job id: {job_id}") + if self.job_context.mpc_prepare_path: + mpc_record = sum(1 for _ in open(self.job_context.mpc_prepare_path)) + self.log.info(f"compute_mpc_job, mpc record: {mpc_record}") + utils.replace(self.job_context.mpc_file_path, mpc_record) + self._replace_mpc_field_holder() + self.components.storage_client.upload_file(self.job_context.mpc_file_path, + job_id + os.sep + self.job_context.mpc_file_name) + job_info = { + "jobId": job_id, + "mpcNodeUseGateway": False, + "receiverNodeIp": "", + "mpcNodeDirectPort": self.components.config_data["MPC_NODE_DIRECT_PORT"], + "participantCount": len(self.job_context.participant_id_list), + "selfIndex": self.job_context.my_index, + "isMalicious": self.components.config_data["IS_MALICIOUS"], + "bitLength": bit_length, + "inputFileName": "{}-P{}-0".format(JobContext.MPC_PREPARE_FILE, self.job_context.my_index), + "outputFileName": JobContext.MPC_OUTPUT_FILE + } + + self.log.info(f"call compute_mpc_job service, model run, params: {job_info}") + self.mpc_client.run(job_info, self.components.config_data['PPCS_RPC_TOKEN']) + self.components.storage_client.download_file(job_id + os.sep + JobContext.MPC_OUTPUT_FILE, + self.job_context.mpc_output_path) + self.log.info(f"call compute_mpc_job service success") + + def _finish_mpc_job(self): + job_id = self.job_context.job_id + index_file = None + if self.components.config_data['AGENCY_ID'] in self.job_context.result_receiver_list: + if self.job_context.need_run_psi: + if not utils.file_exists(self.job_context.psi_result_path): + self.log.info( + f"download finish_mpc_job psi_result_path, job_id={job_id}, " + f"download {job_id + os.sep + JobContext.PSI_RESULT_FILE}") + self.components.storage_client.download_file(job_id + os.sep + JobContext.PSI_RESULT_FILE, + self.job_context.psi_result_path) + self._order_psi_csv() + index_file = self.job_context.psi_result_path + if not utils.file_exists(self.job_context.mpc_output_path): + self.log.info( + f"download finish_mpc_job mpc_output_path, job_id={job_id}, " + f"download {job_id + os.sep + JobContext.MPC_OUTPUT_FILE}") + self.components.storage_client.download_file(job_id + os.sep + JobContext.MPC_OUTPUT_FILE, + self.job_context.mpc_output_path) + + self.log.info( + f"finish_mpc_job mpc_output_path, job_id={job_id}") + self._parse_and_write_final_mpc_result(self.job_context.mpc_output_path, index_file) + self.components.storage_client.upload_file(self.job_context.mpc_result_path, + job_id + os.sep + JobContext.MPC_RESULT_FILE) + self.log.info(f"finish_mpc_job success, job_id={job_id}") + + def _get_share_bytes_length(self, algorithm_content): + target = '# BIT_LENGTH = ' + if target in algorithm_content: + start = algorithm_content.find(target) + end = algorithm_content.find('\n', start + len(target)) + bit_length = int(algorithm_content[start + len(target): end].strip()) + self.log.info(f"OUTPUT_BIT_LENGTH = {bit_length}") + return bit_length + else: + self.log.info(f"OUTPUT_BIT_LENGTH = 64") + return 64 + + def _get_dataset_column_count(self): + with open(self.job_context.mpc_file_path, "r") as file: + mpc_str = file.read() + + lines = mpc_str.split('\n') + for line in lines: + if f"source{self.job_context.my_index}_column_count =" in line or \ + f"source{self.job_context.my_index}_column_count=" in line: + index = line.find('=') + return int(line[index + 1:].strip('\n').strip()) + + def _make_dataset_to_mpc_data_plus_psi_data(self, my_dataset_number): + chunk_list = pd.read_csv(self.job_context.dataset_file_path, delimiter=utils.CSV_SEP, + chunksize=self.components.dataset_handler_initializer.file_chunk_config.read_chunk_size) + psi_data = pd.read_csv(self.job_context.psi_result_path, delimiter=utils.CSV_SEP) + for chunk in chunk_list: + self._make_dataset_field_normalized(chunk) + mpc_data_df = pd.merge(chunk, psi_data, on=['id']).sort_values( + by='id', ascending=True) + self._save_selected_column_data( + my_dataset_number, mpc_data_df, self.job_context.mpc_prepare_path) + + @staticmethod + def _make_dataset_field_normalized(dataset_df): + data_field = dataset_df.columns.values + if 'id' in data_field: + data_field_normalized_names = ['id'] + size = len(data_field) + for i in range(size - 1): + data_field_normalized_names.append( + utils.NORMALIZED_NAMES.format(i)) + else: + data_field_normalized_names = [] + size = len(data_field) + for i in range(size): + data_field_normalized_names.append( + utils.NORMALIZED_NAMES.format(i)) + dataset_df.columns = data_field_normalized_names + + @staticmethod + def _save_selected_column_data(my_dataset_number, data_df, mpc_prepare_path): + column_list = [] + for i in range(0, int(my_dataset_number)): + column_list.append(utils.NORMALIZED_NAMES.format(i)) + result_new = pd.DataFrame(data_df, columns=column_list) + # sep must be space (ppc-mpc inputs) + result_new.to_csv(mpc_prepare_path, sep=' ', mode='a', header=False, index=None) + + def _prepare_mpc_after_psi(self): + job_id = self.job_context.job_id + self.log.info(f"start prepare_mpc_after_psi, job_id={job_id}") + my_dataset_number = self._get_dataset_column_count() + + dataset_helper_factory.download_dataset( + dataset_helper_factory=self.components.dataset_handler_initializer.dataset_helper_factory, + dataset_user=self.job_context.user_name, + dataset_id=self.job_context.dataset_id, + dataset_local_path=self.job_context.dataset_file_path, + log_keyword="prepare_mpc_after_psi", + logger=self.log) + + if not utils.file_exists(self.job_context.psi_result_path): + self.log.info( + f"prepare_mpc_after_psi, download psi_result_path ,job_id={job_id}, " + f"download {job_id + os.sep + JobContext.PSI_RESULT_FILE}") + + self.components.storage_client.download_file(job_id + os.sep + JobContext.PSI_RESULT_FILE, + self.job_context.psi_result_path) + + self._make_dataset_to_mpc_data_plus_psi_data(my_dataset_number) + + hdfs_mpc_prepare_path = "{}-P{}-0".format(job_id + os.sep + JobContext.MPC_PREPARE_FILE, + self.job_context.my_index) + self.components.storage_client.upload_file(self.job_context.mpc_prepare_path, hdfs_mpc_prepare_path) + self.log.info(f"call prepare_mpc_after_psi success: job_id={job_id}") + + def _make_dataset_to_mpc_data_direct(self, my_dataset_number): + chunk_list = pd.read_csv(self.job_context.dataset_file_path, delimiter=utils.CSV_SEP, + chunksize=self.components.dataset_handler_initializer.file_chunk_config.read_chunk_size) + for chunk in chunk_list: + self._make_dataset_field_normalized(chunk) + self._save_selected_column_data(my_dataset_number, chunk, self.job_context.mpc_prepare_path) + + def _prepare_mpc_without_psi(self): + job_id = self.job_context.job_id + self.log.info(f"start prepare_mpc_without_psi, job_id={job_id}") + my_dataset_number = self._get_dataset_column_count() + dataset_helper_factory.download_dataset( + dataset_helper_factory=self.components.dataset_handler_initializer.dataset_helper_factory, + dataset_user=self.job_context.user_name, + dataset_id=self.job_context.dataset_id, + dataset_local_path=self.job_context.dataset_file_path, + log_keyword="prepare_mpc_without_psi", + logger=self.log) + + self._make_dataset_to_mpc_data_direct(my_dataset_number) + + hdfs_mpc_prepare_path = "{}-P{}-0".format(job_id + os.sep + JobContext.MPC_PREPARE_FILE, + self.job_context.my_index) + self.components.storage_client.upload_file(self.job_context.mpc_prepare_path, hdfs_mpc_prepare_path) + self.log.info(f"call prepare_mpc_without_psi success: job_id={job_id}") + + def _order_psi_csv(self): + """ + order_psi_csv + """ + data = pd.read_csv(self.job_context.psi_result_path, delimiter=utils.CSV_SEP) + data.sort_values(by="id").to_csv(self.job_context.psi_result_path, + sep=utils.CSV_SEP, header=True, index=None) + + def _parse_and_write_final_mpc_result(self, mpc_output_path, index_file): + self.log.info("run parse_and_write_final_mpc_result") + final_result_fields = 'id' + need_add_fields = True + column_count = 0 + for row_data in open(mpc_output_path): + if row_data.__contains__(utils.PPC_RESULT_FIELDS_FLAG): + need_add_fields = False + final_result_fields += ',' + \ + row_data[row_data.find( + '=') + 1:].strip().replace(utils.BLANK_SEP, utils.CSV_SEP) + elif row_data.__contains__(utils.PPC_RESULT_VALUES_FLAG): + column_count = len(row_data.split( + '=')[1].strip().split(utils.BLANK_SEP)) + break + + if need_add_fields: + for i in range(column_count): + final_result_fields += ',' + 'result' + str(i) + + id_list = [] + if self.job_context.need_run_psi: + df = pd.read_csv(index_file, delimiter=utils.CSV_SEP) + for result_id in df["id"]: + id_list.append(result_id) + + with open(self.job_context.mpc_result_path, "w", encoding='utf-8') as file: + file.write(final_result_fields + '\n') + row_count = 0 + for row_data in open(mpc_output_path): + if row_data.__contains__(utils.PPC_RESULT_VALUES_FLAG): + values = row_data.split('=')[1].strip().split(utils.BLANK_SEP) + if self.job_context.need_run_psi: + if row_count >= len(id_list): + row = str(id_list[-1] + row_count - len(id_list) + 1) + else: + row = str(id_list[row_count]) + else: + row = str(row_count) + + for value in values: + try: + row += (',' + value) + except: + row += (',%s' % value) + file.write(row + '\n') + row_count += 1 + file.close() + + self.log.info("finish parse_and_write_final_mpc_result") + + def _replace_mpc_field_holder(self): + party_count = len(self.job_context.participant_id_list) + if self.job_context.need_run_psi: + dataset_record_count = 0 + if self.job_context.mpc_prepare_path: + dataset_record_count = sum(1 for _ in open(self.job_context.mpc_prepare_path)) + for i in range(party_count): + utils.replace(self.job_context.mpc_file_path, + dataset_record_count, f'$(source{i}_record_count)') + else: + for i in range(party_count): + utils.replace(self.job_context.mpc_file_path, + self.job_context.dataset_record_count, f'$(source{i}_record_count)') diff --git a/python/ppc_scheduler/workflow/worker/engine/psi_engine.py b/python/ppc_scheduler/workflow/worker/engine/psi_engine.py new file mode 100644 index 00000000..acc5ed29 --- /dev/null +++ b/python/ppc_scheduler/workflow/worker/engine/psi_engine.py @@ -0,0 +1,158 @@ +# -*- coding: utf-8 -*- +import codecs +import os +import time + +from ppc_common.ppc_dataset import dataset_helper_factory +from ppc_common.ppc_utils import utils, common_func +from ppc_scheduler.workflow.common.job_context import JobContext + + +class PsiWorkerEngine: + def __init__(self, psi_client, worker_type, components, job_context: JobContext): + self.psi_client = psi_client + self.worker_type = worker_type + self.components = components + self.job_context = job_context + self.log = self.components.logger() + + def run(self) -> list: + job_id = self.job_context.job_id + start_time = time.time() + self.origin_dataset_to_psi_inputs() + + self.log.info(f"compute two party psi, job_id={job_id}") + if len(self.job_context.participant_id_list) == 2: + self._run_two_party_psi() + else: + self._run_multi_party_psi() + time_costs = time.time() - start_time + self.log.info(f"computing psi finished, job_id={job_id}, timecost: {time_costs}s") + return [JobContext.HDFS_STORAGE_PATH + job_id + os.sep + self.job_context.PSI_RESULT_INDEX_FILE] + + def _run_two_party_psi(self): + job_id = self.job_context.job_id + agency_id = self.components.config_data['AGENCY_ID'] + job_info = { + "taskID": job_id, + "type": 0, + "algorithm": 0, + "syncResult": True, + "lowBandwidth": False, + "parties": [ + { + "id": self.job_context.participant_id_list[1 - self.job_context.my_index], + "partyIndex": 1 - self.job_context.my_index + }, + { + "id": agency_id, + "partyIndex": self.job_context.my_index, + "data": + { + "id": self.job_context.job_id, + "input": { + "type": 2, + "path": self.job_context.HDFS_STORAGE_PATH + job_id + os.sep + JobContext.PSI_PREPARE_FILE + }, + "output": { + "type": 2, + "path": self.job_context.HDFS_STORAGE_PATH + job_id + os.sep + JobContext.PSI_RESULT_FILE + } + } + } + ] + } + psi_result = self.psi_client.run(job_info, self.components.config_data['PPCS_RPC_TOKEN']) + self.log.info(f"call psi service successfully, job_id={job_id}, result: {psi_result}") + + def _run_multi_party_psi(self): + job_id = self.job_context.job_id + participant_number = len(self.job_context.participant_id_list) + participant_list = [] + # parties_index = [0, 1 ··· 1, 2] + parties_index = participant_number * [1] # role: partner + parties_index[0] = 0 # role: calculator + parties_index[-1] = 2 # role: master + for index, agency_id in enumerate(self.job_context.participant_id_list): + party_map = {} + if self.job_context.my_index == index: + party_map["id"] = agency_id + party_map["partyIndex"] = parties_index[index] + party_map["data"] = { + "id": self.job_context.job_id, + "input": { + "type": 2, + "path": self.job_context.HDFS_STORAGE_PATH + job_id + os.sep + JobContext.PSI_PREPARE_FILE + }, + "output": { + "type": 2, + "path": self.job_context.HDFS_STORAGE_PATH + job_id + os.sep + JobContext.PSI_RESULT_FILE + } + } + else: + party_map["id"] = agency_id + party_map["partyIndex"] = parties_index[index] + participant_list.append(party_map) + + job_info = { + "taskID": job_id, + "type": 0, + "algorithm": 4, + "syncResult": True, + "receiverList": self.job_context.result_receiver_list, + "parties": participant_list + } + psi_result = self.psi_client.run(job_info, self.components.config_data['PPCS_RPC_TOKEN']) + self.log.info(f"call psi service successfully, job_id={job_id}, result: {psi_result}") + + def origin_dataset_to_psi_inputs(self): + dataset_helper_factory.download_dataset( + dataset_helper_factory=self.components.dataset_handler_initializer.dataset_helper_factory, + dataset_user=self.job_context.user_name, + dataset_id=self.job_context.dataset_id, + dataset_local_path=self.job_context.dataset_file_path, + log_keyword="prepare_dataset", + logger=self.log) + + field = (self.job_context.psi_fields.split(utils.CSV_SEP)[self.job_context.my_index]).lower() + if field == '': + field = 'id' + prepare_file = open(self.job_context.psi_prepare_path, 'w') + psi_split_reg = "===" + file_encoding = common_func.get_file_encoding(self.job_context.dataset_file_path) + if psi_split_reg in field: + field_multi = field.split(psi_split_reg) + with codecs.open(self.job_context.dataset_file_path, "r", file_encoding) as dataset: + fields = next(dataset).lower() + fields_list = fields.strip().split(utils.CSV_SEP) + id_idx_list = [] + for filed_idx in field_multi: + id_idx_list.append(fields_list.index(filed_idx)) + for line in dataset: + if line.strip() == "": + continue + final_str = "" + for id_idx_multi in id_idx_list: + if len(line.strip().split(utils.CSV_SEP, id_idx_multi + 1)) < id_idx_multi + 1: + continue + final_str = "{}-{}".format(final_str, line.strip().split( + utils.CSV_SEP, id_idx_multi + 1)[id_idx_multi]).strip("\r\n") + print(final_str, file=prepare_file) + prepare_file.close() + else: + with codecs.open(self.job_context.dataset_file_path, "r", file_encoding) as dataset: + # ignore lower/upper case + fields = next(dataset).lower() + fields_list = fields.strip().split(utils.CSV_SEP) + id_idx = fields_list.index(field) + for line in dataset: + if line.strip() == "": + continue + if len(line.strip().split(utils.CSV_SEP, id_idx + 1)) < id_idx + 1: + continue + print(line.strip().split(utils.CSV_SEP, id_idx + 1) + [id_idx], file=prepare_file) + prepare_file.close() + self.components.storage_client.upload_file(self.job_context.psi_prepare_path, + self.job_context.job_id + os.sep + JobContext.PSI_PREPARE_FILE) + utils.delete_file(self.job_context.psi_prepare_path) diff --git a/python/ppc_scheduler/workflow/worker/exit_worker.py b/python/ppc_scheduler/workflow/worker/exit_worker.py new file mode 100644 index 00000000..9cff596b --- /dev/null +++ b/python/ppc_scheduler/workflow/worker/exit_worker.py @@ -0,0 +1,33 @@ +import os +import time + +from ppc_common.ppc_utils import utils +from ppc_scheduler.common import log_utils +from ppc_scheduler.workflow.common.worker_type import WorkerType +from ppc_scheduler.workflow.worker.worker import Worker + + +class ExitWorker(Worker): + + def __init__(self, components, job_context, worker_id, worker_type, *args, **kwargs): + super().__init__(components, job_context, worker_id, worker_type, *args, **kwargs) + + def engine_run(self, worker_inputs): + log_utils.upload_job_log(self.components.storage_client, self.job_context.job_id) + self._save_workflow_view_file() + if self.worker_type == WorkerType.T_ON_FAILURE: + # notice job manager that this job has failed + raise Exception() + + def _save_workflow_view_file(self): + file = f"{self.job_context.workflow_view_path}.svg" + try_count = 10 + while try_count > 0: + if utils.file_exists(file): + break + time.sleep(1) + try_count -= 1 + + self.components.storage_client.upload_file(file, + self.job_context.job_id + os.sep + + self.job_context.workflow_view_path) diff --git a/python/ppc_scheduler/workflow/worker/model_worker.py b/python/ppc_scheduler/workflow/worker/model_worker.py new file mode 100644 index 00000000..e843bd08 --- /dev/null +++ b/python/ppc_scheduler/workflow/worker/model_worker.py @@ -0,0 +1,19 @@ +from ppc_scheduler.node.computing_node_client import ModelClient +from ppc_scheduler.workflow.worker.engine.model_engine import ModelWorkerEngine +from ppc_scheduler.workflow.worker.worker import Worker + + +class ModelWorker(Worker): + + def __init__(self, components, job_context, worker_id, worker_type, *args, **kwargs): + super().__init__(components, job_context, worker_id, worker_type, *args, **kwargs) + + def engine_run(self, worker_inputs): + node_endpoint = self.node_manager.get_node(self.worker_type) + model_client = ModelClient(self.components.logger(), node_endpoint) + model_engine = ModelWorkerEngine(model_client, self.worker_type, self.components, self.job_context) + try: + outputs = model_engine.run() + return outputs + finally: + self.node_manager.release_node(node_endpoint, self.worker_type) diff --git a/python/ppc_scheduler/workflow/worker/mpc_worker.py b/python/ppc_scheduler/workflow/worker/mpc_worker.py new file mode 100644 index 00000000..c5a98c5f --- /dev/null +++ b/python/ppc_scheduler/workflow/worker/mpc_worker.py @@ -0,0 +1,19 @@ +from ppc_scheduler.node.computing_node_client import MpcClient +from ppc_scheduler.workflow.worker.engine.mpc_engine import MpcWorkerEngine +from ppc_scheduler.workflow.worker.worker import Worker + + +class MpcWorker(Worker): + + def __init__(self, components, job_context, worker_id, worker_type, *args, **kwargs): + super().__init__(components, job_context, worker_id, worker_type, *args, **kwargs) + + def engine_run(self, worker_inputs) -> list: + node_endpoint = self.node_manager.get_node(self.worker_type) + mpc_client = MpcClient(node_endpoint) + mpc_engine = MpcWorkerEngine(mpc_client, self.worker_type, self.components, self.job_context) + try: + outputs = mpc_engine.run() + return outputs + finally: + self.node_manager.release_node(node_endpoint, self.worker_type) diff --git a/python/ppc_scheduler/workflow/worker/psi_worker.py b/python/ppc_scheduler/workflow/worker/psi_worker.py new file mode 100644 index 00000000..e8f18cca --- /dev/null +++ b/python/ppc_scheduler/workflow/worker/psi_worker.py @@ -0,0 +1,19 @@ +from ppc_scheduler.node.computing_node_client.psi_node_client import PsiClient +from ppc_scheduler.workflow.worker.engine.psi_engine import PsiWorkerEngine +from ppc_scheduler.workflow.worker.worker import Worker + + +class PsiWorker(Worker): + + def __init__(self, components, job_context, worker_id, worker_type, *args, **kwargs): + super().__init__(components, job_context, worker_id, worker_type, *args, **kwargs) + + def engine_run(self, worker_inputs) -> list: + node_endpoint = self.node_manager.get_node(self.worker_type) + psi_client = PsiClient(self.components.logger(), node_endpoint) + psi_engine = PsiWorkerEngine(psi_client, self.worker_type, self.components, self.job_context) + try: + outputs = psi_engine.run() + return outputs + finally: + self.node_manager.release_node(node_endpoint, self.worker_type) diff --git a/python/ppc_scheduler/workflow/worker/python_worker.py b/python/ppc_scheduler/workflow/worker/python_worker.py new file mode 100644 index 00000000..a6a36013 --- /dev/null +++ b/python/ppc_scheduler/workflow/worker/python_worker.py @@ -0,0 +1,10 @@ +from ppc_scheduler.workflow.worker.worker import Worker + + +class PythonWorker(Worker): + + def __init__(self, components, job_context, worker_id, worker_type, *args, **kwargs): + super().__init__(components, job_context, worker_id, worker_type, *args, **kwargs) + + def engine_run(self, worker_inputs): + ... diff --git a/python/ppc_scheduler/workflow/worker/shell_worker.py b/python/ppc_scheduler/workflow/worker/shell_worker.py new file mode 100644 index 00000000..832b8794 --- /dev/null +++ b/python/ppc_scheduler/workflow/worker/shell_worker.py @@ -0,0 +1,10 @@ +from ppc_scheduler.workflow.worker.worker import Worker + + +class ShellWorker(Worker): + + def __init__(self, components, job_context, worker_id, worker_type, *args, **kwargs): + super().__init__(components, job_context, worker_id, worker_type, *args, **kwargs) + + def engine_run(self, worker_inputs): + ... diff --git a/python/ppc_scheduler/workflow/worker/worker.py b/python/ppc_scheduler/workflow/worker/worker.py new file mode 100644 index 00000000..b9823ce2 --- /dev/null +++ b/python/ppc_scheduler/workflow/worker/worker.py @@ -0,0 +1,96 @@ +import time + +from func_timeout import FunctionTimedOut +from prefect import Task +from prefect.engine import signals + +from ppc_scheduler.database import job_worker_mapper +from ppc_scheduler.node.node_manager import ComputingNodeManager +from ppc_scheduler.workflow.common import codec, flow_utils +from ppc_scheduler.workflow.common.worker_status import WorkerStatus +from ppc_scheduler.workflow.common.worker_type import WorkerType + + +class Worker(Task): + def __init__(self, components, job_context, worker_id, worker_type, retries=0, retry_delay_s=0, *args, **kwargs): + super().__init__(*args, **kwargs) + self.components = components + self.node_manager = ComputingNodeManager(components) + self.log = components.logger() + self.job_context = job_context + self.worker_id = worker_id + self.worker_type = worker_type + self.retries = retries + self.retry_delay_s = retry_delay_s + + def engine_run(self, worker_inputs) -> list: + # this func should be implemented by subclass + ... + + def run(self, worker_status, worker_inputs): + try: + # job is killed + if self.components.thread_event_manager.event_status(self.job_context.job_id): + self._write_failed_status(WorkerStatus.KILLED) + self.log.warn( + f"worker was killed, job_id: {self.job_context.job_id}, worker: {self.worker_id}") + raise signals.FAIL(message='killed!') + + if worker_status == WorkerStatus.SUCCESS: + # return outputs saved in db directly while current worker has been finished + return self._load_output_from_db() + + inputs = [] + if self.worker_type != WorkerType.T_ON_SUCCESS \ + and self.worker_type != WorkerType.T_ON_FAILURE: + inputs = flow_utils.to_origin_inputs(worker_inputs) + + outputs = self._try_run_task(inputs) + self._save_worker_result(outputs) + except FunctionTimedOut: + self._write_failed_status(WorkerStatus.TIMEOUT) + self.log.error( + f"worker was timeout, job_id: {self.job_context.job_id}, worker: {self.worker_id}") + raise signals.FAIL(message='timeout!') + except BaseException as be: + self._write_failed_status(WorkerStatus.FAILURE) + self.log.error(f"[OnError]job worker failed, job_id: {self.job_context.job_id}, worker: {self.worker_id}") + self.log.exception(be) + raise signals.FAIL(message='failed!') + + def _try_run_task(self, inputs): + self.log.info(f"job_id: {self.job_context.job_id}, worker: {self.worker_id}, inputs: {inputs}") + # parse inputs for worker + if self.retries: + attempt = 0 + while attempt <= self.retries: + try: + outputs = self.engine_run(inputs) + return outputs + except Exception as e: + attempt += 1 + if attempt > self.retries: + self.log.warn( + f"worker failed after {self.retries} attempts, " + f"job_id: {self.job_context.job_id}, worker: {self.worker_id}") + raise e + else: + time.sleep(self.retry_delay_s) + else: + outputs = self.engine_run(inputs) + self.log.info(f"job_id: {self.job_context.job_id}, worker: {self.worker_id}, outputs: {outputs}") + return outputs + + def _load_output_from_db(self): + with self.components.create_sql_session() as session: + worker_record = job_worker_mapper.query_job_worker(session, self.job_context.job_id, self.worker_id) + return codec.deserialize_worker_outputs(worker_record.outputs) + + def _save_worker_result(self, outputs): + with self.components.create_sql_session() as session: + job_worker_mapper.update_job_worker(session, self.job_context.job_id, self.worker_id, + WorkerStatus.SUCCESS, outputs) + + def _write_failed_status(self, status): + with self.components.create_sql_session() as session: + job_worker_mapper.update_job_worker(session, self.job_context.job_id, self.worker_id, status, []) diff --git a/python/ppc_scheduler/workflow/worker/worker_factory.py b/python/ppc_scheduler/workflow/worker/worker_factory.py new file mode 100644 index 00000000..016ce9b4 --- /dev/null +++ b/python/ppc_scheduler/workflow/worker/worker_factory.py @@ -0,0 +1,32 @@ +from ppc_scheduler.common.global_context import components +from ppc_scheduler.workflow.common.worker_type import WorkerType +from ppc_scheduler.workflow.worker.api_worker import ApiWorker +from ppc_scheduler.workflow.worker.exit_worker import ExitWorker +from ppc_scheduler.workflow.worker.model_worker import ModelWorker +from ppc_scheduler.workflow.worker.mpc_worker import MpcWorker +from ppc_scheduler.workflow.worker.psi_worker import PsiWorker +from ppc_scheduler.workflow.worker.python_worker import PythonWorker +from ppc_scheduler.workflow.worker.shell_worker import ShellWorker + + +class WorkerFactory: + + @staticmethod + def build_worker(job_context, worker_id, worker_type): + if worker_type == WorkerType.T_API: + return ApiWorker(components, job_context, worker_id, worker_type) + elif worker_type == WorkerType.T_PYTHON: + return PythonWorker(components, job_context, worker_id, worker_type) + elif worker_type == WorkerType.T_SHELL: + return ShellWorker(components, job_context, worker_id, worker_type) + elif worker_type == WorkerType.T_PSI: + return PsiWorker(components, job_context, worker_id, worker_type) + elif worker_type == WorkerType.T_MPC: + return MpcWorker(components, job_context, worker_id, worker_type) + elif worker_type == WorkerType.T_PREPROCESSING or \ + worker_type == WorkerType.T_FEATURE_ENGINEERING or \ + worker_type == WorkerType.T_TRAINING or \ + worker_type == WorkerType.T_PREDICTION: + return ModelWorker(components, job_context, worker_id, worker_type) + else: + return ExitWorker(components, job_context, worker_id, worker_type) diff --git a/python/requirements.txt b/python/requirements.txt index 14e6b826..e464227f 100644 --- a/python/requirements.txt +++ b/python/requirements.txt @@ -24,9 +24,8 @@ promise~=2.3 # protobuf>=4.21.6,<5.0dev protobuf>=5.27.1 pycryptodome==3.9.9 -PyJWT==2.4.0 -PyYAML==5.4.1 -# sha3==0.2.1 +pyjwt +pyyaml mysqlclient==2.1.0 waitress==2.1.2 sqlparse~=0.4.1 @@ -38,7 +37,7 @@ google~=3.0.0 paste~=3.5.0 func_timeout==4.3.0 cheroot==8.5.2 -prefect==0.14.15 +prefect==1.4.0 gmssl~=3.2.1 readerwriterlock~=1.0.4 jsoncomment~=0.2.3 @@ -54,12 +53,11 @@ networkx pydot snowland-smx numpy==1.23.1 -graphviz~=0.20.1 +graphviz grpcio==1.62.1 grpcio-tools==1.62.1 xlrd~=1.0.0 MarkupSafe>=2.1.1 -Werkzeug==2.3.8 urllib3==1.26.18 phe chardet