From d95e5f3e42ff9f597bc3300c397e375a41e4045e Mon Sep 17 00:00:00 2001 From: cyjseagull Date: Wed, 21 Aug 2024 15:55:44 +0800 Subject: [PATCH] add ppml and ppc-model-gateway (#2) --- python/__init__.py | 0 python/ppc_common/__init__.py | 0 python/ppc_common/application-sample.yml | 6 + python/ppc_common/application.yml | 6 + python/ppc_common/config.py | 18 + python/ppc_common/db_models/__init__.py | 6 + .../ppc_common/db_models/file_object_meta.py | 12 + python/ppc_common/db_models/file_path.py | 10 + .../ppc_common/db_models/job_unit_record.py | 14 + python/ppc_common/deps_services/__init__.py | 0 .../ppc_common/deps_services/file_object.py | 85 ++ .../ppc_common/deps_services/hdfs_storage.py | 97 ++ .../ppc_common/deps_services/mysql_storage.py | 65 ++ .../deps_services/serialize_type.py | 7 + .../deps_services/sharding_file_object.py | 429 ++++++++ .../deps_services/sql_storage_api.py | 57 ++ .../ppc_common/deps_services/storage_api.py | 53 + .../deps_services/storage_loader.py | 14 + .../deps_services/tests/mysql_storage_test.py | 104 ++ .../tests/sharding_file_object_test.py | 79 ++ .../ppc_common/ppc_async_executor/__init__.py | 0 .../ppc_async_executor/async_executor.py | 35 + .../async_subprocess_executor.py | 60 ++ .../async_thread_executor.py | 75 ++ .../ppc_async_executor/test/__init__.py | 0 .../test/async_executor_unittest.py | 109 +++ .../thread_event_manager.py | 34 + python/ppc_common/ppc_config/__init__.py | 0 .../ppc_config/file_chunk_config.py | 27 + .../ppc_config/sql_storage_config_loader.py | 37 + python/ppc_common/ppc_crypto/__init__.py | 0 python/ppc_common/ppc_crypto/crypto_utils.py | 130 +++ python/ppc_common/ppc_crypto/ihc_cipher.py | 93 ++ python/ppc_common/ppc_crypto/ihc_codec.py | 29 + .../ppc_common/ppc_crypto/paillier_cipher.py | 45 + .../ppc_common/ppc_crypto/paillier_codec.py | 34 + python/ppc_common/ppc_crypto/phe_cipher.py | 25 + python/ppc_common/ppc_crypto/phe_factory.py | 25 + python/ppc_common/ppc_crypto/test/__init__.py | 0 .../ppc_crypto/test/phe_unittest.py | 105 ++ python/ppc_common/ppc_dataset/__init__.py | 0 .../ppc_common/ppc_dataset/dataset_helper.py | 60 ++ .../ppc_dataset/dataset_helper_factory.py | 61 ++ python/ppc_common/ppc_initialize/__init__.py | 0 .../dataset_handler_initialize.py | 30 + .../tests/dataset_initializer_test.py | 126 +++ python/ppc_common/ppc_ml/__init__.py | 0 python/ppc_common/ppc_ml/feature/__init__.py | 0 .../ppc_ml/feature/feature_importance.py | 246 +++++ .../feature/tests/feature_importance_test.py | 79 ++ .../ppc_common/ppc_ml/model/algorithm_info.py | 25 + python/ppc_common/ppc_mock/__init__.py | 0 python/ppc_common/ppc_mock/mock_objects.py | 59 ++ python/ppc_common/ppc_protos/__init__.py | 0 .../ppc_protos/generated/__init__.py | 0 .../ppc_protos/generated/ppc_model_pb2.py | 51 + .../generated/ppc_model_pb2_grpc.py | 67 ++ .../ppc_protos/generated/ppc_pb2.py | 64 ++ python/ppc_common/ppc_protos/ppc.proto | 160 +++ python/ppc_common/ppc_protos/ppc_model.proto | 83 ++ python/ppc_common/ppc_utils/__init__.py | 1 + .../ppc_common/ppc_utils/anonymous_search.py | 420 ++++++++ python/ppc_common/ppc_utils/audit_utils.py | 56 ++ python/ppc_common/ppc_utils/cem_utils.py | 160 +++ python/ppc_common/ppc_utils/common_func.py | 28 + python/ppc_common/ppc_utils/exception.py | 146 +++ python/ppc_common/ppc_utils/http_utils.py | 102 ++ python/ppc_common/ppc_utils/path.py | 51 + python/ppc_common/ppc_utils/permission.py | 79 ++ python/ppc_common/ppc_utils/plot_utils.py | 230 +++++ .../ppc_utils/ppc_model_config_parser.py | 474 +++++++++ .../ppc_model_config_parser_proxy.py | 491 ++++++++++ .../ppc_utils/tests/thread_safe_list_test.py | 48 + .../ppc_common/ppc_utils/tests/utils_test.py | 40 + .../ppc_common/ppc_utils/thread_safe_list.py | 56 ++ python/ppc_common/ppc_utils/utils.py | 914 ++++++++++++++++++ python/ppc_model/__init__.py | 0 python/ppc_model/common/__init__.py | 0 python/ppc_model/common/base_context.py | 83 ++ python/ppc_model/common/context.py | 14 + python/ppc_model/common/global_context.py | 13 + python/ppc_model/common/initializer.py | 108 +++ python/ppc_model/common/mock/__init__.py | 0 .../ppc_model/common/mock/rpc_client_mock.py | 31 + python/ppc_model/common/model_result.py | 212 ++++ python/ppc_model/common/model_setting.py | 90 ++ python/ppc_model/common/protocol.py | 88 ++ python/ppc_model/conf/application-sample.yml | 43 + python/ppc_model/conf/logging.conf | 40 + python/ppc_model/datasets/__init__.py | 0 .../datasets/data_reduction/__init__.py | 0 .../data_reduction/feature_selection.py | 30 + .../datasets/data_reduction/sampling.py | 86 ++ .../datasets/data_reduction/test/__init__.py | 0 .../test/test_data_reduction.py | 61 ++ python/ppc_model/datasets/dataset.py | 232 +++++ .../datasets/feature_binning/__init__.py | 0 .../feature_binning/feature_binning.py | 131 +++ .../datasets/feature_binning/test/__init__.py | 0 .../test/test_feature_binning.py | 102 ++ python/ppc_model/datasets/test/__init__.py | 0 .../ppc_model/datasets/test/test_dataset.py | 213 ++++ .../ppc_model/feature_engineering/__init__.py | 0 .../feature_engineering_context.py | 48 + .../feature_engineering_engine.py | 41 + .../feature_engineering/test/__init__.py | 0 .../test/feature_engineering_unittest.py | 153 +++ .../feature_engineering/vertical/__init__.py | 0 .../vertical/active_party.py | 202 ++++ .../vertical/passive_party.py | 177 ++++ .../feature_engineering/vertical/utils.py | 76 ++ python/ppc_model/interface/__init__.py | 0 python/ppc_model/interface/model_base.py | 33 + python/ppc_model/interface/rpc_client.py | 8 + python/ppc_model/interface/task_engine.py | 9 + python/ppc_model/metrics/__init__.py | 0 python/ppc_model/metrics/evaluation.py | 276 ++++++ python/ppc_model/metrics/loss.py | 32 + python/ppc_model/metrics/model_plot.py | 181 ++++ python/ppc_model/metrics/test/__init__.py | 0 python/ppc_model/metrics/test/test_metrics.py | 165 ++++ .../model_result/task_result_handler.py | 388 ++++++++ python/ppc_model/network/__init__.py | 0 python/ppc_model/network/grpc/__init__.py | 0 python/ppc_model/network/grpc/grpc_client.py | 82 ++ python/ppc_model/network/grpc/grpc_server.py | 17 + python/ppc_model/network/http/__init__.py | 0 python/ppc_model/network/http/body_schema.py | 19 + .../network/http/model_controller.py | 95 ++ python/ppc_model/network/http/restx.py | 35 + python/ppc_model/network/stub.py | 218 +++++ python/ppc_model/network/test/__init__.py | 0 .../ppc_model/network/test/stub_unittest.py | 83 ++ python/ppc_model/ppc_model_app.py | 112 +++ python/ppc_model/preprocessing/__init__.py | 0 .../local_processing_party.py | 100 ++ .../local_processing/preprocessing.py | 660 +++++++++++++ .../local_processing/psi_select.py | 95 ++ .../local_processing/standard_type_enum.py | 7 + .../preprocessing/preprocessing_engine.py | 19 + .../preprocessing/processing_context.py | 27 + .../preprocessing/tests/test_preprocessing.py | 672 +++++++++++++ python/ppc_model/secure_lgbm/__init__.py | 0 .../ppc_model/secure_lgbm/monitor/__init__.py | 0 .../ppc_model/secure_lgbm/monitor/callback.py | 83 ++ python/ppc_model/secure_lgbm/monitor/core.py | 144 +++ .../secure_lgbm/monitor/early_stopping.py | 122 +++ .../secure_lgbm/monitor/evaluation_monitor.py | 123 +++ .../secure_lgbm/monitor/feature/__init__.py | 0 .../feature/feature_evaluation_info.py | 90 ++ .../monitor/feature/test/__init__.py | 0 .../test/feature_evalution_info_test.py | 73 ++ .../monitor/train_callback_unittest.py | 105 ++ .../secure_lgbm/secure_lgbm_context.py | 254 +++++ .../secure_lgbm_prediction_engine.py | 38 + .../secure_lgbm_training_engine.py | 40 + python/ppc_model/secure_lgbm/test/__init__.py | 0 .../secure_lgbm/test/test_cipher_packing.py | 61 ++ .../secure_lgbm/test/test_pack_gh.py | 43 + .../secure_lgbm/test/test_save_load_model.py | 92 ++ .../test/test_secure_lgbm_context.py | 77 ++ .../test_secure_lgbm_performance_training.py | 172 ++++ .../test/test_secure_lgbm_training.py | 178 ++++ .../secure_lgbm/vertical/__init__.py | 4 + .../secure_lgbm/vertical/active_party.py | 461 +++++++++ .../ppc_model/secure_lgbm/vertical/booster.py | 345 +++++++ .../secure_lgbm/vertical/passive_party.py | 271 ++++++ python/ppc_model/task/__init__.py | 0 python/ppc_model/task/task_manager.py | 183 ++++ python/ppc_model/task/test/__init__.py | 0 .../task/test/task_manager_unittest.py | 155 +++ python/ppc_model/tools/start.sh | 39 + python/ppc_model/tools/stop.sh | 19 + python/ppc_model_gateway/__init__.py | 0 python/ppc_model_gateway/clients/__init__.py | 0 .../clients/client_manager.py | 57 ++ .../conf/application-sample.yml | 19 + python/ppc_model_gateway/conf/logging.conf | 40 + python/ppc_model_gateway/config.py | 60 ++ .../ppc_model_gateway/endpoints/__init__.py | 0 .../endpoints/node_to_partner.py | 36 + .../endpoints/partner_to_node.py | 34 + .../endpoints/response_builder.py | 8 + .../ppc_model_gateway_app.py | 90 ++ python/ppc_model_gateway/test/__init__.py | 0 python/ppc_model_gateway/test/client.py | 41 + python/ppc_model_gateway/test/server.py | 36 + python/ppc_model_gateway/tools/gen_cert.sh | 70 ++ python/ppc_model_gateway/tools/start.sh | 35 + python/ppc_model_gateway/tools/stop.sh | 11 + python/requirements.txt | 65 ++ python/tools/fake_id_data.py | 187 ++++ python/tools/fake_ml_train_data.py | 170 ++++ python/tools/requirements.txt | 1 + 194 files changed, 16071 insertions(+) create mode 100644 python/__init__.py create mode 100644 python/ppc_common/__init__.py create mode 100644 python/ppc_common/application-sample.yml create mode 100644 python/ppc_common/application.yml create mode 100644 python/ppc_common/config.py create mode 100644 python/ppc_common/db_models/__init__.py create mode 100644 python/ppc_common/db_models/file_object_meta.py create mode 100644 python/ppc_common/db_models/file_path.py create mode 100644 python/ppc_common/db_models/job_unit_record.py create mode 100644 python/ppc_common/deps_services/__init__.py create mode 100644 python/ppc_common/deps_services/file_object.py create mode 100644 python/ppc_common/deps_services/hdfs_storage.py create mode 100644 python/ppc_common/deps_services/mysql_storage.py create mode 100644 python/ppc_common/deps_services/serialize_type.py create mode 100644 python/ppc_common/deps_services/sharding_file_object.py create mode 100644 python/ppc_common/deps_services/sql_storage_api.py create mode 100644 python/ppc_common/deps_services/storage_api.py create mode 100644 python/ppc_common/deps_services/storage_loader.py create mode 100644 python/ppc_common/deps_services/tests/mysql_storage_test.py create mode 100644 python/ppc_common/deps_services/tests/sharding_file_object_test.py create mode 100644 python/ppc_common/ppc_async_executor/__init__.py create mode 100644 python/ppc_common/ppc_async_executor/async_executor.py create mode 100644 python/ppc_common/ppc_async_executor/async_subprocess_executor.py create mode 100644 python/ppc_common/ppc_async_executor/async_thread_executor.py create mode 100644 python/ppc_common/ppc_async_executor/test/__init__.py create mode 100644 python/ppc_common/ppc_async_executor/test/async_executor_unittest.py create mode 100644 python/ppc_common/ppc_async_executor/thread_event_manager.py create mode 100644 python/ppc_common/ppc_config/__init__.py create mode 100644 python/ppc_common/ppc_config/file_chunk_config.py create mode 100644 python/ppc_common/ppc_config/sql_storage_config_loader.py create mode 100644 python/ppc_common/ppc_crypto/__init__.py create mode 100644 python/ppc_common/ppc_crypto/crypto_utils.py create mode 100644 python/ppc_common/ppc_crypto/ihc_cipher.py create mode 100644 python/ppc_common/ppc_crypto/ihc_codec.py create mode 100644 python/ppc_common/ppc_crypto/paillier_cipher.py create mode 100644 python/ppc_common/ppc_crypto/paillier_codec.py create mode 100644 python/ppc_common/ppc_crypto/phe_cipher.py create mode 100644 python/ppc_common/ppc_crypto/phe_factory.py create mode 100644 python/ppc_common/ppc_crypto/test/__init__.py create mode 100644 python/ppc_common/ppc_crypto/test/phe_unittest.py create mode 100644 python/ppc_common/ppc_dataset/__init__.py create mode 100644 python/ppc_common/ppc_dataset/dataset_helper.py create mode 100644 python/ppc_common/ppc_dataset/dataset_helper_factory.py create mode 100644 python/ppc_common/ppc_initialize/__init__.py create mode 100644 python/ppc_common/ppc_initialize/dataset_handler_initialize.py create mode 100644 python/ppc_common/ppc_initialize/tests/dataset_initializer_test.py create mode 100644 python/ppc_common/ppc_ml/__init__.py create mode 100644 python/ppc_common/ppc_ml/feature/__init__.py create mode 100644 python/ppc_common/ppc_ml/feature/feature_importance.py create mode 100644 python/ppc_common/ppc_ml/feature/tests/feature_importance_test.py create mode 100644 python/ppc_common/ppc_ml/model/algorithm_info.py create mode 100644 python/ppc_common/ppc_mock/__init__.py create mode 100644 python/ppc_common/ppc_mock/mock_objects.py create mode 100644 python/ppc_common/ppc_protos/__init__.py create mode 100644 python/ppc_common/ppc_protos/generated/__init__.py create mode 100644 python/ppc_common/ppc_protos/generated/ppc_model_pb2.py create mode 100644 python/ppc_common/ppc_protos/generated/ppc_model_pb2_grpc.py create mode 100644 python/ppc_common/ppc_protos/generated/ppc_pb2.py create mode 100644 python/ppc_common/ppc_protos/ppc.proto create mode 100644 python/ppc_common/ppc_protos/ppc_model.proto create mode 100644 python/ppc_common/ppc_utils/__init__.py create mode 100644 python/ppc_common/ppc_utils/anonymous_search.py create mode 100644 python/ppc_common/ppc_utils/audit_utils.py create mode 100644 python/ppc_common/ppc_utils/cem_utils.py create mode 100644 python/ppc_common/ppc_utils/common_func.py create mode 100644 python/ppc_common/ppc_utils/exception.py create mode 100644 python/ppc_common/ppc_utils/http_utils.py create mode 100644 python/ppc_common/ppc_utils/path.py create mode 100644 python/ppc_common/ppc_utils/permission.py create mode 100644 python/ppc_common/ppc_utils/plot_utils.py create mode 100644 python/ppc_common/ppc_utils/ppc_model_config_parser.py create mode 100644 python/ppc_common/ppc_utils/ppc_model_config_parser_proxy.py create mode 100644 python/ppc_common/ppc_utils/tests/thread_safe_list_test.py create mode 100644 python/ppc_common/ppc_utils/tests/utils_test.py create mode 100644 python/ppc_common/ppc_utils/thread_safe_list.py create mode 100644 python/ppc_common/ppc_utils/utils.py create mode 100644 python/ppc_model/__init__.py create mode 100644 python/ppc_model/common/__init__.py create mode 100644 python/ppc_model/common/base_context.py create mode 100644 python/ppc_model/common/context.py create mode 100644 python/ppc_model/common/global_context.py create mode 100644 python/ppc_model/common/initializer.py create mode 100644 python/ppc_model/common/mock/__init__.py create mode 100644 python/ppc_model/common/mock/rpc_client_mock.py create mode 100644 python/ppc_model/common/model_result.py create mode 100644 python/ppc_model/common/model_setting.py create mode 100644 python/ppc_model/common/protocol.py create mode 100644 python/ppc_model/conf/application-sample.yml create mode 100644 python/ppc_model/conf/logging.conf create mode 100644 python/ppc_model/datasets/__init__.py create mode 100644 python/ppc_model/datasets/data_reduction/__init__.py create mode 100644 python/ppc_model/datasets/data_reduction/feature_selection.py create mode 100644 python/ppc_model/datasets/data_reduction/sampling.py create mode 100644 python/ppc_model/datasets/data_reduction/test/__init__.py create mode 100644 python/ppc_model/datasets/data_reduction/test/test_data_reduction.py create mode 100644 python/ppc_model/datasets/dataset.py create mode 100644 python/ppc_model/datasets/feature_binning/__init__.py create mode 100644 python/ppc_model/datasets/feature_binning/feature_binning.py create mode 100644 python/ppc_model/datasets/feature_binning/test/__init__.py create mode 100644 python/ppc_model/datasets/feature_binning/test/test_feature_binning.py create mode 100644 python/ppc_model/datasets/test/__init__.py create mode 100644 python/ppc_model/datasets/test/test_dataset.py create mode 100644 python/ppc_model/feature_engineering/__init__.py create mode 100644 python/ppc_model/feature_engineering/feature_engineering_context.py create mode 100644 python/ppc_model/feature_engineering/feature_engineering_engine.py create mode 100644 python/ppc_model/feature_engineering/test/__init__.py create mode 100644 python/ppc_model/feature_engineering/test/feature_engineering_unittest.py create mode 100644 python/ppc_model/feature_engineering/vertical/__init__.py create mode 100644 python/ppc_model/feature_engineering/vertical/active_party.py create mode 100644 python/ppc_model/feature_engineering/vertical/passive_party.py create mode 100644 python/ppc_model/feature_engineering/vertical/utils.py create mode 100644 python/ppc_model/interface/__init__.py create mode 100644 python/ppc_model/interface/model_base.py create mode 100644 python/ppc_model/interface/rpc_client.py create mode 100644 python/ppc_model/interface/task_engine.py create mode 100644 python/ppc_model/metrics/__init__.py create mode 100644 python/ppc_model/metrics/evaluation.py create mode 100644 python/ppc_model/metrics/loss.py create mode 100644 python/ppc_model/metrics/model_plot.py create mode 100644 python/ppc_model/metrics/test/__init__.py create mode 100644 python/ppc_model/metrics/test/test_metrics.py create mode 100644 python/ppc_model/model_result/task_result_handler.py create mode 100644 python/ppc_model/network/__init__.py create mode 100644 python/ppc_model/network/grpc/__init__.py create mode 100644 python/ppc_model/network/grpc/grpc_client.py create mode 100644 python/ppc_model/network/grpc/grpc_server.py create mode 100644 python/ppc_model/network/http/__init__.py create mode 100644 python/ppc_model/network/http/body_schema.py create mode 100644 python/ppc_model/network/http/model_controller.py create mode 100644 python/ppc_model/network/http/restx.py create mode 100644 python/ppc_model/network/stub.py create mode 100644 python/ppc_model/network/test/__init__.py create mode 100644 python/ppc_model/network/test/stub_unittest.py create mode 100644 python/ppc_model/ppc_model_app.py create mode 100644 python/ppc_model/preprocessing/__init__.py create mode 100644 python/ppc_model/preprocessing/local_processing/local_processing_party.py create mode 100644 python/ppc_model/preprocessing/local_processing/preprocessing.py create mode 100644 python/ppc_model/preprocessing/local_processing/psi_select.py create mode 100644 python/ppc_model/preprocessing/local_processing/standard_type_enum.py create mode 100644 python/ppc_model/preprocessing/preprocessing_engine.py create mode 100644 python/ppc_model/preprocessing/processing_context.py create mode 100644 python/ppc_model/preprocessing/tests/test_preprocessing.py create mode 100644 python/ppc_model/secure_lgbm/__init__.py create mode 100644 python/ppc_model/secure_lgbm/monitor/__init__.py create mode 100644 python/ppc_model/secure_lgbm/monitor/callback.py create mode 100644 python/ppc_model/secure_lgbm/monitor/core.py create mode 100644 python/ppc_model/secure_lgbm/monitor/early_stopping.py create mode 100644 python/ppc_model/secure_lgbm/monitor/evaluation_monitor.py create mode 100644 python/ppc_model/secure_lgbm/monitor/feature/__init__.py create mode 100644 python/ppc_model/secure_lgbm/monitor/feature/feature_evaluation_info.py create mode 100644 python/ppc_model/secure_lgbm/monitor/feature/test/__init__.py create mode 100644 python/ppc_model/secure_lgbm/monitor/feature/test/feature_evalution_info_test.py create mode 100644 python/ppc_model/secure_lgbm/monitor/train_callback_unittest.py create mode 100644 python/ppc_model/secure_lgbm/secure_lgbm_context.py create mode 100644 python/ppc_model/secure_lgbm/secure_lgbm_prediction_engine.py create mode 100644 python/ppc_model/secure_lgbm/secure_lgbm_training_engine.py create mode 100644 python/ppc_model/secure_lgbm/test/__init__.py create mode 100644 python/ppc_model/secure_lgbm/test/test_cipher_packing.py create mode 100644 python/ppc_model/secure_lgbm/test/test_pack_gh.py create mode 100644 python/ppc_model/secure_lgbm/test/test_save_load_model.py create mode 100644 python/ppc_model/secure_lgbm/test/test_secure_lgbm_context.py create mode 100644 python/ppc_model/secure_lgbm/test/test_secure_lgbm_performance_training.py create mode 100644 python/ppc_model/secure_lgbm/test/test_secure_lgbm_training.py create mode 100644 python/ppc_model/secure_lgbm/vertical/__init__.py create mode 100644 python/ppc_model/secure_lgbm/vertical/active_party.py create mode 100644 python/ppc_model/secure_lgbm/vertical/booster.py create mode 100644 python/ppc_model/secure_lgbm/vertical/passive_party.py create mode 100644 python/ppc_model/task/__init__.py create mode 100644 python/ppc_model/task/task_manager.py create mode 100644 python/ppc_model/task/test/__init__.py create mode 100644 python/ppc_model/task/test/task_manager_unittest.py create mode 100644 python/ppc_model/tools/start.sh create mode 100644 python/ppc_model/tools/stop.sh create mode 100644 python/ppc_model_gateway/__init__.py create mode 100644 python/ppc_model_gateway/clients/__init__.py create mode 100644 python/ppc_model_gateway/clients/client_manager.py create mode 100644 python/ppc_model_gateway/conf/application-sample.yml create mode 100644 python/ppc_model_gateway/conf/logging.conf create mode 100644 python/ppc_model_gateway/config.py create mode 100644 python/ppc_model_gateway/endpoints/__init__.py create mode 100644 python/ppc_model_gateway/endpoints/node_to_partner.py create mode 100644 python/ppc_model_gateway/endpoints/partner_to_node.py create mode 100644 python/ppc_model_gateway/endpoints/response_builder.py create mode 100644 python/ppc_model_gateway/ppc_model_gateway_app.py create mode 100644 python/ppc_model_gateway/test/__init__.py create mode 100644 python/ppc_model_gateway/test/client.py create mode 100644 python/ppc_model_gateway/test/server.py create mode 100644 python/ppc_model_gateway/tools/gen_cert.sh create mode 100644 python/ppc_model_gateway/tools/start.sh create mode 100644 python/ppc_model_gateway/tools/stop.sh create mode 100644 python/requirements.txt create mode 100644 python/tools/fake_id_data.py create mode 100644 python/tools/fake_ml_train_data.py create mode 100644 python/tools/requirements.txt diff --git a/python/__init__.py b/python/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/ppc_common/__init__.py b/python/ppc_common/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/ppc_common/application-sample.yml b/python/ppc_common/application-sample.yml new file mode 100644 index 00000000..b30c0146 --- /dev/null +++ b/python/ppc_common/application-sample.yml @@ -0,0 +1,6 @@ +# mysql or dm +DB_TYPE: "mysql" +SQLALCHEMY_DATABASE_URI: "mysql://root:12345678@127.0.0.1:3306/ppc?autocommit=true&charset=utf8" +# SQLALCHEMY_DATABASE_URI: "dm+dmPython://ppcv16:ppc12345678@127.0.0.1:5236" + +MPC_BIT_LENGTH: [@IDC_PPCS_COMMON_MPC_BIT] diff --git a/python/ppc_common/application.yml b/python/ppc_common/application.yml new file mode 100644 index 00000000..0ada951d --- /dev/null +++ b/python/ppc_common/application.yml @@ -0,0 +1,6 @@ +# mysql or dm +DB_TYPE: "mysql" +SQLALCHEMY_DATABASE_URI: "mysql://root:12345678@127.0.0.1:3306/ppc?autocommit=true&charset=utf8" +# SQLALCHEMY_DATABASE_URI: "dm+dmPython://ppcv16:ppc12345678@127.0.0.1:5236" + +MPC_BIT_LENGTH: 1000 diff --git a/python/ppc_common/config.py b/python/ppc_common/config.py new file mode 100644 index 00000000..e8e565ae --- /dev/null +++ b/python/ppc_common/config.py @@ -0,0 +1,18 @@ +# -*- coding: utf-8 -*- +import os +import yaml + + +dirName, _ = os.path.split(os.path.abspath(__file__)) +config_path = '{}/application.yml'.format(dirName) + +CONFIG_DATA = {} + + +def read_config(): + global CONFIG_DATA + with open(config_path, 'rb') as f: + CONFIG_DATA = yaml.safe_load(f.read()) + + +read_config() diff --git a/python/ppc_common/db_models/__init__.py b/python/ppc_common/db_models/__init__.py new file mode 100644 index 00000000..591b4182 --- /dev/null +++ b/python/ppc_common/db_models/__init__.py @@ -0,0 +1,6 @@ +from flask_sqlalchemy import SQLAlchemy + + +db = SQLAlchemy() + +# __all__ = ['computation_provider','data_provider','job_computation_queue','job_data_queue', 'job_result'] diff --git a/python/ppc_common/db_models/file_object_meta.py b/python/ppc_common/db_models/file_object_meta.py new file mode 100644 index 00000000..ef48ef9a --- /dev/null +++ b/python/ppc_common/db_models/file_object_meta.py @@ -0,0 +1,12 @@ +from ppc_common.db_models import db +from sqlalchemy import text + + +class FileObjectMeta(db.Model): + __tablename__ = 't_file_object' + file_path = db.Column(db.String(255), primary_key=True) + file_count = db.Column(db.Integer) + create_time = db.Column(db.TIMESTAMP( + True), nullable=False, server_default=text('NOW()')) + last_update_time = db.Column(db.TIMESTAMP(True), nullable=False, server_default=text( + 'CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP')) diff --git a/python/ppc_common/db_models/file_path.py b/python/ppc_common/db_models/file_path.py new file mode 100644 index 00000000..4c626750 --- /dev/null +++ b/python/ppc_common/db_models/file_path.py @@ -0,0 +1,10 @@ +from ppc_common.db_models import db + + +class FilePathRecord(db.Model): + __tablename__ = 't_file_path' + path = db.Column(db.String(255), primary_key=True) + storage_type = db.Column(db.String(255)) + file_id = db.Column(db.String(255)) + file_hash = db.Column(db.String(255)) + create_time = db.Column(db.BigInteger) diff --git a/python/ppc_common/db_models/job_unit_record.py b/python/ppc_common/db_models/job_unit_record.py new file mode 100644 index 00000000..f953ecbb --- /dev/null +++ b/python/ppc_common/db_models/job_unit_record.py @@ -0,0 +1,14 @@ + +from ppc_common.db_models import db + + +class JobUnitRecord(db.Model): + __tablename__ = 't_job_unit' + unit_id = db.Column(db.String(100), primary_key=True) + job_id = db.Column(db.String(255), index=True) + type = db.Column(db.String(255)) + status = db.Column(db.String(255), index=True) + upstream_units = db.Column(db.Text) + inputs_statement = db.Column(db.Text) + outputs = db.Column(db.Text) + update_time = db.Column(db.BigInteger) diff --git a/python/ppc_common/deps_services/__init__.py b/python/ppc_common/deps_services/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/ppc_common/deps_services/file_object.py b/python/ppc_common/deps_services/file_object.py new file mode 100644 index 00000000..1e69fef0 --- /dev/null +++ b/python/ppc_common/deps_services/file_object.py @@ -0,0 +1,85 @@ +# -*- coding: utf-8 -*- +from enum import Enum +from abc import ABC, abstractmethod + + +class SplitMode(Enum): + NONE = 0 # not split + SIZE = 1 # split by size + LINES = 2 # split by lines + + +class FileObject(ABC): + @abstractmethod + def split(self, split_mode, granularity): + """ + split large file into many small files with given granularity + """ + pass + + @abstractmethod + def download(self, enforce_flush=False): + """ + download the files + """ + pass + + @abstractmethod + def upload(self, split_mode, granularity): + """ + upload the files + """ + pass + + @abstractmethod + def delete(self): + """ + delete the files + """ + pass + + @abstractmethod + def existed(self) -> bool: + """ + check the file object exist or not + """ + pass + + @abstractmethod + def rename(self, storage_path: str): + """ + rename the file object + """ + pass + + @abstractmethod + def hit_local_cache(self): + """hit the local cache or not + """ + pass + + @abstractmethod + def get_local_path(self): + """get the local path + """ + pass + + @abstractmethod + def get_remote_path(self): + """get the remote path + """ + pass + + @abstractmethod + def get_data(self): + """get data + """ + pass + + @abstractmethod + def save_data(self, data, split_mode, granularity): + pass + + @abstractmethod + def update_data(self, updated_data, split_mode, granularity): + pass diff --git a/python/ppc_common/deps_services/hdfs_storage.py b/python/ppc_common/deps_services/hdfs_storage.py new file mode 100644 index 00000000..370ed2d3 --- /dev/null +++ b/python/ppc_common/deps_services/hdfs_storage.py @@ -0,0 +1,97 @@ +import os +from typing import AnyStr + +from hdfs.client import InsecureClient +from ppc_common.ppc_utils import common_func +from ppc_common.deps_services.storage_api import StorageApi, StorageType + +from ppc_common.ppc_utils import utils + + +class HdfsStorage(StorageApi): + + DEFAULT_HDFS_USER = "ppc" + DEFAULT_HDFS_USER_PATH = "/user/" + + def __init__(self, endpoint, hdfs_user, hdfs_home=None): + + # self.client = Client('http://127.0.0.1:9870') + self.endpoint = endpoint + self._user = common_func.get_config_value( + "HDFS_USER", HdfsStorage.DEFAULT_HDFS_USER, hdfs_user, False) + self._hdfs_storage_path = hdfs_home + if hdfs_home is None: + self._hdfs_storage_path = os.path.join( + HdfsStorage.DEFAULT_HDFS_USER_PATH, self._user) + + self.client = InsecureClient(endpoint, user=self._user) + # print(self.client.list('/')) + # print(self.client.list('/user/root/')) + + def get_home_path(self): + return self._hdfs_storage_path + + def storage_type(self): + return StorageType.HDFS + + def download_file(self, hdfs_path, local_file_path, enable_cache=False): + # hit the cache + if enable_cache is True and utils.file_exists(local_file_path): + return + if utils.file_exists(local_file_path): + utils.delete_file(local_file_path) + local_path = os.path.dirname(local_file_path) + if len(local_path) > 0 and not os.path.exists(local_path): + os.makedirs(local_path) + self.client.download(os.path.join(self._hdfs_storage_path, + hdfs_path), local_file_path) + return + + def upload_file(self, local_file_path, hdfs_path): + self.make_file_path(hdfs_path) + self.client.upload(os.path.join(self._hdfs_storage_path, hdfs_path), + local_file_path, overwrite=True) + return + + def make_file_path(self, hdfs_path): + hdfs_dir = os.path.dirname(hdfs_path) + if self.client.status(os.path.join(self._hdfs_storage_path, hdfs_dir), strict=False) is None: + self.client.makedirs(os.path.join( + self._hdfs_storage_path, hdfs_dir)) + return + + def delete_file(self, hdfs_path): + self.client.delete(os.path.join( + self._hdfs_storage_path, hdfs_path), recursive=True) + return + + def save_data(self, data: AnyStr, hdfs_path): + self.make_file_path(hdfs_path) + self.client.write(os.path.join(self._hdfs_storage_path, + hdfs_path), data, overwrite=True) + return + + def get_data(self, hdfs_path) -> AnyStr: + with self.client.read(os.path.join(self._hdfs_storage_path, hdfs_path)) as reader: + content = reader.read() + return content + + def mkdir(self, hdfs_dir): + self.client.makedirs(hdfs_dir) + + def file_existed(self, hdfs_path): + if self.client.status(os.path.join(self._hdfs_storage_path, hdfs_path), strict=False) is None: + return False + return True + + def file_rename(self, old_hdfs_path, hdfs_path): + old_path = os.path.join(self._hdfs_storage_path, old_hdfs_path) + new_path = os.path.join(self._hdfs_storage_path, hdfs_path) + # return for the file not exists + if not self.file_existed(old_path): + return + parent_path = os.path.dirname(new_path) + if len(parent_path) > 0 and not self.file_existed(parent_path): + self.mkdir(parent_path) + self.client.rename(old_path, new_path) + return diff --git a/python/ppc_common/deps_services/mysql_storage.py b/python/ppc_common/deps_services/mysql_storage.py new file mode 100644 index 00000000..ada66a53 --- /dev/null +++ b/python/ppc_common/deps_services/mysql_storage.py @@ -0,0 +1,65 @@ +# -*- coding: utf-8 -*- +from ppc_common.deps_services.sql_storage_api import SQLStorageAPI +from sqlalchemy.orm import sessionmaker, scoped_session +from sqlalchemy import create_engine +from sqlalchemy import delete +from sqlalchemy import text +from contextlib import contextmanager + + +class MySQLStorage(SQLStorageAPI): + def __init__(self, storage_config): + self._engine_url = storage_config.engine_url + self._storage_config = storage_config + connect_args = {} + if storage_config.db_name is not None: + connect_args = {'schema': storage_config.db_name} + self._mysql_engine = create_engine(self._engine_url, pool_recycle=self._storage_config.pool_recycle, + pool_size=self._storage_config.pool_size, max_overflow=self._storage_config.max_overflow, + pool_timeout=self._storage_config.pool_timeout, connect_args=connect_args) + self._session_factory = sessionmaker(bind=self._mysql_engine) + # Note: scoped_session is threadLocal + self._session = scoped_session(self._session_factory) + + @contextmanager + def _get_session(self): + session = self._session() + try: + yield session + session.commit() + except Exception: + session.rollback() + self._session.remove() + raise + finally: + session.close() + + def query(self, object, condition): + """ + query according to the condition + """ + with self._get_session() as session: + return session.query(object).filter(condition) + + def merge(self, record): + """merge the given record to db + Args: + record (Any): the record should been inserted + """ + with self._get_session() as session: + session.merge(record) + + def execute(self, sql: str): + text_sql = text(sql) + with self._get_session() as session: + session.execute(text_sql) + + def delete(self, object, condition): + """delete according to condition + Args: + object (Any): the object + condition (Any): the condition + """ + stmt = delete(object).where(condition) + with self._get_session() as session: + session.execute(stmt) diff --git a/python/ppc_common/deps_services/serialize_type.py b/python/ppc_common/deps_services/serialize_type.py new file mode 100644 index 00000000..31b84f6f --- /dev/null +++ b/python/ppc_common/deps_services/serialize_type.py @@ -0,0 +1,7 @@ +# -*- coding: utf-8 -*- +from enum import Enum + + +class SerializeType(Enum): + CSV = 'csv' + JSON = 'gain' diff --git a/python/ppc_common/deps_services/sharding_file_object.py b/python/ppc_common/deps_services/sharding_file_object.py new file mode 100644 index 00000000..d89df5bd --- /dev/null +++ b/python/ppc_common/deps_services/sharding_file_object.py @@ -0,0 +1,429 @@ +# -*- coding: utf-8 -*- +from ppc_common.deps_services.file_object import FileObject, SplitMode +from ppc_common.db_models.file_object_meta import FileObjectMeta +from ppc_common.ppc_utils.exception import PpcException, PpcErrorCode +import shutil +# from memory_profiler import profile +import os +import time +import uuid +import copy + + +class FileMeta: + SUB_FILE_OBJECT_DIR = "object" + + def __init__(self, local_path, remote_path, file_count, logger): + self.logger = logger + self._local_path = local_path + self._remote_path = remote_path + if self._local_path is None: + self._local_path = os.path.basename( + self.remote_path) + "." + str(uuid.uuid4()) + logger.info(f"generate random local path: {self._local_path}") + self._local_file_dir = os.path.dirname(self._local_path) + self._local_file_name = os.path.basename(self._local_path) + if self._remote_path is None: + self._remote_path = local_path + if file_count is not None: + self._file_count = file_count + + @property + def local_path(self): + return self._local_path + + @local_path.setter + def local_path(self, local_path): + self._local_path = local_path + + @property + def remote_path(self): + return self._remote_path + + @remote_path.setter + def remote_path(self, remote_path): + self._remote_path = remote_path + + @property + def file_count(self): + return self._file_count + + @file_count.setter + def file_count(self, file_count): + self._file_count = file_count + + def get_local_sub_file_name(self, file_index): + return os.path.join(self.get_local_sub_files_dir(), str(file_index)) + + def get_local_sub_files_dir(self): + return os.path.join( + self._local_file_dir, FileMeta.SUB_FILE_OBJECT_DIR, self._local_file_name) + + def mk_local_sub_files_dir(self): + local_sub_files_dir = self.get_local_sub_files_dir() + if os.path.exists(local_sub_files_dir): + return + os.makedirs(local_sub_files_dir) + + def remove_local_cache(self): + local_sub_files_dir = self.get_local_sub_files_dir() + if os.path.exists(local_sub_files_dir) is False: + return + shutil.rmtree(local_sub_files_dir) + + def get_remote_sub_file_name(self, file_index): + return os.path.join(self._remote_path, str(file_index)) + + +class ShardingFileObject(FileObject): + def __init__(self, local_path, remote_path, remote_storage_client, sql_storage, logger): + self.logger = logger + self._remote_storage_client = remote_storage_client + self._sql_storage = sql_storage + self._file_meta = FileMeta(local_path, remote_path, None, logger) + self.logger.info( + f"create ShardingFileObject, storage type: {self._remote_storage_client.storage_type()},local path: {self._file_meta.local_path}, remote path: {self._file_meta.remote_path}") + + @property + def file_meta(self): + return self._file_meta + + def split(self, split_mode, granularity): + """ + split large file many small files + """ + if split_mode is SplitMode.NONE or granularity is None: + return None + local_file_size = os.stat(self._file_meta.local_path).st_size + if local_file_size < granularity: + self.logger.info( + f"upload small files directly without split: {self._file_meta.local_path}") + return None + if split_mode is SplitMode.SIZE: + file_list = self._split_by_size(granularity) + if split_mode is SplitMode.LINES: + file_list = self._split_by_lines(granularity) + self._file_meta.file_count = len(file_list) + return file_list + + # @profile + def _split_by_size(self, granularity): + file_index = 0 + start_t = time.time() + file_list = [] + with open(self._file_meta.local_path, "rb") as fp: + while True: + start = time.time() + fp.seek(file_index * granularity) + data = fp.read(granularity) + if not data: + break + self._file_meta.mk_local_sub_files_dir() + sub_file_path = self._file_meta.get_local_sub_file_name( + file_index) + with open(sub_file_path, "wb") as wfp: + wfp.write(data) + file_list.append(sub_file_path) + self.logger.info( + f"split file by size, file: {self._file_meta.local_path}, sub_file: {sub_file_path}, time cost: {time.time() - start} seconds") + file_index += 1 + self.logger.info( + f"split file by size, file: {self._file_meta.local_path}, split granularity: {granularity}, sub file count: {file_index}, time cost: {time.time() - start_t} seconds") + return (file_list) + + # @profile + def _split_by_lines(self, granularity): + file_index = 0 + self._file_meta.mk_local_sub_files_dir() + start = time.time() + file_list = [] + with open(self._file_meta.local_path, "rb") as fp: + while True: + start_t = time.time() + lines = fp.readlines(granularity) + if not lines: + break + file_name = self._file_meta.get_local_sub_file_name(file_index) + with open(file_name, "wb") as wfp: + wfp.writelines(lines) + file_list.append(file_name) + self.logger.debug( + f"split file by lines, file: {self._file_meta.local_path}, sub file path: {file_name}, timecost: {time.time() - start_t} seconds") + file_index += 1 + self.logger.info( + f"split file by lines, file: {self._file_meta.local_path}, split granularity: {granularity}, sub file count: {file_index}, timecost: {time.time() - start} seconds") + return (file_list) + + def _local_cache_miss(self): + for file_index in range(0, self._file_meta.file_count): + sub_file_path = self._file_meta.get_local_sub_file_name(file_index) + if not os.path.exists(sub_file_path): + return True + return False + + def _check_uploaded_files(self): + if self._local_cache_miss(): + error_msg = f"check upload file failed, {self._file_meta.local_path} => {self._file_meta.remote_path}" + raise PpcException( + PpcErrorCode.FILE_OBJECT_UPLOAD_CHECK_FAILED.get_code(), error_msg) + + def upload(self, split_mode, granularity): + """split and upload the file + """ + if self.split(split_mode, granularity) is not None: + self._upload_chunks() + self.logger.info( + f"Upload success, remove local file cache: {self._file_meta.get_local_sub_files_dir()}") + self._file_meta.remove_local_cache() + return + # upload directly + start = time.time() + self.logger.info( + f"Upload: {self._file_meta.local_path}=>{self._file_meta.remote_path}") + self._remote_storage_client.upload_file( + self._file_meta.local_path, self._file_meta.remote_path) + self.logger.info( + f"Upload success: {self._file_meta.local_path}=>{self._file_meta.remote_path}, timecost: {time.time() - start}s") + + def _upload_chunks(self): + """ + upload the files + """ + start = time.time() + self._check_uploaded_files() + for file_index in range(0, self._file_meta.file_count): + start_t = time.time() + local_file_path = self._file_meta.get_local_sub_file_name( + file_index) + remote_file_path = self._file_meta.get_remote_sub_file_name( + file_index) + self.logger.info(f"upload: {local_file_path}=>{remote_file_path}") + self._remote_storage_client.upload_file( + local_file_path, remote_file_path) + self.logger.info( + f"upload: {local_file_path}=>{remote_file_path} success, timecost: {time.time() - start_t} seconds") + self.logger.info( + f"upload: {self._file_meta.local_path}=>{self._file_meta.remote_path} success, timecost: {time.time() - start} seconds, begin to store the meta information") + start = time.time() + record = FileObjectMeta( + file_path=self._file_meta.remote_path, file_count=self._file_meta.file_count) + self._sql_storage.merge(record) + self.logger.info( + f"store meta for {self._file_meta.remote_path} success, timecost: {time.time() - start} seconds") + + def _fetch_file_meta(self): + file_meta_info = self._sql_storage.query( + FileObjectMeta, FileObjectMeta.file_path == self._file_meta.remote_path) + # the file not exists + if file_meta_info is None or file_meta_info.count() == 0: + return False + self.logger.info( + f"fetch file meta information: {self._remote_storage_client.storage_type()}:{self._file_meta.remote_path}=>{self._file_meta.local_path}, file count: {file_meta_info.first().file_count}") + self._file_meta.file_count = file_meta_info.first().file_count + return True + + def download(self, enforce_flush=False): + """ + download the files + """ + if enforce_flush is not True and self.hit_local_cache(): + return + ret = self._fetch_file_meta() + start = time.time() + # the remote file not exists + if ret is False: + # no sharding case + if self._remote_storage_client.file_existed(self._file_meta.remote_path): + self.logger.info( + f"Download file: {self._file_meta.remote_path}=>{self._file_meta.local_path}") + self._remote_storage_client.download_file( + self._file_meta.remote_path, self._file_meta.local_path) + self.logger.info( + f"Download file success: {self._file_meta.remote_path}=>{self._file_meta.local_path}, timecost: {time.time() - start}") + return + error_msg = f"Download file from {self._remote_storage_client.storage_type()}: {self._file_meta.remote_path} failed for the file not exists!" + self.logger.error(error_msg) + raise PpcException( + PpcErrorCode.FILE_OBJECT_NOT_EXISTS.get_code(), error_msg) + # remove the local file + if os.path.exists(self._file_meta.local_path): + self.logger.info( + f"Download: remove the existed local file {self._file_meta.local_path}") + os.remove(self._file_meta.local_path) + # download from the remote storage client + start = time.time() + # merge the file + offset = 0 + try: + with open(self._file_meta.local_path, "wb") as fp: + for file_index in range(0, self._file_meta.file_count): + start_t = time.time() + fp.seek(offset) + remote_file_path = self._file_meta.get_remote_sub_file_name( + file_index) + local_file_path = self._file_meta.get_local_sub_file_name( + file_index) + self._remote_storage_client.download_file( + remote_file_path, local_file_path) + with open(local_file_path, "rb") as f: + fp.write(f.read()) + offset += os.stat(local_file_path).st_size + self.logger.info( + f"Download: {self._remote_storage_client.storage_type()}:{remote_file_path}=>{local_file_path}, timecost: {time.time() - start_t} seconds") + # remove the local cache + self._file_meta.remove_local_cache() + self.logger.info( + f"Download: {self._remote_storage_client.storage_type()}:{self._file_meta.remote_path}=>{self._file_meta.local_path} success, file count: {self._file_meta.file_count}, timecost: {time.time() - start} seconds") + except Exception as e: + self.logger.warn( + f"Download: {self._remote_storage_client.storage_type()}:{self._file_meta.remote_path}=>{self._file_meta.local_path} failed, {e}") + self._remove_local_file() + raise e + + def _remove_local_file(self): + if not os.path.exists(self._file_meta.local_path): + return + os.remove(self._file_meta.local_path) + + def delete(self): + """ + delete the files + """ + if self._fetch_file_meta() is False: + if self._remote_storage_client.file_existed(self._file_meta.remote_path): + self._remote_storage_client.delete_file( + self._file_meta.remote_path) + self.logger.info( + f"Delete file {self._file_meta.remote_path} success") + return + self.logger.info( + f"Delete nothing for file {self._file_meta.remote_path} not exists") + return + start = time.time() + for file_index in range(0, self._file_meta.file_count): + start_t = time.time() + remote_file_path = self._file_meta.get_remote_sub_file_name( + file_index) + self._remote_storage_client.delete_file(remote_file_path) + self.logger.info( + f"Delete: {self._remote_storage_client.storage_type()}:{remote_file_path}, timecost: {time.time() - start_t} seconds") + self._delete_remote_dir() + # delete the record + self._sql_storage.delete( + FileObjectMeta, FileObjectMeta.file_path == self._file_meta.remote_path) + self.logger.info( + f"Delete: {self._remote_storage_client.storage_type()}:{self._file_meta.remote_path} success, file count: {self._file_meta.file_count}, timecost: {time.time() - start} seconds") + + def _delete_remote_dir(self): + if self._remote_storage_client.file_existed(self._file_meta.remote_path): + self._remote_storage_client.delete_file( + self._file_meta.remote_path) + + def existed(self) -> bool: + """ + check the file object exist or not + """ + if self._fetch_file_meta() is True: + return True + return self._remote_storage_client.file_existed(self._file_meta.remote_path) + + def hit_local_cache(self): + """hit the local cache or not + """ + if not self.existed(): + error_msg = f"The file object: local:{self._file_meta.local_path}, remote: {self._file_meta.remote_path} not exists!" + self.logger.info(f"find local cache failed: {error_msg}") + raise PpcException( + PpcErrorCode.FILE_OBJECT_NOT_EXISTS.get_code(), error_msg) + return os.path.exists(self._file_meta.local_path) + + def rename(self, storage_path: str): + """ + rename the file object + """ + ret = self._fetch_file_meta() + start = time.time() + # the remote file not exists + if ret is False: + # no sharding case + if self._remote_storage_client.file_existed(self._file_meta.remote_path): + self.logger.info( + f"Rename file: {self._file_meta.remote_path}=>{storage_path}") + self._remote_storage_client.file_rename( + self._file_meta.remote_path, storage_path) + self.logger.info( + f"Rename file success: {self._file_meta.remote_path}=>{storage_path}, timecost: {time.time() - start} seconds") + return + error_msg = f"Rename file {self._remote_storage_client.storage_type()} => {self._file_meta.remote_path} failed for the file not exists!" + self.logger.error(error_msg) + raise PpcException( + PpcErrorCode.FILE_OBJECT_NOT_EXISTS.get_code(), error_msg) + # rename from the remote storage client + start = time.time() + new_file_meta = copy.deepcopy(self._file_meta) + new_file_meta.remote_path = storage_path + for file_index in range(0, self._file_meta.file_count): + start_t = time.time() + remote_file_path = self._file_meta.get_remote_sub_file_name( + file_index) + new_file_path = new_file_meta.get_remote_sub_file_name( + file_index) + self._remote_storage_client.file_rename( + remote_file_path, new_file_path) + self.logger.info( + f"Rename: {self._remote_storage_client.storage_type()}:{remote_file_path}=>{new_file_path}, timecost: {time.time() - start_t} seconds") + # delete the old record + self._delete_remote_dir() + self._sql_storage.delete( + FileObjectMeta, FileObjectMeta.file_path == self._file_meta.remote_path) + self._file_meta.remote_path = storage_path + # update the meta + self.logger.info( + f"Rename: {self._remote_storage_client.storage_type()}:{self._file_meta.remote_path}=>{storage_path} success, file count: {self._file_meta.file_count}, timecost: {time.time() - start} seconds") + start = time.time() + record = FileObjectMeta(file_path=storage_path, + file_count=self._file_meta.file_count) + self._sql_storage.merge(record) + self.logger.info( + f"Rename: store meta for {self._file_meta.remote_path} => {storage_path} success, timecost: {time.time() - start} seconds") + + def get_data(self): + self.download() + with open(self._file_meta.local_path, 'r') as file: + file_content = file.read() + file_bytes = file_content.encode('utf-8') + os.remove(self._file_meta.local_path) + return file_bytes + + def save_data(self, data, split_mode, granularity): + with open(self._file_meta.local_path, 'wb') as file: + file.write(data) + self.upload(split_mode, granularity) + os.remove(self._file_meta.local_path) + return + + def update_data(self, updated_data, split_mode, granularity): + start = time.time() + # try to remove the remote file + if self.existed(): + self.logger.info( + f"UpdateData: remove existed remote file: {self._file_meta.remote_path}") + self.delete() + # update the data + self.logger.info( + f"UpdateData: update the remote file: {self._file_meta.remote_path}") + self.save_data(data=updated_data if type(updated_data) is bytes else bytes(updated_data, "utf-8"), + split_mode=split_mode, granularity=granularity) + self.logger.info( + f"UpdateData success: update the remote file: {self._file_meta.remote_path}, timecost: {time.time() - start}") + return len(updated_data) + + def get_local_path(self): + """get the local path + """ + return self._file_meta.local_path + + def get_remote_path(self): + """get the remote path + """ + return self._file_meta.remote_path diff --git a/python/ppc_common/deps_services/sql_storage_api.py b/python/ppc_common/deps_services/sql_storage_api.py new file mode 100644 index 00000000..7f240f05 --- /dev/null +++ b/python/ppc_common/deps_services/sql_storage_api.py @@ -0,0 +1,57 @@ +# -*- coding: utf-8 -*- +from abc import ABC, abstractmethod +from ppc_common.ppc_utils import common_func + + +class SQLStorageConfig(ABC): + """ + configuration for sql storage + """ + DEFAULT_RECYCLE = 7200 + DEFAULT_POOL_SIZE = 16 + DEFAULT_MAX_OVERFLOW = 8 + DEFAULT_POOL_TIMEOUT = 30 + + def __init__(self, url, pool_recycle=None, pool_size=None, max_overflow=None, pool_timeout=None, db_type="mysql", db_name=None): + self.pool_recycle = common_func.get_config_value("pool_recycle", + SQLStorageConfig.DEFAULT_RECYCLE, pool_recycle, False) + self.pool_size = common_func.get_config_value("pool_size", + SQLStorageConfig.DEFAULT_POOL_SIZE, pool_size, False) + self.max_overflow = common_func.get_config_value("max_overflow", + SQLStorageConfig.DEFAULT_MAX_OVERFLOW, max_overflow, False) + self.pool_timeout = common_func.get_config_value("pool_timeout", + SQLStorageConfig.DEFAULT_POOL_TIMEOUT, pool_timeout, False) + self.db_type = db_type + self.engine_url = url + self.db_name = db_name + + +class SQLStorageAPI: + @abstractmethod + def query(self, object, condition): + """ + query the result + """ + pass + + @abstractmethod + def merge(self, record): + """insert the record into db + + Args: + record (Any): the record need to be inserted + """ + pass + + @abstractmethod + def execute(self, sql: str): + pass + + @abstractmethod + def delete(self, object, condition): + """delete according to condition + Args: + object (Any): the object + condition (Any): the condition + """ + pass diff --git a/python/ppc_common/deps_services/storage_api.py b/python/ppc_common/deps_services/storage_api.py new file mode 100644 index 00000000..ffc622de --- /dev/null +++ b/python/ppc_common/deps_services/storage_api.py @@ -0,0 +1,53 @@ +from abc import ABC, abstractmethod +from enum import Enum +from typing import AnyStr + + +class StorageType(Enum): + HDFS = 'HDFS' + + +class StorageApi(ABC): + @abstractmethod + def download_file(self, storage_path: str, local_file_path: str, enable_cache=False): + pass + + @abstractmethod + def upload_file(self, local_file_path: str, storage_path: str): + pass + + @abstractmethod + def make_file_path(self, storage_path: str): + pass + + @abstractmethod + def delete_file(self, storage_path: str): + pass + + @abstractmethod + def save_data(self, data: AnyStr, storage_path: str): + pass + + @abstractmethod + def get_data(self, storage_path: str) -> AnyStr: + pass + + @abstractmethod + def mkdir(self, storage_path: str): + pass + + @abstractmethod + def file_existed(self, storage_path: str) -> bool: + pass + + @abstractmethod + def file_rename(self, old_storage_path: str, storage_path: str): + pass + + @abstractmethod + def storage_type(self): + pass + + @abstractmethod + def get_home_path(self): + return "" diff --git a/python/ppc_common/deps_services/storage_loader.py b/python/ppc_common/deps_services/storage_loader.py new file mode 100644 index 00000000..ba64fdac --- /dev/null +++ b/python/ppc_common/deps_services/storage_loader.py @@ -0,0 +1,14 @@ +from ppc_common.deps_services.storage_api import StorageType +from ppc_common.deps_services.hdfs_storage import HdfsStorage +from ppc_common.ppc_utils import common_func + + +def load(config: dict, logger): + if config['STORAGE_TYPE'] == StorageType.HDFS.value: + hdfs_user = common_func.get_config_value( + 'HDFS_USER', None, config, False) + hdfs_home = common_func.get_config_value( + "HDFS_HOME", None, config, False) + return HdfsStorage(config['HDFS_ENDPOINT'], hdfs_user, hdfs_home) + else: + raise Exception('unsupported storage type') diff --git a/python/ppc_common/deps_services/tests/mysql_storage_test.py b/python/ppc_common/deps_services/tests/mysql_storage_test.py new file mode 100644 index 00000000..b0854d7a --- /dev/null +++ b/python/ppc_common/deps_services/tests/mysql_storage_test.py @@ -0,0 +1,104 @@ +# -*- coding: utf-8 -*- + +import unittest +from ppc_common.deps_services.mysql_storage import MySQLStorage +from ppc_common.deps_services.sql_storage_api import SQLStorageConfig +from ppc_common.db_models.file_object_meta import FileObjectMeta +import threading + + +class MySQLStorageWrapper: + def __init__(self): + # use the default config + self.engine_url = "mysql://root:12345678@127.0.0.1:3306/ppc?autocommit=true&charset=utf8mb4" + self.sql_storage_config = SQLStorageConfig(url=self.engine_url) + self.sql_storage = MySQLStorage(self.sql_storage_config) + + def single_thread_test(self, thread_name, path, ut_obj): + print(f"# begin test for thread: {thread_name}, path: {path}") + record_num = 100 + # insert records + file_count_set = set() + for i in range(0, record_num): + tmp_path = path + "_" + str(i) + file_object_meta = FileObjectMeta(file_path=tmp_path, file_count=i) + self.sql_storage.merge(file_object_meta) + file_count_set.add(i) + # query records + result_list = self.sql_storage.query( + FileObjectMeta, FileObjectMeta.file_path.startswith(path)) + ut_obj.assertEqual(result_list.count(), record_num) + # check the file_count + for item in result_list: + ut_obj.assertTrue(item.file_count in file_count_set) + # check the path + expected_path = path + "_" + str(item.file_count) + ut_obj.assertEqual(item.file_path, expected_path) + file_count_set.remove(item.file_count) + ut_obj.assertEqual(len(file_count_set), 0) + # update the file_count + delta = 100 + for i in range(0, record_num): + tmp_path = path + "_" + str(i) + file_count = i + delta + file_object_meta = FileObjectMeta( + file_path=tmp_path, file_count=file_count) + self.sql_storage.merge(file_object_meta) + file_count_set.add(file_count) + # query and check + result_list = self.sql_storage.query( + FileObjectMeta, FileObjectMeta.file_path.startswith(path)) + ut_obj.assertEqual(result_list.count(), record_num) + for item in result_list: + ut_obj.assertTrue(item.file_count in file_count_set) + # check the path + expected_path = path + "_" + str(item.file_count - delta) + ut_obj.assertEqual(item.file_path, expected_path) + file_count_set.remove(item.file_count) + ut_obj.assertEqual(len(file_count_set), 0) + # delete test + tmp_path = path + "_0" + self.sql_storage.delete( + FileObjectMeta, FileObjectMeta.file_path == tmp_path) + result_list = self.sql_storage.query( + FileObjectMeta, FileObjectMeta.file_path == tmp_path) + ut_obj.assertEqual(result_list.count(), 0) + result_list = self.sql_storage.query( + FileObjectMeta, FileObjectMeta.file_path.startswith(path)) + ut_obj.assertEqual(result_list.count(), record_num - 1) + # delete all data throw raw sql + self.sql_storage.execute( + f"delete from t_file_object where file_path like '{path}%'") + result_list = self.sql_storage.query( + FileObjectMeta, FileObjectMeta.file_path.startswith(path)) + ut_obj.assertEqual(result_list.count(), 0) + print(f"# test for thread: {thread_name}, path: {path} success") + + +class TestMySQLStorage(unittest.TestCase): + def test_single_thread(self): + wrapper = MySQLStorageWrapper() + path = "a/b/c/test_single_thread.csv" + wrapper.single_thread_test(path, "single_thread", self) + + def test_multi_thread(self): + loops = 5 + for j in range(loops): + thread_list = [] + thread_num = 20 + path = "a/b/c/test_multi_thread.csv" + wrapper = MySQLStorageWrapper() + for i in range(thread_num): + thread_name = "job_" + str(i) + tmp_path = "thread_" + str(i) + "_" + path + t = threading.Thread(target=wrapper.single_thread_test, name=thread_name, args=( + thread_name, tmp_path, self,)) + thread_list.append(t) + for t in thread_list: + t.start() + for t in thread_list: + t.join() + + +if __name__ == '__main__': + unittest.main() diff --git a/python/ppc_common/deps_services/tests/sharding_file_object_test.py b/python/ppc_common/deps_services/tests/sharding_file_object_test.py new file mode 100644 index 00000000..60094585 --- /dev/null +++ b/python/ppc_common/deps_services/tests/sharding_file_object_test.py @@ -0,0 +1,79 @@ +# -*- coding: utf-8 -*- +import unittest +from ppc_common.deps_services.sharding_file_object import ShardingFileObject +from ppc_common.deps_services.file_object import SplitMode +from ppc_common.deps_services.mysql_storage import MySQLStorage +from ppc_common.deps_services.sql_storage_api import SQLStorageConfig +from ppc_common.deps_services.hdfs_storage import HdfsStorage +import sys +import logging +import pandas as pd + + +class ShardingFileObjectTestWrapper: + def __init__(self, ut_obj, local_path, remote_path): + self.ut_obj = ut_obj + self.engine_url = "mysql://root:12345678@127.0.0.1:3306/ppc?autocommit=true&charset=utf8mb4" + self.sql_storage_config = SQLStorageConfig(url=self.engine_url) + logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) + self.logger = logging.getLogger(__name__) + self.sql_storage = MySQLStorage(self.sql_storage_config) + remote_storage_url = "http://127.0.0.1:9870" + remote_storage_client = HdfsStorage(remote_storage_url, "chenyujie") + self.sharding_file_object = ShardingFileObject( + local_path, remote_path, remote_storage_client, self.sql_storage, self.logger) + + def upload(self, split_mode, granularity, expected_lines): + file_list = self.sharding_file_object.upload(split_mode, granularity) + if split_mode is SplitMode.LINES: + self._check_lines(file_list, expected_lines) + + def _check_lines(self, file_list, expected_lines): + lines = 0 + # check the lines + i = 0 + columns_info = None + df = None + for file in file_list: + # can be parsed by read_csv + if i == 0: + df = pd.read_csv(file, header=0) + columns_info = df.columns + else: + df = pd.read_csv(file, header=None) + df.columns = columns_info + lines += len(df) + i += 1 + self.ut_obj.assertEqual(lines, expected_lines) + + +class ShardingFileObjectTest(unittest.TestCase): + + def test_split_by_lines(self): + local_path = "bak/train_test.csv" + remote_path = "train_test" + sharding_object_wrapper = ShardingFileObjectTestWrapper( + self, local_path, remote_path) + split_mode = SplitMode.LINES + # 100w + granularity = 20 * 1024 * 1024 + sharding_object_wrapper.upload(split_mode, granularity, 500000) + sharding_object_wrapper.sharding_file_object.download() + + +""" + def test_split_by_size(self): + local_path = "bak/train_test.csv" + remote_path = "train_test" + sharding_object_wrapper = ShardingFileObjectTestWrapper( + self, local_path, remote_path) + split_mode = SplitMode.SIZE + # 20M + granularity = 20 * 1024 * 1024 + sharding_object_wrapper.sharding_file_object.split( + split_mode, granularity) +""" + + +if __name__ == '__main__': + unittest.main() diff --git a/python/ppc_common/ppc_async_executor/__init__.py b/python/ppc_common/ppc_async_executor/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/ppc_common/ppc_async_executor/async_executor.py b/python/ppc_common/ppc_async_executor/async_executor.py new file mode 100644 index 00000000..50b62b7c --- /dev/null +++ b/python/ppc_common/ppc_async_executor/async_executor.py @@ -0,0 +1,35 @@ +from abc import ABC +from typing import Callable + + +class AsyncExecutor(ABC): + def execute(self, + task_id: str, + target: Callable, + on_target_finish: Callable[[str, bool, Exception], None], + args=()): + """ + 启动一个新进程/线程执行指定的目标函数。 + :param task_id: 任务id + :param target: 目标函数 + :param on_target_finish: 目标函数函数退出后的回调 + :param args: 函数的参数元组 + """ + pass + + def kill(self, task_id: str) -> bool: + """ + 强制终止目标函数。 + :param task_id: 任务id + :return 是否成功 + """ + pass + + def kill_all(self): + """ + 强制终止所有目标函数。 + """ + pass + + def __del__(self): + self.kill_all() diff --git a/python/ppc_common/ppc_async_executor/async_subprocess_executor.py b/python/ppc_common/ppc_async_executor/async_subprocess_executor.py new file mode 100644 index 00000000..87083ac8 --- /dev/null +++ b/python/ppc_common/ppc_async_executor/async_subprocess_executor.py @@ -0,0 +1,60 @@ +import multiprocessing +import threading +import time +from typing import Callable + +from ppc_common.ppc_async_executor.async_executor import AsyncExecutor + + +class AsyncSubprocessExecutor(AsyncExecutor): + def __init__(self, logger): + self.logger = logger + self.processes = {} + self.lock = threading.Lock() + self._cleanup_thread = threading.Thread(target=self._loop_cleanup) + self._cleanup_thread.daemon = True + self._cleanup_thread.start() + + def execute(self, task_id: str, target: Callable, on_target_finish: Callable[[str, bool, Exception], None], + args=()): + process = multiprocessing.Process(target=target, args=args) + process.start() + with self.lock: + self.processes[task_id] = process + + def kill(self, task_id: str): + with self.lock: + if task_id not in self.processes: + return False + else: + process = self.processes[task_id] + + process.terminate() + self.logger.info(f"Task {task_id} has been terminated!") + return True + + def kill_all(self): + with self.lock: + keys = self.processes.keys() + + for task_id in keys: + self.kill(task_id) + + def _loop_cleanup(self): + while True: + self._cleanup_finished_processes() + time.sleep(3) + + def _cleanup_finished_processes(self): + with self.lock: + finished_processes = [ + (task_id, proc) for task_id, proc in self.processes.items() if not proc.is_alive()] + + for task_id, process in finished_processes: + with self.lock: + process.join() # 确保进程资源释放 + del self.processes[task_id] + self.logger.info(f"Cleanup finished task process {task_id}") + + def __del__(self): + self.kill_all() diff --git a/python/ppc_common/ppc_async_executor/async_thread_executor.py b/python/ppc_common/ppc_async_executor/async_thread_executor.py new file mode 100644 index 00000000..0bc65b91 --- /dev/null +++ b/python/ppc_common/ppc_async_executor/async_thread_executor.py @@ -0,0 +1,75 @@ +import threading +import time +import traceback +from typing import Callable + +from ppc_common.ppc_async_executor.async_executor import AsyncExecutor +from ppc_common.ppc_async_executor.thread_event_manager import ThreadEventManager + + +class AsyncThreadExecutor(AsyncExecutor): + def __init__(self, event_manager: ThreadEventManager, logger): + self.event_manager = event_manager + self.logger = logger + self.threads = {} + self.lock = threading.Lock() + self._cleanup_thread = threading.Thread(target=self._loop_cleanup) + self._cleanup_thread.daemon = True + self._cleanup_thread.start() + + def execute(self, task_id: str, target: Callable, on_target_finish: Callable[[str, bool, Exception], None], + args=()): + def thread_target(logger, on_finish, *args): + try: + target(*args) + on_finish(task_id, True) + except Exception as e: + logger.warn(traceback.format_exc()) + on_finish(task_id, False, e) + + thread = threading.Thread(target=thread_target, args=( + self.logger, on_target_finish,) + args) + thread.start() + + with self.lock: + self.threads[task_id] = thread + + stop_event = threading.Event() + self.event_manager.add_event(task_id, stop_event) + + def kill(self, task_id: str): + with self.lock: + if task_id not in self.threads: + return False + else: + thread = self.threads[task_id] + + self.event_manager.set_event(task_id) + thread.join() + self.logger.info(f"Task {task_id} has been stopped!") + return True + + def kill_all(self): + with self.lock: + keys = self.threads.keys() + + for task_id in keys: + self.kill(task_id) + + def _loop_cleanup(self): + while True: + self._cleanup_finished_threads() + time.sleep(3) + + def _cleanup_finished_threads(self): + with self.lock: + finished_threads = [ + task_id for task_id, thread in self.threads.items() if not thread.is_alive()] + + for task_id in finished_threads: + with self.lock: + del self.threads[task_id] + self.logger.info(f"Cleanup finished task thread {task_id}") + + def __del__(self): + self.kill_all() diff --git a/python/ppc_common/ppc_async_executor/test/__init__.py b/python/ppc_common/ppc_async_executor/test/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/ppc_common/ppc_async_executor/test/async_executor_unittest.py b/python/ppc_common/ppc_async_executor/test/async_executor_unittest.py new file mode 100644 index 00000000..f4a91e7b --- /dev/null +++ b/python/ppc_common/ppc_async_executor/test/async_executor_unittest.py @@ -0,0 +1,109 @@ +import multiprocessing +import time +import unittest + +# 示例用法 +from ppc_common.ppc_async_executor.async_subprocess_executor import AsyncSubprocessExecutor +from ppc_common.ppc_async_executor.async_thread_executor import AsyncThreadExecutor +from ppc_common.ppc_async_executor.thread_event_manager import ThreadEventManager +from ppc_common.ppc_mock.mock_objects import MockLogger + + +def task1(shared_status1, key, dur): + shared_status1[key] = 1 + time.sleep(dur) + shared_status1[key] = 2 + + +def task2(shared_status2, key, dur): + shared_status2[key] = 1 + time.sleep(dur) + shared_status2[key] = 2 + + +def on_target_finish(task_id, flag, e=None): + if flag: + print(f'{task_id}, success') + else: + print(f'{task_id}, failed, {e}') + + +class TestSubprocessExecutor(unittest.TestCase): + + def test_kill(self): + logger = MockLogger() + executor = AsyncSubprocessExecutor(logger) + + # 使用 Manager 来创建共享状态字典 + manager = multiprocessing.Manager() + shared_status1 = manager.dict() + shared_status2 = manager.dict() + + # 启动两个任务 + key = 'test_kill' + executor.execute('0xaa', task1, on_target_finish, + (shared_status1, key, 2)) + executor.execute('0xbb', task2, on_target_finish, + (shared_status2, key, 2)) + + time.sleep(1) + + # 终止一个任务 + if executor.kill('0xaa'): + print(f"Task {'0xaa'} has been terminated") + + # 等待一段时间 + time.sleep(3) + + self.assertEqual(shared_status1.get(key, None), 1) + self.assertEqual(shared_status2.get(key, None), 2) + + +thread_event_manager = ThreadEventManager() + + +def thread_task1(shared_status1, task_id, key, dur): + shared_status1[key] = 1 + while not thread_event_manager.event_status(task_id): + time.sleep(dur) + shared_status1[key] = 2 + + +def thread_task2(shared_status2, key, dur): + shared_status2[key] = 1 + time.sleep(dur) + shared_status2[key] = 2 + raise Exception("raise error") + + +class TestThreadExecutor(unittest.TestCase): + + def test_kill(self): + logger = MockLogger() + executor = AsyncThreadExecutor(thread_event_manager, logger) + + shared_status1 = {'test': 'test'} + shared_status2 = {'test': 'test'} + key = 'test_kill' + task_id_1 = '0xaa' + task_id_2 = '0xbb' + + executor.execute(task_id_1, thread_task1, on_target_finish, + (shared_status1, task_id_1, key, 1)) + executor.execute(task_id_2, thread_task2, + on_target_finish, (shared_status2, key, 1)) + + time.sleep(2) + + # 终止一个任务 + if executor.kill(task_id_1): + print(f"Task {task_id_1} has been terminated") + + # 等待一段时间 + time.sleep(1) + self.assertEqual(shared_status1.get(key, None), 2) + self.assertEqual(shared_status2.get(key, None), 2) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/ppc_common/ppc_async_executor/thread_event_manager.py b/python/ppc_common/ppc_async_executor/thread_event_manager.py new file mode 100644 index 00000000..f4311b69 --- /dev/null +++ b/python/ppc_common/ppc_async_executor/thread_event_manager.py @@ -0,0 +1,34 @@ +import threading +from typing import Dict + +from readerwriterlock import rwlock + + +class ThreadEventManager: + def __init__(self): + # Event清理由TaskManager完成 + self.events: Dict[str, threading.Event] = {} + self.rw_lock = rwlock.RWLockWrite() + + def add_event(self, task_id: str, event: threading.Event) -> None: + with self.rw_lock.gen_wlock(): + self.events[task_id] = event + + def remove_event(self, task_id: str): + with self.rw_lock.gen_wlock(): + if task_id in self.events: + del self.events[task_id] + + def set_event(self, task_id: str): + with self.rw_lock.gen_wlock(): + if task_id in self.events: + self.events[task_id].set() + else: + raise KeyError(f"Task ID {task_id} not found") + + def event_status(self, task_id: str) -> bool: + with self.rw_lock.gen_rlock(): + if task_id in self.events: + return self.events[task_id].is_set() + else: + return False diff --git a/python/ppc_common/ppc_config/__init__.py b/python/ppc_common/ppc_config/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/ppc_common/ppc_config/file_chunk_config.py b/python/ppc_common/ppc_config/file_chunk_config.py new file mode 100644 index 00000000..80049ce4 --- /dev/null +++ b/python/ppc_common/ppc_config/file_chunk_config.py @@ -0,0 +1,27 @@ +# -*- coding: utf-8 -*- +from ppc_common.ppc_utils import common_func + + +class FileChunkConfig: + # default read chunk size is 512M + DEFAULT_READ_CHUNK_SIZE = 512 + # default file object chunk size is 2G + DEFAULT_FILE_OBJECT_CHUNK_SIZE = 2048 + + READ_CHUNK_SIZE_MB_KEY = "read_chunk_size_mb" + FILE_CHUNK_SIZE_MB_KEY = "file_chunk_size_mb" + ENABLE_ALL_CHUNCK_FILE_MGR_KEY = "enable_file_chunk_on_all_mode" + + def __init__(self, config): + read_chunk_size = common_func.get_config_value( + FileChunkConfig.READ_CHUNK_SIZE_MB_KEY, + FileChunkConfig.DEFAULT_READ_CHUNK_SIZE, + config, False) + self.read_chunk_size = int(read_chunk_size) * 1024 * 1024 + file_object_chunk_size = common_func.get_config_value( + FileChunkConfig.FILE_CHUNK_SIZE_MB_KEY, + FileChunkConfig.DEFAULT_FILE_OBJECT_CHUNK_SIZE, config, False) + self.file_object_chunk_size = int(file_object_chunk_size) * 1024 * 1024 + self.enable_file_chunk_on_all_mode = common_func.get_config_value( + FileChunkConfig.ENABLE_ALL_CHUNCK_FILE_MGR_KEY, + False, config, False) diff --git a/python/ppc_common/ppc_config/sql_storage_config_loader.py b/python/ppc_common/ppc_config/sql_storage_config_loader.py new file mode 100644 index 00000000..9ed1cf3b --- /dev/null +++ b/python/ppc_common/ppc_config/sql_storage_config_loader.py @@ -0,0 +1,37 @@ +# -*- coding: utf-8 -*- + +from ppc_common.deps_services.sql_storage_api import SQLStorageConfig +from ppc_common.ppc_utils import common_func + + +class SQLStorageConfigLoader: + POOL_RECYCLE_KEY = "SQL_POOL_RECYCLE_SECONDS" + POOL_SIZE_KEY = "SQL_POOL_SIZE" + MAX_OVERFLOW_KEY = "SQL_MAX_OVERFLOW_SIZE" + POOL_TIMEOUT_KEY = "SQL_POOL_TIMEOUT_SECONDS" + DATABASE_URL_KEY = "SQLALCHEMY_DATABASE_URI" + DATABASE_TYPE_KEY = "DB_TYPE" + DATABASE_NAME_KEY = "DB_NAME" + + @staticmethod + def load(config): + db_url = common_func.get_config_value( + SQLStorageConfigLoader.DATABASE_URL_KEY, None, config, True) + db_type = common_func.get_config_value( + SQLStorageConfigLoader.DATABASE_TYPE_KEY, None, config, True) + db_name = common_func.get_config_value( + SQLStorageConfigLoader.DATABASE_NAME_KEY, None, config, False) + pool_recycle = common_func.get_config_value( + SQLStorageConfigLoader.POOL_RECYCLE_KEY, None, config, False) + pool_size = common_func.get_config_value( + SQLStorageConfigLoader.POOL_SIZE_KEY, None, config, False) + max_overflow = common_func.get_config_value( + SQLStorageConfigLoader.MAX_OVERFLOW_KEY, None, config, False) + pool_timeout = common_func.get_config_value( + SQLStorageConfigLoader.POOL_TIMEOUT_KEY, None, config, False) + return SQLStorageConfig( + pool_recycle=pool_recycle, + pool_size=pool_size, + max_overflow=max_overflow, + pool_timeout=pool_timeout, + url=db_url, db_type=db_type, db_name=db_name) diff --git a/python/ppc_common/ppc_crypto/__init__.py b/python/ppc_common/ppc_crypto/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/ppc_common/ppc_crypto/crypto_utils.py b/python/ppc_common/ppc_crypto/crypto_utils.py new file mode 100644 index 00000000..43a13aaa --- /dev/null +++ b/python/ppc_common/ppc_crypto/crypto_utils.py @@ -0,0 +1,130 @@ +import binascii + +import base64 +import random +from Crypto.Cipher import AES +from Crypto.Cipher import PKCS1_v1_5 as Cipher_PKCS1_v1_5 +from Crypto.PublicKey import RSA + +from ppc_common.config import CONFIG_DATA + +try: + import gmpy2 + + IS_GMP = True +except ImportError: + IS_GMP = False + +_USE_MOD_GMP_SIZE = (1 << (8 * 2)) +_USE_MULMOD_GMP_SIZE = (1 << 1000) + +RSA_PUBLIC_HEADER = "-----BEGIN PUBLIC KEY-----" +RSA_PUBLIC_END = "-----END PUBLIC KEY-----" + +RSA_PRIVATE_HEADER = "-----BEGIN RSA PRIVATE KEY-----" +RSA_PRIVATE_END = "-----END RSA PRIVATE KEY-----" + + +def powmod(a, b, c): + if a == 1: + return 1 + if not IS_GMP or max(a, b, c) < _USE_MOD_GMP_SIZE: + return pow(a, b, c) + else: + return int(gmpy2.powmod(a, b, c)) + + +def mulmod(a, b, c): + if not IS_GMP or max(a, b, c) < _USE_MULMOD_GMP_SIZE: + return a * b % c + else: + a, b, c = gmpy2.mpz(a), gmpy2.mpz(b), gmpy2.mpz(c) + return int(gmpy2.mod(gmpy2.mul(a, b), c)) + + +DEFAULT_KEYSIZE = 1024 +DEFAULT_G = 9020881489161854992071763483314773468341853433975756385639545080944698236944020124874820917267762049756743282301106459062535797137327360192691469027152272 +DEFAULT_N = 102724610959913950919762303151320427896415051258714708724768326174083057407299433043362228762657118029566890747043004760241559786931866234640457856691885212534669604964926915306738569799518792945024759514373214412797317972739022405456550476153212687312211184540248262330559143446510677062823907392904449451177 +DEFAULT_FI = 102724610959913950919762303151320427896415051258714708724768326174083057407299433043362228762657118029566890747043004760241559786931866234640457856691885192126363163670343672910761259882348623401714459980712242233796355982147797162316532450768783823909695360736554767341443201861573989081253763975895939627220 + +# OUTPUT_BIT_LENGTH = 128 +OUTPUT_BIT_LENGTH = CONFIG_DATA['MPC_BIT_LENGTH'] +DEFAULT_MPC_N = pow(2, OUTPUT_BIT_LENGTH) + + +# DEFAULT_MPC_N = DEFAULT_N + +def ot_base_pown(value): + return powmod(DEFAULT_G, value, DEFAULT_N) + + +def ot_pown(base, value): + return powmod(base, value, DEFAULT_N) + + +def ot_mul_fi(a_val, b_val): + return mulmod(a_val, b_val, DEFAULT_FI) + + +def ot_mul_n(a_val, b_val): + return mulmod(a_val, b_val, DEFAULT_N) + + +# def ot_add(a_val, b_val): +# return (a_val + b_val)%DEFAULT_N + + +def ot_str_to_int(input_str): + return int.from_bytes(input_str.encode('utf-8'), 'big'), len(input_str) + + +def ot_int_to_str(input_int, len_input_str): + # len_int = len(str(input_int))//2 + result = input_int.to_bytes(len_input_str, 'big').decode('utf-8') + # TODO: check all 0x00 valid + # if result[0].encode('ascii') == b'\x00': + # result = result[1:] + return result + + +def get_random_int(): + return random.SystemRandom().randrange(1, DEFAULT_N) + + +def make_rsa_decrypt(encrypted_hex, private_key): + if RSA_PRIVATE_HEADER not in private_key: + private_key = RSA_PRIVATE_HEADER + "\n" + private_key + if RSA_PRIVATE_END not in private_key: + private_key = private_key + "\n" + RSA_PRIVATE_END + rsa_private_key = RSA.importKey(private_key) + encrypted_text = binascii.unhexlify(encrypted_hex) + cipher = Cipher_PKCS1_v1_5.new(rsa_private_key) + decrypted_text = cipher.decrypt(encrypted_text, None) + return decrypted_text.decode('utf-8') + + +class AESCipher: + + def __init__(self, key): + # self.key = bytes(key, 'utf-8') + self.key = key + + def encrypt(self, raw): + raw = pad(raw) + iv = "encryptionIntVec".encode('utf-8') + cipher = AES.new(self.key, AES.MODE_CBC, iv) + return base64.b64encode(cipher.encrypt(raw)) + + def decrypt(self, enc): + iv = "encryptionIntVec".encode('utf-8') + enc = base64.b64decode(enc) + cipher = AES.new(self.key, AES.MODE_CBC, iv) + return unpad(cipher.decrypt(enc)).decode('utf8') + + +BS = 16 +def pad(s): return bytes(s + (BS - len(s) % + BS) * chr(BS - len(s) % BS), 'utf-8') + + +def unpad(s): return s[0:-ord(s[-1:])] diff --git a/python/ppc_common/ppc_crypto/ihc_cipher.py b/python/ppc_common/ppc_crypto/ihc_cipher.py new file mode 100644 index 00000000..7cc385ff --- /dev/null +++ b/python/ppc_common/ppc_crypto/ihc_cipher.py @@ -0,0 +1,93 @@ +# from abc import ABC +# import namedTuple +from dataclasses import dataclass + +import struct +from ppc_common.ppc_crypto.phe_cipher import PheCipher +import secrets + + +@dataclass +class IhcCiphertext(): + __slots__ = ['c_left', 'c_right'] + + def __init__(self, c_left: int, c_right: int) -> None: + self.c_left = c_left + self.c_right = c_right + + def __add__(self, other): + cipher_left = self.c_left + other.c_left + cipher_right = self.c_right + other.c_right + return IhcCiphertext(cipher_left, cipher_right) + + def __eq__(self, other): + return self.c_left == other.c_left and self.c_right == other.c_right + + def encode(self) -> bytes: + # 计算每个整数的字节长度 + len_c_left = (self.c_left.bit_length() + 7) // 8 + len_c_right = (self.c_right.bit_length() + 7) // 8 + + # 将整数转换为字节序列,使用大端字节序和带符号整数 + c_left_bytes = self.c_left.to_bytes(len_c_left, byteorder='big') + c_right_bytes = self.c_right.to_bytes(len_c_right, byteorder='big') + + # 编码整数的长度 + len_bytes = struct.pack('>II', len_c_left, len_c_right) + + # 返回所有数据 + return len_bytes + c_left_bytes + c_right_bytes + + @classmethod + def decode(cls, encoded_data: bytes): + # 解码整数的长度 + len_c_left, len_c_right = struct.unpack('>II', encoded_data[:8]) + + # 根据长度解码整数 + c_left = int.from_bytes(encoded_data[8:8 + len_c_left], byteorder='big') + c_right = int.from_bytes(encoded_data[8 + len_c_left:8 + len_c_left + len_c_right], byteorder='big') + return cls(c_left, c_right) + +class IhcCipher(PheCipher): + def __init__(self, key_length: int = 256, iter_round: int = 16) -> None: + super().__init__(key_length) + key = secrets.randbits(key_length) + self.public_key = key + self.private_key = key + self.iter_round = iter_round + self.key_length = key_length + + self.max_mod = 1 << key_length + + def encrypt(self, number: int) -> IhcCiphertext: + random_u = secrets.randbits(self.key_length) + x_this = number + x_last = random_u + for i in range(0, self.iter_round): + x_tmp = (self.private_key * x_this - x_last) % self.max_mod + x_last = x_this + x_this = x_tmp + # cipher = IhcCiphertext(x_this, x_last, self.max_mod) + cipher = IhcCiphertext(x_this, x_last) + return cipher + + def decrypt(self, cipher: IhcCiphertext) -> int: + x_this = cipher.c_right + x_last = cipher.c_left + for i in range(0, self.iter_round-1): + x_tmp = (self.private_key * x_this - x_last) % self.max_mod + x_last = x_this + x_this = x_tmp + return x_this + + def encrypt_batch(self, numbers) -> list: + return [self.encrypt(num) for num in numbers] + + def decrypt_batch(self, ciphers) -> list: + return [self.decrypt(cipher) for cipher in ciphers] + + def encrypt_batch_parallel(self, numbers: list) -> list: + return self.encrypt_batch(numbers) + + def decrypt_batch_parallel(self, ciphers: list) -> list: + return self.decrypt_batch(ciphers) diff --git a/python/ppc_common/ppc_crypto/ihc_codec.py b/python/ppc_common/ppc_crypto/ihc_codec.py new file mode 100644 index 00000000..9de0c858 --- /dev/null +++ b/python/ppc_common/ppc_crypto/ihc_codec.py @@ -0,0 +1,29 @@ +from typing import Tuple + +from ppc_common.ppc_crypto.ihc_cipher import IhcCiphertext + + +class IhcCodec: + @staticmethod + def _int_to_bytes(x): + return x.to_bytes((x.bit_length() + 7) // 8, 'big') + + @staticmethod + def _bytes_to_int(x): + return int.from_bytes(x, 'big') + + @staticmethod + def encode_enc_key(public_key) -> bytes: + return bytes() + + @staticmethod + def decode_enc_key(public_key_bytes) -> bytes: + return bytes() + + @staticmethod + def encode_cipher(cipher: IhcCiphertext, be_secure=True) -> Tuple[bytes, bytes]: + return cipher.encode(), bytes() + + @staticmethod + def decode_cipher(enc_key, ciphertext: bytes, exponent) -> IhcCiphertext: + return IhcCiphertext.decode(ciphertext) diff --git a/python/ppc_common/ppc_crypto/paillier_cipher.py b/python/ppc_common/ppc_crypto/paillier_cipher.py new file mode 100644 index 00000000..d2e0232a --- /dev/null +++ b/python/ppc_common/ppc_crypto/paillier_cipher.py @@ -0,0 +1,45 @@ +import math +import os +from concurrent.futures import ProcessPoolExecutor + +from phe import paillier, EncryptedNumber + +from ppc_common.ppc_crypto.phe_cipher import PheCipher + + +class PaillierCipher(PheCipher): + + def __init__(self, key_length: int = 2048) -> None: + super().__init__(key_length) + self.public_key, self.private_key = paillier.generate_paillier_keypair( + n_length=self.key_length) + + def encrypt(self, number) -> EncryptedNumber: + return self.public_key.encrypt(int(number)) + + def decrypt(self, cipher: EncryptedNumber) -> int: + return self.private_key.decrypt(cipher) + + def encrypt_batch(self, numbers) -> list: + return [self.encrypt(num) for num in numbers] + + def decrypt_batch(self, ciphers) -> list: + return [self.decrypt(cipher) for cipher in ciphers] + + def encrypt_batch_parallel(self, numbers) -> list: + num_cores = os.cpu_count() + batch_size = math.ceil(len(numbers) / num_cores) + batches = [numbers[i:i + batch_size] for i in range(0, len(numbers), batch_size)] + with ProcessPoolExecutor(max_workers=num_cores) as executor: + futures = [executor.submit(self.encrypt_batch, batch) for batch in batches] + result = [future.result() for future in futures] + return [item for sublist in result for item in sublist] + + def decrypt_batch_parallel(self, ciphers) -> list: + num_cores = os.cpu_count() + batch_size = math.ceil(len(ciphers) / num_cores) + batches = [ciphers[i:i + batch_size] for i in range(0, len(ciphers), batch_size)] + with ProcessPoolExecutor(max_workers=num_cores) as executor: + futures = [executor.submit(self.decrypt_batch, batch) for batch in batches] + result = [future.result() for future in futures] + return [item for sublist in result for item in sublist] diff --git a/python/ppc_common/ppc_crypto/paillier_codec.py b/python/ppc_common/ppc_crypto/paillier_codec.py new file mode 100644 index 00000000..fb5f3bf0 --- /dev/null +++ b/python/ppc_common/ppc_crypto/paillier_codec.py @@ -0,0 +1,34 @@ +from typing import Tuple + +from phe import PaillierPublicKey, paillier, EncryptedNumber + + +class PaillierCodec: + @staticmethod + def _int_to_bytes(x): + return x.to_bytes((x.bit_length() + 7) // 8, 'big') + + @staticmethod + def _bytes_to_int(x): + return int.from_bytes(x, 'big') + + @staticmethod + def encode_enc_key(public_key: PaillierPublicKey) -> bytes: + return PaillierCodec._int_to_bytes(public_key.n) + + @staticmethod + def decode_enc_key(public_key_bytes: bytes) -> PaillierPublicKey: + public_key_n = PaillierCodec._bytes_to_int(public_key_bytes) + return paillier.PaillierPublicKey(n=public_key_n) + + @staticmethod + def encode_cipher(cipher: EncryptedNumber, be_secure=True) -> Tuple[bytes, bytes]: + return PaillierCodec._int_to_bytes(cipher.ciphertext(be_secure=be_secure)), \ + PaillierCodec._int_to_bytes(cipher.exponent) + + @staticmethod + def decode_cipher(public_key: PaillierPublicKey, ciphertext: bytes, exponent: bytes) -> EncryptedNumber: + return paillier.EncryptedNumber(public_key, + PaillierCodec._bytes_to_int( + ciphertext), + PaillierCodec._bytes_to_int(exponent)) diff --git a/python/ppc_common/ppc_crypto/phe_cipher.py b/python/ppc_common/ppc_crypto/phe_cipher.py new file mode 100644 index 00000000..3581ac45 --- /dev/null +++ b/python/ppc_common/ppc_crypto/phe_cipher.py @@ -0,0 +1,25 @@ +from abc import ABC + + +class PheCipher(ABC): + def __init__(self, key_length: int) -> None: + self.key_length = key_length + pass + + def encrypt(self, number: int): + pass + + def decrypt(self, cipher) -> int: + pass + + def encrypt_batch(self, numbers: list) -> list: + pass + + def decrypt_batch(self, ciphers: list) -> list: + pass + + def encrypt_batch_parallel(self, numbers: list) -> list: + pass + + def decrypt_batch_parallel(self, ciphers: list) -> list: + pass diff --git a/python/ppc_common/ppc_crypto/phe_factory.py b/python/ppc_common/ppc_crypto/phe_factory.py new file mode 100644 index 00000000..5db01063 --- /dev/null +++ b/python/ppc_common/ppc_crypto/phe_factory.py @@ -0,0 +1,25 @@ +from ppc_common.ppc_crypto.ihc_cipher import IhcCipher +from ppc_common.ppc_crypto.ihc_codec import IhcCodec +from ppc_common.ppc_crypto.paillier_cipher import PaillierCipher +from ppc_common.ppc_crypto.paillier_codec import PaillierCodec + + +class PheCipherFactory(object): + + @staticmethod + def build_phe(homo_algorithm=0, key_length=2048): + if homo_algorithm == 0: + return IhcCipher() + if homo_algorithm == 1: + return PaillierCipher(key_length) + else: + raise ValueError("Unsupported homo algorithm") + + @staticmethod + def build_codec(homo_algorithm=0): + if homo_algorithm == 0: + return IhcCodec() + if homo_algorithm == 1: + return PaillierCodec() + else: + raise ValueError("Unsupported homo algorithm") diff --git a/python/ppc_common/ppc_crypto/test/__init__.py b/python/ppc_common/ppc_crypto/test/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/ppc_common/ppc_crypto/test/phe_unittest.py b/python/ppc_common/ppc_crypto/test/phe_unittest.py new file mode 100644 index 00000000..5d036b75 --- /dev/null +++ b/python/ppc_common/ppc_crypto/test/phe_unittest.py @@ -0,0 +1,105 @@ +import time +import unittest + +import numpy as np + +from ppc_common.ppc_crypto.ihc_cipher import IhcCipher, IhcCiphertext +from ppc_common.ppc_crypto.ihc_codec import IhcCodec +from ppc_common.ppc_crypto.paillier_cipher import PaillierCipher + + +class PaillierUtilsTest(unittest.TestCase): + + def test_enc_and_dec_parallel(self): + paillier = PaillierCipher(key_length=1024) + inputs = np.random.randint(1, 10001, size=10) + + # start_time = time.time() + # paillier.encrypt_batch(inputs) + # end_time = time.time() + # print("enc:", end_time - start_time, "seconds") + + start_time = time.time() + ciphers = paillier.encrypt_batch_parallel(inputs) + end_time = time.time() + print("enc_p:", end_time - start_time, "seconds") + + start_time = time.time() + outputs = paillier.decrypt_batch_parallel(ciphers) + end_time = time.time() + print("dec_p:", end_time - start_time, "seconds") + + self.assertListEqual(list(inputs), list(outputs)) + + def test_ihc_enc_and_dec_parallel(self): + ihc = IhcCipher(key_length=256) + try_size = 100000 + inputs = np.random.randint(1, 10001, size=try_size) + expected = np.sum(inputs) + + start_time = time.time() + ciphers = ihc.encrypt_batch_parallel(inputs) + end_time = time.time() + print( + f"size:{try_size}, enc_p: {end_time - start_time} seconds, " + f"average times: {(end_time - start_time) / try_size * 1000 * 1000} us") + + start_time = time.time() + cipher_start = ciphers[0] + for i in range(1, len(ciphers)): + cipher_left = (cipher_start.c_left + ciphers[i].c_left) + cipher_right = (cipher_start.c_right + ciphers[i].c_right ) + # IhcCiphertext(cipher_left, cipher_right, cipher_start.max_mod) + IhcCiphertext(cipher_left, cipher_right) + end_time = time.time() + print(f"size:{try_size}, add_p raw with class: {end_time - start_time} seconds, average times: {(end_time - start_time)/try_size * 1000 * 1000} us") + + start_time = time.time() + cipher_start = ciphers[0] + for i in range(1, len(ciphers)): + cipher_left = (cipher_start.c_left + ciphers[i].c_left) + cipher_right = (cipher_start.c_right + ciphers[i].c_right ) + # IhcCiphertext(cipher_left, cipher_right) + end_time = time.time() + print(f"size:{try_size}, add_p raw: {end_time - start_time} seconds, average times: {(end_time - start_time)/try_size * 1000 * 1000} us") + + start_time = time.time() + cipher_start = ciphers[0] + for i in range(1, len(ciphers)): + cipher_start = cipher_start + ciphers[i] + end_time = time.time() + print( + f"size:{try_size}, add_p: {end_time - start_time} seconds, " + f"average times: {(end_time - start_time) / try_size * 1000 * 1000} us") + + start_time = time.time() + outputs = ihc.decrypt_batch_parallel(ciphers) + end_time = time.time() + print( + f"size:{try_size}, dec_p: {end_time - start_time} seconds, " + f"average times: {(end_time - start_time) / try_size * 1000 * 1000} us") + + decrypted = ihc.decrypt(cipher_start) + self.assertListEqual(list(inputs), list(outputs)) + assert decrypted == expected + + def test_ihc_code(self): + ihc = IhcCipher(key_length=256) + try_size = 100000 + inputs = np.random.randint(1, 10001, size=try_size) + start_time = time.time() + ciphers = ihc.encrypt_batch_parallel(inputs) + end_time = time.time() + print( + f"size:{try_size}, enc_p: {end_time - start_time} seconds, " + f"average times: {(end_time - start_time) / try_size * 1000 * 1000} us") + for i in range(0, len(ciphers)): + cipher: IhcCiphertext = ciphers[i] + encoded, _ = IhcCodec.encode_cipher(cipher) + + decoded = IhcCodec.decode_cipher(None, encoded, None) + assert cipher == decoded + + +if __name__ == '__main__': + unittest.main() diff --git a/python/ppc_common/ppc_dataset/__init__.py b/python/ppc_common/ppc_dataset/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/ppc_common/ppc_dataset/dataset_helper.py b/python/ppc_common/ppc_dataset/dataset_helper.py new file mode 100644 index 00000000..718134a0 --- /dev/null +++ b/python/ppc_common/ppc_dataset/dataset_helper.py @@ -0,0 +1,60 @@ +# -*- coding: utf-8 -*- + +from ppc_common.deps_services.sharding_file_object import ShardingFileObject +from ppc_common.deps_services.file_object import SplitMode +from ppc_common.deps_services.storage_api import StorageType +import os + + +class DataSetHelper: + # default split granularity is 1G + def __init__(self, dataset_user, dataset_id, dataset_local_path, + sql_storage, remote_storage, chunk_config, logger): + self._split_mode = SplitMode.NONE + self._chunk_config = chunk_config + self._chunk_size = self._chunk_config.file_object_chunk_size + self._dataset_user = dataset_user + remote_path = os.path.join(dataset_user, dataset_id) + self._file_object = ShardingFileObject( + dataset_local_path, remote_path, remote_storage, sql_storage, logger) + + @property + def file_object(self): + return self._file_object + + def download_dataset(self, enforce_flush=False): + """download the dataset + """ + self._file_object.download(enforce_flush=enforce_flush) + + def upload_dataset(self): + """upload the dataset + """ + self._file_object.upload(self._split_mode, self._chunk_size) + + def save_data(self, data): + """save the dataset + """ + self._file_object.save_data( + data=data, split_mode=self._split_mode, granularity=self._chunk_size) + + def update_data(self, data): + return self._file_object.update_data( + updated_data=data, split_mode=self._split_mode, granularity=self._chunk_size) + + def file_rename(self, new_storage_path, with_user_home=False): + if with_user_home: + self._file_object.rename(os.path.join( + self._dataset_user, new_storage_path)) + else: + self._file_object.rename(new_storage_path) + + def get_local_path(self): + """get the local path + """ + return self._file_object.get_local_path() + + def get_remote_path(self): + """get the remote path + """ + return self._file_object.get_remote_path() diff --git a/python/ppc_common/ppc_dataset/dataset_helper_factory.py b/python/ppc_common/ppc_dataset/dataset_helper_factory.py new file mode 100644 index 00000000..4fc0a2b0 --- /dev/null +++ b/python/ppc_common/ppc_dataset/dataset_helper_factory.py @@ -0,0 +1,61 @@ +# -*- coding: utf-8 -*- + +from ppc_common.ppc_dataset.dataset_helper import DataSetHelper + + +class DataSetHelperFactory: + def __init__(self, sql_storage, remote_storage, chunk_config, logger): + self._sql_storage = sql_storage + self._remote_storage = remote_storage + self._chunk_config = chunk_config + self._logger = logger + + def create(self, dataset_user, dataset_id, dataset_local_path): + return DataSetHelper(dataset_user=dataset_user, dataset_id=dataset_id, + dataset_local_path=dataset_local_path, sql_storage=self._sql_storage, + chunk_config=self._chunk_config, logger=self._logger, + remote_storage=self._remote_storage) + + +def upload_dataset(dataset_helper_factory, dataset_user, dataset_id, dataset_local_path, logger, log_keyword): + dataset_helper = dataset_helper_factory.create( + dataset_user, dataset_id, dataset_local_path) + logger.info( + f"{log_keyword}: Upload dataset: {dataset_helper.get_local_path()} => {dataset_helper.get_remote_path()}") + dataset_helper.upload_dataset() + logger.info( + f"{log_keyword}: Upload dataset success: {dataset_helper.get_local_path()} => {dataset_helper.get_remote_path()}") + return dataset_helper + + +def delete_dataset(dataset_helper_factory, dataset_user, dataset_id, logger, log_keyword): + dataset_helper = dataset_helper_factory.create( + dataset_user=dataset_user, dataset_id=dataset_id, dataset_local_path=None) + logger.info( + f"{log_keyword}: Delete dataset: {dataset_helper.get_remote_path()}") + dataset_helper.file_object.delete() + logger.info( + f"{log_keyword}: Delete dataset success: {dataset_helper.get_remote_path()}") + return dataset_helper + + +def download_dataset(dataset_helper_factory, dataset_user, dataset_id, dataset_local_path, logger, log_keyword, enforce_flush=False): + dataset_helper = dataset_helper_factory.create( + dataset_user, dataset_id, dataset_local_path) + logger.info( + f"{log_keyword}: Download dataset: {dataset_helper.get_remote_path()}=>{dataset_helper.get_local_path()}") + dataset_helper.download_dataset(enforce_flush) + logger.info( + f"{log_keyword} dataset success: {dataset_helper.get_remote_path()}=>{dataset_helper.get_local_path()}") + return dataset_helper + + +def get_dataset(dataset_helper_factory, dataset_user, dataset_id, logger, log_keyword): + dataset_helper = dataset_helper_factory.create( + dataset_user, dataset_id, None) + logger.info( + f"{log_keyword}: Get dataset from: {dataset_helper.get_remote_path()}") + data = dataset_helper.file_object.get_data() + logger.info( + f"{log_keyword}: Get dataset success from: {dataset_helper.get_remote_path()}") + return data diff --git a/python/ppc_common/ppc_initialize/__init__.py b/python/ppc_common/ppc_initialize/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/ppc_common/ppc_initialize/dataset_handler_initialize.py b/python/ppc_common/ppc_initialize/dataset_handler_initialize.py new file mode 100644 index 00000000..4b2e64fe --- /dev/null +++ b/python/ppc_common/ppc_initialize/dataset_handler_initialize.py @@ -0,0 +1,30 @@ +# -*- coding: utf-8 -*- +from ppc_common.ppc_config.sql_storage_config_loader import SQLStorageConfigLoader +from ppc_common.ppc_config.file_chunk_config import FileChunkConfig +from ppc_common.ppc_dataset.dataset_helper_factory import DataSetHelperFactory +from ppc_common.deps_services.mysql_storage import MySQLStorage +from ppc_common.deps_services import storage_loader + + +class DataSetHandlerInitialize: + def __init__(self, config, logger): + self._config = config + self._logger = logger + self._init_sql_storage() + self._init_remote_storage() + self._init_dataset_factory() + + def _init_sql_storage(self): + self.sql_storage = MySQLStorage( + SQLStorageConfigLoader.load(self._config)) + + def _init_remote_storage(self): + self.storage_client = storage_loader.load(self._config, self._logger) + + def _init_dataset_factory(self): + self.file_chunk_config = FileChunkConfig(self._config) + self.dataset_helper_factory = DataSetHelperFactory( + sql_storage=self.sql_storage, + remote_storage=self.storage_client, + chunk_config=self.file_chunk_config, + logger=self._logger) diff --git a/python/ppc_common/ppc_initialize/tests/dataset_initializer_test.py b/python/ppc_common/ppc_initialize/tests/dataset_initializer_test.py new file mode 100644 index 00000000..af3018e2 --- /dev/null +++ b/python/ppc_common/ppc_initialize/tests/dataset_initializer_test.py @@ -0,0 +1,126 @@ +# -*- coding: utf-8 -*- +import unittest +import logging +from ppc_common.ppc_config.sql_storage_config_loader import SQLStorageConfigLoader +from ppc_common.ppc_initialize.dataset_handler_initialize import DataSetHandlerInitialize +from ppc_common.ppc_config.file_chunk_config import FileChunkConfig +from ppc_common.ppc_dataset import dataset_helper_factory +import sys +from ppc_common.ppc_utils import utils + + +class DataSetInitializerWrapper: + def __init__(self, ut_obj, file_chunk_size): + logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) + self.logger = logging.getLogger(__name__) + self.config = {} + self.config[SQLStorageConfigLoader.DATABASE_URL_KEY] = "mysql://root:12345678@127.0.0.1:3306/ppc?autocommit=true&charset=utf8mb4" + self.config[FileChunkConfig.ENABLE_ALL_CHUNCK_FILE_MGR_KEY] = True + self.config["STORAGE_TYPE"] = "HDFS" + self.config["HDFS_USER"] = "chenyujie" + self.config["HDFS_ENDPOINT"] = "http://127.0.0.1:9870" + self.config[FileChunkConfig.FILE_CHUNK_SIZE_MB_KEY] = file_chunk_size + self.dataset_handler_initializer = DataSetHandlerInitialize( + self.config, self.logger) + self.ut_obj = ut_obj + + def _download_and_check(self, dataset_user, dataset_id, dataset_local_path, expected_md5=None): + # download + download_local_path = dataset_local_path + ".download" + dataset_helper_factory.download_dataset( + dataset_helper_factory=self.dataset_handler_initializer.dataset_helper_factory, + dataset_id=dataset_id, dataset_user=dataset_user, logger=self.logger, + dataset_local_path=download_local_path, + log_keyword="testDownload", enforce_flush=True + ) + # check md5 + downloaded_file_md5 = utils.calculate_md5(download_local_path) + self.ut_obj.assertEquals(expected_md5, downloaded_file_md5) + # get the dataset + data = dataset_helper_factory.get_dataset( + dataset_helper_factory=self.dataset_handler_initializer.dataset_helper_factory, + dataset_id=dataset_id, dataset_user=dataset_user, logger=self.logger, + log_keyword="testGetDataSet" + ) + get_data_md5sum = utils.md5sum(data) + self.ut_obj.assertEquals(expected_md5, get_data_md5sum) + + def test_dataset_ops(self, dataset_user, dataset_id, dataset_local_path, updated_data): + # upload the dataset + dataset_helper_factory.upload_dataset( + dataset_helper_factory=self.dataset_handler_initializer.dataset_helper_factory, + dataset_id=dataset_id, dataset_user=dataset_user, + dataset_local_path=dataset_local_path, + logger=self.logger, log_keyword="testUpload") + origin_md5 = utils.calculate_md5(dataset_local_path) + self._download_and_check(dataset_user=dataset_user, + dataset_id=dataset_id, + dataset_local_path=dataset_local_path, + expected_md5=origin_md5) + # update dataset + dataset_helper = self.dataset_handler_initializer.dataset_helper_factory.create( + dataset_user=dataset_user, + dataset_id=dataset_id, + dataset_local_path=None) + dataset_helper.update_data(updated_data) + # check + self._download_and_check(dataset_user=dataset_user, + dataset_id=dataset_id, + dataset_local_path=dataset_local_path, + expected_md5=utils.md5sum(updated_data)) + # rename + rename_path = dataset_id + ".rename" + dataset_helper.file_rename(rename_path, True) + # check + self._download_and_check(dataset_user=dataset_user, + dataset_id=rename_path, + dataset_local_path=dataset_local_path + ".rename", + expected_md5=utils.md5sum(updated_data)) + # delete + dataset_helper = dataset_helper_factory.delete_dataset( + dataset_helper_factory=self.dataset_handler_initializer.dataset_helper_factory, + dataset_id=rename_path, dataset_user=dataset_user, + logger=self.logger, log_keyword="testDelete") + self.ut_obj.assertEquals(dataset_helper.file_object.existed(), False) + + +class DatasetInitializerTest(unittest.TestCase): + def testNoFileSplit(self): + wrapper = DataSetInitializerWrapper(self, 100) + dataset_user = "yujiechen" + dataset_id = "test_file_no_split" + dataset_local_path = "tools/test" + uploaded_data = "abcdsfsdfsdfssdfsd" + wrapper.test_dataset_ops(dataset_user=dataset_user, + dataset_id=dataset_id, + dataset_local_path=dataset_local_path, + updated_data=uploaded_data) + + def testFileSplit(self): + wrapper = DataSetInitializerWrapper(self, 1) + dataset_user = "yujiechen" + dataset_id = "test_file_split_file" + dataset_local_path = "tools/test" + uploaded_data = "abcdsfsdfsdfssdfsd" + wrapper.test_dataset_ops(dataset_user=dataset_user, + dataset_id=dataset_id, + dataset_local_path=dataset_local_path, + updated_data=uploaded_data) + + def testFileSplitAndUpdateSplit(self): + wrapper = DataSetInitializerWrapper(self, 1) + dataset_user = "yujiechen" + dataset_id = "test_large_file_split_file" + dataset_local_path = "tools/test" + updated_file = "tools/test2" + uploaded_data = None + with open(updated_file, "rb") as fp: + uploaded_data = fp.read() + wrapper.test_dataset_ops(dataset_user=dataset_user, + dataset_id=dataset_id, + dataset_local_path=dataset_local_path, + updated_data=uploaded_data) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/ppc_common/ppc_ml/__init__.py b/python/ppc_common/ppc_ml/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/ppc_common/ppc_ml/feature/__init__.py b/python/ppc_common/ppc_ml/feature/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/ppc_common/ppc_ml/feature/feature_importance.py b/python/ppc_common/ppc_ml/feature/feature_importance.py new file mode 100644 index 00000000..2441a20e --- /dev/null +++ b/python/ppc_common/ppc_ml/feature/feature_importance.py @@ -0,0 +1,246 @@ +# -*- coding: utf-8 -*- +from enum import Enum +import pandas as pd +from ppc_common.deps_services.serialize_type import SerializeType + + +class FeatureGainInfo: + def __init__(self, feature_idx, feature_gain): + self.feature_idx = feature_idx + self.feature_gain = feature_gain + +# TODO: cover + + +class FeatureImportanceType(Enum): + WEIGHT = 'weight' # The number of times the feature is used in all trees + GAIN = 'gain' # The gain of features in predictions across all trees + + @classmethod + def has_value(cls, value): + return value in cls._value2member_map_ + + +class FeatureImportance: + """the feature importance + """ + + def __init__(self, importance_type: FeatureImportanceType, importance=0): + self.importance_type = importance_type + self.importance = importance + + def inc(self, value): + self.importance += value + + def desc(self): + return f"importance type: {self._importance_type}, importance: {self.importance}" + + def __cmp__(self, other): + if self.importance > other.importance: + return 1 + elif self.importance < other.importance: + return -1 + return 0 + + def __eq__(self, other): + return self.importance == other.importance + + def __lt__(self, other): + return self.importance < other.importance + + def __add__(self, other): + new_importance = FeatureImportance( + self.importance_type, importance=self.importance + other.importance) + return new_importance + + +class ReadOnlyFeatureImportanceStore: + RANK_PROPERTY = "score_rank" + DEAULT_SCORE_THRESHOLD = 0.95 + DEFAULT_TOPK_PROPERTY = "topk" + DEFAULT_EMPTY_TOPK_FLAG = " " + DEFAULT_TOPK_FLAG = "True" + + def __init__(self, feature_importance_dict, logger): + # feature_importance_type ==> fid ==> value + self.feature_importance_dict = feature_importance_dict + self.logger = logger + + def get_feature_importance(self, fid, type): + if type not in self.feature_importance_dict: + return 0 + if fid not in self.feature_importance_dict[type]: + return 0 + return self.feature_importance_dict[type][fid].importance + + def _get_sorting_columns(self): + sorting_columns = [] + ascending_list = [] + if FeatureImportanceType.GAIN in self.feature_importance_type_list: + sorting_columns.append(FeatureImportanceType.GAIN.name) + ascending_list.append(False) + if FeatureImportanceType.WEIGHT in self.feature_importance_type_list: + sorting_columns.append(FeatureImportanceType.WEIGHT.name) + ascending_list.append(False) + return (sorting_columns, ascending_list) + + @staticmethod + def load(df, logger): + """load the feature importance + + Args: + df (DataFrame): the feature importance + """ + logger.debug(f"load feature_importance data: {df}") + feature_importance_type_list = [] + for column in df.columns: + if not FeatureImportanceType.has_value(column.lower()): + continue + enum_feature_importance_type = FeatureImportanceType( + column.lower()) + feature_importance_type_list.append(enum_feature_importance_type) + logger.debug( + f"load feature_importance, feature_importance_type_list: {feature_importance_type_list}") + # feature_importance_type ==> fid_key ==> value + feature_importance_dict = dict() + for row in df.itertuples(): + fid_key = getattr( + row, FeatureImportanceStore.DEFAULT_FEATURE_PROPERTY) + for importance_type in feature_importance_type_list: + value = getattr(row, importance_type.name) + if importance_type not in feature_importance_dict: + feature_importance_dict.update({importance_type: dict()}) + feature_importance_dict[importance_type].update( + {fid_key: FeatureImportance(importance_type, value)}) + return ReadOnlyFeatureImportanceStore(feature_importance_dict, logger) + + def to_dataframe(self, topk_threshold=0.95): + """convert the feature importance into pd + the format: + | feature | score | score_rank | topk| + | x16 | 0.08234 | 1 | | + | x11 | 0.08134 | 2 | | + | x1 | 0.08034 | 3 | | + """ + if self.feature_importance_dict is None or len(self.feature_importance_dict) < 1: + return None + df = pd.DataFrame() + # the feature column + df.insert(df.shape[1], FeatureImportanceStore.DEFAULT_FEATURE_PROPERTY, + self.feature_importance_dict[self.feature_importance_type_list[0]].keys()) + # the importance columns + for importance_type in self.feature_importance_dict.keys(): + feature_importance_values = [] + for feature_importance in self.feature_importance_dict[importance_type].values(): + feature_importance_values.append(feature_importance.importance) + # calculate weight-average for the score + if importance_type == FeatureImportanceType.GAIN: + feature_importance_sum = sum(feature_importance_values) + if feature_importance_sum != 0: + for i in range(len(feature_importance_values)): + feature_importance_values[i] = float( + feature_importance_values[i]) / feature_importance_sum + + df.insert(df.shape[1], importance_type.name, + feature_importance_values) + # sort by the importance + (sorting_columns, ascending_list) = self._get_sorting_columns() + df = df.sort_values(by=sorting_columns, ascending=ascending_list) + # rank + df.insert(df.shape[1], + ReadOnlyFeatureImportanceStore.RANK_PROPERTY, + range(1, len(df) + 1)) + # top-k + if FeatureImportanceType.GAIN not in self.feature_importance_type_list: + return df + score_sum = float(0) + topk_list = [ + ReadOnlyFeatureImportanceStore.DEFAULT_EMPTY_TOPK_FLAG for _ in range(len(df))] + i = 0 + for score in df[FeatureImportanceType.GAIN.name].T: + score_sum += score + if score_sum >= topk_threshold: + topk_list[i] = ReadOnlyFeatureImportanceStore.DEFAULT_TOPK_FLAG + break + i += 1 + df.insert( + df.shape[1], ReadOnlyFeatureImportanceStore.DEFAULT_TOPK_PROPERTY, topk_list) + return df + + +class FeatureImportanceStore(ReadOnlyFeatureImportanceStore): + """store all the feature importance + """ + DEFAULT_FEATURE_PROPERTY = "feature" + DEFAULT_IMPORTANCE_LIST = [ + FeatureImportanceType.GAIN, FeatureImportanceType.WEIGHT] + + def __init__(self, feature_importance_type_list, feature_list, logger): + # feature_importance_type ==> fid ==> value + super().__init__(dict(), logger) + self.feature_list = feature_list + if self.feature_list is not None: + self.logger.info( + f"create FeatureImportanceStore, all features: {self.feature_list}") + self.feature_importance_type_list = feature_importance_type_list + for importance_type in self.feature_importance_type_list: + self.feature_importance_dict.update({importance_type: dict()}) + self._init() + + def _init(self): + if self.feature_list is None: + return + for i in range(len(self.feature_list)): + gain_list = {FeatureImportanceType.GAIN: 0, + FeatureImportanceType.WEIGHT: 0} + self.update_feature_importance(i, gain_list) + + def set_init(self, feature_list): + self.feature_list = feature_list + self._init() + + def generate_fid_key(self, fid): + if fid >= len(self.feature_list): + return None + return f"{self.feature_list[fid]}" + + def update_feature_importance(self, fid, gain_list): + """update the feature importance + + Args: + fid (int): the idx of the best feature(maxk) + gain_list (dict): the gain list for every importance_type + + Raises: + Exception: invalid gain_list + """ + fid_key = self.generate_fid_key(fid) + if fid_key is None: + return + for importance_type in gain_list.keys(): + # unknown importance_type + if importance_type not in self.feature_importance_dict: + continue + if fid_key not in self.feature_importance_dict[importance_type]: + self.feature_importance_dict[importance_type][fid_key] = FeatureImportance( + importance_type, 0) + self.feature_importance_dict[importance_type][fid_key].inc( + gain_list[importance_type]) + + def store(self, serialize_type: SerializeType, local_file_path, remote_file_path, storage_client): + """store the feature importance into file, upload using storage_client + + Args: + storage_client: the client used to upload the result + """ + df = self.to_dataframe() + if serialize_type is SerializeType.CSV: + df.to_csv(local_file_path, index=False) + else: + df.to_json(orient='split', path_or_buf=local_file_path) + self.logger.info( + f"Store feature_importance to {local_file_path}, file type: {serialize_type}") + if storage_client is not None: + storage_client.upload_file(local_file_path, remote_file_path) + self.logger.info( + f"Upload feature_importance to {local_file_path} success, file type: {serialize_type}") diff --git a/python/ppc_common/ppc_ml/feature/tests/feature_importance_test.py b/python/ppc_common/ppc_ml/feature/tests/feature_importance_test.py new file mode 100644 index 00000000..3b0f5661 --- /dev/null +++ b/python/ppc_common/ppc_ml/feature/tests/feature_importance_test.py @@ -0,0 +1,79 @@ +# -*- coding: utf-8 -*- +import unittest +from ppc_common.ppc_ml.feature.feature_importance import FeatureImportanceType +from ppc_common.ppc_ml.feature.feature_importance import FeatureImportanceStore +from ppc_common.ppc_ml.feature.feature_importance import ReadOnlyFeatureImportanceStore +from ppc_common.deps_services.serialize_type import SerializeType +import logging +import random +import sys +import pandas as pd + + +class FeatureImportanceWrapper: + def __init__(self, ut_obj, feature_size): + self.ut_obj = ut_obj + self.feature_size = feature_size + self.feature_list = [] + self._fake_features() + logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) + self.logger = logging.getLogger(__name__) + self.feature_importance_type_list = [ + FeatureImportanceType.GAIN, FeatureImportanceType.WEIGHT] + self.feature_importance_store = FeatureImportanceStore( + self.feature_importance_type_list, self.feature_list, self.logger) + self.epsilon = 0.000000001 + + def _fake_features(self): + for i in range(self.feature_size): + self.feature_list.append("feature_" + str(i)) + + def update_feature_importance_and_check(self, rounds, local_file_path): + for i in range(rounds): + selected_feature = random.randint(0, self.feature_size - 1) + key = self.feature_importance_store.generate_fid_key( + selected_feature) + gain = random.uniform(0, 101) + pre_gain = self.feature_importance_store.get_feature_importance( + key, FeatureImportanceType.GAIN) + pre_weight = self.feature_importance_store.get_feature_importance( + key, FeatureImportanceType.WEIGHT) + self.feature_importance_store.update_feature_importance( + selected_feature, {FeatureImportanceType.GAIN: gain, FeatureImportanceType.WEIGHT: 1}) + self.ut_obj.assertEqual( + pre_gain + gain, self.feature_importance_store.get_feature_importance(key, FeatureImportanceType.GAIN)) + self.ut_obj.assertEqual( + pre_weight + 1, self.feature_importance_store.get_feature_importance(key, FeatureImportanceType.WEIGHT)) + # store + self.feature_importance_store.store( + SerializeType.CSV, local_file_path, None, None) + # load + df = pd.read_csv(local_file_path) + loaded_feature_importance_store = ReadOnlyFeatureImportanceStore.load( + df, self.logger) + # check the dict + for importance_type in loaded_feature_importance_store.feature_importance_dict: + self.ut_obj.assertTrue( + importance_type in self.feature_importance_store.feature_importance_dict) + feature_importances = self.feature_importance_store.feature_importance_dict[ + importance_type] + loaded_feature_importances = loaded_feature_importance_store.feature_importance_dict[ + importance_type] + sum_data = 0 + for fid in loaded_feature_importances.keys(): + sum_data += loaded_feature_importances[fid].importance + self.ut_obj.assertTrue(abs(sum_data - 1) < self.epsilon) + local_json_file = local_file_path + ".json" + self.feature_importance_store.store( + SerializeType.JSON, local_json_file, None, None) + + +class TestFeatureImportance(unittest.TestCase): + def test_update_gain_and_store(self): + feature_size = 100 + self.wrapper = FeatureImportanceWrapper( + self, feature_size=feature_size) + rounds = 100000 + local_file_path = "feature_importance_case1.csv" + self.wrapper.update_feature_importance_and_check( + rounds, local_file_path=local_file_path) diff --git a/python/ppc_common/ppc_ml/model/algorithm_info.py b/python/ppc_common/ppc_ml/model/algorithm_info.py new file mode 100644 index 00000000..a117887d --- /dev/null +++ b/python/ppc_common/ppc_ml/model/algorithm_info.py @@ -0,0 +1,25 @@ +# -*- coding: utf-8 -*- + +from enum import Enum + + +class ClassificationType(Enum): + TWO = 'two' + MULTI = 'multi' + + @classmethod + def has_value(cls, value): + return value in cls._value2member_map_ + + +class EvaluationType(Enum): + ROC = "roc", + PR = "pr", + KS = "ks", + ACCURACY = "accuracy", + CONFUSION_MATRIX = "confusion_matrix" + + +class ModelRole(Enum): + ACTIVE = "active" + PASSIVE = "passive" diff --git a/python/ppc_common/ppc_mock/__init__.py b/python/ppc_common/ppc_mock/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/ppc_common/ppc_mock/mock_objects.py b/python/ppc_common/ppc_mock/mock_objects.py new file mode 100644 index 00000000..4ac711e0 --- /dev/null +++ b/python/ppc_common/ppc_mock/mock_objects.py @@ -0,0 +1,59 @@ +from datetime import datetime +from typing import AnyStr + +from ppc_common.deps_services.storage_api import StorageApi + + +class MockLogger: + @staticmethod + def debug(msg): + current_time = datetime.now() + print(f"{current_time}, Debug: {msg}") + + @staticmethod + def info(msg): + current_time = datetime.now() + print(f"{current_time}, Info: {msg}") + + @staticmethod + def warn(msg): + current_time = datetime.now() + print(f"{current_time}, Warn: {msg}") + + @staticmethod + def error(msg): + current_time = datetime.now() + print(f"{current_time}, Error: {msg}") + + +class MockStorageClient(StorageApi): + + def download_file(self, storage_path: str, local_file_path: str, enable_cache=False): + print(f'download_file: {storage_path} -> {local_file_path}') + + def upload_file(self, local_file_path: str, storage_path: str): + print(f'upload_file: {storage_path} -> {local_file_path}') + + def make_file_path(self, storage_path: str): + pass + + def delete_file(self, storage_path: str): + pass + + def save_data(self, data: AnyStr, storage_path: str): + pass + + def get_data(self, storage_path: str) -> AnyStr: + pass + + def mkdir(self, storage_path: str): + pass + + def file_existed(self, storage_path: str) -> bool: + pass + + def file_rename(self, old_storage_path: str, storage_path: str): + pass + + def storage_type(self): + pass diff --git a/python/ppc_common/ppc_protos/__init__.py b/python/ppc_common/ppc_protos/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/ppc_common/ppc_protos/generated/__init__.py b/python/ppc_common/ppc_protos/generated/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/ppc_common/ppc_protos/generated/ppc_model_pb2.py b/python/ppc_common/ppc_protos/generated/ppc_model_pb2.py new file mode 100644 index 00000000..57f7ca3a --- /dev/null +++ b/python/ppc_common/ppc_protos/generated/ppc_model_pb2.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: ppc_model.proto +# Protobuf Python Version: 4.25.3 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0fppc_model.proto\x12\tppc.model\"|\n\x0cModelRequest\x12\x0e\n\x06sender\x18\x01 \x01(\t\x12\x10\n\x08receiver\x18\x02 \x01(\t\x12\x0f\n\x07task_id\x18\x03 \x01(\t\x12\x0b\n\x03key\x18\x04 \x01(\t\x12\x0b\n\x03seq\x18\x05 \x01(\x03\x12\x11\n\tslice_num\x18\x06 \x01(\x03\x12\x0c\n\x04\x64\x61ta\x18\x07 \x01(\x0c\"3\n\x0c\x42\x61seResponse\x12\x12\n\nerror_code\x18\x01 \x01(\x03\x12\x0f\n\x07message\x18\x02 \x01(\t\"M\n\rModelResponse\x12.\n\rbase_response\x18\x01 \x01(\x0b\x32\x17.ppc.model.BaseResponse\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"#\n\rPlainBoolList\x12\x12\n\nplain_list\x18\x01 \x03(\x08\"\xb1\x01\n\rBestSplitInfo\x12\x0f\n\x07tree_id\x18\x01 \x01(\x03\x12\x0f\n\x07leaf_id\x18\x02 \x01(\x03\x12\x0f\n\x07\x66\x65\x61ture\x18\x03 \x01(\x03\x12\r\n\x05value\x18\x04 \x01(\x03\x12\x12\n\nagency_idx\x18\x05 \x01(\x03\x12\x16\n\x0e\x61gency_feature\x18\x06 \x01(\x03\x12\x11\n\tbest_gain\x18\x07 \x01(\x02\x12\x0e\n\x06w_left\x18\x08 \x01(\x02\x12\x0f\n\x07w_right\x18\t \x01(\x02\"3\n\x0bModelCipher\x12\x12\n\nciphertext\x18\x01 \x01(\x0c\x12\x10\n\x08\x65xponent\x18\x02 \x01(\x0c\"M\n\nCipherList\x12\x12\n\npublic_key\x18\x01 \x01(\x0c\x12+\n\x0b\x63ipher_list\x18\x02 \x03(\x0b\x32\x16.ppc.model.ModelCipher\"=\n\x0e\x43ipher1DimList\x12+\n\x0b\x63ipher_list\x18\x01 \x03(\x0b\x32\x16.ppc.model.ModelCipher\"W\n\x0e\x43ipher2DimList\x12\x12\n\npublic_key\x18\x01 \x01(\x0c\x12\x31\n\x0e\x63ipher_1d_list\x18\x02 \x03(\x0b\x32\x19.ppc.model.Cipher1DimList\"_\n\rEncAggrLabels\x12\r\n\x05\x66ield\x18\x01 \x01(\t\x12\x12\n\ncount_list\x18\x02 \x03(\x03\x12+\n\x0b\x63ipher_list\x18\x03 \x03(\x0b\x32\x16.ppc.model.ModelCipher\"_\n\x11\x45ncAggrLabelsList\x12\x12\n\npublic_key\x18\x01 \x01(\x0c\x12\x36\n\x14\x65nc_aggr_labels_list\x18\x02 \x03(\x0b\x32\x18.ppc.model.EncAggrLabels\"/\n\x10IterationRequest\x12\r\n\x05\x65poch\x18\x01 \x01(\x03\x12\x0c\n\x04stop\x18\x02 \x01(\x08\x32Y\n\x0cModelService\x12I\n\x12MessageInteraction\x12\x17.ppc.model.ModelRequest\x1a\x18.ppc.model.ModelResponse\"\x00\x42\x08P\x01\xa2\x02\x03PPCb\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'ppc_model_pb2', _globals) +if _descriptor._USE_C_DESCRIPTORS == False: + _globals['DESCRIPTOR']._options = None + _globals['DESCRIPTOR']._serialized_options = b'P\001\242\002\003PPC' + _globals['_MODELREQUEST']._serialized_start=30 + _globals['_MODELREQUEST']._serialized_end=154 + _globals['_BASERESPONSE']._serialized_start=156 + _globals['_BASERESPONSE']._serialized_end=207 + _globals['_MODELRESPONSE']._serialized_start=209 + _globals['_MODELRESPONSE']._serialized_end=286 + _globals['_PLAINBOOLLIST']._serialized_start=288 + _globals['_PLAINBOOLLIST']._serialized_end=323 + _globals['_BESTSPLITINFO']._serialized_start=326 + _globals['_BESTSPLITINFO']._serialized_end=503 + _globals['_MODELCIPHER']._serialized_start=505 + _globals['_MODELCIPHER']._serialized_end=556 + _globals['_CIPHERLIST']._serialized_start=558 + _globals['_CIPHERLIST']._serialized_end=635 + _globals['_CIPHER1DIMLIST']._serialized_start=637 + _globals['_CIPHER1DIMLIST']._serialized_end=698 + _globals['_CIPHER2DIMLIST']._serialized_start=700 + _globals['_CIPHER2DIMLIST']._serialized_end=787 + _globals['_ENCAGGRLABELS']._serialized_start=789 + _globals['_ENCAGGRLABELS']._serialized_end=884 + _globals['_ENCAGGRLABELSLIST']._serialized_start=886 + _globals['_ENCAGGRLABELSLIST']._serialized_end=981 + _globals['_ITERATIONREQUEST']._serialized_start=983 + _globals['_ITERATIONREQUEST']._serialized_end=1030 + _globals['_MODELSERVICE']._serialized_start=1032 + _globals['_MODELSERVICE']._serialized_end=1121 +# @@protoc_insertion_point(module_scope) diff --git a/python/ppc_common/ppc_protos/generated/ppc_model_pb2_grpc.py b/python/ppc_common/ppc_protos/generated/ppc_model_pb2_grpc.py new file mode 100644 index 00000000..e56168ae --- /dev/null +++ b/python/ppc_common/ppc_protos/generated/ppc_model_pb2_grpc.py @@ -0,0 +1,67 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + +from ppc_common.ppc_protos.generated import ppc_model_pb2 as ppc__model__pb2 + + +class ModelServiceStub(object): + """Missing associated documentation comment in .proto file.""" + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.MessageInteraction = channel.unary_unary( + '/ppc.model.ModelService/MessageInteraction', + request_serializer=ppc__model__pb2.ModelRequest.SerializeToString, + response_deserializer=ppc__model__pb2.ModelResponse.FromString, + ) + + +class ModelServiceServicer(object): + """Missing associated documentation comment in .proto file.""" + + def MessageInteraction(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_ModelServiceServicer_to_server(servicer, server): + rpc_method_handlers = { + 'MessageInteraction': grpc.unary_unary_rpc_method_handler( + servicer.MessageInteraction, + request_deserializer=ppc__model__pb2.ModelRequest.FromString, + response_serializer=ppc__model__pb2.ModelResponse.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'ppc.model.ModelService', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + + # This class is part of an EXPERIMENTAL API. + + +class ModelService(object): + """Missing associated documentation comment in .proto file.""" + + @staticmethod + def MessageInteraction(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/ppc.model.ModelService/MessageInteraction', + ppc__model__pb2.ModelRequest.SerializeToString, + ppc__model__pb2.ModelResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/python/ppc_common/ppc_protos/generated/ppc_pb2.py b/python/ppc_common/ppc_protos/generated/ppc_pb2.py new file mode 100644 index 00000000..5c8200ef --- /dev/null +++ b/python/ppc_common/ppc_protos/generated/ppc_pb2.py @@ -0,0 +1,64 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: ppc.proto +# Protobuf Python Version: 4.25.1 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\tppc.proto\x12\tppc.proto\"\xe5\x01\n\x0f\x41lgorithmDetail\x12\x1d\n\x15\x61lgorithm_description\x18\x01 \x01(\t\x12\x16\n\x0e\x61lgorithm_type\x18\x02 \x01(\t\x12\x13\n\x0b\x63reate_time\x18\x03 \x01(\x03\x12\x13\n\x0bupdate_time\x18\x04 \x01(\x03\x12\x17\n\x0fowner_user_name\x18\x05 \x01(\t\x12\x17\n\x0fowner_agency_id\x18\x06 \x01(\t\x12\x19\n\x11owner_agency_name\x18\x07 \x01(\t\x12$\n\x08ppc_flow\x18\x08 \x01(\x0b\x32\x12.ppc.proto.PpcFlow\"\xc1\x01\n\x07PpcFlow\x12\x13\n\x0binput_count\x18\x01 \x01(\t\x12\x12\n\nsql_module\x18\x02 \x01(\t\x12\x12\n\nmpc_module\x18\x03 \x01(\t\x12\x18\n\x10mpc_model_module\x18\x04 \x01(\t\x12\x14\n\x0cmatch_module\x18\x05 \x01(\t\x12\x19\n\x11\x61lgorithm_subtype\x18\x06 \x01(\t\x12\x1a\n\x12participant_agency\x18\x07 \x01(\t\x12\x12\n\nmodel_type\x18\x08 \x01(\t\"\xee\x01\n\rDatasetDetail\x12\x15\n\rdataset_title\x18\x01 \x01(\t\x12\x1b\n\x13\x64\x61taset_description\x18\x02 \x01(\t\x12\x14\n\x0c\x64\x61taset_hash\x18\x03 \x01(\t\x12\x14\n\x0c\x63olumn_count\x18\x04 \x01(\x03\x12\x11\n\trow_count\x18\x05 \x01(\x03\x12\x13\n\x0b\x63reate_time\x18\x06 \x01(\x03\x12\x11\n\tuser_name\x18\x07 \x01(\t\x12\x17\n\x0f\x64\x61ta_field_list\x18\x08 \x01(\t\x12\x14\n\x0c\x64\x61taset_size\x18\t \x01(\x03\x12\x13\n\x0bupdate_time\x18\n \x01(\x03\"K\n\x0f\x44\x61tasetAuthData\x12\x0f\n\x07is_auth\x18\x01 \x01(\x08\x12\x13\n\x0b\x63reate_time\x18\x02 \x01(\x03\x12\x12\n\nvalid_time\x18\x03 \x01(\x03\"\xde\x01\n\x12jobDatasetProvider\x12\x11\n\tagency_id\x18\x01 \x01(\t\x12\x13\n\x0b\x61gency_name\x18\x02 \x01(\t\x12\x12\n\ndataset_id\x18\x03 \x01(\t\x12\x15\n\rdataset_title\x18\x04 \x01(\t\x12\x15\n\rloading_index\x18\x05 \x01(\x03\x12\x1b\n\x13\x64\x61taset_description\x18\x06 \x01(\t\x12\x17\n\x0fowner_user_name\x18\x07 \x01(\t\x12\x14\n\x0c\x63olumns_size\x18\x08 \x01(\x03\x12\x12\n\nis_instant\x18\t \x01(\x03\"\x94\x01\n\x15\x64\x61tasourceInstantInfo\x12\r\n\x05\x64\x62_ip\x18\x01 \x01(\t\x12\x0f\n\x07\x64\x62_name\x18\x02 \x01(\t\x12\x13\n\x0b\x64\x62_password\x18\x03 \x01(\t\x12\x0f\n\x07\x64\x62_port\x18\x04 \x01(\x03\x12\x0e\n\x06\x64\x62_sql\x18\x05 \x01(\t\x12\x14\n\x0c\x64\x62_user_name\x18\x06 \x01(\t\x12\x0f\n\x07\x64\x62_type\x18\x07 \x01(\t\"I\n\x16jobDatasetProviderList\x12/\n\x08provider\x18\x01 \x03(\x0b\x32\x1d.ppc.proto.jobDatasetProvider\"m\n\x16jobComputationProvider\x12\x11\n\tagency_id\x18\x01 \x01(\t\x12\x13\n\x0b\x61gency_name\x18\x02 \x01(\t\x12\x12\n\nagency_url\x18\x03 \x01(\t\x12\x17\n\x0f\x63omputing_index\x18\x04 \x01(\x03\"Q\n\x1ajobComputationProviderList\x12\x33\n\x08provider\x18\x01 \x03(\x0b\x32!.ppc.proto.jobComputationProvider\")\n\x15jobResultReceiverList\x12\x10\n\x08receiver\x18\x01 \x03(\t\"\x8f\x05\n\x08JobEvent\x12\x0e\n\x06job_id\x18\x01 \x01(\t\x12\x11\n\tjob_title\x18\x02 \x01(\t\x12\x17\n\x0fjob_description\x18\x03 \x01(\t\x12\x14\n\x0cjob_priority\x18\x04 \x01(\x03\x12\x13\n\x0bjob_creator\x18\x05 \x01(\t\x12\x1b\n\x13initiator_agency_id\x18\x06 \x01(\t\x12\x1d\n\x15initiator_agency_name\x18\x07 \x01(\t\x12\x1b\n\x13initiator_signature\x18\x08 \x01(\t\x12\x18\n\x10job_algorithm_id\x18\t \x01(\t\x12\x1b\n\x13job_algorithm_title\x18\n \x01(\t\x12\x1d\n\x15job_algorithm_version\x18\x0b \x01(\t\x12\x1a\n\x12job_algorithm_type\x18\x0c \x01(\t\x12$\n\x1cjob_dataset_provider_list_pb\x18\r \x01(\t\x12#\n\x1bjob_result_receiver_list_pb\x18\x0e \x01(\t\x12\x13\n\x0b\x63reate_time\x18\x0f \x01(\x03\x12\x13\n\x0bupdate_time\x18\x10 \x01(\x03\x12\x12\n\npsi_fields\x18\x11 \x01(\t\x12\x1d\n\x15job_algorithm_subtype\x18\x12 \x01(\t\x12\x14\n\x0cmatch_fields\x18\x13 \x01(\t\x12\x16\n\x0eis_cem_encrypt\x18\x14 \x01(\x08\x12\x17\n\x0f\x64\x61taset_id_list\x18\x15 \x03(\t\x12\x30\n\x0e\x64\x61taset_detail\x18\x16 \x01(\x0b\x32\x18.ppc.proto.DatasetDetail\x12\x1e\n\x16tag_provider_agency_id\x18\x17 \x01(\t\x12\x10\n\x08job_type\x18\x18 \x01(\t\"/\n\tAuditItem\x12\x13\n\x0b\x64\x65scription\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t\"5\n\tAuditData\x12(\n\naudit_item\x18\x01 \x03(\x0b\x32\x14.ppc.proto.AuditItem\"E\n\x10JobOutputPreview\x12\x0e\n\x06header\x18\x01 \x03(\t\x12!\n\x04line\x18\x02 \x03(\x0b\x32\x13.ppc.proto.DataLine\"\x19\n\x08\x44\x61taLine\x12\r\n\x05value\x18\x01 \x03(\t\"=\n\x0eInputStatement\x12\x15\n\rupstream_unit\x18\x01 \x01(\t\x12\x14\n\x0coutput_index\x18\x02 \x01(\x03\"M\n\x16JobUnitInputsStatement\x12\x33\n\x10inputs_statement\x18\x01 \x03(\x0b\x32\x19.ppc.proto.InputStatement\"!\n\x0eJobUnitOutputs\x12\x0f\n\x07outputs\x18\x01 \x03(\t\")\n\x0fJobUnitUpstream\x12\x16\n\x0eupstream_units\x18\x01 \x03(\t\"<\n\tAlgorithm\x12\x14\n\x0c\x61lgorithm_id\x18\x01 \x01(\t\x12\x19\n\x11\x61lgorithm_version\x18\x02 \x01(\t\"6\n\nAlgorithms\x12(\n\nalgorithms\x18\x01 \x03(\x0b\x32\x14.ppc.proto.Algorithmb\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'ppc_pb2', _globals) +if _descriptor._USE_C_DESCRIPTORS == False: + DESCRIPTOR._options = None + _globals['_ALGORITHMDETAIL']._serialized_start = 25 + _globals['_ALGORITHMDETAIL']._serialized_end = 254 + _globals['_PPCFLOW']._serialized_start = 257 + _globals['_PPCFLOW']._serialized_end = 450 + _globals['_DATASETDETAIL']._serialized_start = 453 + _globals['_DATASETDETAIL']._serialized_end = 691 + _globals['_DATASETAUTHDATA']._serialized_start = 693 + _globals['_DATASETAUTHDATA']._serialized_end = 768 + _globals['_JOBDATASETPROVIDER']._serialized_start = 771 + _globals['_JOBDATASETPROVIDER']._serialized_end = 993 + _globals['_DATASOURCEINSTANTINFO']._serialized_start = 996 + _globals['_DATASOURCEINSTANTINFO']._serialized_end = 1144 + _globals['_JOBDATASETPROVIDERLIST']._serialized_start = 1146 + _globals['_JOBDATASETPROVIDERLIST']._serialized_end = 1219 + _globals['_JOBCOMPUTATIONPROVIDER']._serialized_start = 1221 + _globals['_JOBCOMPUTATIONPROVIDER']._serialized_end = 1330 + _globals['_JOBCOMPUTATIONPROVIDERLIST']._serialized_start = 1332 + _globals['_JOBCOMPUTATIONPROVIDERLIST']._serialized_end = 1413 + _globals['_JOBRESULTRECEIVERLIST']._serialized_start = 1415 + _globals['_JOBRESULTRECEIVERLIST']._serialized_end = 1456 + _globals['_JOBEVENT']._serialized_start = 1459 + _globals['_JOBEVENT']._serialized_end = 2114 + _globals['_AUDITITEM']._serialized_start = 2116 + _globals['_AUDITITEM']._serialized_end = 2163 + _globals['_AUDITDATA']._serialized_start = 2165 + _globals['_AUDITDATA']._serialized_end = 2218 + _globals['_JOBOUTPUTPREVIEW']._serialized_start = 2220 + _globals['_JOBOUTPUTPREVIEW']._serialized_end = 2289 + _globals['_DATALINE']._serialized_start = 2291 + _globals['_DATALINE']._serialized_end = 2316 + _globals['_INPUTSTATEMENT']._serialized_start = 2318 + _globals['_INPUTSTATEMENT']._serialized_end = 2379 + _globals['_JOBUNITINPUTSSTATEMENT']._serialized_start = 2381 + _globals['_JOBUNITINPUTSSTATEMENT']._serialized_end = 2458 + _globals['_JOBUNITOUTPUTS']._serialized_start = 2460 + _globals['_JOBUNITOUTPUTS']._serialized_end = 2493 + _globals['_JOBUNITUPSTREAM']._serialized_start = 2495 + _globals['_JOBUNITUPSTREAM']._serialized_end = 2536 + _globals['_ALGORITHM']._serialized_start = 2538 + _globals['_ALGORITHM']._serialized_end = 2598 + _globals['_ALGORITHMS']._serialized_start = 2600 + _globals['_ALGORITHMS']._serialized_end = 2654 +# @@protoc_insertion_point(module_scope) diff --git a/python/ppc_common/ppc_protos/ppc.proto b/python/ppc_common/ppc_protos/ppc.proto new file mode 100644 index 00000000..a9571c0f --- /dev/null +++ b/python/ppc_common/ppc_protos/ppc.proto @@ -0,0 +1,160 @@ +syntax = "proto3"; + +package ppc.proto; + +message AlgorithmDetail { + string algorithm_description = 1; + string algorithm_type = 2; + int64 create_time = 3; + int64 update_time = 4; + string owner_user_name = 5; + string owner_agency_id = 6; + string owner_agency_name = 7; + PpcFlow ppc_flow = 8; +} + +message PpcFlow { + // input_count need present '3+' + string input_count = 1; + string sql_module = 2; + string mpc_module = 3; + string mpc_model_module = 4; + string match_module = 5; + string algorithm_subtype = 6; + string participant_agency = 7; + string model_type = 8; +} + +message DatasetDetail { + string dataset_title = 1; + string dataset_description = 2; + string dataset_hash = 3; + int64 column_count = 4; + int64 row_count = 5; + int64 create_time = 6; + string user_name = 7; + string data_field_list = 8; + int64 dataset_size = 9; + int64 update_time = 10; +} + +message DatasetAuthData { + bool is_auth = 1; + int64 create_time = 2; + int64 valid_time = 3; +} + + +message jobDatasetProvider { + string agency_id = 1; + string agency_name = 2; + string dataset_id = 3; + string dataset_title = 4; + int64 loading_index = 5; + string dataset_description = 6; + string owner_user_name = 7; + int64 columns_size = 8; + int64 is_instant = 9; +} + +message datasourceInstantInfo { + string db_ip = 1; + string db_name = 2; + string db_password = 3; + int64 db_port = 4; + string db_sql = 5; + string db_user_name = 6; + string db_type = 7; +} + +message jobDatasetProviderList { + repeated jobDatasetProvider provider = 1; +} + +message jobComputationProvider { + string agency_id = 1; + string agency_name = 2; + string agency_url = 3; + int64 computing_index = 4; +} + +message jobComputationProviderList { + repeated jobComputationProvider provider = 1; +} + +message jobResultReceiverList { + repeated string receiver = 1; +} + + +message JobEvent { + string job_id = 1; + string job_title = 2; + string job_description = 3; + int64 job_priority = 4; + string job_creator = 5; + string initiator_agency_id = 6; + string initiator_agency_name = 7; + string initiator_signature = 8; + string job_algorithm_id = 9; + string job_algorithm_title = 10; + string job_algorithm_version = 11; + string job_algorithm_type = 12; + string job_dataset_provider_list_pb = 13; // jobDatasetProviderList + string job_result_receiver_list_pb = 14; // jobResultReceiverList + int64 create_time = 15; + int64 update_time = 16; + string psi_fields = 17; + string job_algorithm_subtype = 18; + string match_fields = 19; + bool is_cem_encrypt = 20; + repeated string dataset_id_list = 21; // cem encrypt dataset id without username prefix + DatasetDetail dataset_detail = 22; // cem encrypt dataset transfer + string tag_provider_agency_id = 23; + string job_type = 24; +} + +message AuditItem { + string description = 1; + string value = 2; +} + +message AuditData { + repeated AuditItem audit_item = 1; +} + +message JobOutputPreview { + repeated string header = 1; + repeated DataLine line = 2; +} + +message DataLine { + repeated string value = 1; +} + +message InputStatement { + string upstream_unit = 1; + int64 output_index = 2; +} + +message JobUnitInputsStatement { + repeated InputStatement inputs_statement = 1; +} + +message JobUnitOutputs { + repeated string outputs = 1; +} + +message JobUnitUpstream { + repeated string upstream_units = 1; +} + +message Algorithm { + string algorithm_id = 1; + string algorithm_version = 2; +} + +message Algorithms { + repeated Algorithm algorithms = 1; +} + diff --git a/python/ppc_common/ppc_protos/ppc_model.proto b/python/ppc_common/ppc_protos/ppc_model.proto new file mode 100644 index 00000000..ac1d72f1 --- /dev/null +++ b/python/ppc_common/ppc_protos/ppc_model.proto @@ -0,0 +1,83 @@ +syntax = "proto3"; + +option java_multiple_files = true; +//option java_package = "unknown"; +//option java_outer_classname = "unknown"; +option objc_class_prefix = "PPC"; + +package ppc.model; + +service ModelService { + rpc MessageInteraction (ModelRequest) returns (ModelResponse) {} +} + +message ModelRequest { + string sender = 1; + string receiver = 2; + string task_id = 3; + string key = 4; + int64 seq = 5; + int64 slice_num = 6; + bytes data = 7; +} + +message BaseResponse { + int64 error_code = 1; + string message = 2; +} + +message ModelResponse { + BaseResponse base_response = 1; + bytes data = 2; +} + +message PlainBoolList{ + repeated bool plain_list = 1; +} + +message BestSplitInfo{ + int64 tree_id = 1; + int64 leaf_id = 2; + int64 feature = 3; + int64 value = 4; + int64 agency_idx = 5; + int64 agency_feature = 6; + float best_gain = 7; + float w_left = 8; + float w_right = 9; +} + +message ModelCipher { + bytes ciphertext = 1; + bytes exponent = 2; +} + +message CipherList { + bytes public_key = 1; + repeated ModelCipher cipher_list = 2; +} + +message Cipher1DimList { + repeated ModelCipher cipher_list = 1; +} + +message Cipher2DimList { + bytes public_key = 1; + repeated Cipher1DimList cipher_1d_list = 2; +} + +message EncAggrLabels { + string field = 1; + repeated int64 count_list = 2; + repeated ModelCipher cipher_list = 3; +} + +message EncAggrLabelsList { + bytes public_key = 1; + repeated EncAggrLabels enc_aggr_labels_list = 2; +} + +message IterationRequest { + int64 epoch = 1; + bool stop = 2; +} diff --git a/python/ppc_common/ppc_utils/__init__.py b/python/ppc_common/ppc_utils/__init__.py new file mode 100644 index 00000000..8d7ee87a --- /dev/null +++ b/python/ppc_common/ppc_utils/__init__.py @@ -0,0 +1 @@ +__all__ = ['exception', 'path', 'permission', 'utils', 'common_func'] diff --git a/python/ppc_common/ppc_utils/anonymous_search.py b/python/ppc_common/ppc_utils/anonymous_search.py new file mode 100644 index 00000000..80c3df61 --- /dev/null +++ b/python/ppc_common/ppc_utils/anonymous_search.py @@ -0,0 +1,420 @@ +import hashlib +import json +import logging +import os +import random +import string +import unittest +import uuid +from io import BytesIO + +import pandas as pd + +from ppc_common.ppc_utils import utils, http_utils +from ppc_common.ppc_crypto import crypto_utils +from ppc_common.ppc_utils.exception import PpcException, PpcErrorCode + +log = logging.getLogger(__name__) + + +def make_hash(data): + m = hashlib.sha3_256() + m.update(data) + return m.hexdigest() + + +def requester_gen_ot_cipher(id_list, obfuscation_order): + blinding_a = crypto_utils.get_random_int() + blinding_b = crypto_utils.get_random_int() + x_value = crypto_utils.ot_base_pown(blinding_a) + y_value = crypto_utils.ot_base_pown(blinding_b) + c_blinding = crypto_utils.ot_mul_fi(blinding_a, blinding_b) + id_index_list = [] + send_hash_vec = [] + z_value_list = [] + for id_hash in id_list: + obs_list = id_obfuscation(obfuscation_order, None) + id_index = random.randint(0, obfuscation_order) + obs_list[id_index] = id_hash + id_index_list.append(id_index) + send_hash_vec.append(obs_list) + z_value = crypto_utils.ot_base_pown(c_blinding - id_index) + z_value_list.append(str(z_value)) + return id_index_list, blinding_b, x_value, y_value, send_hash_vec, z_value_list + + +def provider_gen_ot_cipher(x_value, y_value, send_hash_vec, z_value_list, data_map, is_contain_result=True): + """ + data_map = hashmap[id1: "message1", + id2: "message2"] + """ + if isinstance(x_value, str): + x_value = int(x_value) + if isinstance(y_value, str): + y_value = int(y_value) + if len(send_hash_vec) != len(z_value_list): + raise PpcException(PpcErrorCode.AYS_LENGTH_ERROR.get_code(), + PpcErrorCode.AYS_LENGTH_ERROR.get_msg()) + message_cipher_vec = [] + for idx, z_value in enumerate(z_value_list): + if isinstance(z_value, str): + z_value = int(z_value) + message_cipher_list = [] + message_int_len = 0 + for send_hash in send_hash_vec[idx]: + blinding_r = crypto_utils.get_random_int() + blinding_s = crypto_utils.get_random_int() + w_value = crypto_utils.ot_mul_n(crypto_utils.ot_pown(x_value, blinding_s), + crypto_utils.ot_base_pown(blinding_r)) + key_value = crypto_utils.ot_mul_n(crypto_utils.ot_pown(z_value, blinding_s), + crypto_utils.ot_pown(y_value, blinding_r)) + aes_key_bytes = os.urandom(16) + aes_default = utils.AESCipher(aes_key_bytes) + aes_key_base64str = utils.bytes_to_base64str(aes_key_bytes) + message_int, message_int_len = crypto_utils.ot_str_to_int( + aes_key_base64str) + + if send_hash not in data_map: + letters = string.ascii_lowercase + # random string + ot_message_str = 'message not found' + cipher = aes_default.encrypt(ot_message_str) + cipher_str = utils.bytes_to_base64str(cipher) + else: + if is_contain_result: + ot_message_str = json.dumps(data_map[send_hash]) + cipher = aes_default.encrypt(ot_message_str) + cipher_str = utils.bytes_to_base64str(cipher) + else: + ot_message_str = 'True' + cipher = aes_default.encrypt(ot_message_str) + cipher_str = utils.bytes_to_base64str(cipher) + message_cipher = key_value ^ message_int + z_value = crypto_utils.ot_mul_n(z_value, crypto_utils.DEFAULT_G) + message_cipher_list.append({ + "w": str(w_value), + "e": str(message_cipher), + "len": message_int_len, + 'aesCipher': cipher_str, + }) + message_cipher_vec.append(message_cipher_list) + return message_cipher_vec + + +def requester_ot_recover_result(id_index_list, blinding_b, message_cipher_vec): + if len(id_index_list) != len(message_cipher_vec): + raise PpcException(PpcErrorCode.AYS_LENGTH_ERROR.get_code(), + PpcErrorCode.AYS_LENGTH_ERROR.get_msg()) + result_list = [] + for idx, id_index in enumerate(id_index_list): + w_value = message_cipher_vec[idx][id_index]['w'] + e_value = message_cipher_vec[idx][id_index]['e'] + message_len = message_cipher_vec[idx][id_index]['len'] + cipher_str = message_cipher_vec[idx][id_index]['aesCipher'] + + if isinstance(w_value, str): + w_value = int(w_value) + if isinstance(e_value, str): + e_value = int(e_value) + w_1 = crypto_utils.ot_pown(w_value, blinding_b) + message_recover = w_1 ^ e_value + try: + aes_key = crypto_utils.ot_int_to_str(message_recover, message_len) + key_recover = utils.base64str_to_bytes(aes_key) + aes_recover = utils.AESCipher(key_recover) + cipher_recover = utils.base64str_to_bytes(cipher_str) + message_result = aes_recover.decrypt(cipher_recover) + result_list.append(message_result) + except Exception as be: + result_list.append(None) + return result_list + + +def id_obfuscation(obfuscation_order, rule=None): + # use rule extend different id order, such as driver card or name: + if rule is not None: + print('obfuscation is work in progress') + obs_list = [] + for i in range(0, obfuscation_order+1): + obs_list.append(make_hash(bytes(str(uuid.uuid4()), 'utf8'))) + return obs_list + + +def prepare_dataset_with_matrix(search_id_matrix, tmp_file_path, prepare_dataset_tmp_file_path): + search_reg = 'ppc-normal-prefix' + for search_list in search_id_matrix: + for search_id in search_list: + search_reg = '{}|{}'.format(search_reg, search_id) + exec_command = 'head -n 1 {} >> {}'.format( + tmp_file_path, prepare_dataset_tmp_file_path) + (status, result) = utils.getstatusoutput(exec_command) + if status != 0: + log.error( + f'[OnError]prepare_dataset_with_matrix! status is {status}, output is {result}') + else: + log.info( + f'prepare_dataset_with_matrix success! status is {status}, output is {result}') + exec_command = 'grep -E \'{}\' {} >> {}'.format( + search_reg, tmp_file_path, prepare_dataset_tmp_file_path) + (status, result) = utils.getstatusoutput(exec_command) + if status != 0: + log.error( + f'[OnError]prepare_dataset_with_matrix! status is {status}, output is {result}') + else: + log.info( + f'prepare_dataset_with_matrix success! status is {status}, output is {result}') + + +class TestOtMethods(unittest.TestCase): + + def test_choice_all_flow(self): + file_path = '/Users/asher/Downloads/test_file.csv' + # data_pd = pd.read_csv(file_path, index_col=0, header=0) + data_pd = pd.read_csv(file_path) + data_map = data_pd.set_index(data_pd.columns[0]).T.to_dict('list') + + choice = ['bob', '小鸡', '美丽'] + obs_order = 10 + _id_index_list, _blinding_b, _x_value, _y_value, _send_hash_vec, _z_value_list = requester_gen_ot_cipher( + choice, obs_order) + _message_cipher_vec = provider_gen_ot_cipher( + _x_value, _y_value, _send_hash_vec, _z_value_list, data_map) + result = requester_ot_recover_result( + _id_index_list, _blinding_b, _message_cipher_vec) + for idx, id_num in enumerate(choice): + print(f"found {id_num} value is {result[idx]}") + + # def test_choice_ot(self): + # choice_list = [1, 4] + # blinding_a = crypto_utils.get_random_int() + # blinding_b = crypto_utils.get_random_int() + # x_value = crypto_utils.ot_base_pown(blinding_a) + # y_value = crypto_utils.ot_base_pown(blinding_b) + # c_blinding = crypto_utils.ot_mul_fi(blinding_a, blinding_b) + # # c_value = crypto_utils.ot_base_pown(c_blinding) + # z_value_list = [] + # for choice in choice_list: + # z_value = crypto_utils.ot_base_pown(c_blinding - choice) + # z_value_list.append(z_value) + # # send x_value, y_value, z_value + # message_str_list = ['hello', 'world', 'ot', 'cipher', 'test'] + # message_list = [] + # for message_str in message_str_list: + # message_int, message_int_len = crypto_utils.ot_str_to_int(message_str) + # message_list.append(message_int) + # # message_list = [111111, 222222, 333333, 444444, 555555] + # cipher_vec = [] + # for z_value in z_value_list: + # cipher_list = [] + # for message in message_list: + # blinding_r = crypto_utils.get_random_int() + # blinding_s = crypto_utils.get_random_int() + # w_value = crypto_utils.ot_mul_n(crypto_utils.ot_pown(x_value, blinding_s), + # crypto_utils.ot_base_pown(blinding_r)) + # key_value = crypto_utils.ot_mul_n(crypto_utils.ot_pown(z_value, blinding_s), + # crypto_utils.ot_pown(y_value, blinding_r)) + # z_value = crypto_utils.ot_mul_n(z_value, crypto_utils.DEFAULT_G) + # e_cipher = key_value ^ message + # cipher_list.append({ + # "w": w_value, + # "e": e_cipher + # }) + # cipher_vec.append(cipher_list) + # + # for cipher_each in cipher_vec: + # for cipher in cipher_each: + # w_1 = crypto_utils.ot_pown(cipher['w'], blinding_b) + # message_recover = w_1 ^ cipher['e'] + # print(message_recover) + # for idx, cipher_list in enumerate(cipher_vec): + # w_1 = crypto_utils.ot_pown(cipher_list[choice_list[idx]]['w'], blinding_b) + # message_recover = w_1 ^ cipher_list[choice_list[idx]]['e'] + # s = crypto_utils.ot_int_to_str(message_recover) + # print(s) + + # def test_id_ot(self): + # print("test_id_ot") + # choice_id_list = [crypto_utils.ot_str_to_int('小明'), crypto_utils.ot_str_to_int('张三')] + # blinding_a = crypto_utils.get_random_int() + # blinding_b = crypto_utils.get_random_int() + # x_value = crypto_utils.ot_base_pown(blinding_a) + # y_value = crypto_utils.ot_base_pown(blinding_b) + # c_blinding = crypto_utils.ot_mul_fi(blinding_a, blinding_b) + # # c_value = crypto_utils.ot_base_pown(c_blinding) + # z_value_list = [] + # for choice in choice_id_list: + # z_value = crypto_utils.ot_base_pown(c_blinding - choice) + # z_value_list.append(z_value) + # # z_value = crypto_utils.ot_base_pown(c_blinding - choice_id) + # # send x_value, y_value, z_value + # id_str_list = ['小往', '小明', 'asher', 'set', '张三'] + # id_list = [] + # for id_str in id_str_list: + # id_list.append(crypto_utils.ot_str_to_int(id_str)) + # message_str_list = ['hello', 'world', 'ot', 'cipher', 'test'] + # message_list = [] + # for message_str in message_str_list: + # message_list.append(crypto_utils.ot_str_to_int(message_str)) + # # message_list = [111111, 222222, 333333, 444444, 555555] + # cipher_vec = [] + # for z_value in z_value_list: + # cipher_list = [] + # for idx, message in enumerate(message_list): + # blinding_r = crypto_utils.get_random_int() + # blinding_s = crypto_utils.get_random_int() + # z_value_use = crypto_utils.ot_mul_n(z_value, crypto_utils.ot_base_pown(id_list[idx])) + # w_value = crypto_utils.ot_mul_n(crypto_utils.ot_pown(x_value, blinding_s), + # crypto_utils.ot_base_pown(blinding_r)) + # key_value = crypto_utils.ot_mul_n(crypto_utils.ot_pown(z_value_use, blinding_s), + # crypto_utils.ot_pown(y_value, blinding_r)) + # e_cipher = key_value ^ message + # cipher_list.append({ + # "w": w_value, + # "e": e_cipher + # }) + # cipher_vec.append(cipher_list) + # + # for idx, cipher_list in enumerate(cipher_vec): + # for now_idx, cipher in enumerate(cipher_list): + # w_1 = crypto_utils.ot_pown(cipher['w'], blinding_b) + # message_recover = w_1 ^ cipher['e'] + # # print(message_recover) + # # print(idx) + # # print(now_idx) + # if (idx == 0 and now_idx == 1) or (idx == 1 and now_idx == 4): + # s = crypto_utils.ot_int_to_str(message_recover) + # print(s) + + def test_pd_with_multi_index(self): + print(True) + df = pd.DataFrame( + [[21, 'Amol', 72, 67], + [23, 'Lini', 78, 69], + [32, 'Kiku', 74, 56], + [52, 'Ajit', 54, 76], + [53, 'Ajit', 55, 78] + ], + columns=['rollno', 'name', 'physics', 'botony']) + + print('DataFrame with default index\n', df) + # set multiple columns as index + # df_map = df.set_index('name').T.to_dict('list') + # df_map = df.set_index('name').groupby(level=0).apply(lambda x: x.to_dict('r')).to_dict() + df_map = df.set_index('name').groupby(level=0).apply( + lambda x: x.to_dict('r')).to_dict() + print(df_map) + print(json.dumps(df_map['Ajit'])) + print(type(json.dumps(df_map['Kiku']))) + + def test_prepare_dataset_with_matrix(self): + tmp_file_path = "/Users/asher/Desktop/数据集2021/8_1_100w.csv" + prepare_dataset_tmp_file_path = "/Users/asher/Desktop/数据集2021/pre-test_100.csv" + search_id_matrix = [['645515750175253924', '779808531920530393'], [ + '399352968694137676', '399352968694137676222', '399352968694137']] + prepare_dataset_with_matrix( + search_id_matrix, tmp_file_path, prepare_dataset_tmp_file_path) + + # + # def test_choice_ot_multi(self): + # choice_list = [1, 2, 4] + # blinding_a = crypto_utils.get_random_int() + # blinding_b = crypto_utils.get_random_int() + # x_value = crypto_utils.ot_base_pown(blinding_a) + # y_value = crypto_utils.ot_base_pown(blinding_b) + # c_blinding = crypto_utils.ot_mul_fi(blinding_a, blinding_b) + # # c_value = crypto_utils.ot_base_pown(c_blinding) + # choice_final = 0 + # for choice in choice_list: + # choice_final = choice_final + choice + # z_value = crypto_utils.ot_base_pown(c_blinding - choice_final) + # # send x_value, y_value, z_value + # message_str_list = ['hello', 'world', 'ot', 'cipher', 'test'] + # message_list = [] + # for message_str in message_str_list: + # message_list.append(crypto_utils.ot_str_to_int(message_str)) + # # message_list = [111111, 222222, 333333, 444444, 555555] + # cipher_list = [] + # for message in message_list: + # blinding_r = crypto_utils.get_random_int() + # blinding_s = crypto_utils.get_random_int() + # w_value = crypto_utils.ot_mul_n(crypto_utils.ot_pown(x_value, blinding_s), + # crypto_utils.ot_base_pown(blinding_r)) + # key_value = crypto_utils.ot_mul_n(crypto_utils.ot_pown(z_value, blinding_s), + # crypto_utils.ot_pown(y_value, blinding_r)) + # z_value = crypto_utils.ot_mul_n(z_value, crypto_utils.DEFAULT_G) + # e_cipher = key_value ^ message + # cipher_list.append({ + # "w": w_value, + # "e": e_cipher + # }) + # + # # for cipher in cipher_list: + # # w_1 = crypto_utils.ot_pown(cipher['w'], blinding_b) + # # message_recover = w_1 ^ cipher['e'] + # # print(message_recover) + # + # for choice in choice_list: + # w_1 = crypto_utils.ot_pown(cipher_list[choice]['w'], blinding_b) + # for item in choice_list: + # if choice == item: + # continue + # else: + # base_g = crypto_utils.ot_base_pown(-item) + # w_1 = crypto_utils.ot_mul_n(w_1, base_g) + # message_recover = w_1 ^ cipher_list[choice]['e'] + # print(message_recover) + # # s = crypto_utils.ot_int_to_str(message_recover) + # # print(s) + # + + +# if __name__ == '__main__': + + # print(True) + # df = pd.DataFrame( + # [[21, 'Amol', 72, 67], + # [23, 'Lini', 78, 69], + # [32, 'Kiku', 74, 56], + # [52, 'Ajit', 54, 76]], + # columns=['rollno', 'name', 'physics', 'botony']) + # + # print('DataFrame with default index\n', df) + # # set multiple columns as index + # df = df.set_index(['rollno', 'name']) + # + # print('\nDataFrame with MultiIndex\n', df) + # point1 = point.base() + # print(point.base(scalar1).hex()) + + # json_response = "{\"id\": {\"0\": \"67b176705b46206614219f47a05aee7ae6a3edbe850bbbe214c536b989aea4d2\", \"1\": \"b1b1bd1ed240b1496c81ccf19ceccf2af6fd24fac10ae42023628abbe2687310\"}, \"x0\": {\"0\": 10, \"1\": 20}, \"x1\": {\"0\": 11, \"1\": 22}}" + # json_dict = json.loads(json_response) + # # json_pd = pd.json_normalize(json_dict) + # print(json_dict) + + # hex_str = utils.make_hash(bytes(str(796443), 'utf8'), CryptoType.ECDSA, HashType.HEXSTR) + # print(hex_str) + # csv_path1 = '/Users/asher/Downloads/UseCase120/usecase120_party1.csv' + # csv_path2 = '/Users/asher/Downloads/UseCase120/usecase120_party2.csv' + # data = pd.read_csv(csv_path1) + # if 'id' in data.columns.values: + # duplicated_list = data.duplicated('id', False).tolist() + # if True in duplicated_list: + # log.error(f"[OnError]id duplicated, check csv file") + # raise PpcException(PpcErrorCode.DATASET_CSV_ERROR.get_code(), + # PpcErrorCode.DATASET_CSV_ERROR.get_msg()) + # start = time.time() + # get_pd_file_with_hash_requester(csv_path1, f'{csv_path1}-pre', 2) + # start2 = time.time() + # print(f"requester prepare time is {start2 - start}s") + # + # get_pd_file_with_hash_data_provider(csv_path2, f'{csv_path2}-pre') + # start3 = time.time() + # print(f"provider prepare time is {start3 - start2}s") + # provider_output_path = '/Users/asher/Downloads/UseCase120/output.csv' + # get_anonymous_data(f'{csv_path1}-pre-requester', f'{csv_path2}-pre', provider_output_path) + # start4 = time.time() + # print(f"provider compute time is {start4 - start3}s") + # result_path = '/Users/asher/Downloads/UseCase120/result.csv' + # recover_result_data(f'{csv_path1}-pre', '/Users/asher/Downloads/UseCase120/output.csv', result_path) + # end = time.time() + # print(f"requester get result time is {end - start4}s") diff --git a/python/ppc_common/ppc_utils/audit_utils.py b/python/ppc_common/ppc_utils/audit_utils.py new file mode 100644 index 00000000..913a3934 --- /dev/null +++ b/python/ppc_common/ppc_utils/audit_utils.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- +import getopt +import sys + +from ppc_common.ppc_utils import utils +from ppc_common.ppc_utils.utils import CryptoType + +AUDIT_KEYS = ["agency_name", "input_dataset_hash", + "psi_input_hash", "psi_output_hash", + "mpc_result_hash"] + + +def parse_parameter(argv): + file_path = 0 + data_hash_value = 0 + crypto_type = None + try: + opts, args = getopt.getopt( + argv, "hf:v:c:", ["file_path=", "data_hash_value="]) + except getopt.GetoptError: + usage() + sys.exit(2) + if len(opts) == 0: + usage() + sys.exit(2) + for opt, arg in opts: + if opt == '-h': + usage() + sys.exit(0) + elif opt in ("-f", "--file_path"): + file_path = arg + elif opt in ("-v", "--data_hash_value"): + data_hash_value = arg + elif opt in ("-c", "--crypto_type"): + crypto_type = arg + else: + usage() + sys.exit(2) + return file_path, data_hash_value, crypto_type + + +def usage(): + print('audit.py -f -v -c crypto_type') + print('usage:') + print(' -f Notice:The file will be hashed to audit.') + print(' -v Notice:The dataset hash value will be compared to audit.') + print( + ' -c Notice:The crypto type[ECDSA or GM] will be used.') + + +if __name__ == '__main__': + file_path, data_hash_value, crypto_type = parse_parameter(sys.argv[1:]) + file_data_hash = utils.make_hash_from_file_path( + file_path, CryptoType[crypto_type]) + print(f'The file hash is:{file_data_hash}') + print(f'Audit result is:{file_data_hash == data_hash_value}') diff --git a/python/ppc_common/ppc_utils/cem_utils.py b/python/ppc_common/ppc_utils/cem_utils.py new file mode 100644 index 00000000..6e9d9d5f --- /dev/null +++ b/python/ppc_common/ppc_utils/cem_utils.py @@ -0,0 +1,160 @@ +import json +import unittest + +LEGAL_OPERATOR_LIST = ['>', '<', '>=', '<=', '==', '!='] +LEGAL_QUOTA_LIST = ['and', 'or'] + + +def get_rule_detail(match_module_dict): + ruleset = match_module_dict['ruleset'] + operator_set = set() + count_set = set() + quota_set = set() + for ruleset_item in ruleset: + sub_rule_set = ruleset_item['set'] + for sub_rule_item in sub_rule_set: + rule = sub_rule_item['rule'] + operator_set.add(rule['operator']) + count_set.add(rule['count']) + quota_set.add(rule['quota']) + return operator_set, count_set, quota_set + + +def check_dataset_id_has_duplicated(match_module_dict): + ruleset = match_module_dict['ruleset'] + for ruleset_item in ruleset: + sub_rule_set = ruleset_item['set'] + for sub_rule_item in sub_rule_set: + dataset_id_set = set() + dataset_id_list = sub_rule_item['dataset'] + for dataset_id in dataset_id_list: + dataset_id_set.add(dataset_id) + if len(dataset_id_set) != len(dataset_id_list): + return True + return False + + +def get_dataset_id_set(match_module_dict): + dataset_id_set = set() + ruleset = match_module_dict['ruleset'] + for ruleset_item in ruleset: + sub_rule_set = ruleset_item['set'] + for sub_rule_item in sub_rule_set: + dataset_id_list = sub_rule_item['dataset'] + for dataset_id in dataset_id_list: + dataset_id_set.add(dataset_id) + return dataset_id_set + + +# get field_dataset_map: {'x1':['d1', 'd2', 'd3'], 'x2':['d1', 'd2', 'd3']} +def parse_field_dataset_map(match_module_dict): + field_dataset_map = {} + ruleset = match_module_dict['ruleset'] + for ruleset_item in ruleset: + field = ruleset_item['field'] + if field in field_dataset_map: + field_dataset_id_set = field_dataset_map[field] + else: + field_dataset_id_set = set() + field_dataset_map[field] = field_dataset_id_set + sub_rule_set = ruleset_item['set'] + for sub_rule_item in sub_rule_set: + dataset_id_list = sub_rule_item['dataset'] + for dataset_id in dataset_id_list: + field_dataset_id_set.add(dataset_id) + return field_dataset_map + + +# get dataset_field_map: {'d1':['x1', 'x2'], 'd2':['x1', 'x2'], 'd3':['x1', 'x2']} +def parse_dataset_field_map(dataset_id_set, field_dataset_map): + dataset_field_map = {} + for dataset_id in dataset_id_set: + for field, field_dataset_id_set in field_dataset_map.items(): + if dataset_id in field_dataset_id_set: + if dataset_id in dataset_field_map: + dataset_field_set = dataset_field_map[dataset_id] + else: + dataset_field_set = set() + dataset_field_map[dataset_id] = dataset_field_set + dataset_field_set.add(field) + return dataset_field_map + + +def parse_match_param(match_fields, match_module_dict): + # step1 get field_dataset_map: {'x1':['d1', 'd2', 'd3'], 'x2':['d1', 'd2', 'd3']} + field_dataset_map = parse_field_dataset_map(match_module_dict) + # step2 get all dataset_id {'d1', 'd2', 'd3'} + dataset_id_set = set() + for field, field_dataset_id_set in field_dataset_map.items(): + dataset_id_set.update(field_dataset_id_set) + # step3 get dataset_field_map: {'d1':['x1', 'x2'], 'd2':['x1', 'x2'], 'd3':['x1', 'x2']} + dataset_field_map = parse_dataset_field_map( + dataset_id_set, field_dataset_map) + # step4 get match_param_list from dataset_field_map and match_field: + # [ + # {'dataset_id':'d1', 'match_field':{'x1':'xxx, 'x2':'xxx}, + # {'dataset_id':'d2', 'match_field':{'x1':'xxx, 'x2':'xxx}, + # {'dataset_id':'d3', 'match_field':{'x1':'xxx, 'x2':'xxx}, + # ] + match_param_list = parse_match_param_list(dataset_field_map, match_fields) + return dataset_id_set, match_param_list + + +# get match_param_list from dataset_field_map and match_field: +# [ +# {'dataset_id':'d1', 'match_field':{'x1':'xxx, 'x2':'xxx}, +# {'dataset_id':'d2', 'match_field':{'x1':'xxx, 'x2':'xxx}, +# {'dataset_id':'d3', 'match_field':{'x1':'xxx, 'x2':'xxx}, +# ] +def parse_match_param_list(dataset_field_map, match_fields): + match_param_list = [] + match_fields = match_fields.replace("'", '"') + match_fields_object = json.loads(match_fields) + for dataset_id, field_set in dataset_field_map.items(): + match_param = {'dataset_id': dataset_id} + field_value_map = {} + for field in field_set: + # allow some part field match + if field in match_fields_object.keys(): + field_value_map[field] = match_fields_object[field] + match_param['match_field'] = field_value_map + match_param_list.append(match_param) + return match_param_list + + +class TestCemUtils(unittest.TestCase): + def test_cem_match_algorithm_load(self): + match_module = '{"ruleset":[' \ + '{"field":"x1",' \ + '"set":[' \ + '{"rule":{"operator":"<","count":50,"quota":"and"},"dataset":["d1-encrypted","d2-encrypted"]},' \ + '{"rule":{"operator":">","count":3,"quota":"or"},"dataset":["d3-encrypted"]}]},' \ + '{"field":"x2","set":[' \ + '{"rule":{"operator":"<","count":2,"quota":"or"},"dataset":["d1-encrypted","d2-encrypted","d3-encrypted"]}]}]}' + match_module_dict = json.loads(match_module) + # match_module_dict = utils.json_loads(match_module) + print(match_module_dict) + + def test_check_dataset_id_has_duplicated(self): + match_module = '{"ruleset":[' \ + '{"field":"x1",' \ + '"set":[' \ + '{"rule":{"operator":"<","count":50,"quota":"and"},"dataset":["d1-encrypted","d2-encrypted"]},' \ + '{"rule":{"operator":">","count":3,"quota":"or"},"dataset":["d3-encrypted"]}]},' \ + '{"field":"x2","set":[' \ + '{"rule":{"operator":"<","count":2,"quota":"or"},"dataset":["d1-encrypted","d2-encrypted","d3-encrypted"]}]}]}' + match_module_dict = json.loads(match_module) + has_duplicated = check_dataset_id_has_duplicated(match_module_dict) + assert has_duplicated == False + + def test_check_dataset_id_has_duplicated(self): + match_module = '{"ruleset":[' \ + '{"field":"x1",' \ + '"set":[' \ + '{"rule":{"operator":"<","count":50,"quota":"and"},"dataset":["d1-encrypted","d2-encrypted"]},' \ + '{"rule":{"operator":">","count":3,"quota":"or"},"dataset":["d3-encrypted"]}]},' \ + '{"field":"x2","set":[' \ + '{"rule":{"operator":"<","count":2,"quota":"or"},"dataset":["d1-encrypted","d1-encrypted","d3-encrypted"]}]}]}' + match_module_dict = json.loads(match_module) + has_duplicated = check_dataset_id_has_duplicated(match_module_dict) + assert has_duplicated == True diff --git a/python/ppc_common/ppc_utils/common_func.py b/python/ppc_common/ppc_utils/common_func.py new file mode 100644 index 00000000..e907d6bd --- /dev/null +++ b/python/ppc_common/ppc_utils/common_func.py @@ -0,0 +1,28 @@ +# -*- coding: utf-8 -*- +from contextlib import contextmanager +import chardet + + +def get_config_value(key, default_value, config_value, required): + if required and config_value is None: + raise Exception(f"Invalid config for '{key}' for not set!") + value = config_value + if type(config_value) is dict: + if key in config_value: + value = config_value[key] + else: + value = default_value + if value is None: + return default_value + return value + + +def get_file_encoding(file_path): + encoding = None + with open(file_path, "rb") as fp: + header = fp.readline() + file_chardet = chardet.detect(header) + if file_chardet is None: + raise Exception(f"Unknown File Encoding, file: {file_path}") + encoding = file_chardet["encoding"] + return encoding diff --git a/python/ppc_common/ppc_utils/exception.py b/python/ppc_common/ppc_utils/exception.py new file mode 100644 index 00000000..212b3f71 --- /dev/null +++ b/python/ppc_common/ppc_utils/exception.py @@ -0,0 +1,146 @@ +# -*- coding: utf-8 -*- +from enum import Enum, unique + + +@unique +class PpcErrorCode(Enum): + SUCCESS = {0: 'success'} + INTERNAL_ERROR = {10000: "internal error"} + + NETWORK_ERROR = {10001: 'network error'} + JOB_STATUS_ERROR = {10002: 'job status check error'} + JOB_ROLE_ERROR = {10003: 'job role check error'} + DATABASE_ERROR = {10004: 'database related operation error'} + DATASET_CSV_ERROR = {10005: 'dataset csv format error'} + DATASET_PATH_ERROR = {10006: 'dataset path permission check error'} + PARAMETER_CHECK_ERROR = {10007: 'parameter check error'} + CALL_SYNCS_SERVICE_ERROR = {10008: 'call syncs service error'} + DATA_SET_ERROR = {10009: 'dataset operation error'} + + INSUFFICIENT_AUTHORITY = {10010: 'insufficient authority'} + UNDEFINED_TYPE = {10011: 'undefined type'} + UNDEFINED_STATUS = {10012: 'undefined status'} + QUERY_USERNAME_ERROR = {10013: 'query username error'} + DATASET_NOT_FOUND = {10014: 'dataset not found'} + ALGORITHM_NOT_FOUND = {10015: 'algorithm queried not found'} + JOB_NOT_FOUND = {10016: 'job not found'} + AUTH_INFO_FOUND = {10017: 'authorization not found'} + AGENCY_NOT_FOUND = {10018: 'agency not found'} + AGENCY_MANAGEMENT_NOT_FOUND = {10019: 'agency management not found'} + FIELD_NOT_FOUND = {10020: 'dataset field not found'} + AUTH_ALREADY_EXISTED = {10021: 'authorization already existed'} + DATABASE_TYPE_ERROR = {10022: 'database type error'} + DATABASE_IP_ERROR = {10023: 'database ip not in the white list'} + DATASET_FROM_DB_RETRY_OVER_LIMIT_ERROR = { + 10024: 'access the database has exceeded the allowed limit times'} + + DATASET_PARSE_ERROR = {10300: 'parse dataset failed'} + DATASET_EXIST_ERROR = {10301: 'dataset already existed!'} + DATASET_DELETE_ERROR = {10302: 'dataset already deleted!'} + DATASET_PERMISSION_ERROR = {10303: 'dataset permission check failed!'} + DATASET_UPLOAD_ERROR = {10304: 'dataset upload error'} + + ALGORITHM_PARSE_ERROR = {10400: 'parse algorithm failed'} + ALGORITHM_EXIST_ERROR = {10401: 'algorithm already existed!'} + ALGORITHM_DELETE_ERROR = {10402: 'algorithm already deleted!'} + ALGORITHM_TYPE_ERROR = {10403: 'algorithm type is not existed!'} + ALGORITHM_COMPILE_ERROR = {10410: 'compile mpc algorithm error'} + ALGORITHM_BAD_SQL = {10411: 'bad sql'} + ALGORITHM_PPC_CONFIG_ERROR = {10412: 'parse algorithm config error'} + ALGORITHM_PPC_MODEL_ALGORITHM_NAME_ERROR = { + 10413: 'algorithm subtype not found'} + ALGORITHM_PPC_MODEL_OUTPUT_NUMBER_ERROR = {10414: 'output number error'} + ALGORITHM_PPC_MODEL_LAYERS_ERROR = { + 10415: 'layers attribute should not be set'} + ALGORITHM_MPC_SYNTAX_CHECK_ERROR = {10416: 'check ppc mpc syntax error'} + ALGORITHM_PPC_MODEL_OUTPUT_NUMBER_ERROR_TEMP = { + 10417: 'output number should be set 1'} + ALGORITHM_PPC_MODEL_PARTICIPANTS_ERROR_TEMP = { + 10418: 'participants should be greater or equal to 2'} + ALGORITHM_PPC_MODEL_TEST_DATASET_PERCENTAGE_ERROR = { + 10419: 'test_dataset_percentage should be set in (0, 0.5]'} + ALGORITHM_PPC_MODEL_EPOCHS_ERROR = { + 10420: 'epochs should be set in [1, 5]'} + ALGORITHM_PPC_MODEL_BATCH_SIZE_ERROR = { + 10421: 'batch_size should be set [1, min(128, max_train_dataset_size)]'} + ALGORITHM_PPC_MODEL_THREADS_ERROR = { + 10422: 'threads should be set in [1,8]'} + ALGORITHM_PPC_MODEL_OPTIMIZER_ERROR = {10423: 'optimizer not found'} + ALGORITHM_PPC_MODEL_LEARNING_RATE_ERROR = { + 10424: 'learning rate should be set in (0, 1)'} + ALGORITHM_PPC_MODEL_LAYERS_ERROR2 = { + 10425: 'Conv2d layer should not be the first layer in HeteroNN'} + ALGORITHM_NOT_EXIST_ERROR = {10426: 'algorithm does not exist!'} + ALGORITHM_PPC_MODEL_TREES_ERROR = { + 10427: 'num_trees should be set in [1, 300]'} + ALGORITHM_PPC_MODEL_DEPTH_ERROR = { + 10428: 'max_depth should be set in [1, 10]'} + + JOB_CREATE_ERROR = {10500: 'job create failed'} + JOB_COMPUTATION_EXISTED_ERROR = {10501: 'job computation not existed'} + JOB_AYS_MODE_CHECK_ERROR = {10502: 'patch request need static token'} + JOB_MANAGEMENT_RUN_ERROR = {10503: 'job run failed'} + JOB_CEM_ERROR = {10504: 'at least one field need to be provided'} + JOB_IS_RUNNING_ERROR = {10505: 'job is running'} + NO_PARTICAPATING_IN_JOB_ERROR = {10506: 'not participating in the job'} + JOB_DOWNLOAD_RESULT_EMPTY = {10507: '任务结果为空,无法下载'} + + HDFS_STORAGE_ERROR = {10601: 'hdfs storage error'} + + AYS_LENGTH_ERROR = {10701: 'base ot message length check error'} + AYS_ORDER_ERROR = {10702: 'search obfuscation order must > 1'} + AYS_RESULT_LENGTH_ERROR = {10703: 'message result length check error'} + AYS_FIELD_ERROR = {10704: 'search filed not found in obfuscate file'} + CALL_SCS_ERROR = {10705: 'computation node call error'} + + MERGE_FILE_CHECK_ERROR = {10801: 'merge check files failed'} + MERGE_FILE_FORMAT_ERROR = {10802: 'file format is not csv'} + FILE_SIZE_ERROR = {10803: 'cannot get file size'} + FILE_SPLIT_ERROR = {10804: 'split file failed!'} + FILE_NOT_EXIST_ERROR = {10805: 'share file not existed!'} + DUPLICATED_MERGE_FILE_REQUEST = {10806: 'duplicated merge file request'} + + XGB_PREPROCESSING_ERROR = {10901: 'xgb preprocessing failed!'} + + FILE_OBJECT_UPLOAD_CHECK_FAILED = {20000: "upload file object failed!"} + FILE_OBJECT_NOT_EXISTS = {20001: "the file not exists!"} + + TASK_EXISTS = {11000: "the task already exists!"} + TASK_NOT_FOUND = {11001: "the task not found!"} + TASK_IS_KILLED = {11002: "the task is killed!"} + + ROLE_TYPE_ERROR = {12000: "role type is illegal."} + + def get_code(self): + return list(self.value.keys())[0] + + def get_error_code(self): + return list(self.value.keys())[0] + + def get_msg(self): + return list(self.value.values())[0] + + def get_message(self): + return list(self.value.values())[0] + + +class PpcException(Exception): + + def __init__(self, code, message): + Exception.__init__(self) + self.code = code + self.message = message + + def to_dict(self): + return {'code': self.code, 'message': self.message} + + def get_code(self): + return self.code + + def __str__(self): + return self.message + + @classmethod + def by_ppc_error_code(cls, ppc_error_code): + cls.code = ppc_error_code.get_code() + cls.message = ppc_error_code.get_msg() diff --git a/python/ppc_common/ppc_utils/http_utils.py b/python/ppc_common/ppc_utils/http_utils.py new file mode 100644 index 00000000..4489240d --- /dev/null +++ b/python/ppc_common/ppc_utils/http_utils.py @@ -0,0 +1,102 @@ +# -*- coding: utf-8 -*- +import json +import logging +from json.decoder import JSONDecodeError + +import requests +import urllib3 +from urllib3.exceptions import SecurityWarning + +from ppc_common.ppc_utils.exception import PpcException, PpcErrorCode + +log = logging.getLogger(__name__) + + +def check_response(response): + if response.status_code != 200 and response.status_code != 201: + message = response.text + raise PpcException(PpcErrorCode.NETWORK_ERROR, + f"Call request failed, response message:{message}") + + +def send_get_request(endpoint, uri, params=None, headers=None): + urllib3.disable_warnings(SecurityWarning) + if not headers: + headers = {'content-type': 'application/json'} + if uri: + url = f"http://{endpoint}{uri}" + else: + url = f"http://{endpoint}" + log.debug(f"send a get request, url: {url}, params: {params}") + response = requests.get(url=url, params=params, headers=headers, timeout=30) + log.debug(f"response: {response.text}") + check_response(response) + response_data = json.loads(response.text) + + return response_data + + +def send_post_request(endpoint, uri, params=None, headers=None, data=None): + if not headers: + headers = {'content-type': 'application/json'} + if uri: + url = f"http://{endpoint}{uri}" + else: + url = f"http://{endpoint}" + log.debug(f"send a post request, url: {url}, params: {params}") + response = requests.post(url, data=data, json=params, headers=headers) + log.debug(f"response: {response.text}") + # check_response(response) + try: + response_data = json.loads(response.text) + except JSONDecodeError: + response_data = response.text + return response_data + + +def send_delete_request(endpoint, uri, params=None, headers=None): + if not headers: + headers = {'content-type': 'application/json'} + if uri: + url = f"http://{endpoint}{uri}" + else: + url = f"http://{endpoint}" + log.debug(f"send a delete request, url: {url}, params: {params}") + response = requests.delete(url, json=params, headers=headers) + check_response(response) + log.debug(f"response: {response.text}") + response_data = json.loads(response.text) + + return response_data + + +def send_patch_request(endpoint, uri, params=None, headers=None, data=None): + if not headers: + headers = {'content-type': 'application/json'} + url = f"http://{endpoint}{uri}" + log.debug(f"send a patch request, url: {url}, params: {params}") + response = requests.patch(url, data=data, json=params, headers=headers) + check_response(response) + log.debug(f"response: {response.text}") + response_data = json.loads(response.text) + + return response_data + + +def send_upload_request(endpoint, uri, params=None, headers=None, data=None): + if not headers: + headers = {'content-type': 'application/json'} + if uri: + url = f"http://{endpoint}{uri}" + else: + url = endpoint + log.debug(f"send a post request, url: {url}, params: {params}") + response = requests.post(url, data=data, json=params, headers=headers) + log.debug(f"response: {response.text}") + check_response(response) + try: + response_data = json.loads(response.text) + except JSONDecodeError: + response_data = response.text + return response_data + diff --git a/python/ppc_common/ppc_utils/path.py b/python/ppc_common/ppc_utils/path.py new file mode 100644 index 00000000..b04daab5 --- /dev/null +++ b/python/ppc_common/ppc_utils/path.py @@ -0,0 +1,51 @@ +# coding:utf-8 +"""[path.py] + +Returns: + [type] -- [description] +""" + + +class Path(object): + ''' + fisco generator path configuration + ''' + dir = '' + + def get_name(self): + """[get some name] + + maybe it will usedful not now + + Returns: + [string] -- [name] + """ + + def get_pylint(self): + """[get some name] + + maybe it will usedful not now + + Returns: + [string] -- [name] + """ + + +def set_path(_dir): + """[set path] + + Arguments: + dir {[PATH]} -- [path] + """ + + Path.dir = _dir + + +def get_path(): + """[get path] + + Returns: + [PATH] -- [path] + """ + + return Path.dir diff --git a/python/ppc_common/ppc_utils/permission.py b/python/ppc_common/ppc_utils/permission.py new file mode 100644 index 00000000..7eb0c619 --- /dev/null +++ b/python/ppc_common/ppc_utils/permission.py @@ -0,0 +1,79 @@ +from enum import Enum + +from ppc_common.ppc_utils.exception import PpcException, PpcErrorCode +from ppc_common.ppc_utils.utils import JobRole, JobStatus + + +def check_job_status(job_status): + if job_status in JobStatus.__members__: + return True + return False + + +def check_job_role(job_role): + if job_role in JobRole.__members__: + return True + return False + + +ADMIN_PERMISSIONS = 'ADMIN_PERMISSIONS' + + +class UserRole(Enum): + ADMIN = 1 + DATA_PROVIDER = 2 + ALGO_PROVIDER = 3 + DATA_CONSUMER = 4 + + +class PermissionGroup(Enum): + AGENCY_GROUP = 1 + DATA_GROUP = 2 + ALGO_GROUP = 3 + JOB_GROUP = 4 + AUDIT_GROUP = 5 + + +class AgencyGroup(Enum): + LIST_AGENCY = 1 + WRITE_AGENCY = 2 + + +class DataGroup(Enum): + LIST_DATA = 1 + READ_DATA_PUBLIC_INFO = 2 + READ_DATA_PRIVATE_INFO = 3 + WRITE_DATA = 4 + + +class AlgoGroup(Enum): + LIST_ALGO = 1 + READ_ALGO_PUBLIC_INFO = 2 + READ_ALGO_PRIVATE_INFO = 3 + WRITE_ALGO = 4 + + +class JobGroup(Enum): + LIST_JOB = 1 + READ_JOB_PUBLIC_INFO = 2 + READ_JOB_PRIVATE_INFO = 3 + WRITE_JOB = 4 + + +class AuditGroup(Enum): + READ_AUDIT = 1 + WRITE_AUDIT = 2 + + +# permissions formed as permission_a|permission_b|permission_a_group|... +def check_permission(permissions, needed_permission_group, *needed_permissions): + permission_list = permissions.split('|') + if ADMIN_PERMISSIONS in permission_list: + return 0 + if needed_permission_group in permission_list: + return 1 + for needed_permission in needed_permissions: + if needed_permission in permission_list: + return 1 + raise PpcException(PpcErrorCode.INSUFFICIENT_AUTHORITY.get_code( + ), PpcErrorCode.INSUFFICIENT_AUTHORITY.get_msg()) diff --git a/python/ppc_common/ppc_utils/plot_utils.py b/python/ppc_common/ppc_utils/plot_utils.py new file mode 100644 index 00000000..3420c164 --- /dev/null +++ b/python/ppc_common/ppc_utils/plot_utils.py @@ -0,0 +1,230 @@ +import gc + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import seaborn as sns +from sklearn.metrics import accuracy_score +from sklearn.metrics import confusion_matrix +from sklearn.metrics import precision_recall_curve +from sklearn.metrics import roc_curve, auc + + +def plot_two_class_graph(job_context, y_scores=None, y_true=None): + y_pred_probs = job_context.y_pred_probs + y_label_probs = job_context.y_label_probs + if y_scores: + y_pred_probs = y_scores + if y_true: + y_label_probs = y_true + plt.rcParams['figure.figsize'] = (12.0, 8.0) + + # plot ROC + fpr, tpr, thresholds = roc_curve(y_label_probs, y_pred_probs, pos_label=1) + auc_value = auc(fpr, tpr) + plt.figure(f'roc-{job_context.job_id}') + plt.title('ROC Curve') # give plot a title + plt.xlabel('False Positive Rate (1 - Specificity)') + plt.ylabel('True Positive Rate (Sensitivity)') + plt.plot([0, 1], [0, 1], 'k--', lw=2) + plt.plot(fpr, tpr, label='area = {0:0.5f}' + ''.format(auc_value)) + plt.legend(loc="lower right") + plt.savefig(job_context.mpc_metric_roc_path, dpi=1000) + plt.show() + + plt.close('all') + gc.collect() + + # plot KS + plt.figure(f'ks-{job_context.job_id}') + threshold_x = np.sort(thresholds) + threshold_x[-1] = 1 + ks_value = max(abs(fpr - tpr)) + plt.title('KS Curve') + plt.xlabel('Threshold') + plt.plot(threshold_x, tpr, label='True Positive Rate') + plt.plot(threshold_x, fpr, label='False Positive Rate') + # 标记最大ks值 + x_index = np.argwhere(abs(fpr - tpr) == ks_value)[0, 0] + plt.plot((threshold_x[x_index], threshold_x[x_index]), (fpr[x_index], tpr[x_index]), + label='ks = {:.3f}'.format(ks_value), color='r', marker='o', markerfacecolor='r', markersize=5) + plt.legend(loc="lower right") + plt.savefig(job_context.mpc_metric_ks_path, dpi=1000) + plt.show() + + plt.close('all') + gc.collect() + + # plot Precision Recall + plt.figure(f'pr-{job_context.job_id}') + plt.title('Precision/Recall Curve') + plt.xlabel('Recall') + plt.ylabel('Precision') + plt.xlim(0.0, 1.0) + plt.ylim(0.0, 1.05) + precision, recall, thresholds = precision_recall_curve( + y_label_probs, y_pred_probs) + plt.plot(recall, precision) + plt.savefig(job_context.mpc_metric_pr_path, dpi=1000) + plt.show() + + plt.close('all') + gc.collect() + + # plot accuracy + plt.figure(f'accuracy-{job_context.job_id}') + thresholds = np.linspace(0, 1, num=100) # 在0~1之间生成100个阈值 + accuracies = [] + for threshold in thresholds: + predicted_labels = (y_pred_probs >= threshold).astype(int) + accuracy = accuracy_score(y_label_probs, predicted_labels) + accuracies.append(accuracy) + plt.title('Accuracy Curve') + plt.xlabel('Threshold') + plt.ylabel('Accuracy') + plt.xlim(0.0, 1.0) + plt.ylim(0.0, 1.05) + plt.plot(thresholds, accuracies) + plt.savefig(job_context.mpc_metric_accuracy_path, dpi=1000) + plt.show() + + plt.close('all') + gc.collect() + return (ks_value, auc_value) + + +def plot_multi_class_graph(job_context, n_class=None, y_label_value=None, y_pred_value=None): + if not n_class: + n_class = job_context.n_class + if not y_label_value: + y_label_value = job_context.y_label_value + if not y_pred_value: + y_pred_value = job_context.y_pred_value + y_label_probs = job_context.y_label_probs + y_pred_probs = job_context.y_pred_probs + + class_names = [x for x in range(n_class)] + plt.rcParams['figure.figsize'] = (12.0, 8.0) + plt.figure(f'roc-{job_context.job_id}') + multi_class_roc(job_context, plt, y_label_probs, y_pred_probs, class_names) + plt.figure(f'pr-{job_context.job_id}') + multi_class_precision_recall( + job_context, plt, y_label_probs, y_pred_probs, class_names) + plt.figure(f'cm-{job_context.job_id}') + multi_class_confusion_matrix( + job_context, plt, y_label_value, y_pred_value, class_names) + + +def sigmoid(x): + return 1 / (1 + np.exp(-x)) + + +def softmax(x): + x -= np.max(x, axis=1, keepdims=True) + return np.exp(x) / np.sum(np.exp(x), axis=1, keepdims=True) + + +# Converts [[0.3, 0.6, 0.1], [0.1, 0.2, 0.7], [0.8, 0.1, 0.1]] to [1, 2, 0] +def get_value_from_probs(probs): + return [np.argmax(prob) for prob in probs] + + +# Converts [1, 2, 0] to [[0, 1, 0], [0, 0, 1], [1, 0, 0]] +def get_probs_from_value(values, n_classes): + probs = np.zeros((len(values), n_classes), int) + for p, v in zip(probs, values): + p[v] = 1 + + return probs + + +# This need one-hot encoding +def multi_class_roc(job_context, plt, y_label, y_pred, class_names): + n_classes = len(class_names) + + fpr = dict() + tpr = dict() + roc_auc = dict() + for i in range(n_classes): + fpr[i], tpr[i], _ = roc_curve(y_label[:, i], y_pred[:, i]) + roc_auc[i] = auc(fpr[i], tpr[i]) + + # micro + fpr["micro"], tpr["micro"], _ = roc_curve(y_label.ravel(), y_pred.ravel()) + roc_auc["micro"] = auc(fpr["micro"], tpr["micro"]) + + # macro + # First aggregate all false positive rates + all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)])) + # Then interpolate all ROC curves at this points + mean_tpr = np.zeros_like(all_fpr) + for i in range(n_classes): + mean_tpr += np.interp(all_fpr, fpr[i], tpr[i]) + # Finally average it and compute AUC + mean_tpr /= n_classes + fpr["macro"] = all_fpr + tpr["macro"] = mean_tpr + roc_auc["macro"] = auc(fpr["macro"], tpr["macro"]) + + # Plot all ROC curves + lw = 2 + plt.plot(fpr["micro"], tpr["micro"], + label='Micro-averaging (area = {0:0.5f})' + ''.format(roc_auc["micro"]), + linestyle=':', linewidth=4) + + plt.plot(fpr["macro"], tpr["macro"], + label='Macro-averaging (area = {0:0.5f})' + ''.format(roc_auc["macro"]), + linestyle=':', linewidth=4) + + for i in range(n_classes): + plt.plot(fpr[i], tpr[i], lw=lw, + label='Class {0} (area = {1:0.5f})' + ''.format(class_names[i], roc_auc[i])) + + plt.xlim([0.0, 1.0]) + plt.ylim([0.0, 1.05]) + plt.xlabel('False Positive Rate (1 - Specificity)') + plt.ylabel('True Positive Rate (Sensitivity)') + plt.title('Multi-class ROC Curve') + plt.legend(loc="lower right") + plt.plot([0, 1], [0, 1], 'k--', lw=lw) + plt.savefig(job_context.mpc_metric_roc_path, dpi=1000) + + +# This need one-hot encoding +def multi_class_precision_recall(job_context, plt, y_label, y_pred, class_names): + n_classes = len(class_names) + precision = dict() + recall = dict() + for i in range(n_classes): + precision[i], recall[i], _ = precision_recall_curve( + y_label[:, i], y_pred[:, i]) + + # Plot all ROC curves + lw = 2 + for i in range(n_classes): + plt.plot(recall[i], precision[i], lw=lw, + label='Class {0}'.format(class_names[i])) + plt.xlim([0.0, 1.0]) + plt.ylim([0.0, 1.05]) + plt.xlabel('Recall') + plt.ylabel('Precision') + plt.title('Multi-class Precision/Recall Curve') + plt.legend(loc="lower right") + plt.savefig(job_context.mpc_metric_pr_path, dpi=1000) + + +# This need value encoding +def multi_class_confusion_matrix(job_context, plt, y_label, y_pred, class_names): + cm = confusion_matrix(y_label, y_pred) + conf_matrix = pd.DataFrame(cm, index=class_names, columns=class_names) + + sns.heatmap(conf_matrix, annot=True, annot_kws={ + "size": 19}, cmap="Blues", fmt='d') + plt.ylabel('True label') + plt.xlabel('Predicted label') + plt.title('Confusion Matrix') + plt.savefig(job_context.mpc_metric_confusion_matrix_path, dpi=1000) diff --git a/python/ppc_common/ppc_utils/ppc_model_config_parser.py b/python/ppc_common/ppc_utils/ppc_model_config_parser.py new file mode 100644 index 00000000..ab2cce26 --- /dev/null +++ b/python/ppc_common/ppc_utils/ppc_model_config_parser.py @@ -0,0 +1,474 @@ +# coding:utf-8 +import json +import os +import unittest +from enum import unique, Enum + +from jsoncomment import JsonComment + +from ppc_common.ppc_utils import utils +from ppc_common.ppc_utils.exception import PpcException, PpcErrorCode +from ppc_common.ppc_utils.utils import check_ppc_model_algorithm_is_homo, parse_n_class + + +@unique +class ModelAlgorithmType(Enum): + HeteroLR = 1 + HomoLR = 2 + HeteroNN = 3 + HomoNN = 4 + HeteroXGB = 5 + + +@unique +class OptimizerType(Enum): + sgd = 1 + adam = 2 + + +algorithm_types = [ModelAlgorithmType.HeteroLR.name, ModelAlgorithmType.HomoLR.name, ModelAlgorithmType.HeteroNN.name, + ModelAlgorithmType.HomoNN.name, ModelAlgorithmType.HeteroXGB.name] + +optimizer_types = [OptimizerType.sgd.name, OptimizerType.adam.name] + +default_epochs = 10 +default_threads = 8 + + +FILE_PATH = os.path.abspath(__file__) +CURRENT_PATH = os.path.abspath(os.path.dirname(FILE_PATH) + os.path.sep + ".") + + +def get_dir(): + ppc_model_template_dir = f'{CURRENT_PATH}{os.sep}..{os.sep}ppc_model_template{os.sep}' + return ppc_model_template_dir + + +def parse_read_hetero_dataset_loop(participants): + loop_start = [] + loop_end = [] + start = '' + end = '' + for i in range(participants): + if i == 0 or i == participants - 1: + if i == 0: + start = f'{start}source{i}_feature_count' + end = f'{start} + source{i + 1}_feature_count' + else: + start = f'{start} + source{i}_feature_count' + end = f'{start} + source{i + 1}_feature_count' + else: + start = f'{start} + source{i}_feature_count' + end = f'{start} + source{i + 1}_feature_count' + loop_start.append(start) + loop_end.append(end) + return participants - 1, loop_start[0:participants - 1], loop_end[0:participants - 1] + + +def parse_read_homo_dataset_loop(participants): + loop_start = [] + loop_end = [] + start = '' + end = '' + for i in range(participants): + if i == 0 or i == participants - 1: + if i == 0: + start = f'{start}source{i}_record_count' + end = f'{start} + source{i + 1}_record_count' + else: + start = f'{start} + source{i}_record_count' + end = f'{start} + source{i + 1}_record_count' + else: + start = f'{start} + source{i}_record_count' + end = f'{start} + source{i + 1}_record_count' + loop_start.append(start) + loop_end.append(end) + return participants - 1, loop_start[0:participants - 1], loop_end[0:participants - 1] + + +def insert_train_record_count(layer, index, record_type): + if 'Dense' in layer: + layer_arr = layer.split('(') + else: + layer_arr = layer.split('([') + if index == 0 and 'Dense' in layer: + new_layer = f'{layer_arr[0]}({record_type}, total_feature_count, {layer_arr[1]}' + else: + if 'Dense' in layer: + new_layer = f'{layer_arr[0]}({record_type}, {layer_arr[1]}' + else: + new_layer = f'{layer_arr[0]}([{record_type}, {layer_arr[1]}' + if 'Conv2d' in new_layer: + if ', [' in new_layer: + layer_arr2 = new_layer.split(', [') + new_layer = f'{layer_arr2[0]}, [{record_type}, {layer_arr2[1]}' + if ',[' in new_layer: + layer_arr2 = new_layer.split(',[') + new_layer = f'{layer_arr2[0]},[{record_type}, {layer_arr2[1]}' + return new_layer + + +def set_nn_layers(mpc_train_algorithm, layers, record_type): + mpc_train_algorithm = f"{mpc_train_algorithm}\n" + if layers: + layers_str = 'layers = [' + for i in range(len(layers)): + new_layer = insert_train_record_count(layers[i], i, record_type) + layers_str = f"{layers_str}{new_layer},\n" + n_class = parse_n_class(layers[-1]) + mpc_train_algorithm = f"{mpc_train_algorithm}{layers_str}" + + if n_class == 1: + mpc_train_algorithm = f"{mpc_train_algorithm}ml.Output({record_type}, approx=3)]\n\n" + mpc_train_algorithm = f"{mpc_train_algorithm}test_Y = pfix.Array(test_record_count)\n" + mpc_train_algorithm = f"{mpc_train_algorithm}test_X = pfix.Matrix(test_record_count, total_feature_count)\n" + elif n_class > 1: + mpc_train_algorithm = f"{mpc_train_algorithm}ml.MultiOutput({record_type}, {n_class})]\n\n" + mpc_train_algorithm = f"{mpc_train_algorithm}total_class_count = {n_class}\n" + mpc_train_algorithm = f"{mpc_train_algorithm}test_Y = pint.Matrix(test_record_count, total_class_count)\n" + mpc_train_algorithm = f"{mpc_train_algorithm}test_X = pfix.Matrix(test_record_count, total_feature_count)\n" + else: + raise PpcException(PpcErrorCode.ALGORITHM_PPC_MODEL_OUTPUT_NUMBER_ERROR.get_code(), + PpcErrorCode.ALGORITHM_PPC_MODEL_OUTPUT_NUMBER_ERROR.get_msg()) + else: + mpc_train_algorithm = f"{mpc_train_algorithm}test_Y = pfix.Array(test_record_count)\n" + mpc_train_algorithm = f"{mpc_train_algorithm}test_X = pfix.Matrix(test_record_count, total_feature_count)\n" + mpc_train_algorithm = f"{mpc_train_algorithm}" \ + f"layers = [pDense({record_type}, total_feature_count, 128, " \ + f"activation='relu'), pDense({record_type}, 128, 1), " \ + f"ml.Output({record_type}, approx=3)]\n" + mpc_train_algorithm = f"{mpc_train_algorithm}\n" + return mpc_train_algorithm + + +def set_logreg_train_layers(mpc_algorithm): + mpc_algorithm = f"{mpc_algorithm}\n" + mpc_algorithm = f"{mpc_algorithm}layers = [pDense(train_record_count, total_feature_count, 1), ml.Output(train_record_count, approx=3)]\n\n" + mpc_algorithm = f"{mpc_algorithm}test_Y = pfix.Array(test_record_count)\n" + mpc_algorithm = f"{mpc_algorithm}test_X = pfix.Matrix(test_record_count, total_feature_count)\n" + return mpc_algorithm + + +def generate_set_logreg_predict_layers(mpc_algorithm): + mpc_algorithm = f"{mpc_algorithm}\n" + mpc_algorithm = f"{mpc_algorithm}layers = [pDense(test_record_count, total_feature_count, 1), ml.Output(test_record_count, approx=3)]\n" + return mpc_algorithm + + +def generate_homo_predict_static_template(mpc_predict_algorithm): + homo_nn_static_template = utils.read_content_from_file( + f'{get_dir()}homo_predict_static_template.mpc') + mpc_predict_algorithm = f'{mpc_predict_algorithm}\n' + mpc_predict_algorithm = f'{mpc_predict_algorithm}{homo_nn_static_template}' + return mpc_predict_algorithm + + +def set_hetero_train_static_template(model_config_dict, mpc_train_algorithm): + optimizer = model_config_dict['optimizer'] + learning_rate = model_config_dict['learning_rate'] + if optimizer == OptimizerType.sgd.name: + hetero_train_static_template = utils.read_content_from_file( + f'{get_dir()}hetero_train_sgd_static_template.mpc') + hetero_train_static_template = hetero_train_static_template.replace('gamma = MemValue(cfix(.1))', + f'gamma = MemValue(cfix({learning_rate}))') + elif optimizer == OptimizerType.adam.name: + hetero_train_static_template = utils.read_content_from_file( + f'{get_dir()}hetero_train_adam_static_template.mpc') + hetero_train_static_template = hetero_train_static_template.replace('gamma = MemValue(cfix(.001))', + f'gamma = MemValue(cfix({learning_rate}))') + mpc_train_algorithm = f'{mpc_train_algorithm}\n' + mpc_train_algorithm = f'{mpc_train_algorithm}{hetero_train_static_template}' + return mpc_train_algorithm + + +def generate_hetero_predict_static_template(mpc_predict_algorithm): + hetero_logreg_static_template = utils.read_content_from_file( + f'{get_dir()}hetero_predict_static_template.mpc') + mpc_predict_algorithm = f'{mpc_predict_algorithm}\n' + mpc_predict_algorithm = f'{mpc_predict_algorithm}{hetero_logreg_static_template}' + return mpc_predict_algorithm + + +def set_homo_train_static_template(model_config_dict, mpc_train_algorithm): + optimizer = model_config_dict['optimizer'] + learning_rate = model_config_dict['learning_rate'] + if optimizer == OptimizerType.sgd.name: + homo_train_static_template = utils.read_content_from_file( + f'{get_dir()}homo_train_sgd_static_template.mpc') + homo_train_static_template = homo_train_static_template.replace('gamma = MemValue(cfix(.1))', + f'gamma = MemValue(cfix({learning_rate}))') + elif optimizer == OptimizerType.adam.name: + homo_train_static_template = utils.read_content_from_file( + f'{get_dir()}homo_train_adam_static_template.mpc') + homo_train_static_template = homo_train_static_template.replace('gamma = MemValue(cfix(.001))', + f'gamma = MemValue(cfix({learning_rate}))') + mpc_train_algorithm = f'{mpc_train_algorithm}\n' + mpc_train_algorithm = f'{mpc_train_algorithm}{homo_train_static_template}' + return mpc_train_algorithm + + +def generate_set_common_code(is_psi, participants): + mpc_algorithm = '#This file is auto generated by ppc. DO NOT EDIT!\n\n' + if is_psi: + mpc_algorithm = f'{mpc_algorithm}#PSI_OPTION=True' + else: + mpc_algorithm = f'{mpc_algorithm}#PSI_OPTION=False' + mpc_algorithm = f'{mpc_algorithm}\n' + mpc_algorithm = f'{mpc_algorithm}from ppc import *\n' + mpc_algorithm = f'{mpc_algorithm}from Compiler import config\n' + mpc_algorithm = f'{mpc_algorithm}import sys\n' + mpc_algorithm = f'{mpc_algorithm}program.options_from_args()\n\n' + mpc_algorithm = f'{mpc_algorithm}program.use_trunc_pr = True\n' + mpc_algorithm = f'{mpc_algorithm}program.use_split(3)\n\n' + # if participants == 3: + # mpc_algorithm = f'{mpc_algorithm}program.use_trunc_pr = True\n' + # mpc_algorithm = f'{mpc_algorithm}program.use_split(3)\n\n' + for i in range(participants): + mpc_algorithm = f'{mpc_algorithm}SOURCE{i}={i}\n' + return mpc_algorithm + + +def set_hetero_feature_count(mpc_algorithm, participants): + total_feature_count_str = 'total_feature_count=' + for i in range(participants): + mpc_algorithm = f'{mpc_algorithm}source{i}_feature_count=$(source{i}_feature_count)\n' + if i == participants - 1: + total_feature_count_str = f'{total_feature_count_str}source{i}_feature_count' + else: + total_feature_count_str = f'{total_feature_count_str}source{i}_feature_count+' + mpc_algorithm = f'{mpc_algorithm}{total_feature_count_str}\n' + return mpc_algorithm + + +def set_homo_train_record_count(mpc_train_algorithm, participants): + total_record_count_str = 'total_record_count=' + for i in range(participants): + mpc_train_algorithm = f'{mpc_train_algorithm}source{i}_record_count=$(source{i}_record_count)\n' + if i == participants - 1: + total_record_count_str = f'{total_record_count_str}source{i}_record_count' + else: + total_record_count_str = f'{total_record_count_str}source{i}_record_count+' + mpc_train_algorithm = f'{mpc_train_algorithm}{total_record_count_str}\n' + mpc_train_algorithm = f'{mpc_train_algorithm}train_record_count=$(train_record_count)\n' + mpc_train_algorithm = f'{mpc_train_algorithm}test_record_count=$(test_record_count)\n' + return mpc_train_algorithm + + +def generate_set_homo_predict_record_count(mpc_predict_algorithm): + mpc_predict_algorithm = f'{mpc_predict_algorithm}total_feature_count=$(total_feature_count)\n' + mpc_predict_algorithm = f'{mpc_predict_algorithm}test_record_count=$(test_record_count)\n' + mpc_predict_algorithm = f'{mpc_predict_algorithm}test_X = pfix.Matrix(test_record_count, total_feature_count)\n' + return mpc_predict_algorithm + + +def set_hetero_train_record_count(mpc_train_algorithm): + mpc_train_algorithm = f'{mpc_train_algorithm}\n' + mpc_train_algorithm = f'{mpc_train_algorithm}total_record_count=$(total_record_count)\n' + mpc_train_algorithm = f'{mpc_train_algorithm}train_record_count=$(train_record_count)\n' + mpc_train_algorithm = f'{mpc_train_algorithm}test_record_count=$(test_record_count)\n' + mpc_train_algorithm = f'{mpc_train_algorithm}\n' + return mpc_train_algorithm + + +def generate_set_hetero_predict_record_count(mpc_predict_algorithm): + mpc_predict_algorithm = f'{mpc_predict_algorithm}\n' + mpc_predict_algorithm = f'{mpc_predict_algorithm}test_record_count=$(test_record_count)\n' + mpc_predict_algorithm = f'{mpc_predict_algorithm}test_X = pfix.Matrix(test_record_count, total_feature_count)\n' + return mpc_predict_algorithm + + +def set_homo_feature_count(mpc_train_algorithm): + mpc_train_algorithm = f'{mpc_train_algorithm}\n' + mpc_train_algorithm = f'{mpc_train_algorithm}total_feature_count=$(total_feature_count)\n' + mpc_train_algorithm = f'{mpc_train_algorithm}\n' + return mpc_train_algorithm + + +def read_hetero_train_dataset(mpc_train_algorithm, participants): + mpc_train_algorithm = f'{mpc_train_algorithm}def read_hetero_dataset():\n' + mpc_train_algorithm = f'{mpc_train_algorithm}\tprint(f"train_record_count:{{train_record_count}}")\n' + mpc_train_algorithm = f'{mpc_train_algorithm}\tprint(f"test_record_count:{{test_record_count}}")\n' + mpc_train_algorithm = f'{mpc_train_algorithm}\tdo_read_hetero_y_part(SOURCE0)\n' + for i in range(participants): + mpc_train_algorithm = f'{mpc_train_algorithm}\tdo_read_hetero_x_part(SOURCE{i}, source{i}_feature_count)\n' + return mpc_train_algorithm + + +def generate_read_hetero_predict_dataset(mpc_predict_algorithm, participants): + mpc_predict_algorithm = f'{mpc_predict_algorithm}\n' + mpc_predict_algorithm = f'{mpc_predict_algorithm}def read_hetero_test_dataset():\n' + for i in range(participants): + mpc_predict_algorithm = f'{mpc_predict_algorithm}\tdo_read_hetero_x_part(SOURCE{i}, source{i}_feature_count)\n' + return mpc_predict_algorithm + + +def read_homo_dataset(mpc_train_algorithm, participants): + mpc_train_algorithm = f'{mpc_train_algorithm}def read_homo_dataset():\n' + mpc_train_algorithm = f'{mpc_train_algorithm}\tprint(f"total_feature_count:{{total_feature_count}}")\n' + mpc_train_algorithm = f'{mpc_train_algorithm}\tprint(f"train_record_count:{{train_record_count}}")\n' + mpc_train_algorithm = f'{mpc_train_algorithm}\tprint(f"test_record_count:{{test_record_count}}")\n' + for i in range(participants): + mpc_train_algorithm = f'{mpc_train_algorithm}\tdo_read_homo_dataset(SOURCE{i}, source{i}_record_count)\n' + + return mpc_train_algorithm + + +def set_parameters(model_config_dict, mpc_train_algorithm): + mpc_train_algorithm = f'{mpc_train_algorithm}\n' + epochs = model_config_dict['epochs'] + batch_size = model_config_dict['batch_size'] + threads = model_config_dict['threads'] + if epochs <= 0: + mpc_train_algorithm = f'{mpc_train_algorithm}epochs={default_epochs}\n' + else: + mpc_train_algorithm = f'{mpc_train_algorithm}epochs={epochs}\n' + if int(batch_size) <= 0: + mpc_train_algorithm = f'{mpc_train_algorithm}batch_size=train_record_count\n' + else: + mpc_train_algorithm = f'{mpc_train_algorithm}user_batch_size={batch_size}\n' + mpc_train_algorithm = f'{mpc_train_algorithm}batch_size=min(user_batch_size, min(128, train_record_count))\n' + if threads <= 0: + mpc_train_algorithm = f'{mpc_train_algorithm}threads={default_threads}\n' + else: + mpc_train_algorithm = f'{mpc_train_algorithm}threads={threads}\n' + return mpc_train_algorithm + + +def generate_mpc_train_algorithm(model_config_dict, algorithm_name, is_psi): + participants = model_config_dict['participants'] + mpc_train_algorithm = generate_set_common_code(is_psi, participants) + if algorithm_name == ModelAlgorithmType.HeteroLR.name or algorithm_name == ModelAlgorithmType.HeteroNN.name: + mpc_train_algorithm = set_hetero_feature_count( + mpc_train_algorithm, participants) + mpc_train_algorithm = set_hetero_train_record_count( + mpc_train_algorithm) + mpc_train_algorithm = read_hetero_train_dataset( + mpc_train_algorithm, participants) + mpc_train_algorithm = set_parameters( + model_config_dict, mpc_train_algorithm) + if algorithm_name == ModelAlgorithmType.HeteroLR.name: + if 'layers' in model_config_dict.keys(): + raise PpcException(PpcErrorCode.ALGORITHM_PPC_MODEL_LAYERS_ERROR.get_code(), + PpcErrorCode.ALGORITHM_PPC_MODEL_LAYERS_ERROR.get_msg()) + mpc_train_algorithm = set_logreg_train_layers(mpc_train_algorithm) + mpc_train_algorithm = set_hetero_train_static_template( + model_config_dict, mpc_train_algorithm) + else: + layers = [] + if 'layers' in model_config_dict.keys(): + layers = model_config_dict['layers'] + if layers and 'Conv2d' in layers[0]: + raise PpcException(PpcErrorCode.ALGORITHM_PPC_MODEL_LAYERS_ERROR2.get_code(), + PpcErrorCode.ALGORITHM_PPC_MODEL_LAYERS_ERROR2.get_msg()) + mpc_train_algorithm = set_nn_layers( + mpc_train_algorithm, layers, 'train_record_count') + mpc_train_algorithm = set_hetero_train_static_template( + model_config_dict, mpc_train_algorithm) + elif algorithm_name == ModelAlgorithmType.HomoLR.name or algorithm_name == ModelAlgorithmType.HomoNN.name: + mpc_train_algorithm = set_homo_train_record_count( + mpc_train_algorithm, participants) + mpc_train_algorithm = set_parameters( + model_config_dict, mpc_train_algorithm) + mpc_train_algorithm = set_homo_feature_count(mpc_train_algorithm) + mpc_train_algorithm = read_homo_dataset( + mpc_train_algorithm, participants) + if algorithm_name == ModelAlgorithmType.HomoLR.name: + if 'layers' in model_config_dict.keys(): + raise PpcException(PpcErrorCode.ALGORITHM_PPC_MODEL_LAYERS_ERROR.get_code(), + PpcErrorCode.ALGORITHM_PPC_MODEL_LAYERS_ERROR.get_msg()) + mpc_train_algorithm = set_logreg_train_layers(mpc_train_algorithm) + mpc_train_algorithm = set_homo_train_static_template( + model_config_dict, mpc_train_algorithm) + else: + layers = [] + if 'layers' in model_config_dict.keys(): + layers = model_config_dict['layers'] + mpc_train_algorithm = set_nn_layers( + mpc_train_algorithm, layers, 'train_record_count') + mpc_train_algorithm = set_homo_train_static_template( + model_config_dict, mpc_train_algorithm) + + return mpc_train_algorithm + + +def generate_mpc_predict_algorithm(algorithm_name, layers, participants, is_psi): + mpc_predict_algorithm = generate_set_common_code(is_psi, participants) + if algorithm_name == ModelAlgorithmType.HeteroLR.name or algorithm_name == ModelAlgorithmType.HeteroNN.name: + mpc_predict_algorithm = set_hetero_feature_count( + mpc_predict_algorithm, participants) + mpc_predict_algorithm = generate_set_hetero_predict_record_count( + mpc_predict_algorithm) + mpc_predict_algorithm = generate_read_hetero_predict_dataset( + mpc_predict_algorithm, participants) + if algorithm_name == ModelAlgorithmType.HeteroLR.name: + mpc_predict_algorithm = generate_set_logreg_predict_layers( + mpc_predict_algorithm) + else: + mpc_predict_algorithm = set_nn_layers( + mpc_predict_algorithm, layers, 'test_record_count') + mpc_predict_algorithm = generate_hetero_predict_static_template( + mpc_predict_algorithm) + + elif algorithm_name == ModelAlgorithmType.HomoLR.name or algorithm_name == ModelAlgorithmType.HomoNN.name: + mpc_predict_algorithm = generate_set_homo_predict_record_count( + mpc_predict_algorithm) + if algorithm_name == ModelAlgorithmType.HomoLR.name: + mpc_predict_algorithm = generate_set_logreg_predict_layers( + mpc_predict_algorithm) + else: + mpc_predict_algorithm = set_nn_layers( + mpc_predict_algorithm, layers, 'test_record_count') + mpc_predict_algorithm = generate_homo_predict_static_template( + mpc_predict_algorithm) + + return mpc_predict_algorithm + + # try: + # test_dataset_percentage = float(model_config_dict['test_dataset_percentage']) + # model_config_dict['test_dataset_percentage'] = test_dataset_percentage + # if not 0 < test_dataset_percentage <= 0.5: + # raise PpcException(PpcErrorCode.ALGORITHM_PPC_MODEL_TEST_DATASET_PERCENTAGE_ERROR.get_code(), + # PpcErrorCode.ALGORITHM_PPC_MODEL_TEST_DATASET_PERCENTAGE_ERROR.get_msg()) + # except BaseException as e: + # raise PpcException(PpcErrorCode.ALGORITHM_PPC_MODEL_TEST_DATASET_PERCENTAGE_ERROR.get_code(), + # PpcErrorCode.ALGORITHM_PPC_MODEL_TEST_DATASET_PERCENTAGE_ERROR.get_msg()) + + # try: + # learning_rate = float(model_config_dict['learning_rate']) + # model_config_dict['learning_rate'] = learning_rate + # if not 0 < learning_rate <= 1: + # raise PpcException(PpcErrorCode.ALGORITHM_PPC_MODEL_LEARNING_RATE_ERROR.get_code(), + # PpcErrorCode.ALGORITHM_PPC_MODEL_LEARNING_RATE_ERROR.get_msg()) + # except BaseException as e: + # raise PpcException(PpcErrorCode.ALGORITHM_PPC_MODEL_LEARNING_RATE_ERROR.get_code(), + # PpcErrorCode.ALGORITHM_PPC_MODEL_LEARNING_RATE_ERROR.get_msg()) + + # try: + # num_trees = int(model_config_dict['num_trees']) + # model_config_dict['num_trees'] = num_trees + # if not 0 < num_trees <= 300: + # raise PpcException(PpcErrorCode.ALGORITHM_PPC_MODEL_TREES_ERROR.get_code(), + # PpcErrorCode.ALGORITHM_PPC_MODEL_TREES_ERROR.get_msg()) + # except BaseException as e: + # raise PpcException(PpcErrorCode.ALGORITHM_PPC_MODEL_TREES_ERROR.get_code(), + # PpcErrorCode.ALGORITHM_PPC_MODEL_TREES_ERROR.get_msg()) + + # try: + # max_depth = int(model_config_dict['max_depth']) + # model_config_dict['max_depth'] = max_depth + # if not 0 < max_depth <= 10: + # raise PpcException(PpcErrorCode.ALGORITHM_PPC_MODEL_DEPTH_ERROR.get_code(), + # PpcErrorCode.ALGORITHM_PPC_MODEL_DEPTH_ERROR.get_msg()) + # except BaseException as e: + # raise PpcException(PpcErrorCode.ALGORITHM_PPC_MODEL_DEPTH_ERROR.get_code(), + # PpcErrorCode.ALGORITHM_PPC_MODEL_DEPTH_ERROR.get_msg()) + + # try: + # threads = int(model_config_dict['threads']) + # model_config_dict['threads'] = threads + # if not (0 < threads <= 8): + # raise PpcException(PpcErrorCode.ALGORITHM_PPC_MODEL_THREADS_ERROR.get_code(), + # PpcErrorCode.ALGORITHM_PPC_MODEL_THREADS_ERROR.get_msg()) + # except BaseException as e: + # raise PpcException(PpcErrorCode.ALGORITHM_PPC_MODEL_THREADS_ERROR.get_code(), + # PpcErrorCode.ALGORITHM_PPC_MODEL_THREADS_ERROR.get_msg()) \ No newline at end of file diff --git a/python/ppc_common/ppc_utils/ppc_model_config_parser_proxy.py b/python/ppc_common/ppc_utils/ppc_model_config_parser_proxy.py new file mode 100644 index 00000000..b5aa49e9 --- /dev/null +++ b/python/ppc_common/ppc_utils/ppc_model_config_parser_proxy.py @@ -0,0 +1,491 @@ +# coding:utf-8 +import os +import unittest +from enum import unique, Enum + +from ppc_common.ppc_utils import utils +from ppc_common.ppc_utils.exception import PpcException, PpcErrorCode +from ppc_common.ppc_utils.utils import check_ppc_model_algorithm_is_homo, parse_n_class, PPCModleType + + +@unique +class ModelAlgorithmType(Enum): + HeteroLR = 1 + HomoLR = 2 + HeteroNN = 3 + HomoNN = 4 + + +@unique +class OptimizerType(Enum): + sgd = 1 + adam = 2 + + +algorithm_types = [ModelAlgorithmType.HeteroLR.name, ModelAlgorithmType.HomoLR.name, ModelAlgorithmType.HeteroNN.name, + ModelAlgorithmType.HomoNN.name] + +optimizer_types = [OptimizerType.sgd.name, OptimizerType.adam.name] + +default_epochs = 10 +default_threads = 8 + +FILE_PATH = os.path.abspath(__file__) +CURRENT_PATH = os.path.abspath(os.path.dirname(FILE_PATH) + os.path.sep + ".") + + +def get_dir(): + ppc_model_template_dir = f'{CURRENT_PATH}{os.sep}..{os.sep}ppc_model_template{os.sep}' + return ppc_model_template_dir + + +def parse_read_hetero_dataset_loop(participants): + loop_start = [] + loop_end = [] + start = '' + end = '' + for i in range(participants): + if i == 0 or i == participants - 1: + if i == 0: + start = f'{start}source{i}_feature_count' + end = f'{start} + source{i + 1}_feature_count' + else: + start = f'{start} + source{i}_feature_count' + end = f'{start} + source{i + 1}_feature_count' + else: + start = f'{start} + source{i}_feature_count' + end = f'{start} + source{i + 1}_feature_count' + loop_start.append(start) + loop_end.append(end) + return participants - 1, loop_start[0:participants - 1], loop_end[0:participants - 1] + + +def parse_read_homo_dataset_loop(participants): + loop_start = [] + loop_end = [] + start = '' + end = '' + for i in range(participants): + if i == 0 or i == participants - 1: + if i == 0: + start = f'{start}source{i}_record_count' + end = f'{start} + source{i + 1}_record_count' + else: + start = f'{start} + source{i}_record_count' + end = f'{start} + source{i + 1}_record_count' + else: + start = f'{start} + source{i}_record_count' + end = f'{start} + source{i + 1}_record_count' + loop_start.append(start) + loop_end.append(end) + return participants - 1, loop_start[0:participants - 1], loop_end[0:participants - 1] + + +def insert_train_record_count(layer, index, record_type): + if 'Dense' in layer: + layer_arr = layer.split('(') + else: + layer_arr = layer.split('([') + if index == 0 and 'Dense' in layer: + new_layer = f'{layer_arr[0]}({record_type}, total_feature_count, {layer_arr[1]}' + else: + if 'Dense' in layer: + new_layer = f'{layer_arr[0]}({record_type}, {layer_arr[1]}' + else: + new_layer = f'{layer_arr[0]}([{record_type}, {layer_arr[1]}' + if 'Conv2d' in new_layer: + if ', [' in new_layer: + layer_arr2 = new_layer.split(', [') + new_layer = f'{layer_arr2[0]}, [{record_type}, {layer_arr2[1]}' + if ',[' in new_layer: + layer_arr2 = new_layer.split(',[') + new_layer = f'{layer_arr2[0]},[{record_type}, {layer_arr2[1]}' + return new_layer + + +def set_nn_layers(mpc_algorithm, layers, record_type, ppc_model_type=None, model_algorithm_type=None, participants=None): + mpc_algorithm = f"{mpc_algorithm}\n" + if layers: + layers_str = 'layers = [' + for i in range(len(layers)): + new_layer = insert_train_record_count(layers[i], i, record_type) + layers_str = f"{layers_str}{new_layer},\n" + n_class = parse_n_class(layers[-1]) + if ppc_model_type == PPCModleType.Train: + mpc_algorithm = set_nn_train_output(mpc_algorithm, n_class) + else: + mpc_algorithm = set_nn_predict_output(mpc_algorithm, n_class) + mpc_algorithm = f"{mpc_algorithm}{layers_str}" + if n_class == 1: + mpc_algorithm = f"{mpc_algorithm}ml.Output({record_type}, approx=3)]\n\n" + if ppc_model_type == PPCModleType.Train: + mpc_algorithm = f"{mpc_algorithm}train_Y = layers[-1].Y\n" + mpc_algorithm = f"{mpc_algorithm}train_X = layers[0].X\n" + mpc_algorithm = f"{mpc_algorithm}test_Y = pint.Array(test_record_count)\n" + mpc_algorithm = f"{mpc_algorithm}test_X = pfix.Matrix(test_record_count, total_feature_count)\n" + elif n_class > 1: + mpc_algorithm = f"{mpc_algorithm}ml.MultiOutput({record_type}, {n_class})]\n\n" + if ppc_model_type == PPCModleType.Train: + mpc_algorithm = f"{mpc_algorithm}total_class_count = {n_class}\n" + mpc_algorithm = f"{mpc_algorithm}train_Y = layers[-1].Y\n" + mpc_algorithm = f"{mpc_algorithm}train_X = layers[0].X\n" + mpc_algorithm = f"{mpc_algorithm}test_Y = pint.Matrix(test_record_count, total_class_count)\n" + mpc_algorithm = f"{mpc_algorithm}test_X = pfix.Matrix(test_record_count, total_feature_count)\n" + else: + raise PpcException(PpcErrorCode.ALGORITHM_PPC_MODEL_OUTPUT_NUMBER_ERROR.get_code(), + PpcErrorCode.ALGORITHM_PPC_MODEL_OUTPUT_NUMBER_ERROR.get_msg()) + + if ppc_model_type == PPCModleType.Train and model_algorithm_type == ModelAlgorithmType.HeteroNN: + mpc_algorithm = read_hetero_train_dataset( + mpc_algorithm, participants, n_class) + if ppc_model_type == PPCModleType.Train and model_algorithm_type == ModelAlgorithmType.HomoNN: + mpc_algorithm = read_homo_train_dataset( + mpc_algorithm, participants, n_class) + else: + mpc_algorithm = set_nn_train_output(mpc_algorithm, 1) + mpc_algorithm = f"{mpc_algorithm}" \ + f"layers = [pDense({record_type}, total_feature_count, 128, " \ + f"activation='relu'), pDense({record_type}, 128, 1), " \ + f"ml.Output({record_type}, approx=3)]\n" + if ppc_model_type == PPCModleType.Train: + mpc_algorithm = f"{mpc_algorithm}train_Y = layers[-1].Y\n" + mpc_algorithm = f"{mpc_algorithm}train_X = layers[0].X\n" + mpc_algorithm = f"{mpc_algorithm}test_Y = pint.Array(test_record_count)\n" + mpc_algorithm = f"{mpc_algorithm}test_X = pfix.Matrix(test_record_count, total_feature_count)\n" + mpc_algorithm = f"{mpc_algorithm}\n" + mpc_algorithm = read_homo_train_dataset(mpc_algorithm, participants) + + return mpc_algorithm + + +def set_logreg_train_layers(mpc_algorithm): + mpc_algorithm = f"{mpc_algorithm}\n" + mpc_algorithm = f"{mpc_algorithm}layers = [pDense(train_record_count, total_feature_count, 1), ml.Output(train_record_count, approx=3)]\n\n" + mpc_algorithm = f"{mpc_algorithm}train_Y = layers[-1].Y\n" + mpc_algorithm = f"{mpc_algorithm}train_X = layers[0].X\n" + mpc_algorithm = f"{mpc_algorithm}test_Y = pint.Array(test_record_count)\n" + mpc_algorithm = f"{mpc_algorithm}test_X = pfix.Matrix(test_record_count, total_feature_count)\n" + return mpc_algorithm + + +def generate_set_logreg_predict_layers(mpc_algorithm): + mpc_algorithm = f"{mpc_algorithm}\n" + mpc_algorithm = f"{mpc_algorithm}layers = [pDense(test_record_count, total_feature_count, 1), ml.Output(test_record_count, approx=3)]\n\n" + mpc_algorithm = f"{mpc_algorithm}result_columns = 1\n" + mpc_algorithm = f"{mpc_algorithm}result_matrix = Matrix(test_record_count, result_columns, pfix)\n\n" + return mpc_algorithm + + +def generate_homo_predict_static_template(mpc_predict_algorithm): + homo_nn_static_template = utils.read_content_from_file( + f'{get_dir()}homo_predict_static_template_proxy.mpc') + mpc_predict_algorithm = f'{mpc_predict_algorithm}\n' + mpc_predict_algorithm = f'{mpc_predict_algorithm}{homo_nn_static_template}' + return mpc_predict_algorithm + + +def set_hetero_train_static_template(model_config_dict, mpc_train_algorithm): + optimizer = model_config_dict['optimizer'] + learning_rate = model_config_dict['learning_rate'] + if optimizer == OptimizerType.sgd.name: + hetero_train_static_template = utils.read_content_from_file( + f'{get_dir()}hetero_train_sgd_static_template_proxy.mpc') + hetero_train_static_template = hetero_train_static_template.replace('gamma = MemValue(cfix(.1))', + f'gamma = MemValue(cfix({learning_rate}))') + elif optimizer == OptimizerType.adam.name: + hetero_train_static_template = utils.read_content_from_file( + f'{get_dir()}hetero_train_adam_static_template_proxy.mpc') + hetero_train_static_template = hetero_train_static_template.replace('gamma = MemValue(cfix(.001))', + f'gamma = MemValue(cfix({learning_rate}))') + mpc_train_algorithm = f'{mpc_train_algorithm}\n' + mpc_train_algorithm = f'{mpc_train_algorithm}{hetero_train_static_template}' + return mpc_train_algorithm + + +def generate_hetero_predict_static_template(mpc_predict_algorithm): + hetero_logreg_static_template = utils.read_content_from_file( + f'{get_dir()}hetero_predict_static_template_proxy.mpc') + mpc_predict_algorithm = f'{mpc_predict_algorithm}\n' + mpc_predict_algorithm = f'{mpc_predict_algorithm}{hetero_logreg_static_template}' + return mpc_predict_algorithm + + +def set_homo_train_static_template(model_config_dict, mpc_train_algorithm): + optimizer = model_config_dict['optimizer'] + learning_rate = model_config_dict['learning_rate'] + if optimizer == OptimizerType.sgd.name: + homo_train_static_template = utils.read_content_from_file( + f'{get_dir()}homo_train_sgd_static_template_proxy.mpc') + homo_train_static_template = homo_train_static_template.replace('gamma = MemValue(cfix(.1))', + f'gamma = MemValue(cfix({learning_rate}))') + elif optimizer == OptimizerType.adam.name: + homo_train_static_template = utils.read_content_from_file( + f'{get_dir()}homo_train_adam_static_template_proxy.mpc') + homo_train_static_template = homo_train_static_template.replace('gamma = MemValue(cfix(.001))', + f'gamma = MemValue(cfix({learning_rate}))') + mpc_train_algorithm = f'{mpc_train_algorithm}\n' + mpc_train_algorithm = f'{mpc_train_algorithm}{homo_train_static_template}' + return mpc_train_algorithm + + +def generate_set_common_code(is_psi, participants): + mpc_algorithm = '#This file is auto generated by ppc. DO NOT EDIT!\n\n' + if is_psi: + mpc_algorithm = f'{mpc_algorithm}#PSI_OPTION=True' + else: + mpc_algorithm = f'{mpc_algorithm}#PSI_OPTION=False' + mpc_algorithm = f'{mpc_algorithm}\n' + mpc_algorithm = f'{mpc_algorithm}from ppc import *\n' + mpc_algorithm = f'{mpc_algorithm}from Compiler import config\n' + mpc_algorithm = f'{mpc_algorithm}import sys\n' + mpc_algorithm = f'{mpc_algorithm}program.options_from_args()\n\n' + # if participants == 3: + # mpc_algorithm = f'{mpc_algorithm}program.use_trunc_pr = True\n' + # mpc_algorithm = f'{mpc_algorithm}program.use_split(3)\n\n' + for i in range(participants): + mpc_algorithm = f'{mpc_algorithm}SOURCE{i}={i}\n' + return mpc_algorithm + + +def set_hetero_feature_count(mpc_algorithm, participants): + total_feature_count_str = 'total_feature_count=' + for i in range(participants): + mpc_algorithm = f'{mpc_algorithm}source{i}_feature_count=$(source{i}_feature_count)\n' + if i == participants - 1: + total_feature_count_str = f'{total_feature_count_str}source{i}_feature_count' + else: + total_feature_count_str = f'{total_feature_count_str}source{i}_feature_count+' + mpc_algorithm = f'{mpc_algorithm}{total_feature_count_str}\n' + return mpc_algorithm + + +def set_homo_train_record_count(mpc_train_algorithm, participants): + total_record_count_str = 'total_record_count=' + for i in range(participants): + mpc_train_algorithm = f'{mpc_train_algorithm}source{i}_record_count=$(source{i}_record_count)\n' + if i == participants - 1: + total_record_count_str = f'{total_record_count_str}source{i}_record_count' + else: + total_record_count_str = f'{total_record_count_str}source{i}_record_count+' + mpc_train_algorithm = f'{mpc_train_algorithm}{total_record_count_str}\n' + mpc_train_algorithm = f'{mpc_train_algorithm}train_record_count=$(train_record_count)\n' + mpc_train_algorithm = f'{mpc_train_algorithm}test_record_count=$(test_record_count)\n' + mpc_train_algorithm = f'{mpc_train_algorithm}total_feature_count=$(total_feature_count)\n' + return mpc_train_algorithm + + +def generate_set_homo_predict_record_count(mpc_predict_algorithm): + mpc_predict_algorithm = f'{mpc_predict_algorithm}total_feature_count=$(total_feature_count)\n' + mpc_predict_algorithm = f'{mpc_predict_algorithm}test_record_count=$(test_record_count)\n' + mpc_predict_algorithm = f'{mpc_predict_algorithm}test_X = pfix.Matrix(test_record_count, total_feature_count)\n\n' + mpc_predict_algorithm = f'{mpc_predict_algorithm}file_offset = 0\n' + mpc_predict_algorithm = f'{mpc_predict_algorithm}file_offset = test_X.read_from_file(file_offset)\n' + return mpc_predict_algorithm + + +def set_hetero_train_record_count(mpc_train_algorithm): + mpc_train_algorithm = f'{mpc_train_algorithm}\n' + mpc_train_algorithm = f'{mpc_train_algorithm}total_record_count=$(total_record_count)\n' + mpc_train_algorithm = f'{mpc_train_algorithm}train_record_count=$(train_record_count)\n' + mpc_train_algorithm = f'{mpc_train_algorithm}test_record_count=$(test_record_count)\n' + mpc_train_algorithm = f'{mpc_train_algorithm}\n' + return mpc_train_algorithm + + +def generate_set_hetero_predict_record_count(mpc_predict_algorithm): + mpc_predict_algorithm = f'{mpc_predict_algorithm}\n' + mpc_predict_algorithm = f'{mpc_predict_algorithm}test_record_count=$(test_record_count)\n' + mpc_predict_algorithm = f'{mpc_predict_algorithm}test_X = pfix.Matrix(test_record_count, total_feature_count)\n' + return mpc_predict_algorithm + + +def read_hetero_train_dataset(mpc_train_algorithm, participants, n_class=1): + mpc_train_algorithm = f'{mpc_train_algorithm}\n' + mpc_train_algorithm = f'{mpc_train_algorithm}file_offset = 0\n' + for i in range(participants): + if i == 0: + if n_class == 1: + mpc_train_algorithm = f'{mpc_train_algorithm}source0_record_y = Array(total_record_count, pint)\n' + elif n_class > 1: + mpc_train_algorithm = f'{mpc_train_algorithm}source0_record_y = Matrix(total_record_count, total_class_count, pint)\n' + mpc_train_algorithm = f'{mpc_train_algorithm}source0_record_x = Matrix(total_record_count, source0_feature_count, pfix)\n' + mpc_train_algorithm = f'{mpc_train_algorithm}file_offset = source0_record_y.read_from_file(file_offset)\n' + mpc_train_algorithm = f'{mpc_train_algorithm}file_offset = source0_record_x.read_from_file(file_offset)\n\n' + else: + mpc_train_algorithm = f'{mpc_train_algorithm}source{i}_record_x = Matrix(total_record_count, source{i}_feature_count, pfix)\n' + mpc_train_algorithm = f'{mpc_train_algorithm}file_offset = source{i}_record_x.read_from_file(file_offset)\n\n' + + mpc_train_algorithm = f'{mpc_train_algorithm}def read_hetero_dataset():\n' + mpc_train_algorithm = f'{mpc_train_algorithm}\tprint(f"train_record_count:{{train_record_count}}")\n' + mpc_train_algorithm = f'{mpc_train_algorithm}\tprint(f"test_record_count:{{test_record_count}}")\n' + mpc_train_algorithm = f'{mpc_train_algorithm}\tdo_read_hetero_y_part()\n' + for i in range(participants): + mpc_train_algorithm = f'{mpc_train_algorithm}\tdo_read_hetero_x_part(source{i}_record_x, source{i}_feature_count)\n' + mpc_train_algorithm = f'{mpc_train_algorithm}\n' + return mpc_train_algorithm + + +def generate_read_hetero_predict_dataset(mpc_predict_algorithm, participants): + mpc_predict_algorithm = f'{mpc_predict_algorithm}\n' + mpc_predict_algorithm = f'{mpc_predict_algorithm}file_offset = 0\n' + for i in range(participants): + mpc_predict_algorithm = f'{mpc_predict_algorithm}source{i}_record_x = Matrix(test_record_count, source{i}_feature_count, pfix)\n' + mpc_predict_algorithm = f'{mpc_predict_algorithm}file_offset = source{i}_record_x.read_from_file(file_offset)\n\n' + + mpc_predict_algorithm = f'{mpc_predict_algorithm}\n' + mpc_predict_algorithm = f'{mpc_predict_algorithm}def read_hetero_test_dataset():\n' + for i in range(participants): + mpc_predict_algorithm = f'{mpc_predict_algorithm}\tdo_read_hetero_x_part(source{i}_record_x, source{i}_feature_count)\n' + return mpc_predict_algorithm + + +def read_homo_train_dataset(mpc_train_algorithm, participants, n_class=1): + mpc_train_algorithm = f'{mpc_train_algorithm}\n' + mpc_train_algorithm = f'{mpc_train_algorithm}file_offset = 0\n' + for i in range(participants): + if n_class == 1: + mpc_train_algorithm = f'{mpc_train_algorithm}source{i}_record_y = Array(source{i}_record_count, pint)\n' + elif n_class > 1: + mpc_train_algorithm = f'{mpc_train_algorithm}source{i}_record_y = Matrix(source{i}_record_count, total_class_count, pint)\n' + mpc_train_algorithm = f'{mpc_train_algorithm}source{i}_record_x = Matrix(source{i}_record_count, total_feature_count, pfix)\n' + mpc_train_algorithm = f'{mpc_train_algorithm}file_offset = source{i}_record_y.read_from_file(file_offset)\n' + mpc_train_algorithm = f'{mpc_train_algorithm}file_offset = source{i}_record_x.read_from_file(file_offset)\n\n' + + mpc_train_algorithm = f'{mpc_train_algorithm}def read_homo_dataset():\n' + mpc_train_algorithm = f'{mpc_train_algorithm}\tprint(f"total_feature_count:{{total_feature_count}}")\n' + mpc_train_algorithm = f'{mpc_train_algorithm}\tprint(f"train_record_count:{{train_record_count}}")\n' + mpc_train_algorithm = f'{mpc_train_algorithm}\tprint(f"test_record_count:{{test_record_count}}")\n' + for i in range(participants): + mpc_train_algorithm = f'{mpc_train_algorithm}\tdo_read_homo_dataset(source{i}_record_y, source{i}_record_x, source{i}_record_count)\n' + mpc_train_algorithm = f'{mpc_train_algorithm}\n' + return mpc_train_algorithm + + +def set_parameters(model_config_dict, mpc_train_algorithm): + mpc_train_algorithm = f'{mpc_train_algorithm}\n' + epochs = model_config_dict['epochs'] + batch_size = model_config_dict['batch_size'] + threads = model_config_dict['threads'] + if epochs <= 0: + mpc_train_algorithm = f'{mpc_train_algorithm}epochs={default_epochs}\n' + else: + mpc_train_algorithm = f'{mpc_train_algorithm}epochs={epochs}\n' + if int(batch_size) <= 0: + mpc_train_algorithm = f'{mpc_train_algorithm}batch_size=train_record_count\n' + else: + mpc_train_algorithm = f'{mpc_train_algorithm}user_batch_size={batch_size}\n' + mpc_train_algorithm = f'{mpc_train_algorithm}batch_size=min(user_batch_size, min(128, train_record_count))\n' + if threads <= 0: + mpc_train_algorithm = f'{mpc_train_algorithm}threads={default_threads}\n' + else: + mpc_train_algorithm = f'{mpc_train_algorithm}threads={threads}\n' + return mpc_train_algorithm + + +def set_lr_train_output(mpc_train_algorithm): + mpc_train_algorithm = f'{mpc_train_algorithm}result_columns = 2\n' + mpc_train_algorithm = f'{mpc_train_algorithm}result_matrix = Matrix(test_record_count, result_columns, pfix)\n\n' + return mpc_train_algorithm + + +def set_nn_train_output(mpc_train_algorithm, n_class): + mpc_train_algorithm = f'{mpc_train_algorithm}result_columns = {2*n_class}\n' + mpc_train_algorithm = f'{mpc_train_algorithm}result_matrix = Matrix(test_record_count, result_columns, pfix)\n\n' + return mpc_train_algorithm + + +def set_nn_predict_output(mpc_train_algorithm, n_class): + mpc_train_algorithm = f'{mpc_train_algorithm}result_columns = {n_class}\n' + mpc_train_algorithm = f'{mpc_train_algorithm}result_matrix = Matrix(test_record_count, result_columns, pfix)\n\n' + return mpc_train_algorithm + + +def generate_mpc_train_algorithm(model_config_dict, algorithm_name, is_psi): + participants = model_config_dict['participants'] + mpc_train_algorithm = generate_set_common_code(is_psi, participants) + if algorithm_name == ModelAlgorithmType.HeteroLR.name or algorithm_name == ModelAlgorithmType.HeteroNN.name: + mpc_train_algorithm = set_hetero_feature_count( + mpc_train_algorithm, participants) + mpc_train_algorithm = set_hetero_train_record_count( + mpc_train_algorithm) + mpc_train_algorithm = set_parameters( + model_config_dict, mpc_train_algorithm) + if algorithm_name == ModelAlgorithmType.HeteroLR.name: + if 'layers' in model_config_dict.keys(): + raise PpcException(PpcErrorCode.ALGORITHM_PPC_MODEL_LAYERS_ERROR.get_code(), + PpcErrorCode.ALGORITHM_PPC_MODEL_LAYERS_ERROR.get_msg()) + mpc_train_algorithm = read_hetero_train_dataset( + mpc_train_algorithm, participants) + mpc_train_algorithm = set_lr_train_output(mpc_train_algorithm) + mpc_train_algorithm = set_logreg_train_layers(mpc_train_algorithm) + mpc_train_algorithm = set_hetero_train_static_template( + model_config_dict, mpc_train_algorithm) + else: + layers = [] + if 'layers' in model_config_dict.keys(): + layers = model_config_dict['layers'] + if layers and 'Conv2d' in layers[0]: + raise PpcException(PpcErrorCode.ALGORITHM_PPC_MODEL_LAYERS_ERROR2.get_code(), + PpcErrorCode.ALGORITHM_PPC_MODEL_LAYERS_ERROR2.get_msg()) + mpc_train_algorithm = set_nn_layers( + mpc_train_algorithm, layers, 'train_record_count', PPCModleType.Train, ModelAlgorithmType.HeteroNN, participants) + mpc_train_algorithm = set_hetero_train_static_template( + model_config_dict, mpc_train_algorithm) + elif algorithm_name == ModelAlgorithmType.HomoLR.name or algorithm_name == ModelAlgorithmType.HomoNN.name: + mpc_train_algorithm = set_homo_train_record_count( + mpc_train_algorithm, participants) + mpc_train_algorithm = set_parameters( + model_config_dict, mpc_train_algorithm) + if algorithm_name == ModelAlgorithmType.HomoLR.name: + if 'layers' in model_config_dict.keys(): + raise PpcException(PpcErrorCode.ALGORITHM_PPC_MODEL_LAYERS_ERROR.get_code(), + PpcErrorCode.ALGORITHM_PPC_MODEL_LAYERS_ERROR.get_msg()) + mpc_train_algorithm = read_homo_train_dataset( + mpc_train_algorithm, participants) + mpc_train_algorithm = set_lr_train_output(mpc_train_algorithm) + mpc_train_algorithm = set_logreg_train_layers(mpc_train_algorithm) + mpc_train_algorithm = set_homo_train_static_template( + model_config_dict, mpc_train_algorithm) + else: + layers = [] + if 'layers' in model_config_dict.keys(): + layers = model_config_dict['layers'] + mpc_train_algorithm = set_nn_layers( + mpc_train_algorithm, layers, 'train_record_count', PPCModleType.Train, ModelAlgorithmType.HomoNN, participants) + mpc_train_algorithm = set_homo_train_static_template( + model_config_dict, mpc_train_algorithm) + + return mpc_train_algorithm + + +def generate_mpc_predict_algorithm(algorithm_name, layers, participants, is_psi): + mpc_predict_algorithm = generate_set_common_code(is_psi, participants) + if algorithm_name == ModelAlgorithmType.HeteroLR.name or algorithm_name == ModelAlgorithmType.HeteroNN.name: + mpc_predict_algorithm = set_hetero_feature_count( + mpc_predict_algorithm, participants) + mpc_predict_algorithm = generate_set_hetero_predict_record_count( + mpc_predict_algorithm) + mpc_predict_algorithm = generate_read_hetero_predict_dataset( + mpc_predict_algorithm, participants) + if algorithm_name == ModelAlgorithmType.HeteroLR.name: + mpc_predict_algorithm = generate_set_logreg_predict_layers( + mpc_predict_algorithm) + else: + mpc_predict_algorithm = set_nn_layers( + mpc_predict_algorithm, layers, 'test_record_count') + mpc_predict_algorithm = generate_hetero_predict_static_template( + mpc_predict_algorithm) + + elif algorithm_name == ModelAlgorithmType.HomoLR.name or algorithm_name == ModelAlgorithmType.HomoNN.name: + mpc_predict_algorithm = generate_set_homo_predict_record_count( + mpc_predict_algorithm) + if algorithm_name == ModelAlgorithmType.HomoLR.name: + mpc_predict_algorithm = generate_set_logreg_predict_layers( + mpc_predict_algorithm) + else: + mpc_predict_algorithm = set_nn_layers(mpc_predict_algorithm, layers, 'test_record_count', + PPCModleType.Predict, participants) + mpc_predict_algorithm = generate_homo_predict_static_template( + mpc_predict_algorithm) + + return mpc_predict_algorithm diff --git a/python/ppc_common/ppc_utils/tests/thread_safe_list_test.py b/python/ppc_common/ppc_utils/tests/thread_safe_list_test.py new file mode 100644 index 00000000..74028266 --- /dev/null +++ b/python/ppc_common/ppc_utils/tests/thread_safe_list_test.py @@ -0,0 +1,48 @@ +# -*- coding: utf-8 -*- +import unittest +import threading +from ppc_common.ppc_utils.thread_safe_list import ThreadSafeList + + +if __name__ == '__main__': + unittest.main() + +thread_safe_list = ThreadSafeList() + + +def test(object, thread_name, ut_obj): + try: + print(f"### thread: {thread_name}") + # append + thread_safe_list.append(object) + # contains + ut_obj.assertTrue(thread_safe_list.contains(object)) + # remove + thread_safe_list.remove(object) + ut_obj.assertTrue(thread_safe_list.contains(object) is False) + print(f"### result: {thread_safe_list.get()}") + # get element + copied_list = thread_safe_list.get() + if len(copied_list) >= 1: + thread_safe_list.get_element(len(copied_list) - 1) + thread_safe_list.get_element(len(copied_list)) + except Exception as e: + print(f"### Exception: {e}") + + +class TestThreadSafeList(unittest.TestCase): + def test_multi_thread(self): + loops = 5 + for j in range(loops): + thread_list = [] + thread_num = 20 + for i in range(thread_num): + thread_name = "t" + str(i) + "_" + str(j) + object = "job_" + str(j) + "_" + str(i) + t = threading.Thread(target=test, name=thread_name, args=(object, + thread_name, self)) + thread_list.append(t) + for t in thread_list: + t.start() + for t in thread_list: + t.join() diff --git a/python/ppc_common/ppc_utils/tests/utils_test.py b/python/ppc_common/ppc_utils/tests/utils_test.py new file mode 100644 index 00000000..42c11b4b --- /dev/null +++ b/python/ppc_common/ppc_utils/tests/utils_test.py @@ -0,0 +1,40 @@ +# -*- coding: utf-8 -*- +import unittest +from ppc_common.ppc_utils import utils +import os +import time + + +class FileOperationWrapper: + def __init__(self, file_chunk_dir, merged_file_path): + self.file_chunk_dir = file_chunk_dir + self.merged_file_path = merged_file_path + chunk_list = os.listdir(self.file_chunk_dir) + self.file_chunk_list = [] + for chunk in chunk_list: + self.file_chunk_list.append(os.path.join(file_chunk_dir, chunk)) + + +class TestFileOperations(unittest.TestCase): + def test_merge_and_calculate_hash_for_files(self): + chunk_file_dir = "/data/app/files/ppcs-modeladm/dataset" + # chunk_file_dir = "bak/object/train_test.csv" + wrapper = FileOperationWrapper(chunk_file_dir, "dataset1") + start_t = time.time() + print( + f"#### begin merge file for {len(wrapper.file_chunk_list)} chunks") + utils.merge_files(file_list=wrapper.file_chunk_list, + output_file=wrapper.merged_file_path) + print( + f"#### success merge file for {len(wrapper.file_chunk_list)} chunks success, time cost: {time.time() - start_t}") + + print( + f"#### calculate hash for {wrapper.merged_file_path}, size: {os.stat(wrapper.merged_file_path).st_size}") + start_t = time.time() + utils.calculate_md5(wrapper.merged_file_path) + print( + f"#### calculate hash for {wrapper.merged_file_path} success, size: {os.stat(wrapper.merged_file_path).st_size}, timecost: {time.time() - start_t}") + + +if __name__ == '__main__': + unittest.main() diff --git a/python/ppc_common/ppc_utils/thread_safe_list.py b/python/ppc_common/ppc_utils/thread_safe_list.py new file mode 100644 index 00000000..d34437f4 --- /dev/null +++ b/python/ppc_common/ppc_utils/thread_safe_list.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- + +import threading + + +class ThreadSafeList: + def __init__(self): + self.processing_list = [] + self.lock = threading.Lock() + + def append(self, object): + try: + self.lock.acquire() + self.processing_list.append(object) + except Exception as e: + raise e + finally: + self.lock.release() + + def remove(self, object): + try: + self.lock.acquire() + self.processing_list.remove(object) + except Exception as e: + raise e + finally: + self.lock.release() + + def contains(self, object): + try: + self.lock.acquire() + return object in self.processing_list + except Exception as e: + raise e + finally: + self.lock.release() + + def get(self): + try: + copyed_list = [] + self.lock.acquire() + copyed_list.append(self.processing_list) + return copyed_list + except Exception as e: + raise e + finally: + self.lock.release() + + def get_element(self, i): + try: + self.lock.acquire() + return self.processing_list[i] + except Exception as e: + raise e + finally: + self.lock.release() diff --git a/python/ppc_common/ppc_utils/utils.py b/python/ppc_common/ppc_utils/utils.py new file mode 100644 index 00000000..d2564cce --- /dev/null +++ b/python/ppc_common/ppc_utils/utils.py @@ -0,0 +1,914 @@ +import base64 +import datetime +import hashlib +import io +import json +import logging +import os +import random +import re +import shutil +import string +import subprocess +import time +import uuid +from concurrent.futures.thread import ThreadPoolExecutor +from enum import Enum, IntEnum, unique + +import jwt +import pandas as pd +from ecdsa import SigningKey, SECP256k1, VerifyingKey +from gmssl import func, sm2, sm3 +from google.protobuf.descriptor import FieldDescriptor +from jsoncomment import JsonComment +from pysmx.SM3 import SM3 + +from ppc_common.ppc_protos.generated.ppc_pb2 import DatasetDetail +from ppc_common.ppc_utils.exception import PpcException, PpcErrorCode + +log = logging.getLogger(__name__) + +MAX_SUPPORTED_PARTIES = 5 + +SERVER_RUNNING_STATUS = 0 +DEFAULT_DATASET_RECORD_COUNT = 5 +DEFAULT_PAGE_OFFSET = 0 +DEFAULT_PAGE_SIZE = 5 +MIN_PARTICIPATE_NUMBER = 2 +TWO_PARTY_PSI_PARTICIPATE_NUMBER = 2 +MIN_MULTI_PARTY_PSI_PARTICIPATE_NUMBER = 3 +MPC_MAX_SOURCE_COUNT = 50 +MPC_MAX_FIELD_COUNT = 50 + +BASE_RESPONSE = {'errorCode': PpcErrorCode.SUCCESS.get_code( +), 'message': PpcErrorCode.SUCCESS.get_msg()} +LOG_NAME = 'ppcs-modeladm-scheduler.log' +LOG_CHARACTER_NUMBER = 100000 +CSV_SEP = ',' +BLANK_SEP = ' ' +NORMALIZED_NAMES = 'field{}' + +PPC_RESULT_FIELDS_FLAG = 'result_fields' +PPC_RESULT_VALUES_FLAG = 'result_values' + +ADMIN_USER = 'admin' + +MPC_RECORD_PLACE_HOLDER = '$(ppc_max_record_count)' +MPC_START_PLACE_HOLDER = '${ph_start}' +MPC_END_PLACE_HOLDER = '${ph_end}' + +HOMO_MODEL_ALGORITHM = 'homo' +MPC_TOTAL_RECORD_COUNT_PLACE_HOLDER = '$(total_record_count)' +MPC_TRAIN_RECORD_COUNT_PLACE_HOLDER = '$(train_record_count)' +MPC_TEST_RECORD_COUNT_PLACE_HOLDER = '$(test_record_count)' +MPC_TOTAL_FEATURE_COUNT_PLACE_HOLDER = '$(total_feature_count)' + +XGB_TREE_PERFIX = "xgb_tree" +MPC_TRAIN_METRIC_ROC_FILE = "mpc_metric_roc.svg" +MPC_TRAIN_METRIC_KS_FILE = "mpc_metric_ks.svg" +MPC_TRAIN_METRIC_PR_FILE = "mpc_metric_pr.svg" +MPC_TRAIN_METRIC_ACCURACY_FILE = "mpc_metric_accuracy.svg" +MPC_TRAIN_METRIC_KS_TABLE = "mpc_metric_ks.csv" +MPC_TRAIN_SET_METRIC_ROC_FILE = "mpc_train_metric_roc.svg" +MPC_TRAIN_SET_METRIC_KS_FILE = "mpc_train_metric_ks.svg" +MPC_TRAIN_SET_METRIC_PR_FILE = "mpc_train_metric_pr.svg" +MPC_TRAIN_SET_METRIC_ACCURACY_FILE = "mpc_train_metric_accuracy.svg" +MPC_TRAIN_SET_METRIC_KS_TABLE = "mpc_train_metric_ks.csv" +MPC_EVAL_METRIC_ROC_FILE = "mpc_eval_metric_roc.svg" +MPC_EVAL_METRIC_KS_FILE = "mpc_eval_metric_ks.svg" +MPC_EVAL_METRIC_PR_FILE = "mpc_eval_metric_pr.svg" +MPC_EVAL_METRIC_ACCURACY_FILE = "mpc_eval_metric_accuracy.svg" +MPC_EVAL_METRIC_KS_TABLE = "mpc_eval_metric_ks.csv" +MPC_TRAIN_METRIC_CONFUSION_MATRIX_FILE = "mpc_metric_confusion_matrix.svg" +METRICS_OVER_ITERATION_FILE = "metrics_over_iterations.svg" + +# the ks-auc table, e.g.: +# |总样本|正样本| KS | AUC | +# 训练集| 500 | 100 | 0.4161 | 0.7685 | +# 验证集| 154 | 37 | 0.2897 | 0.6376 | +MPC_XGB_EVALUATION_TABLE = "mpc_xgb_evaluation_table.csv" + +# the feature-importance table, e.g.: +# |特征 | score | weight | score_rank| topk | +# | x1 | 0.08 | 1000 | 1 | | +# | x2 | 0.07 | 900 | 2 | | +XGB_FEATURE_IMPORTANCE_TABLE = "xgb_result_feature_importance_table.csv" + +# png, jpeg, pdf etc. ref: https://graphviz.org/docs/outputs/ +WORKFLOW_VIEW_FORMAT = 'svg' + +WORKFLOW_VIEW_NAME = 'workflow_view' + +PROXY_PSI_CIPHER_SUITE = "HMAC_BASED_PRIVATE_SET_INTERSECTION" +PROXY_MPC_CIPHER_SUITE = "SHA256_WITH_REPLICATED_SECRET_SHARING" +PROXY_PSI_MPC_CIPHER_SUITE = PROXY_PSI_CIPHER_SUITE + "-" + PROXY_MPC_CIPHER_SUITE +PPC_ALL_AUTH_FLAG = "PPC_ALGO_ALL" + +CEM_CIPHER_LEN = 288 + + +@unique +class PPCModleType(Enum): + Train = 1 + Predict = 2 + + +@unique +class CryptoType(Enum): + ECDSA = 1 + GM = 2 + + +@unique +class HashType(Enum): + BYTES = 1 + HEXSTR = 2 + + +class DeployMode(IntEnum): + PRIVATE_MODE = 1 + SAAS_MODE = 2 + PROXY_MODE = 3 + + +class JobRole(Enum): + ALGORITHM_PROVIDER = 1 + COMPUTATION_PROVIDER = 2 + DATASET_PROVIDER = 3 + DATA_CONSUMER = 4 + + +class JobStatus(Enum): + SUCCEED = 1 + FAILED = 2 + RUNNING = 3 + WAITING = 4 + NONE = 5 + RETRY = 6 + + +class AlgorithmType(Enum): + Train = "Train", + Predict = "Predict" + + +class AlgorithmSubtype(Enum): + HeteroLR = 1 + HomoLR = 2 + HeteroNN = 3 + HomoNN = 4 + HeteroXGB = 5 + + +class DataAlgorithmType(Enum): + PIR = 'pir' + CEM = 'cem' + ALL = 'all' + + +class IdPrefixEnum(Enum): + DATASET = "d-" + ALGORITHM = "a-" + JOB = "j-" + + +class JobCemResultSuffixEnum(Enum): + JOB_ALGORITHM_AGENCY = "-1" + INITIAL_JOB_AGENCY = "-2" + + +class OriginAlgorithm(Enum): + PPC_AYS = ["a-1001", "匿踪查询", "1", "1.0"] + PPC_PSI = ["a-1002", "隐私求交", "2+", "1.0"] + + +executor = ThreadPoolExecutor(max_workers=10) + + +def async_fn(func): + def wrapper(*args, **kwargs): + executor.submit(func, *args, **kwargs) + + return wrapper + + +def json_loads(json_config): + try: + json_comment = JsonComment(json) + return json_comment.loads(json_config) + except BaseException: + raise PpcException(PpcErrorCode.ALGORITHM_PPC_CONFIG_ERROR.get_code(), + PpcErrorCode.ALGORITHM_PPC_CONFIG_ERROR.get_msg()) + + +def parse_n_class(layer_str): + if layer_str == '[]': + return 1 + else: + return int(re.findall(r"\d+\.?\d*", layer_str)[-1]) + + +def check_ppc_model_algorithm_is_homo(algorithm_name): + return algorithm_name[0:4].lower() == HOMO_MODEL_ALGORITHM + + +def get_log_file_path(app_dir): + return os.sep.join([app_dir, "logs", LOG_NAME]) + + +def get_log_temp_file_path(app_dir, job_id): + return os.sep.join([app_dir, "logs", f"{job_id}.log"]) + + +def df_to_dict(df, orient='split'): + data_json_str = df.to_json(orient=orient, force_ascii=False) + data_dict = json.loads(data_json_str) + del data_dict['index'] + return data_dict + + +def file_exists(_file): + if os.path.exists(_file) and os.path.isfile(_file): + return True + return False + + +def decode_jwt(token): + """ + decode jwt + :param token: + :return: + """ + result = {"data": None, "error": None} + try: + payload = jwt.decode(token.split(" ")[1], options={ + "verify_signature": False}) + result["data"] = payload + except (IndexError, jwt.DecodeError): + result["error"] = "JWT token is decoded fail" + return result + + +def make_timestamp(): + return int(round(time.time() * 1000)) + + +def encode(data_bytes): + return base64.b64encode(data_bytes) + + +def decode(data_str): + return base64.b64decode(data_str) + + +def pb_to_str(data_pb): + return encode(data_pb.SerializeToString()).decode("utf-8") + + +def str_to_pb(data_pb, data_str): + data_pb.ParseFromString(decode(data_str)) + return data_pb + + +def pb_to_bytes(data_pb): + return data_pb.SerializeToString() + + +def bytes_to_pb(data_pb, data_bytes): + return data_pb.ParseFromString(data_bytes) + + +def str_to_base64str(data_str): + message_bytes = data_str.encode('utf-8') + base64_bytes = base64.b64encode(message_bytes) + base64_message = base64_bytes.decode('utf-8') + return base64_message + + +def base64str_to_str(base64_str): + base64_bytes = base64_str.encode('utf-8') + message_bytes = base64.b64decode(base64_bytes) + message = message_bytes.decode('utf-8') + return message + + +def bytes_to_base64str(data_bytes): + base64_bytes = base64.b64encode(data_bytes) + base64_message = base64_bytes.decode('utf-8') + return base64_message + + +def make_response(code, message, data=None): + return {'errorCode': code, 'message': message, 'data': data} + + +def base64str_to_bytes(base64_str): + base64_bytes = base64_str.encode('utf-8') + message_bytes = base64.b64decode(base64_bytes) + return message_bytes + + +def make_hash_from_file_path(file_path, crypto_type): + file_data = read_content_from_file_by_binary(file_path) + return make_hash(file_data, crypto_type, HashType.HEXSTR) + + +def read_chunks(file, size=io.DEFAULT_BUFFER_SIZE): + while True: + chunk = bytes(file.read(size), 'utf-8') + if not chunk: + break + yield chunk + + +def make_hash_from_file_path_by_chunks(file_path, crypto_type, block_size=1 << 10): + m = SM3() + if crypto_type == CryptoType.ECDSA: + m = hashlib.sha3_256() + dataset_size = 0 + with open(file_path, 'r', encoding='utf-8') as f: + for block in read_chunks(f, size=block_size): + dataset_size += len(block) + m.update(block) + return m.hexdigest(), dataset_size + + +def read_content_from_file_by_binary(file_path): + with open(file_path, 'rb') as file: + content = file.read() + return content + + +def make_hash(data, crypto_type, hash_type=None): + if crypto_type == CryptoType.ECDSA: + m = hashlib.sha3_256() + m.update(data) + if hash_type == HashType.HEXSTR: + return m.hexdigest() + if hash_type == HashType.BYTES: + return m.digest() + if crypto_type == CryptoType.GM: + return sm3.sm3_hash(func.bytes_to_list(data)) + + +def pb2dict(obj): + """ + Takes a ProtoBuf Message obj and convertes it to a dict. + """ + adict = {} + if not obj.IsInitialized(): + return None + for field in obj.DESCRIPTOR.fields: + if not getattr(obj, field.name): + continue + if not field.label == FieldDescriptor.LABEL_REPEATED: + if not field.type == FieldDescriptor.TYPE_MESSAGE: + adict[field.name] = getattr(obj, field.name) + else: + value = pb2dict(getattr(obj, field.name)) + if value: + adict[field.name] = value + else: + if field.type == FieldDescriptor.TYPE_MESSAGE: + adict[field.name] = \ + [pb2dict(v) for v in getattr(obj, field.name)] + else: + adict[field.name] = [v for v in getattr(obj, field.name)] + return adict + + +def write_content_to_file(content, file_path): + with open(file_path, 'w', encoding="utf-8") as file: + file.write(content) + + +def write_content_to_file_by_append(content, file_path): + with open(file_path, 'a', encoding="utf-8") as file: + file.write(content) + + +def read_content_from_file(file_path): + with open(file_path, 'r', encoding='utf-8') as file: + content = file.read() + return content + + +def make_job_event_message(job_id, job_priority, initiator_agency_id, receiver_agency_id, job_algorithm_id, + job_dataset): + message = '{}|{}|{}|{}|{}'.format(job_id, job_priority, initiator_agency_id, receiver_agency_id, job_algorithm_id, + job_dataset) + return message.encode('utf-8') + + +def sign_with_secp256k1(message, private_key_str): + if isinstance(private_key_str, str): + private_key = SigningKey.from_string( + bytes().fromhex(private_key_str), curve=SECP256k1) + else: + private_key = SigningKey.from_string(private_key_str, curve=SECP256k1) + signature_bytes = private_key.sign( + make_hash(message, CryptoType.ECDSA, HashType.BYTES)) + return str(encode(signature_bytes), 'utf-8') + + +def verify_with_secp256k1(message, signature_str, public_key_str): + if isinstance(public_key_str, str): + verify_key = VerifyingKey.from_string( + bytes().fromhex(public_key_str), curve=SECP256k1) + else: + verify_key = VerifyingKey.from_string( + decode(public_key_str), curve=SECP256k1) + return verify_key.verify(decode(signature_str), make_hash(message, CryptoType.ECDSA, HashType.BYTES)) + + +def sign_with_sm2(message, private_key_str): + sm2_crypt = sm2.CryptSM2(private_key_str, "") + random_hex_str = func.random_hex(sm2_crypt.para_len) + message_hash = make_hash(message, CryptoType.GM).encode(encoding='utf-8') + signature_str = sm2_crypt.sign(message_hash, random_hex_str) + return signature_str + + +def verify_with_sm2(message, signature_str, public_key_str): + sm2_crypt = sm2.CryptSM2("", public_key_str) + message_hash = make_hash(message, CryptoType.GM).encode(encoding='utf-8') + return sm2_crypt.verify(signature_str, message_hash) + + +def make_signature(message, private_key_str, crypto_type): + if crypto_type == CryptoType.ECDSA: + return sign_with_secp256k1(message, private_key_str) + + if crypto_type == CryptoType.GM: + return sign_with_sm2(message, private_key_str) + + +def verify_signature(message, signature, public_key_str, crypto_type): + if crypto_type == CryptoType.ECDSA: + return verify_with_secp256k1(message, signature, public_key_str) + + if crypto_type == CryptoType.GM: + return verify_with_sm2(message, signature, public_key_str) + + +def exec_bash_command(cmd): + """replace commands.get_status_output + + Arguments: + cmd {[string]} + """ + + get_cmd = subprocess.Popen( + cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + ret = get_cmd.communicate() + out = ret[0] + err = ret[1] + output = '' + if out is not None: + output = output + out.decode('utf-8') + if err is not None: + output = output + err.decode('utf-8') + + return get_cmd.returncode, output + + +def delete_file(path): + """[delete data_dir] + + Arguments: + path {[get_dir]} -- [description] + """ + + if os.path.isfile(path): + os.remove(path) + elif os.path.isdir(path): + shutil.rmtree(path) + else: + raise (Exception(' path not exisited ! path => %s', path)) + + +def make_dir(_dir): + if not os.path.exists(_dir): + os.mkdir(_dir) + + +def load_credential_from_file(filepath): + real_path = os.path.join(os.path.dirname(__file__), filepath) + with open(real_path, 'rb') as f: + return f.read() + + +def getstatusoutput(cmd): + """replace commands.getstatusoutput + + Arguments: + cmd {[string]} + """ + + get_cmd = subprocess.Popen( + cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + ret = get_cmd.communicate() + out = ret[0] + err = ret[1] + output = '' + if not out is None: + output = output + out.decode('utf-8') + if not err is None: + output = output + err.decode('utf-8') + + log.debug(' cmd is %s, status is %s, output is %s', + cmd, str(get_cmd.returncode), output) + + return (get_cmd.returncode, output) + + +def make_dataset_metadata(data_field_list, dataset_description, dataset_size, row_count, column_count, + version_hash, user_name): + dataset_metadata = DatasetDetail() + dataset_metadata.data_field_list = data_field_list + dataset_metadata.dataset_description = dataset_description + dataset_metadata.dataset_size = dataset_size + dataset_metadata.row_count = row_count + dataset_metadata.column_count = column_count + dataset_metadata.dataset_hash = version_hash + dataset_metadata.user_name = user_name + + date_time = make_timestamp() + dataset_metadata.create_time = date_time + dataset_metadata.update_time = date_time + return pb_to_str(dataset_metadata) + + +def default_model_module(subtype): + defaultModelModule = [ + { + 'label': 'use_psi', + 'type': 'bool', + 'value': 0, + 'min_value': 0, + 'max_value': 1, + 'description': '是否进行隐私求交,只能为0或1' + }, + { + 'label': 'fillna', + 'type': 'bool', + 'value': 0, + 'min_value': 0, + 'max_value': 1, + 'description': '是否进行缺失值填充,只能为0或1' + }, + { + 'label': 'na_select', + 'type': 'float', + 'value': 1, + 'min_value': 0, + 'max_value': 1, + 'description': '缺失值筛选阈值,取值范围为0~1之间(0表示只要有缺失值就移除,1表示移除全为缺失值的列)' + }, + { + 'label': 'filloutlier', + 'type': 'bool', + 'value': 0, + 'min_value': 0, + 'max_value': 1, + 'description': '是否进行异常值处理,只能为0或1' + }, + { + 'label': 'normalized', + 'type': 'bool', + 'value': 0, + 'min_value': 0, + 'max_value': 1, + 'description': '是否归一化,只能为0或1' + }, + { + 'label': 'standardized', + 'type': 'bool', + 'value': 0, + 'min_value': 0, + 'max_value': 1, + 'description': '是否标准化,只能为0或1' + }, + { + 'label': 'categorical', + 'type': 'string', + 'value': '', + 'description': '标记所有分类特征字段,格式:"x1,x12" (空代表无分类特征)' + }, + { + 'label': 'psi_select_col', + 'type': 'string', + 'value': '', + 'description': 'PSI稳定性筛选时间列名(空代表不进行PSI筛选)' + }, + { + 'label': 'psi_select_base', + 'type': 'string', + 'value': '', + 'description': 'PSI稳定性筛选的基期(空代表不进行PSI筛选)' + }, + { + 'label': 'psi_select_thresh', + 'type': 'float', + 'value': 0.3, + 'min_value': 0, + 'max_value': 1, + 'description': 'PSI筛选阈值,取值范围为0~1之间' + }, + { + 'label': 'psi_select_bins', + 'type': 'int', + 'value': 4, + 'min_value': 3, + 'max_value': 100, + 'description': '计算PSI时分箱数, 取值范围为3~100之间' + }, + { + 'label': 'corr_select', + 'type': 'float', + 'value': 0, + 'min_value': 0, + 'max_value': 1, + 'description': '特征相关性筛选阈值,取值范围为0~1之间(值为0时不进行相关性筛选)' + }, + { + 'label': 'use_iv', + 'type': 'bool', + 'value': 0, + 'min_value': 0, + 'max_value': 1, + 'description': '是否使用iv进行特征筛选,只能为0或1' + }, + { + 'label': 'group_num', + 'type': 'int', + 'value': 4, + 'min_value': 3, + 'max_value': 100, + 'description': 'woe计算分箱数,取值范围为3~100之间的整数' + }, + { + 'label': 'iv_thresh', + 'type': 'float', + 'value': 0.1, + 'min_value': 0.01, + 'max_value': 1, + 'description': 'iv特征筛选的阈值,取值范围为0.01~1之间' + }] + + if subtype == 'HeteroXGB': + defaultModelModule.extend([ + { + 'label': 'use_goss', + 'type': 'bool', + 'value': 0, + 'min_value': 0, + 'max_value': 1, + 'description': '是否采用goss抽样加速训练, 只能为0或1' + }, + { + 'label': 'test_dataset_percentage', + 'type': 'float', + 'value': 0.3, + 'min_value': 0.1, + 'max_value': 0.5, + 'description': '测试集比例, 取值范围为0.1~0.5之间' + }, + { + 'label': 'learning_rate', + 'type': 'float', + 'value': 0.1, + 'min_value': 0.01, + 'max_value': 1, + 'description': '学习率, 取值范围为0.01~1之间' + }, + { + 'label': 'num_trees', + 'type': 'int', + 'value': 6, + 'min_value': 1, + 'max_value': 300, + 'description': 'XGBoost迭代树棵树, 取值范围为1~300之间的整数' + }, + { + 'label': 'max_depth', + 'type': 'int', + 'value': 3, + 'min_value': 1, + 'max_value': 6, + 'description': 'XGBoost树深度, 取值范围为1~6之间的整数' + }, + { + 'label': 'max_bin', + 'type': 'int', + 'value': 4, + 'min_value': 3, + 'max_value': 100, + 'description': '特征分箱数, 取值范围为3~100之间的整数' + }, + { + 'label': 'silent', + 'type': 'bool', + 'value': 0, + 'min_value': 0, + 'max_value': 1, + 'description': '值为1时不打印运行信息,只能为0或1' + }, + { + 'label': 'subsample', + 'type': 'float', + 'value': 1, + 'min_value': 0.1, + 'max_value': 1, + 'description': '训练每棵树使用的样本比例,取值范围为0.1~1之间' + }, + { + 'label': 'colsample_bytree', + 'type': 'float', + 'value': 1, + 'min_value': 0.1, + 'max_value': 1, + 'description': '训练每棵树使用的特征比例,取值范围为0.1~1之间' + }, + { + 'label': 'colsample_bylevel', + 'type': 'float', + 'value': 1, + 'min_value': 0.1, + 'max_value': 1, + 'description': '训练每一层使用的特征比例,取值范围为0.1~1之间' + }, + { + 'label': 'reg_alpha', + 'type': 'float', + 'value': 0, + 'min_value': 0, + 'description': 'L1正则化项,用于控制模型复杂度,取值范围为大于等于0的数值' + }, + { + 'label': 'reg_lambda', + 'type': 'float', + 'value': 1, + 'min_value': 0, + 'description': 'L2正则化项,用于控制模型复杂度,取值范围为大于等于0的数值' + }, + { + 'label': 'gamma', + 'type': 'float', + 'value': 0, + 'min_value': 0, + 'description': '最优分割点所需的最小损失函数下降值,取值范围为大于等于0的数值' + }, + { + 'label': 'min_child_weight', + 'type': 'float', + 'value': 0, + 'min_value': 0, + 'description': '最优分割点所需的最小叶子节点权重,取值范围为大于等于0的数值' + }, + { + 'label': 'min_child_samples', + 'type': 'int', + 'value': 10, + 'min_value': 1, + 'max_value': 1000, + 'description': '最优分割点所需的最小叶子节点样本数量,取值范围为1~1000之间的整数' + }, + { + 'label': 'seed', + 'type': 'int', + 'value': 2024, + 'min_value': 0, + 'max_value': 10000, + 'description': '分割训练集/测试集时随机数种子,取值范围为0~10000之间的整数' + }, + { + 'label': 'early_stopping_rounds', + 'type': 'int', + 'value': 5, + 'min_value': 0, + 'max_value': 100, + 'description': '指定迭代多少次没有提升则停止训练, 值为0时不执行, 取值范围为0~100之间的整数' + }, + { + 'label': 'eval_metric', + 'type': 'string', + 'value': 'auc', + 'description': '早停的评估指标,支持:auc, acc, recall, precision' + }, + { + 'label': 'verbose_eval', + 'type': 'int', + 'value': 1, + 'min_value': 0, + 'max_value': 100, + 'description': '按传入的间隔输出训练过程中的评估信息,0表示不打印' + }, + { + 'label': 'eval_set_column', + 'type': 'string', + 'value': '', + 'description': '指定训练集测试集标记字段名称' + }, + { + 'label': 'train_set_value', + 'type': 'string', + 'value': '', + 'description': '指定训练集标记值' + }, + { + 'label': 'eval_set_value', + 'type': 'string', + 'value': '', + 'description': '指定测试集标记值' + }, + { + 'label': 'train_features', + 'type': 'string', + 'value': '', + 'description': '指定入模特征' + } # , + # { + # 'label': 'threads', + # 'type': 'int', + # 'value': 8, + # 'min_value': 1, + # 'max_value': 8, + # 'description': '取值范围为1~8之间的整数' + # } + ]) + else: + defaultModelModule.extend([ + { + 'label': 'test_dataset_percentage', + 'type': 'float', + 'value': 0.3, + 'min_value': 0.1, + 'max_value': 0.5, + 'description': '取值范围为0.1~0.5之间' + }, + { + 'label': 'learning_rate', + 'type': 'float', + 'value': 0.3, + 'min_value': 0.001, + 'max_value': 0.999, + 'description': '取值范围为0.001~0.999之间' + }, + { + 'label': 'epochs', + 'type': 'int', + 'value': 3, + 'min_value': 1, + 'max_value': 5, + 'description': '取值范围为1~5之间的整数' + }, + { + 'label': 'batch_size', + 'type': 'int', + 'value': 16, + 'min_value': 1, + 'max_value': 100, + 'description': '取值范围为1~100之间的整数' + }, + { + 'label': 'threads', + 'type': 'int', + 'value': 8, + 'min_value': 1, + 'max_value': 8, + 'description': '取值范围为1~8之间的整数' + } + ]) + + return defaultModelModule + + +def merge_files(file_list, output_file): + try: + with open(output_file, 'wb') as outfile: + for file_name in file_list: + with open(file_name, 'rb') as f: + outfile.write(f.read()) + except Exception as pe: + log.info(f"merge files failed: {pe}") + raise PpcException(-1, f"merge files failed for: {pe}") + + +def md5sum(data_content): + md5_hash = hashlib.md5() + data = data_content if type( + data_content) is bytes else bytes(data_content, "utf-8") + md5_hash.update(data) + return md5_hash.hexdigest() + + +def calculate_md5(file_path, granularity=2 * 1024 * 1024): + md5_hash = hashlib.md5() + + with open(file_path, 'rb') as file: + # 逐块读取文件内容,以提高性能 + for chunk in iter(lambda: file.read(granularity), b''): + md5_hash.update(chunk) + + return md5_hash.hexdigest() diff --git a/python/ppc_model/__init__.py b/python/ppc_model/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/ppc_model/common/__init__.py b/python/ppc_model/common/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/ppc_model/common/base_context.py b/python/ppc_model/common/base_context.py new file mode 100644 index 00000000..a6c1b582 --- /dev/null +++ b/python/ppc_model/common/base_context.py @@ -0,0 +1,83 @@ +import os +from ppc_common.ppc_utils import utils + + +class BaseContext: + PSI_RESULT_FILE = "psi_result.csv" + MODEL_PREPARE_FILE = "model_prepare.csv" + PREPROCESSING_RESULT_FILE = "preprocessing_result.csv" + EVAL_COLUMN_FILE = "model_eval_column.csv" + WOE_IV_FILE = 'woe_iv.csv' + IV_SELECTED_FILE = 'iv_selected.csv' + SELECTED_COL_FILE = "xgb_result_column_info_selected.csv" + + # TODO: rename xgb filename + FEATURE_BIN_FILE = "feature_bin.json" + # MODEL_DATA_FILE = "model_data.json" + MODEL_DATA_FILE = utils.XGB_TREE_PERFIX + '.json' + TEST_MODEL_RESULT_FILE = "model_result.csv" + # TEST_MODEL_OUTPUT_FILE = "model_output.csv" + TEST_MODEL_OUTPUT_FILE = "xgb_output.csv" + TRAIN_MODEL_RESULT_FILE = "train_model_result.csv" + # TRAIN_MODEL_OUTPUT_FILE = "train_model_output.csv" + TRAIN_MODEL_OUTPUT_FILE = "xgb_train_output.csv" + + def __init__(self, job_id: str, job_temp_dir: str): + self.job_id = job_id + self.workspace = os.path.join(job_temp_dir, self.job_id) + if not os.path.exists(self.workspace): + os.makedirs(self.workspace) + self.psi_result_path = os.path.join(self.workspace, self.PSI_RESULT_FILE) + self.model_prepare_file = os.path.join(self.workspace, self.MODEL_PREPARE_FILE) + self.preprocessing_result_file = os.path.join(self.workspace, self.PREPROCESSING_RESULT_FILE) + self.eval_column_file = os.path.join(self.workspace, self.EVAL_COLUMN_FILE) + self.woe_iv_file = os.path.join(self.workspace, self.WOE_IV_FILE) + self.iv_selected_file = os.path.join(self.workspace, self.IV_SELECTED_FILE) + self.selected_col_file = os.path.join(self.workspace, self.SELECTED_COL_FILE) + self.remote_selected_col_file = os.path.join(self.job_id, self.SELECTED_COL_FILE) + + self.summary_evaluation_file = os.path.join(self.workspace, utils.MPC_XGB_EVALUATION_TABLE) + self.feature_importance_file = os.path.join(self.workspace, utils.XGB_FEATURE_IMPORTANCE_TABLE) + self.feature_bin_file = os.path.join(self.workspace, self.FEATURE_BIN_FILE) + self.model_data_file = os.path.join(self.workspace, self.MODEL_DATA_FILE) + self.test_model_result_file = os.path.join(self.workspace, self.TEST_MODEL_RESULT_FILE) + self.test_model_output_file = os.path.join(self.workspace, self.TEST_MODEL_OUTPUT_FILE) + self.train_model_result_file = os.path.join(self.workspace, self.TRAIN_MODEL_RESULT_FILE) + self.train_model_output_file = os.path.join(self.workspace, self.TRAIN_MODEL_OUTPUT_FILE) + + self.train_metric_roc_file = os.path.join(self.workspace, utils.MPC_TRAIN_SET_METRIC_ROC_FILE) + self.train_metric_ks_file = os.path.join(self.workspace, utils.MPC_TRAIN_SET_METRIC_KS_FILE) + self.train_metric_pr_file = os.path.join(self.workspace, utils.MPC_TRAIN_SET_METRIC_PR_FILE) + self.train_metric_acc_file = os.path.join(self.workspace, utils.MPC_TRAIN_SET_METRIC_ACCURACY_FILE) + self.test_metric_roc_file = os.path.join(self.workspace, utils.MPC_TRAIN_METRIC_ROC_FILE) + self.test_metric_ks_file = os.path.join(self.workspace, utils.MPC_TRAIN_METRIC_KS_FILE) + self.test_metric_pr_file = os.path.join(self.workspace, utils.MPC_TRAIN_METRIC_PR_FILE) + self.test_metric_acc_file = os.path.join(self.workspace, utils.MPC_TRAIN_METRIC_ACCURACY_FILE) + self.train_metric_ks_table = os.path.join(self.workspace, utils.MPC_TRAIN_SET_METRIC_KS_TABLE) + self.test_metric_ks_table = os.path.join(self.workspace, utils.MPC_TRAIN_METRIC_KS_TABLE) + self.model_tree_prefix = os.path.join(self.workspace, utils.XGB_TREE_PERFIX) + self.metrics_iteration_file = os.path.join(self.workspace, utils.METRICS_OVER_ITERATION_FILE) + + self.remote_summary_evaluation_file = os.path.join(self.job_id, utils.MPC_XGB_EVALUATION_TABLE) + self.remote_feature_importance_file = os.path.join(self.job_id, utils.XGB_FEATURE_IMPORTANCE_TABLE) + self.remote_feature_bin_file = os.path.join(self.job_id, self.FEATURE_BIN_FILE) + self.remote_model_data_file = os.path.join(self.job_id, self.MODEL_DATA_FILE) + self.remote_test_model_output_file = os.path.join(self.job_id, self.TEST_MODEL_OUTPUT_FILE) + self.remote_train_model_output_file = os.path.join(self.job_id, self.TRAIN_MODEL_OUTPUT_FILE) + + self.remote_train_metric_roc_file = os.path.join(self.job_id, utils.MPC_TRAIN_SET_METRIC_ROC_FILE) + self.remote_train_metric_ks_file = os.path.join(self.job_id, utils.MPC_TRAIN_SET_METRIC_KS_FILE) + self.remote_train_metric_pr_file = os.path.join(self.job_id, utils.MPC_TRAIN_SET_METRIC_PR_FILE) + self.remote_train_metric_acc_file = os.path.join(self.job_id, utils.MPC_TRAIN_SET_METRIC_ACCURACY_FILE) + self.remote_test_metric_roc_file = os.path.join(self.job_id, utils.MPC_TRAIN_METRIC_ROC_FILE) + self.remote_test_metric_ks_file = os.path.join(self.job_id, utils.MPC_TRAIN_METRIC_KS_FILE) + self.remote_test_metric_pr_file = os.path.join(self.job_id, utils.MPC_TRAIN_METRIC_PR_FILE) + self.remote_test_metric_acc_file = os.path.join(self.job_id, utils.MPC_TRAIN_METRIC_ACCURACY_FILE) + self.remote_train_metric_ks_table = os.path.join(self.job_id, utils.MPC_TRAIN_SET_METRIC_KS_TABLE) + self.remote_test_metric_ks_table = os.path.join(self.job_id, utils.MPC_TRAIN_METRIC_KS_TABLE) + self.remote_model_tree_prefix = os.path.join(self.job_id, utils.XGB_TREE_PERFIX) + self.remote_metrics_iteration_file = os.path.join(self.job_id, utils.METRICS_OVER_ITERATION_FILE) + + @staticmethod + def feature_engineering_input_path(job_id: str, job_temp_dir: str): + return os.path.join(job_temp_dir, job_id, BaseContext.MODEL_PREPARE_FILE) diff --git a/python/ppc_model/common/context.py b/python/ppc_model/common/context.py new file mode 100644 index 00000000..02614e47 --- /dev/null +++ b/python/ppc_model/common/context.py @@ -0,0 +1,14 @@ +import unittest + +from ppc_model.common.base_context import BaseContext +from ppc_model.common.initializer import Initializer +from ppc_model.common.protocol import TaskRole + + +class Context(BaseContext): + + def __init__(self, job_id: str, task_id: str, components: Initializer, role: TaskRole = None): + super().__init__(job_id, components.config_data['JOB_TEMP_DIR']) + self.task_id = task_id + self.components = components + self.role = role diff --git a/python/ppc_model/common/global_context.py b/python/ppc_model/common/global_context.py new file mode 100644 index 00000000..13437d90 --- /dev/null +++ b/python/ppc_model/common/global_context.py @@ -0,0 +1,13 @@ +import os +import threading + +from ppc_model.common.initializer import Initializer + +dirName, _ = os.path.split(os.path.abspath(__file__)) +# config_path = '{}/../application.yml'.format(dirName) +config_path = "application.yml" + +components = Initializer(log_config_path='logging.conf', config_path=config_path) + +# matplotlib 线程不安全,并行任务绘图增加全局锁 +plot_lock = threading.Lock() diff --git a/python/ppc_model/common/initializer.py b/python/ppc_model/common/initializer.py new file mode 100644 index 00000000..6dd6258f --- /dev/null +++ b/python/ppc_model/common/initializer.py @@ -0,0 +1,108 @@ +import logging +import logging.config +import os + +import yaml + +from ppc_common.deps_services import storage_loader +from ppc_common.deps_services.storage_api import StorageType +from ppc_common.ppc_async_executor.thread_event_manager import ThreadEventManager +from ppc_common.ppc_utils import common_func +from ppc_model.network.grpc.grpc_client import GrpcClient +from ppc_model.network.stub import ModelStub +from ppc_model.task.task_manager import TaskManager + + +class Initializer: + def __init__(self, log_config_path, config_path): + self.log_config_path = log_config_path + self.config_path = config_path + self.config_data = None + self.grpc_options = None + self.stub = None + self.task_manager = None + self.thread_event_manager = None + self.storage_client = None + # 只用于测试 + self.mock_logger = None + self.public_key_length = 2048 + self.homo_algorithm = 0 + + def init_all(self): + self.init_log() + self.init_config() + self.init_stub() + self.init_task_manager() + self.init_storage_client() + self.init_cache() + + def init_log(self): + logging.config.fileConfig(self.log_config_path) + + def init_cache(self): + self.job_cache_dir = common_func.get_config_value( + "JOB_TEMP_DIR", "/tmp", self.config_data, False) + + def init_config(self): + with open(self.config_path, 'rb') as f: + self.config_data = yaml.safe_load(f.read()) + self.public_key_length = self.config_data['PUBLIC_KEY_LENGTH'] + storage_type = common_func.get_config_value( + "STORAGE_TYPE", "HDFS", self.config_data, False) + if 'HOMO_ALGORITHM' in self.config_data: + self.homo_algorithm = self.config_data['HOMO_ALGORITHM'] + + def init_stub(self): + self.thread_event_manager = ThreadEventManager() + self.grpc_options = [ + ('grpc.ssl_target_name_override', 'PPCS MODEL GATEWAY'), + ('grpc.max_send_message_length', + self.config_data['MAX_MESSAGE_LENGTH_MB'] * 1024 * 1024), + ('grpc.max_receive_message_length', + self.config_data['MAX_MESSAGE_LENGTH_MB'] * 1024 * 1024), + ('grpc.keepalive_time_ms', 15000), # 每 15 秒发送一次心跳 + ('grpc.keepalive_timeout_ms', 5000), # 等待心跳回应的超时时间为 5 秒 + ('grpc.keepalive_permit_without_calls', True), # 即使没有调用也允许发送心跳 + ('grpc.http2.min_time_between_pings_ms', 15000), # 心跳之间最小时间间隔为 15 秒 + ('grpc.http2.max_pings_without_data', 0), # 在发送数据前不限制心跳次数 + # 在没有数据传输的情况下,确保心跳包之间至少有 20 秒的间隔 + ('grpc.http2.min_ping_interval_without_data_ms', 20000), + ("grpc.so_reuseport", 1), + ("grpc.use_local_subchannel_pool", 1), + ('grpc.enable_retries', 1), + ('grpc.service_config', + '{ "retryPolicy":{ "maxAttempts": 4, "initialBackoff": "0.1s", "maxBackoff": "1s", "backoffMutiplier": ' + '2, "retryableStatusCodes": [ "UNAVAILABLE" ] } }') + ] + rpc_client = GrpcClient( + logger=self.logger(), + endpoint=self.config_data['GATEWAY_ENDPOINT'], + grpc_options=self.grpc_options, + ssl_switch=self.config_data['SSL_SWITCH'], + ca_path=self.config_data['CA_CRT'], + ssl_key_path=self.config_data['SSL_KEY'], + ssl_crt_path=self.config_data['SSL_CRT'] + ) + self.stub = ModelStub( + agency_id=self.config_data['AGENCY_ID'], + thread_event_manager=self.thread_event_manager, + rpc_client=rpc_client + ) + + def init_task_manager(self): + self.task_manager = TaskManager( + logger=self.logger(), + thread_event_manager=self.thread_event_manager, + stub=self.stub, + task_timeout_h=self.config_data['TASK_TIMEOUT_H'] + ) + + def init_storage_client(self): + self.storage_client = storage_loader.load( + self.config_data, self.logger()) + + def logger(self, name=None): + if self.mock_logger is None: + return logging.getLogger(name) + else: + return self.mock_logger diff --git a/python/ppc_model/common/mock/__init__.py b/python/ppc_model/common/mock/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/ppc_model/common/mock/rpc_client_mock.py b/python/ppc_model/common/mock/rpc_client_mock.py new file mode 100644 index 00000000..c3094c4d --- /dev/null +++ b/python/ppc_model/common/mock/rpc_client_mock.py @@ -0,0 +1,31 @@ +import threading + +from ppc_common.ppc_protos.generated.ppc_model_pb2 import ModelRequest, ModelResponse +from ppc_model.interface.rpc_client import RpcClient + + +class RpcClientMock(RpcClient): + def __init__(self, need_failed=False): + self._need_failed = need_failed + self._bad_guy = 0 + self._lock = threading.Lock() + self._on_message_received = None + + def set_message_handler(self, on_message_received): + self._on_message_received = on_message_received + + def send(self, request: ModelRequest): + # print( + # f"send data to {request.receiver}, task_id: {request.task_id}, " + # f"key: {request.key}, seq: {request.seq}") + self._on_message_received(request) + response = ModelResponse() + if self._need_failed: + # 模拟网络断连 + with self._lock: + self._bad_guy += 1 + response.base_response.error_code = self._bad_guy % 2 + else: + response.base_response.error_code = 0 + response.base_response.message = "success" + return response diff --git a/python/ppc_model/common/model_result.py b/python/ppc_model/common/model_result.py new file mode 100644 index 00000000..5b156948 --- /dev/null +++ b/python/ppc_model/common/model_result.py @@ -0,0 +1,212 @@ +import os +import shutil +import pandas as pd +import time +from enum import Enum +import base64 +from ppc_common.ppc_utils import utils +from ppc_common.ppc_utils.utils import AlgorithmType +from ppc_model.common.context import Context +from ppc_model.common.protocol import TaskRole +from ppc_model.network.stub import PushRequest, PullRequest + + +class ResultFileHandling: + + def __init__(self, ctx: Context) -> None: + self.ctx = ctx + self.log = ctx.components.logger() + + if ctx.algorithm_type == AlgorithmType.Train.name: + self._process_fe_result() + + # remove job workspace + # self._remove_workspace() + + # Synchronization result file + if (len(ctx.result_receiver_id_list) == 1 and ctx.participant_id_list[0] != ctx.result_receiver_id_list[0]) \ + or len(ctx.result_receiver_id_list) > 1: + self._sync_result_files() + + def _process_fe_result(self): + if os.path.exists(self.ctx.preprocessing_result_file): + column_info_fm = pd.read_csv(self.ctx.preprocessing_result_file, index_col=0) + if os.path.exists(self.ctx.iv_selected_file): + column_info_iv_fm = pd.read_csv(self.ctx.iv_selected_file, index_col=0) + merged_df = self.union_column_info(column_info_fm, column_info_iv_fm) + else: + merged_df = column_info_fm + + merged_df.fillna("None", inplace=True) + merged_df.to_csv(self.ctx.selected_col_file, sep=utils.CSV_SEP, header=True, index_label='id') + # 存储column_info到hdfs给前端展示 + self._upload_file(self.ctx.components.storage_client, self.ctx.selected_col_file, self.ctx.remote_selected_col_file) + + @staticmethod + def union_column_info(column_info1: pd.DataFrame, column_info2: pd.DataFrame): + """ + union the column_info1 with the column_info2. + + Args: + column_info1 (DataFrame): The column_info1 to be merged. + column_info2 (DataFrame): The column_info2 to be merged. + + Returns: + column_info_merge (DataFrame): The union column_info. + """ + # 将column_info1和column_info2按照left_index=True, right_index=True的方式进行合并 如果列有缺失则赋值为None 行的顺序按照column_info1 + column_info_conbine = column_info1.merge(column_info2, how='outer', left_index=True, right_index=True, sort=False) + col1_index_list = column_info1.index.to_list() + col2_index_list = column_info2.index.to_list() + merged_list = col1_index_list + [item for item in col2_index_list if item not in col1_index_list] + column_info_conbine = column_info_conbine.reindex(merged_list) + return column_info_conbine + + @staticmethod + def _upload_file(storage_client, local_file, remote_file): + if storage_client is not None: + storage_client.upload_file(local_file, remote_file) + + @staticmethod + def _download_file(storage_client, local_file, remote_file): + if storage_client is not None and not os.path.exists(local_file): + storage_client.download_file(remote_file, local_file) + @staticmethod + def make_graph_data(components, job_id, graph_file_name): + graph_format = 'svg+xml' + # download with cache + remote_file_path = os.path.join(job_id, graph_file_name) + local_file_path = os.path.join( + components.job_cache_dir, remote_file_path) + components.storage_client.download_file( + remote_file_path, local_file_path, True) + file_bytes = None + with open(local_file_path, 'r') as file: + file_content = file.read() + file_bytes = file_content.encode('utf-8') + encoded_data = "" + if file_bytes is not None: + encoded_data = base64.b64encode(file_bytes).decode('ascii') + time.sleep(0.1) + return f"data:image/{graph_format};base64,{encoded_data}" + + def get_remote_path(components, job_id, csv_file_name): + if components.storage_client.get_home_path() is None: + return os.path.join(job_id, csv_file_name) + return os.path.join(components.storage_client.get_home_path(), job_id, csv_file_name) + + @staticmethod + def make_csv_data(components, job_id, csv_file_name): + import pandas as pd + from io import StringIO + remote_file_path = os.path.join(job_id, csv_file_name) + local_file_path = os.path.join( + components.job_cache_dir, remote_file_path) + components.storage_client.download_file( + remote_file_path, local_file_path, True) + file_bytes = None + with open(local_file_path, 'r') as file: + file_content = file.read() + file_bytes = file_content.encode('utf-8') + # encoded_data = base64.b64encode(data).decode('ascii') + # csv_data = np.genfromtxt(data.decode(), delimiter=',') + # csv_data = np.genfromtxt(StringIO(data.decode()), delimiter=',') + csv_data = "" + if file_bytes is not None: + csv_data = pd.read_csv(StringIO(file_bytes.decode())).astype('str') + return csv_data + + def _remove_workspace(self): + if os.path.exists(self.ctx.workspace): + shutil.rmtree(self.ctx.workspace) + self.log.info(f'job {self.ctx.job_id}: {self.ctx.workspace} has been removed.') + else: + self.log.info(f'job {self.ctx.job_id}: {self.ctx.workspace} does not exist.') + + def _sync_result_files(self): + if self.ctx.algorithm_type == AlgorithmType.Train.name: + self.sync_result_file(self.ctx, self.ctx.metrics_iteration_file, + self.ctx.remote_metrics_iteration_file, 'f1') + self.sync_result_file(self.ctx, self.ctx.feature_importance_file, + self.ctx.remote_feature_importance_file, 'f2') + self.sync_result_file(self.ctx, self.ctx.summary_evaluation_file, + self.ctx.remote_summary_evaluation_file, 'f3') + self.sync_result_file(self.ctx, self.ctx.train_metric_ks_table, + self.ctx.remote_train_metric_ks_table, 'f4') + self.sync_result_file(self.ctx, self.ctx.train_metric_roc_file, + self.ctx.remote_train_metric_roc_file, 'f5') + self.sync_result_file(self.ctx, self.ctx.train_metric_ks_file, + self.ctx.remote_train_metric_ks_file, 'f6') + self.sync_result_file(self.ctx, self.ctx.train_metric_pr_file, + self.ctx.remote_train_metric_pr_file, 'f7') + self.sync_result_file(self.ctx, self.ctx.train_metric_acc_file, + self.ctx.remote_train_metric_acc_file, 'f8') + self.sync_result_file(self.ctx, self.ctx.test_metric_ks_table, + self.ctx.remote_test_metric_ks_table, 'f9') + self.sync_result_file(self.ctx, self.ctx.test_metric_roc_file, + self.ctx.remote_test_metric_roc_file, 'f10') + self.sync_result_file(self.ctx, self.ctx.test_metric_ks_file, + self.ctx.remote_test_metric_ks_file, 'f11') + self.sync_result_file(self.ctx, self.ctx.test_metric_pr_file, + self.ctx.remote_test_metric_pr_file, 'f12') + self.sync_result_file(self.ctx, self.ctx.test_metric_acc_file, + self.ctx.remote_test_metric_acc_file, 'f13') + + @staticmethod + def sync_result_file(ctx, local_file, remote_file, key_file): + if ctx.role == TaskRole.ACTIVE_PARTY: + with open(local_file, 'rb') as f: + byte_data = f.read() + for partner_index in range(1, len(ctx.participant_id_list)): + if ctx.participant_id_list[partner_index] in ctx.result_receiver_id_list: + SendMessage._send_byte_data(ctx.components.stub, ctx, f'{CommonMessage.SYNC_FILE.value}_{key_file}', + byte_data, partner_index) + else: + if ctx.components.config_data['AGENCY_ID'] in ctx.result_receiver_id_list: + byte_data = SendMessage._receive_byte_data(ctx.components.stub, ctx, + f'{CommonMessage.SYNC_FILE.value}_{key_file}', 0) + with open(local_file, 'wb') as f: + f.write(byte_data) + ResultFileHandling._upload_file(ctx.components.storage_client, local_file, remote_file) + + +class CommonMessage(Enum): + SYNC_FILE = "SYNC_FILE" + EVAL_SET_FILE = "EVAL_SET_FILE" + + +class SendMessage: + + @staticmethod + def _send_byte_data(stub, ctx, key_type, byte_data, partner_index): + log = ctx.components.logger() + start_time = time.time() + partner_id = ctx.participant_id_list[partner_index] + + stub.push(PushRequest( + receiver=partner_id, + task_id=ctx.task_id, + key=key_type, + data=byte_data + )) + + log.info( + f"task {ctx.task_id}: Sending {key_type} to {partner_id} finished, " + f"data_size: {len(byte_data) / 1024}KB, time_costs: {time.time() - start_time}s") + + @staticmethod + def _receive_byte_data(stub, ctx, key_type, partner_index): + log = ctx.components.logger() + start_time = time.time() + partner_id = ctx.participant_id_list[partner_index] + + byte_data = stub.pull(PullRequest( + sender=partner_id, + task_id=ctx.task_id, + key=key_type + )) + + log.info( + f"task {ctx.task_id}: Received {key_type} from {partner_id} finished, " + f"data_size: {len(byte_data) / 1024}KB, time_costs: {time.time() - start_time}s") + return byte_data diff --git a/python/ppc_model/common/model_setting.py b/python/ppc_model/common/model_setting.py new file mode 100644 index 00000000..ab139ec0 --- /dev/null +++ b/python/ppc_model/common/model_setting.py @@ -0,0 +1,90 @@ +# -*- coding: utf-8 -*- +from ppc_common.ppc_utils import common_func + + +class ModelSetting: + def __init__(self, model_dict): + self.use_psi = common_func.get_config_value( + "use_psi", False, model_dict, False) + self.fillna = common_func.get_config_value( + "fillna", False, model_dict, False) + self.na_select = float(common_func.get_config_value( + "na_select", 1.0, model_dict, False)) + self.filloutlier = common_func.get_config_value( + "filloutlier", False, model_dict, False) + self.normalized = common_func.get_config_value( + "normalized", False, model_dict, False) + self.standardized = common_func.get_config_value( + "standardized", False, model_dict, False) + self.categorical = common_func.get_config_value( + "categorical", '', model_dict, False) + self.psi_select_col = common_func.get_config_value( + "psi_select_col", '', model_dict, False) + self.psi_select_base = common_func.get_config_value( + "psi_select_base", '', model_dict, False) + self.psi_select_base = float(common_func.get_config_value( + "psi_select_thresh", 0.3, model_dict, False)) + self.psi_select_bins = int(common_func.get_config_value( + "psi_select_bins", 4, model_dict, False)) + self.corr_select = float(common_func.get_config_value( + "corr_select", 0, model_dict, False)) + self.use_iv = common_func.get_config_value( + "use_iv", False, model_dict, False) + self.group_num = int(common_func.get_config_value( + "group_num", 4, model_dict, False)) + self.iv_thresh = float(common_func.get_config_value( + "iv_thresh", 0.1, model_dict, False)) + self.use_goss = common_func.get_config_value( + "use_goss", False, model_dict, False) + self.test_dataset_percentage = float(common_func.get_config_value( + "test_dataset_percentage", 0.3, model_dict, False)) + self.learning_rate = float(common_func.get_config_value( + "learning_rate", 0.1, model_dict, False)) + self.num_trees = int(common_func.get_config_value( + "num_trees", 6, model_dict, False)) + self.max_depth = int(common_func.get_config_value( + "max_depth", 3, model_dict, False)) + self.max_bin = int(common_func.get_config_value( + "max_bin", 4, model_dict, False)) + self.silent = common_func.get_config_value( + "silent", False, model_dict, False) + self.subsample = float(common_func.get_config_value( + "subsample", 1, model_dict, False)) + self.colsample_bytree = float(common_func.get_config_value( + "colsample_bytree", 1, model_dict, False)) + self.colsample_bylevel = float(common_func.get_config_value( + "colsample_bylevel", 1, model_dict, False)) + self.reg_alpha = float(common_func.get_config_value( + "reg_alpha", 0, model_dict, False)) + self.reg_lambda = float(common_func.get_config_value( + "reg_lambda", 1, model_dict, False)) + self.gamma = float(common_func.get_config_value( + "gamma", 0, model_dict, False)) + self.min_child_weight = float(common_func.get_config_value( + "min_child_weight", 0.0, model_dict, False)) + self.min_child_samples = int(common_func.get_config_value( + "min_child_samples", 10, model_dict, False)) + self.seed = int(common_func.get_config_value( + "seed", 2024, model_dict, False)) + self.early_stopping_rounds = int(common_func.get_config_value( + "early_stopping_rounds", 5, model_dict, False)) + self.eval_metric = common_func.get_config_value( + "eval_metric", "auc", model_dict, False) + self.verbose_eval = int(common_func.get_config_value( + "verbose_eval", 1, model_dict, False)) + self.eval_set_column = common_func.get_config_value( + "eval_set_column", "", model_dict, False) + self.train_set_value = common_func.get_config_value( + "train_set_value", "", model_dict, False) + self.eval_set_value = common_func.get_config_value( + "eval_set_value", "", model_dict, False) + self.train_features = common_func.get_config_value( + "train_features", "", model_dict, False) + self.epochs = int(common_func.get_config_value( + "epochs", 3, model_dict, False)) + self.batch_size = int(common_func.get_config_value( + "batch_size", 16, model_dict, False)) + self.threads = int(common_func.get_config_value( + "threads", 8, model_dict, False)) + self.one_hot = common_func.get_config_value( + "one_hot", 0, model_dict, False) diff --git a/python/ppc_model/common/protocol.py b/python/ppc_model/common/protocol.py new file mode 100644 index 00000000..c09a4da4 --- /dev/null +++ b/python/ppc_model/common/protocol.py @@ -0,0 +1,88 @@ +from enum import Enum + +from ppc_common.ppc_protos.generated.ppc_model_pb2 import Cipher1DimList, Cipher2DimList +from ppc_common.ppc_protos.generated.ppc_model_pb2 import CipherList, ModelCipher +from ppc_common.ppc_utils import utils + + +class TaskRole(Enum): + ACTIVE_PARTY = "ACTIVE_PARTY" + PASSIVE_PARTY = "PASSIVE_PARTY" + + +class ModelTask(Enum): + PREPROCESSING = "PREPROCESSING" + FEATURE_ENGINEERING = "FEATURE_ENGINEERING" + XGB_TRAINING = "XGB_TRAINING" + XGB_PREDICTING = "XGB_PREDICTING" + + +class TaskStatus(Enum): + RUNNING = "RUNNING" + COMPLETED = "COMPLETED" + FAILED = "FAILED" + + +class RpcType(Enum): + HTTP = "HTTP" + GRPC = "GRPC" + + +class PheMessage: + + @staticmethod + def packing_data(codec, public_key, cipher_list): + enc_data_pb = CipherList() + enc_data_pb.public_key = codec.encode_enc_key(public_key) + + for cipher in cipher_list: + model_cipher = ModelCipher() + model_cipher.ciphertext, model_cipher.exponent = codec.encode_cipher(cipher) + enc_data_pb.cipher_list.append(model_cipher) + + return utils.pb_to_bytes(enc_data_pb) + + @staticmethod + def unpacking_data(codec, data): + enc_data_pb = CipherList() + utils.bytes_to_pb(enc_data_pb, data) + public_key = codec.decode_enc_key(enc_data_pb.public_key) + enc_data = [codec.decode_cipher(public_key, + cipher.ciphertext, + cipher.exponent + ) for cipher in enc_data_pb.cipher_list] + return public_key, enc_data + + @staticmethod + def packing_2dim_data(codec, public_key, cipher_2d_list): + enc_data_pb = Cipher2DimList() + enc_data_pb.public_key = codec.encode_enc_key(public_key) + + for cipher_list in cipher_2d_list: + enc_1d_pb = Cipher1DimList() + for cipher in cipher_list: + model_cipher = ModelCipher() + model_cipher.ciphertext, model_cipher.exponent = \ + codec.encode_cipher(cipher, be_secure=False) + enc_1d_pb.cipher_list.append(model_cipher) + enc_data_pb.cipher_1d_list.append(enc_1d_pb) + + return utils.pb_to_bytes(enc_data_pb) + + @staticmethod + def unpacking_2dim_data(codec, data): + enc_data_pb = Cipher2DimList() + utils.bytes_to_pb(enc_data_pb, data) + public_key = codec.decode_enc_key(enc_data_pb.public_key) + enc_data = [] + for enc_1d_pb in enc_data_pb.cipher_1d_list: + enc_1d_data = [codec.decode_cipher(public_key, + cipher.ciphertext, + cipher.exponent + ) for cipher in enc_1d_pb.cipher_list] + enc_data.append(enc_1d_data) + return public_key, enc_data + + +LOG_START_FLAG_FORMATTER = "$$$StartModelJob:{job_id}" +LOG_END_FLAG_FORMATTER = "$$$EndModelJob:{job_id}" diff --git a/python/ppc_model/conf/application-sample.yml b/python/ppc_model/conf/application-sample.yml new file mode 100644 index 00000000..2fecc8eb --- /dev/null +++ b/python/ppc_model/conf/application-sample.yml @@ -0,0 +1,43 @@ +HOST: "0.0.0.0" +HTTP_PORT: 43471 +RPC_PORT: 43472 + +AGENCY_ID: 'WeBank' +GATEWAY_ENDPOINT: "127.0.0.1:43454" + +PUBLIC_KEY_LENGTH: 2048 + +MAX_MESSAGE_LENGTH_MB: 100 +TASK_TIMEOUT_H: 1800 + +SSL_SWITCH: 0 +CA_CRT: "./ca.crt" +SSL_CRT: "./ssl.crt" +SSL_KEY: "./ssl.key" + + +PEM_PATH: "/data/app/ppcs-model4ef/wedpr-model-node/ppc_model_service/server.pem" +SHARE_PATH: "/data/app/ppcs-model4ef/wedpr-model-node/ppc_model_service/dataset_share/" + +DB_TYPE: "mysql" +SQLALCHEMY_DATABASE_URI: "mysql://[*user_ppcsmodeladm]:[*pass_ppcsmodeladm]@[@4346-TDSQL_VIP]:[@4346-TDSQL_PORT]/ppcsmodeladm?autocommit=true&charset=utf8mb4" + +# interagency services +HDFS_ENDPOINT: "http://127.0.0.1:50070" +# HDFS, +STORAGE_TYPE: "HDFS" +HDFS_ENDPOINT: "http://127.0.0.1:9870" +HDFS_USER: "ppc" +HDFS_HOME: "/user/ppc/model/webank" + +# ECDSA or GM +CRYPTO_TYPE: "ECDSA" +private_key: "" +public_key: "" +gm_private_key: "" +gm_public_key: "" + +UPLOAD_FOLDER: "./upload_data_folder" +JOB_TEMP_DIR: ".cache/job" + +FE_TIMEOUT_S: 5400 diff --git a/python/ppc_model/conf/logging.conf b/python/ppc_model/conf/logging.conf new file mode 100644 index 00000000..31f30acb --- /dev/null +++ b/python/ppc_model/conf/logging.conf @@ -0,0 +1,40 @@ +[loggers] +keys=root,wsgi + +[logger_root] +level=INFO +handlers=consoleHandler,fileHandler + +[logger_wsgi] +level = INFO +handlers = accessHandler +qualname = wsgi +propagate = 0 + +[handlers] +keys=fileHandler,consoleHandler,accessHandler + +[handler_accessHandler] +class=handlers.TimedRotatingFileHandler +args=('/data/app/logs/ppcs-model4ef/appmonitor.log', 'D', 1, 30, 'utf-8') +level=INFO +formatter=simpleFormatter + +[handler_fileHandler] +class=handlers.TimedRotatingFileHandler +args=('/data/app/logs/ppcs-model4ef/ppcs-model4ef-node.log', 'D', 1, 30, 'utf-8') +level=INFO +formatter=simpleFormatter + +[handler_consoleHandler] +class=StreamHandler +args=(sys.stdout,) +level=ERROR +formatter=simpleFormatter + +[formatters] +keys=simpleFormatter + +[formatter_simpleFormatter] +format=[%(levelname)s][%(asctime)s %(msecs)03d][%(process)d][%(filename)s:%(lineno)d] %(message)s +datefmt=%Y-%m-%d %H:%M:%S diff --git a/python/ppc_model/datasets/__init__.py b/python/ppc_model/datasets/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/ppc_model/datasets/data_reduction/__init__.py b/python/ppc_model/datasets/data_reduction/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/ppc_model/datasets/data_reduction/feature_selection.py b/python/ppc_model/datasets/data_reduction/feature_selection.py new file mode 100644 index 00000000..cfeec89b --- /dev/null +++ b/python/ppc_model/datasets/data_reduction/feature_selection.py @@ -0,0 +1,30 @@ +import numpy as np + + +class FeatureSelection: + + @staticmethod + def feature_selecting(feature_name: list, train_feats: list, fr: float): + if train_feats is not None and len(train_feats) > 0: + feature_select = FeatureSelection._get_train_feature( + feature_name, train_feats) + elif fr > 0 and fr < 1: + feature_select = FeatureSelection._get_feature_rate( + feature_name, fr) + else: + feature_select = list(range(len(feature_name))) + return feature_select + + @staticmethod + def _get_train_feature(feature_name: list, train_feats: list): + feature_select = [] + for i, feature in enumerate(feature_name): + if feature in train_feats: + feature_select.append(i) + return feature_select + + @staticmethod + def _get_feature_rate(feature_name: list, fr: float): + feature_select = sorted(np.random.choice( + range(len(feature_name)), size=int(len(feature_name) * fr), replace=False)) + return feature_select diff --git a/python/ppc_model/datasets/data_reduction/sampling.py b/python/ppc_model/datasets/data_reduction/sampling.py new file mode 100644 index 00000000..0866e3dd --- /dev/null +++ b/python/ppc_model/datasets/data_reduction/sampling.py @@ -0,0 +1,86 @@ +import numpy as np + + +class Sampling: + + @staticmethod + def sample_selecting( + g_list: np.ndarray, + h_list: np.ndarray, + subsample: float = 0, + use_goss: bool = False, + top_rate: float = 0.2, + other_rate: float = 0.1 + ): + if use_goss: + instance, used_glist, used_hlist = Sampling._get_goss_sampling( + g_list, h_list, top_rate, other_rate) + elif subsample > 0 and subsample < 1: + instance, used_glist, used_hlist = Sampling._get_subsample_sampling( + g_list, h_list, subsample) + else: + instance, used_glist, used_hlist = Sampling._get_sampling( + g_list, h_list) + + return instance, used_glist, used_hlist + + def _get_goss_sampling(g_list, h_list, top_rate, other_rate): + + n = len(g_list) + instance, used_glist, used_hlist = Sampling._goss_sampleing( + n, top_rate, other_rate, g_list, h_list) + + return instance, used_glist, used_hlist + + def _get_subsample_sampling(g_list, h_list, subsample): + + rand_size = int(len(g_list) * subsample) + rand_idx = np.array(sorted( + np.random.choice(list(range(len(g_list))), size=rand_size, replace=False))) + used_glist = np.array(g_list)[(rand_idx)] + used_hlist = np.array(h_list)[(rand_idx)] + + # used_idx = {} + # for i in range(rand_size): + # used_idx[rand_idx[i]] = i + # curr_instance = np.array(list(used_idx.keys())) + + return rand_idx, used_glist, used_hlist + + def _get_sampling(g_list, h_list): + + used_glist, used_hlist = g_list, h_list + instance = np.array(list(range(len(g_list)))) + + return instance, used_glist, used_hlist + + @staticmethod + def _goss_sampleing(n, a, b, g_list, h_list): + top_size = int(n * a) + rand_size = int(n * b) + abs_g = np.abs(g_list) + + top_idx = np.argsort(abs_g)[-top_size:] + rand_idx = np.random.choice(np.argsort( + abs_g)[:-top_size], size=rand_size, replace=False) + used_idx = np.append(top_idx, rand_idx) + + fact = (1 - a) / b + rand_glist = np.array(g_list)[(rand_idx)] * fact + used_glist = np.append(np.array(g_list)[(top_idx)], rand_glist) + rand_hlist = np.array(h_list)[(rand_idx)] * fact + used_hlist = np.append(np.array(h_list)[(top_idx)], rand_hlist) + + return Sampling._sort_instance(used_idx, used_glist, used_hlist) + + @staticmethod + def _sort_instance(instance, g_list, h_list): + # 获取排序索引 + sorted_indices = np.argsort(instance) + + # 对所有数组进行排序 + sorted_idx = instance[sorted_indices] + sorted_glist = g_list[sorted_indices] + sorted_hlist = h_list[sorted_indices] + + return sorted_idx, sorted_glist, sorted_hlist diff --git a/python/ppc_model/datasets/data_reduction/test/__init__.py b/python/ppc_model/datasets/data_reduction/test/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/ppc_model/datasets/data_reduction/test/test_data_reduction.py b/python/ppc_model/datasets/data_reduction/test/test_data_reduction.py new file mode 100644 index 00000000..689be74f --- /dev/null +++ b/python/ppc_model/datasets/data_reduction/test/test_data_reduction.py @@ -0,0 +1,61 @@ +import unittest +import numpy as np + +from ppc_model.datasets.data_reduction.feature_selection import FeatureSelection +from ppc_model.datasets.data_reduction.sampling import Sampling + + +class TestFeatureSelection(unittest.TestCase): + + feature_name = [f'x{i+1}' for i in range(30)] + + def test_fr_feature_select(self): + + feature_select = FeatureSelection.feature_selecting( + self.feature_name, [], 0.8) + self.assertEqual(len(feature_select), len(self.feature_name) * 0.8) + + def test_customized_feature_select(self): + + train_feats = ['x1', 'x3', 'x15', 'x27', 'x33'] + feature_select = FeatureSelection.feature_selecting( + self.feature_name, train_feats, 0.8) + self.assertEqual(len(feature_select), len( + set(self.feature_name).intersection(set(train_feats)))) + self.assertEqual(sorted([f'x{i+1}' for i in feature_select]), + sorted(set(self.feature_name).intersection(set(train_feats)))) + + def test_feature_select(self): + + feature_select = FeatureSelection.feature_selecting( + self.feature_name, [], 0) + self.assertEqual(len(feature_select), len(self.feature_name)) + self.assertEqual(feature_select, list(range(len(self.feature_name)))) + + +class TestSampling(unittest.TestCase): + + g_list = [np.random.rand() for i in range(500)] + h_list = [np.random.rand() for i in range(500)] + + def test_goss_sampling(self): + instance, used_glist, used_hlist = Sampling.sample_selecting( + self.g_list, self.h_list, use_goss=True) + self.assertEqual(len(instance), int( + len(self.g_list) * 0.2) + int(len(self.g_list) * 0.1)) + assert max(self.g_list) in used_glist + assert np.argmax(self.g_list) in instance + + def test_subsample_sampling(self): + instance, used_glist, used_hlist = Sampling.sample_selecting( + self.g_list, self.h_list, subsample=0.6) + self.assertEqual(len(instance), int(len(self.g_list) * 0.6)) + + def test_sampling(self): + instance, used_glist, used_hlist = Sampling.sample_selecting( + self.g_list, self.h_list) + self.assertEqual(len(instance), len(self.g_list)) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/ppc_model/datasets/dataset.py b/python/ppc_model/datasets/dataset.py new file mode 100644 index 00000000..0278c51c --- /dev/null +++ b/python/ppc_model/datasets/dataset.py @@ -0,0 +1,232 @@ +import os +import numpy as np +import pandas as pd +from sklearn.model_selection import train_test_split + +from ppc_common.ppc_utils.exception import PpcException, PpcErrorCode +from ppc_common.ppc_utils.utils import AlgorithmType +from ppc_model.common.protocol import TaskRole +from ppc_model.common.model_result import ResultFileHandling, CommonMessage, SendMessage +from ppc_model.secure_lgbm.secure_lgbm_context import SecureLGBMContext + + +class SecureDataset: + def __init__(self, ctx: SecureLGBMContext, model_data=None, delimiter: str = ' '): + + self.eval_column_file = ctx.eval_column_file + self.iv_selected_file = ctx.iv_selected_file + self.selected_col_file = ctx.selected_col_file + self.is_label_holder = ctx.is_label_holder + self.algorithm_type = ctx.algorithm_type + self.test_size = ctx.lgbm_params.test_size + self.random_state = ctx.lgbm_params.random_state + self.eval_set_column = ctx.lgbm_params.eval_set_column + self.train_set_value = ctx.lgbm_params.train_set_value + self.eval_set_value = ctx.lgbm_params.eval_set_value + + self.ctx = ctx + self.train_X = None + self.test_X = None + self.train_y = None + self.test_y = None + self.train_idx = None + self.test_idx = None + self.feature_name = None + + if model_data is None: + self.model_data = pd.read_csv( + ctx.model_prepare_file, header=0, delimiter=delimiter) + else: + self.model_data = model_data + + self._construct_dataset() + + @staticmethod + def read_dataset(file_path, has_label: bool, delimiter: str = ' '): + df = pd.read_csv(file_path, header=0, delimiter=delimiter) + + if 'id' in df.columns: + df.drop('id', axis=1, inplace=True) + + field_list = df.columns.tolist() + + if has_label: + return field_list, df.iloc[:, 0].values, df.iloc[:, 1:].values + else: + return field_list, None, df.iloc[:, :].values + + @staticmethod + def simulate_dataset(data_size: int = 1000, feature_dim: int = 100, has_label: bool = True): + X = np.random.rand(data_size, feature_dim) + if has_label: + y = np.random.randint(2, size=data_size) + return SecureDataset.assembling_dataset(X, y) + return SecureDataset.assembling_dataset(X) + + @staticmethod + def assembling_dataset(X: np.ndarray, y: np.ndarray = None): + # 创建自定义列名 + num_features = X.shape[1] + column_names = [f'x{i+1}' for i in range(num_features)] + + # 创建数据框并使用自定义列名 + df = pd.DataFrame(X, columns=column_names) + + # 添加 id 列和 y 列 + df.insert(0, 'id', range(1, len(df) + 1)) + if y is not None: + df.insert(1, 'y', y) + + return df + + @staticmethod + def hetero_split_dataset(df: pd.DataFrame, split_point: int = None): + # 获取特征列名 + column_names = df.columns.tolist() + if 'id' in column_names: + column_names.remove('id') + if 'y' in column_names: + column_names.remove('y') + + # 计算切分点 + if not split_point: + split_point = (df.shape[1] - 2) // 2 + + # 划分特征列 + columns_with_y = ['id', 'y'] + column_names[:split_point] + columns_without_y = ['id'] + column_names[split_point:] + + # 创建两个数据集 + df_with_y = df[columns_with_y] + df_without_y = df[columns_without_y] + + return df_with_y, df_without_y + + def _random_split_dataset(self): + + train_data, test_data = train_test_split( + self.model_data, test_size=self.test_size, random_state=self.random_state) + + return train_data, test_data + + def _customized_split_dataset(self): + if self.ctx.role == TaskRole.ACTIVE_PARTY: + for partner_index in range(1, len(self.ctx.participant_id_list)): + byte_data = SendMessage._receive_byte_data(self.ctx.components.stub, self.ctx, + f'{CommonMessage.EVAL_SET_FILE.value}', partner_index) + if not os.path.exists(self.eval_column_file) and byte_data != bytes(): + with open(self.eval_column_file, 'wb') as f: + f.write(byte_data) + with open(self.eval_column_file, 'rb') as f: + byte_data = f.read() + for partner_index in range(1, len(self.ctx.participant_id_list)): + SendMessage._send_byte_data(self.ctx.components.stub, self.ctx, f'{CommonMessage.EVAL_SET_FILE.value}', + byte_data, partner_index) + else: + if not os.path.exists(self.eval_column_file): + byte_data = bytes() + else: + with open(self.eval_column_file, 'rb') as f: + byte_data = f.read() + SendMessage._send_byte_data(self.ctx.components.stub, self.ctx, f'{CommonMessage.EVAL_SET_FILE.value}', + byte_data, 0) + byte_data = SendMessage._receive_byte_data(self.ctx.components.stub, self.ctx, + f'{CommonMessage.EVAL_SET_FILE.value}', 0) + if not os.path.exists(self.eval_column_file): + with open(self.eval_column_file, 'wb') as f: + f.write(byte_data) + + eval_set_df = pd.read_csv(self.eval_column_file, header=0) + train_data = self.model_data[eval_set_df[self.eval_set_column] == self.train_set_value] + test_data = self.model_data[eval_set_df[self.eval_set_column] == self.eval_set_value] + + return train_data, test_data + + def _construct_model_dataset(self, train_data, test_data): + + self.train_idx = train_data['id'].values + self.test_idx = test_data['id'].values + + if self.is_label_holder and 'y' in train_data.columns: + self.train_y = train_data['y'].values + self.test_y = test_data['y'].values + self.train_X = train_data.drop(columns=['id', 'y']).values + self.test_X = test_data.drop(columns=['id', 'y']).values + self.feature_name = train_data.drop( + columns=['id', 'y']).columns.tolist() + else: + self.train_X = train_data.drop(columns=['id']).values + self.test_X = test_data.drop(columns=['id']).values + self.feature_name = train_data.drop( + columns=['id']).columns.tolist() + + def _construct_predict_dataset(self, test_data): + self.test_idx = test_data['id'].values + if self.is_label_holder and 'y' in test_data.columns: + self.test_y = test_data['y'].values + self.test_X = test_data.drop(columns=['id', 'y']).values + self.feature_name = test_data.drop( + columns=['id', 'y']).columns.tolist() + else: + self.test_X = test_data.drop(columns=['id']).values + self.feature_name = test_data.drop(columns=['id']).columns.tolist() + + def _dataset_fe_selected(self, file_path, feature_name): + iv_selected = pd.read_csv(file_path, header=0) + selected_list = iv_selected[feature_name][iv_selected['iv_selected'] == 1].tolist( + ) + + drop_columns = [] + for column in self.model_data.columns: + if column == 'id' or column == 'y': + continue + if column not in selected_list: + drop_columns.append(column) + + if len(drop_columns) > 0: + self.model_data = self.model_data.drop(columns=drop_columns) + + def _construct_dataset(self): + + if os.path.exists(self.iv_selected_file): + self._dataset_fe_selected(self.iv_selected_file, 'feature') + + if self.algorithm_type == AlgorithmType.Predict.name \ + and not os.path.exists(self.selected_col_file): + try: + self.ctx.remote_selected_col_file = os.path.join( + self.ctx.lgbm_params.training_job_id, self.ctx.SELECTED_COL_FILE) + ResultFileHandling._download_file(self.ctx.components.storage_client, + self.selected_col_file, self.ctx.remote_selected_col_file) + self._dataset_fe_selected(self.selected_col_file, 'id') + except: + pass + + if 'id' not in self.model_data.columns: + if self.ctx.dataset_file_path is None: + import glob + pattern = os.path.join(self.ctx.workspace, 'd-*') + dataset_file_path = glob.glob(pattern)[0] + else: + dataset_file_path = self.ctx.dataset_file_path + dataset_id = pd.read_csv( + dataset_file_path, header=0, usecols=['id']) + if os.path.exists(self.ctx.psi_result_path): + psi_data = pd.read_csv(self.ctx.psi_result_path, header=0) + dataset_id = pd.merge(dataset_id, psi_data, on=[ + 'id']).sort_values(by='id', ascending=True) + self.model_data = pd.concat([dataset_id, self.model_data], axis=1) + + if self.algorithm_type == AlgorithmType.Train.name: + if self.eval_set_column: + train_data, test_data = self._customized_split_dataset() + else: + train_data, test_data = self._random_split_dataset() + self._construct_model_dataset(train_data, test_data) + + elif self.algorithm_type == AlgorithmType.Predict.name: + test_data = self.model_data + self._construct_predict_dataset(test_data) + else: + raise PpcException(PpcErrorCode.ALGORITHM_TYPE_ERROR.get_code(), + PpcErrorCode.ALGORITHM_TYPE_ERROR.get_message()) diff --git a/python/ppc_model/datasets/feature_binning/__init__.py b/python/ppc_model/datasets/feature_binning/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/ppc_model/datasets/feature_binning/feature_binning.py b/python/ppc_model/datasets/feature_binning/feature_binning.py new file mode 100644 index 00000000..91c5ab3d --- /dev/null +++ b/python/ppc_model/datasets/feature_binning/feature_binning.py @@ -0,0 +1,131 @@ +import numpy as np +import pandas as pd + +from ppc_common.ppc_utils.utils import AlgorithmType +from ppc_model.common.context import Context + + +class FeatureBinning: + def __init__(self, ctx: Context): + self.ctx = ctx + self.params = ctx.lgbm_params + self.data = None + self.data_bin = None + self.data_split = None + + @staticmethod + def binning_continuous_feature(feature: np.ndarray, max_bin: int, is_equal_freq: bool = True): + try: + if is_equal_freq: + # 等频分箱,不替换缺失值 + Xk_bin, Xk_split = pd.qcut( + feature, q=max_bin, retbins=True, labels=False) + else: + # 等距分箱,不替换缺失值 + Xk_bin, Xk_split = pd.cut( + feature, max_bin, retbins=True, labels=False) + # 将缺失值标记为 -1 + Xk_bin[np.isnan(feature)] = -1 + except ValueError: + unique_values = sorted(set(feature[~np.isnan(feature)])) + if len(unique_values) > 1000: + raise Exception( + 'Features with more than 1000 groups are not supported.') + if len(unique_values) == 2 and 0 in unique_values and 1 in unique_values: + Xk_bin = feature + Xk_split = [min(unique_values) - 0.01, 0.5, max(unique_values)] + elif len(unique_values) > max_bin: + Xk_bin, Xk_split = pd.cut( + feature, max_bin, labels=False, retbins=True) + else: + # 创建映射字典 + mapping_dict = {value: i for i, + value in enumerate(unique_values)} + # 使用map函数进行重新编号 + Xk_bin = pd.DataFrame(feature)[0].map(mapping_dict).values + Xk_split = [min(unique_values) - 0.01] + list(unique_values) + Xk_bin[np.isnan(feature)] = -1 + + return Xk_bin, Xk_split + + @staticmethod + def binning_categorical_feature(feature: np.ndarray): + unique_values = sorted(set(feature[~np.isnan(feature)])) + if len(unique_values) > 1000: + raise Exception( + 'Features with more than 1000 groups are not supported.') + mapping_dict = {value: i for i, value in enumerate(unique_values)} + # 使用map函数进行重新编号 + Xk_bin = pd.DataFrame(feature)[0].map(mapping_dict).values + Xk_split = [min(unique_values) - 0.01] + list(unique_values) + Xk_bin[np.isnan(feature)] = -1 + + return Xk_bin, Xk_split + + def data_binning(self, data: np.ndarray, data_split=None): + + self.data = data + self.data_split = data_split + + if self.ctx.algorithm_type == AlgorithmType.Train.name and self.data_split is None: + self._generate_data_binning() + else: + self._reuse_data_binning(data_split) + + return self.data_bin, self.data_split + + def _generate_data_binning(self): + + n = self.data.shape[0] + d = self.data.shape[1] + + X_bin = np.zeros((d, n), dtype='int16') + X_split = [] + + for idx, feature in enumerate(self.data.T): + if idx in self.params.my_categorical_idx: + Xk_bin, Xk_split = FeatureBinning.binning_categorical_feature( + feature) + else: + Xk_bin, Xk_split = FeatureBinning.binning_continuous_feature( + feature, self.params.max_bin) + + X_bin[idx] = Xk_bin + if isinstance(Xk_split, np.ndarray): + Xk_split = Xk_split.tolist() + X_split.append(Xk_split) + + self.data_bin = X_bin.T + self.data_split = X_split + + def _reuse_data_binning(self, data_split): + + self.data_split = data_split + + n = self.data.shape[0] + d = self.data.shape[1] + + test_X_bin = np.zeros((d, n), dtype='int16') + + for k in range(d): + bin_min = min(self.data[:, k]) - 1 + bin_max = max(self.data[:, k]) + 1 + if np.isnan(bin_min): + bin_min = min(self.data_split[k]) + if np.isnan(bin_max): + bin_max = max(self.data_split[k]) + + bin_min = min(bin_min, min(self.data_split[k])) - 1 + bin_max = max(bin_max, max(self.data_split[k])) + 1 + + if len(self.data_split[k]) > 2: + bins = np.concatenate( + ([bin_min], self.data_split[k][1:-1], [bin_max]), axis=0) + else: + bins = np.concatenate(([bin_min], [bin_max]), axis=0) + test_Xk_bin = pd.cut(self.data[:, k], bins, labels=False) + test_Xk_bin[np.isnan(test_Xk_bin)] = -1 + + test_X_bin[k] = test_Xk_bin + + self.data_bin = test_X_bin.T diff --git a/python/ppc_model/datasets/feature_binning/test/__init__.py b/python/ppc_model/datasets/feature_binning/test/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/ppc_model/datasets/feature_binning/test/test_feature_binning.py b/python/ppc_model/datasets/feature_binning/test/test_feature_binning.py new file mode 100644 index 00000000..e06b82b2 --- /dev/null +++ b/python/ppc_model/datasets/feature_binning/test/test_feature_binning.py @@ -0,0 +1,102 @@ +import unittest + +from ppc_model.common.initializer import Initializer +from ppc_common.ppc_mock.mock_objects import MockLogger +from ppc_model.datasets.dataset import SecureDataset +from ppc_model.datasets.feature_binning.feature_binning import FeatureBinning +from ppc_model.secure_lgbm.secure_lgbm_context import SecureLGBMContext + + +data_size = 1000 +feature_dim = 100 +has_label = True + + +class TestFeatureBinning(unittest.TestCase): + + components = Initializer(log_config_path='', config_path='') + components.config_data = {'JOB_TEMP_DIR': '/tmp'} + components.mock_logger = MockLogger() + + def test_train_feature_binning(self): + + # 构造主动方参数配置 + args = { + 'job_id': 'j-123', + 'task_id': '1', + 'is_label_holder': True, + 'result_receiver_id_list': [], + 'participant_id_list': [], + 'model_predict_algorithm': None, + 'algorithm_type': 'Train', + 'algorithm_subtype': None, + 'model_dict': { + 'objective': 'regression', + 'max_bin': 10, + 'n_estimators': 6, + 'max_depth': 3, + 'use_goss': 1 + } + } + + task_info = SecureLGBMContext(args, self.components) + model_data = SecureDataset.simulate_dataset(data_size, feature_dim, has_label) + secure_dataset = SecureDataset(task_info, model_data) + print(secure_dataset.train_idx.shape) + print(secure_dataset.train_X.shape) + print(secure_dataset.train_y.shape) + print(secure_dataset.test_idx.shape) + print(secure_dataset.test_X.shape) + print(secure_dataset.test_y.shape) + + feat_bin = FeatureBinning(task_info) + data_bin, data_split = feat_bin.data_binning(secure_dataset.train_X) + + self.assertEqual(data_bin.shape, secure_dataset.train_X.shape) + # print(data_bin) + # print(data_split) + + def test_test_feature_binning(self): + + # 构造主动方参数配置 + args = { + 'job_id': 'j-123', + 'task_id': '1', + 'is_label_holder': True, + 'result_receiver_id_list': [], + 'participant_id_list': [], + 'model_predict_algorithm': None, + 'algorithm_type': 'Train', + 'algorithm_subtype': None, + 'model_dict': { + 'objective': 'regression', + 'categorical_feature': [], + 'max_bin': 10, + 'n_estimators': 6, + 'max_depth': 3, + 'use_goss': 1 + } + } + + task_info = SecureLGBMContext(args, self.components) + model_data = SecureDataset.simulate_dataset(data_size, feature_dim, has_label) + secure_dataset = SecureDataset(task_info, model_data) + print(secure_dataset.train_idx.shape) + print(secure_dataset.train_X.shape) + print(secure_dataset.train_y.shape) + print(secure_dataset.test_idx.shape) + print(secure_dataset.test_X.shape) + print(secure_dataset.test_y.shape) + + feat_bin = FeatureBinning(task_info) + data_split = None + data_bin, data_split = feat_bin.data_binning( + secure_dataset.train_X, data_split) + + self.assertEqual(data_bin.shape, secure_dataset.train_X.shape) + # print(data_bin) + # print(data_split) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/ppc_model/datasets/test/__init__.py b/python/ppc_model/datasets/test/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/ppc_model/datasets/test/test_dataset.py b/python/ppc_model/datasets/test/test_dataset.py new file mode 100644 index 00000000..7a544cfc --- /dev/null +++ b/python/ppc_model/datasets/test/test_dataset.py @@ -0,0 +1,213 @@ +import unittest +import numpy as np +import pandas as pd +from sklearn.model_selection import train_test_split +from sklearn.datasets import load_breast_cancer + +from ppc_model.common.initializer import Initializer +from ppc_common.ppc_mock.mock_objects import MockLogger +from ppc_model.secure_lgbm.secure_lgbm_context import SecureLGBMContext +from ppc_model.datasets.dataset import SecureDataset + + +class TestSecureDataset(unittest.TestCase): + cancer = load_breast_cancer() + X = cancer.data + y = cancer.target + + df = SecureDataset.assembling_dataset(X, y) + assert (df.columns == ['id', 'y'] + [f'x{i + 1}' for i in range(30)]).all() + + # 模拟生成主动方和被动方数据集 + df_with_y, df_without_y = SecureDataset.hetero_split_dataset(df) + assert (df_with_y.columns == ['id', 'y'] + + [f'x{i + 1}' for i in range(15)]).all() + assert (df_without_y.columns == ['id'] + + [f'x{i + 16}' for i in range(15)]).all() + + # 模拟自定义分组 + import os + + eval_column_file = './model_eval_column.csv' + if not os.path.exists(eval_column_file): + # 创建一个包含569行,2列的数据,其中415个为'INS',154个为'OOS' + group_set = np.concatenate([['INS'] * 415, ['OOS'] * 154]) + np.random.shuffle(group_set) + eval_set_df = pd.DataFrame( + {'id': np.arange(1, 570), 'group': group_set}) + eval_set_df.to_csv(eval_column_file, index=None) + + df_with_y_file = './df_with_y.csv' + if not os.path.exists(df_with_y_file): + df_with_y.to_csv(df_with_y_file, index=None, sep=' ') + + df_without_y_file = './df_without_y.csv' + if not os.path.exists(df_without_y_file): + df_without_y.to_csv(df_without_y_file, index=None, sep=' ') + + iv_selected_file = './iv_selected.csv' + if not os.path.exists(iv_selected_file): + iv_selected = pd.DataFrame( + {'feature': [f'x{i + 1}' for i in range(30)], + 'iv_selected': np.random.binomial(n=1, p=0.5, size=30)}) + iv_selected.to_csv(iv_selected_file, index=None) + + components = Initializer(log_config_path='', config_path='') + components.config_data = {'JOB_TEMP_DIR': '/tmp', 'MAX_THREAD_WORKERS': 10} + components.mock_logger = MockLogger() + + def test_random_split_dataset(self): + + # 构造主动方参数配置 + args = { + 'job_id': 'j-123', + 'task_id': '1', + 'is_label_holder': True, + 'result_receiver_id_list': [], + 'participant_id_list': [], + 'model_predict_algorithm': None, + 'algorithm_type': 'Train', + 'algorithm_subtype': None, + 'model_dict': { + 'random_state': 2024 + } + } + task_info = SecureLGBMContext(args, self.components) + print(task_info.lgbm_params.get_all_params()) + + # 模拟构造主动方数据集 + dataset_with_y = SecureDataset(task_info, self.df_with_y) + assert (dataset_with_y.train_idx == + train_test_split(np.array(range(1, 570)), test_size=0.3, random_state=2024)[0]).all() + self.assertEqual(dataset_with_y.train_X.shape, (398, 15)) + self.assertEqual(dataset_with_y.test_X.shape, (171, 15)) + self.assertEqual(dataset_with_y.train_y.shape, (398,)) + self.assertEqual(dataset_with_y.test_y.shape, (171,)) + + # 构造被动方参数配置 + args = { + 'job_id': 'j-123', + 'task_id': '1', + 'is_label_holder': False, + 'result_receiver_id_list': [], + 'participant_id_list': [], + 'model_predict_algorithm': None, + 'algorithm_type': 'Train', + 'algorithm_subtype': None, + 'model_dict': { + 'random_state': 2024 + } + } + task_info = SecureLGBMContext(args, self.components) + print(task_info.lgbm_params.get_all_params()) + + # 模拟构造被动方数据集 + dataset_without_y = SecureDataset(task_info, self.df_without_y) + assert (dataset_without_y.train_idx == dataset_with_y.train_idx).all() + self.assertEqual(dataset_without_y.train_X.shape, (398, 15)) + self.assertEqual(dataset_without_y.test_X.shape, (171, 15)) + self.assertEqual(dataset_without_y.train_y, None) + self.assertEqual(dataset_without_y.test_y, None) + + def test_customized_split_dataset(self): + + # 构造主动方参数配置 + args = { + 'job_id': 'j-123', + 'task_id': '1', + 'is_label_holder': True, + 'result_receiver_id_list': [], + 'participant_id_list': [], + 'model_predict_algorithm': None, + 'algorithm_type': 'Train', + 'algorithm_subtype': None, + 'model_dict': { + 'eval_set_column': 'group', + 'train_set_value': 'INS', + 'eval_set_value': 'OOS' + } + } + task_info = SecureLGBMContext(args, self.components) + print(task_info.lgbm_params.get_all_params()) + + # 模拟构造主动方数据集 + task_info.eval_column_file = self.eval_column_file + task_info.model_prepare_file = self.df_with_y_file + eval_set_df = pd.read_csv(task_info.eval_column_file, header=0) + + dataset_with_y = SecureDataset(task_info) + assert (dataset_with_y.train_idx == + eval_set_df['id'][eval_set_df['group'] == 'INS']).all() + self.assertEqual(dataset_with_y.train_X.shape, (415, 15)) + self.assertEqual(dataset_with_y.test_X.shape, (154, 15)) + self.assertEqual(dataset_with_y.train_y.shape, (415,)) + self.assertEqual(dataset_with_y.test_y.shape, (154,)) + + def test_predict_dataset(self): + + # 构造主动方参数配置 + args = { + 'job_id': 'j-123', + 'task_id': '1', + 'is_label_holder': True, + 'result_receiver_id_list': [], + 'participant_id_list': [], + 'model_predict_algorithm': None, + 'algorithm_type': 'Predict', + 'algorithm_subtype': None, + 'model_dict': {} + } + task_info = SecureLGBMContext(args, self.components) + print(task_info.lgbm_params.get_all_params()) + + # 模拟构造主动方数据集 + task_info.model_prepare_file = self.df_with_y_file + dataset_with_y = SecureDataset(task_info) + + self.assertEqual(dataset_with_y.train_X, None) + self.assertEqual(dataset_with_y.test_X.shape, (569, 15)) + self.assertEqual(dataset_with_y.train_y, None) + self.assertEqual(dataset_with_y.test_y.shape, (569,)) + + def test_iv_selected_dataset(self): + + # 构造主动方参数配置 + args = { + 'job_id': 'j-123', + 'task_id': '1', + 'is_label_holder': True, + 'result_receiver_id_list': [], + 'participant_id_list': [], + 'model_predict_algorithm': None, + 'algorithm_type': 'Predict', + 'algorithm_subtype': None, + 'model_dict': {} + } + task_info = SecureLGBMContext(args, self.components) + print(task_info.lgbm_params.get_all_params()) + + # 模拟构造主动方数据集 + task_info.model_prepare_file = self.df_with_y_file + task_info.iv_selected_file = './iv_selected.csv' + dataset_with_y = SecureDataset(task_info) + + self.assertEqual(dataset_with_y.train_X, None) + self.assertEqual(dataset_with_y.test_X.shape, (569, 9)) + self.assertEqual(dataset_with_y.train_y, None) + self.assertEqual(dataset_with_y.test_y.shape, (569,)) + + def test_read_dataset(self): + np.random.seed(0) + origin_data = np.random.randint(0, 100, size=(100, 10)) + columns = ['id'] + [f"x{i}" for i in range(2, 11)] + df = pd.DataFrame(origin_data, columns=columns) + csv_file = '/tmp/data_x1_to_x10.csv' + df.to_csv(csv_file, index=False) + field_list, label, feature = SecureDataset.read_dataset(csv_file, False, delimiter=',') + self.assertEqual(['id'] + field_list, columns) + field_list, label, feature = SecureDataset.read_dataset(csv_file, True, delimiter=',') + self.assertEqual(['id'] + field_list, columns) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/ppc_model/feature_engineering/__init__.py b/python/ppc_model/feature_engineering/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/ppc_model/feature_engineering/feature_engineering_context.py b/python/ppc_model/feature_engineering/feature_engineering_context.py new file mode 100644 index 00000000..c3a793ea --- /dev/null +++ b/python/ppc_model/feature_engineering/feature_engineering_context.py @@ -0,0 +1,48 @@ +from enum import Enum + +import numpy as np + +from ppc_common.ppc_crypto.phe_factory import PheCipherFactory +from ppc_model.common.context import Context +from ppc_model.common.initializer import Initializer +from ppc_model.common.protocol import TaskRole +from ppc_model.common.model_setting import ModelSetting + + +class FeMessage(Enum): + ENC_LABELS = "ENC_LABELS" + AGGR_LABELS = "AGGR_LABELS" + WOE_FILE = "WOE_FILE" + IV_SELECTED_FILE = "IV_SELECTED_FILE" + + +class FeatureEngineeringContext(Context): + + def __init__(self, + args, + components: Initializer, + role: TaskRole, + feature: np.ndarray, + feature_name_list: list, + label: np.ndarray = None): + super().__init__(args['job_id'], + args['task_id'], + components, + role) + self.feature_name_list = feature_name_list + self.participant_id_list = args['participant_id_list'] + self.result_receiver_id_list = args['result_receiver_id_list'] + self.model_dict = args['model_dict'] + self.feature = feature + self.label = label + self.phe = PheCipherFactory.build_phe( + components.homo_algorithm, components.public_key_length) + self.codec = PheCipherFactory.build_codec(components.homo_algorithm) + self.model_setting = ModelSetting(self.model_dict) + self._parse_model_dict() + + def _parse_model_dict(self): + self.use_iv = self.model_setting.use_iv + self.iv_thresh = self.model_setting.iv_thresh + self.categorical = self.model_setting.categorical + self.group_num = self.model_setting.group_num diff --git a/python/ppc_model/feature_engineering/feature_engineering_engine.py b/python/ppc_model/feature_engineering/feature_engineering_engine.py new file mode 100644 index 00000000..16acdb7b --- /dev/null +++ b/python/ppc_model/feature_engineering/feature_engineering_engine.py @@ -0,0 +1,41 @@ +from ppc_model.common.base_context import BaseContext +from ppc_model.common.global_context import components +from ppc_model.common.protocol import ModelTask, TaskRole +from ppc_model.datasets.dataset import SecureDataset +from ppc_model.feature_engineering.feature_engineering_context import FeatureEngineeringContext +from ppc_model.feature_engineering.vertical.active_party import VerticalFeatureEngineeringActiveParty +from ppc_model.feature_engineering.vertical.passive_party import VerticalFeatureEngineeringPassiveParty +from ppc_model.interface.task_engine import TaskEngine + + +class FeatureEngineeringEngine(TaskEngine): + task_type = ModelTask.FEATURE_ENGINEERING + + @staticmethod + def run(args): + input_path = BaseContext.feature_engineering_input_path( + args['job_id'], components.config_data['JOB_TEMP_DIR']) + + if args['is_label_holder']: + field_list, label, feature = SecureDataset.read_dataset(input_path, True) + context = FeatureEngineeringContext( + args=args, + components=components, + role=TaskRole.ACTIVE_PARTY, + feature=feature, + feature_name_list=field_list[1:], + label=label + ) + vfe = VerticalFeatureEngineeringActiveParty(context) + else: + field_list, _, feature = SecureDataset.read_dataset(input_path, False) + context = FeatureEngineeringContext( + args=args, + components=components, + role=TaskRole.PASSIVE_PARTY, + feature=feature, + feature_name_list=field_list, + label=None + ) + vfe = VerticalFeatureEngineeringPassiveParty(context) + vfe.fit() diff --git a/python/ppc_model/feature_engineering/test/__init__.py b/python/ppc_model/feature_engineering/test/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/ppc_model/feature_engineering/test/feature_engineering_unittest.py b/python/ppc_model/feature_engineering/test/feature_engineering_unittest.py new file mode 100644 index 00000000..67251ad4 --- /dev/null +++ b/python/ppc_model/feature_engineering/test/feature_engineering_unittest.py @@ -0,0 +1,153 @@ +import multiprocessing +import threading +import traceback +import unittest + +import numpy as np + +from ppc_common.ppc_async_executor.thread_event_manager import ThreadEventManager +from ppc_common.ppc_mock.mock_objects import MockLogger, MockStorageClient +from ppc_model.common.initializer import Initializer +from ppc_model.common.mock.rpc_client_mock import RpcClientMock +from ppc_model.common.protocol import TaskRole +from ppc_model.feature_engineering.feature_engineering_context import FeatureEngineeringContext +from ppc_model.feature_engineering.vertical.active_party import VerticalFeatureEngineeringActiveParty +from ppc_model.feature_engineering.vertical.passive_party import VerticalFeatureEngineeringPassiveParty +from ppc_model.network.stub import ModelStub + +ACTIVE_PARTY = 'ACTIVE_PARTY' + +PASSIVE_PARTY = 'PASSIVE_PARTY' + + +def construct_dataset(num_samples, num_features): + np.random.seed(0) + # 生成标签列 + labels = np.random.choice([0, 1], size=num_samples) + # 生成特征列 + features = np.random.rand(num_samples, num_features) + return labels, features + + +def mock_args(num_features, iv_thresh): + job_id = '0x12345678' + active_fields = ['a' + str(i) for i in range(num_features)] + passive_fields = ['b' + str(i) for i in range(num_features)] + + model_config_dict = { + 'use_iv': True, + 'iv_thresh': iv_thresh, + 'categorical': '0', + 'group_num': 100, + + } + + args_a = { + 'job_id': job_id, + 'task_id': job_id, + 'feature_name_list': active_fields, + 'participant_id_list': [ACTIVE_PARTY, PASSIVE_PARTY], + 'result_receiver_id_list': [ACTIVE_PARTY, PASSIVE_PARTY], + 'model_dict': model_config_dict, + } + + args_b = { + 'job_id': job_id, + 'task_id': job_id, + 'feature_name_list': passive_fields, + 'participant_id_list': [ACTIVE_PARTY, PASSIVE_PARTY], + 'result_receiver_id_list': [ACTIVE_PARTY, PASSIVE_PARTY], + 'model_dict': model_config_dict, + } + return args_a, args_b + + +class TestFeatureEngineering(unittest.TestCase): + + def setUp(self): + self._active_rpc_client = RpcClientMock() + self._passive_rpc_client = RpcClientMock() + self._thread_event_manager = ThreadEventManager() + self._active_stub = ModelStub( + agency_id=ACTIVE_PARTY, + thread_event_manager=self._thread_event_manager, + rpc_client=self._active_rpc_client, + send_retry_times=3, + retry_interval_s=0.1 + ) + self._passive_stub = ModelStub( + agency_id=PASSIVE_PARTY, + thread_event_manager=self._thread_event_manager, + rpc_client=self._passive_rpc_client, + send_retry_times=3, + retry_interval_s=0.1 + ) + self._active_rpc_client.set_message_handler( + self._passive_stub.on_message_received) + self._passive_rpc_client.set_message_handler( + self._active_stub.on_message_received) + + def test_fit(self): + num_samples = 100000 + num_features = 100 + iv_thresh = 0.05 + labels, features = construct_dataset(num_samples, num_features) + args_a, args_b = mock_args(num_features, iv_thresh) + + active_components = Initializer(log_config_path='', config_path='') + active_components.homo_algorithm = 0 + active_components.stub = self._active_stub + active_components.config_data = {'JOB_TEMP_DIR': '/tmp'} + active_components.mock_logger = MockLogger() + active_components.storage_client = MockStorageClient() + active_context = FeatureEngineeringContext( + args=args_a, + components=active_components, + role=TaskRole.ACTIVE_PARTY, + feature=features, + feature_name_list=args_a['feature_name_list'], + label=labels + ) + active_vfe = VerticalFeatureEngineeringActiveParty(active_context) + + passive_components = Initializer(log_config_path='', config_path='') + passive_components.homo_algorithm = 0 + passive_components.stub = self._passive_stub + passive_components.config_data = {'JOB_TEMP_DIR': '/tmp'} + passive_components.mock_logger = MockLogger() + passive_components.storage_client = MockStorageClient() + passive_context = FeatureEngineeringContext( + args=args_b, + components=passive_components, + role=TaskRole.PASSIVE_PARTY, + feature=features, + feature_name_list=args_b['feature_name_list'], + label=None + ) + passive_vfe = VerticalFeatureEngineeringPassiveParty(passive_context) + + def active_worker(): + try: + active_vfe.fit() + except Exception as e: + active_components.logger().info(traceback.format_exc()) + + def passive_worker(): + try: + passive_vfe.fit() + except Exception as e: + active_components.logger().info(traceback.format_exc()) + + thread_fe_active = threading.Thread(target=active_worker, args=()) + thread_fe_active.start() + + thread_fe_passive = threading.Thread(target=passive_worker, args=()) + thread_fe_passive.start() + + thread_fe_active.join() + thread_fe_passive.join() + + +if __name__ == '__main__': + multiprocessing.set_start_method('spawn') + unittest.main() diff --git a/python/ppc_model/feature_engineering/vertical/__init__.py b/python/ppc_model/feature_engineering/vertical/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/ppc_model/feature_engineering/vertical/active_party.py b/python/ppc_model/feature_engineering/vertical/active_party.py new file mode 100644 index 00000000..6710c47d --- /dev/null +++ b/python/ppc_model/feature_engineering/vertical/active_party.py @@ -0,0 +1,202 @@ +import multiprocessing +import os +import time + +import pandas as pd + +from ppc_common.ppc_protos.generated.ppc_model_pb2 import EncAggrLabelsList +from ppc_common.ppc_utils import utils +from ppc_model.common.protocol import PheMessage +from ppc_model.feature_engineering.feature_engineering_context import FeatureEngineeringContext, FeMessage +from ppc_model.feature_engineering.vertical.utils import is_continuous_feature, calculate_woe_iv, \ + calculate_woe_iv_with_pos_event +from ppc_model.interface.model_base import VerticalModel +from ppc_model.network.stub import PushRequest, PullRequest + + +class VerticalFeatureEngineeringActiveParty(VerticalModel): + + def __init__(self, ctx: FeatureEngineeringContext): + super().__init__(ctx) + # 记录woe和iv详情 + self.woe_iv_df = pd.DataFrame(columns=['feature', 'bins', 'count', 'pos_event', 'pos_event_rate', + 'neg_event', 'neg_event_rate', 'woe', 'iv', 'iv_total']) + # 记录特征筛选情况 + self.iv_selected_df = pd.DataFrame(columns=['feature', 'iv_selected']) + + def fit(self, *args, **kwargs) -> None: + log = self.ctx.components.logger() + task_id = self.ctx.task_id + start_time = time.time() + if self.ctx.use_iv: + log.info(f"Start feature engineering, task_id: {task_id}, shape: {self.ctx.feature.shape}, " + f"feature_name_list: {self.ctx.feature_name_list}") + + if len(self.ctx.feature_name_list) != 0: + # 明文计算己方特征的woe/iv + self._compute_self_woe_iv() + + # 密态交互计算合作方特征的woe/iv + log.info(f"Start enc labels, task_id: {task_id}") + enc_start_time = time.time() + enc_labels = self.ctx.phe.encrypt_batch_parallel(self.ctx.label) + log.info( + f"Enc labels finished, task_id: {task_id}, count: {len(enc_labels)}, " + f"time_costs: {time.time() - enc_start_time}s") + + for i in range(1, len(self.ctx.participant_id_list)): + self._compute_partner_woe_iv(enc_labels, i) + + # 处理计算结果 + self._save_and_sync_fe_results() + + log.info( + f"Feature engineering finished, task_id: {task_id}, time_costs: {time.time() - start_time}s, " + f"iv_selected: {self.iv_selected_df}") + + def _compute_self_woe_iv(self): + log = self.ctx.components.logger() + start_time = time.time() + for i in range(self.ctx.feature.shape[1]): + field = self.ctx.feature_name_list[i] + is_continuous = is_continuous_feature(self.ctx.categorical, field) + grouped, iv_total = calculate_woe_iv(self.ctx.feature[:, i], + self.ctx.label, + self.ctx.group_num, + is_continuous) + for index, row in grouped.iterrows(): + self.woe_iv_df = pd.concat([self.woe_iv_df, pd.DataFrame({ + 'feature': field, + 'bins': index, + 'count': [row['count']], + 'pos_event': [row['pos_event']], + 'pos_event_rate': [row['pos_event_rate']], + 'neg_event': [row['neg_event']], + 'neg_event_rate': [row['neg_event_rate']], + 'woe': [row['woe']], + 'iv': [row['iv']], + 'iv_total': [row['iv_total']] + })], ignore_index=True) + + self.iv_selected_df.loc[len(self.iv_selected_df)] = {'feature': field, + 'iv_selected': int(iv_total >= self.ctx.iv_thresh)} + log.info( + f"Computing self woe/iv finished, task_id: {self.ctx.task_id}, time_costs: {time.time() - start_time}s") + + def _compute_partner_woe_iv(self, enc_labels, partner_index): + log = self.ctx.components.logger() + start_time = time.time() + + partner_id = self.ctx.participant_id_list[partner_index] + self._send_enc_labels(enc_labels, partner_id) + enc_aggr_labels = self._get_all_enc_aggr_labels(partner_id) + + # results = [] + # with ProcessPoolExecutor(max_workers=max(1, os.cpu_count() - 1)) as executor: + # futures = [executor.submit( + # self._process_one_feature, + # self.ctx.phe, + # field, + # count_list, + # enc_aggr_labels) for field, count_list, enc_aggr_labels in enc_aggr_labels] + # for future in as_completed(futures): + # results.append(future.result()) + + pool = multiprocessing.Pool() + tasks = [(self.ctx.phe, field, count_list, enc_aggr_labels) for field, count_list, enc_aggr_labels in + enc_aggr_labels] + results = pool.starmap(self._process_one_feature, tasks) + pool.close() + pool.join() + + for field, field_woe_iv_df, iv_total in results: + # 记录新字段的woe和iv + self.woe_iv_df = pd.concat( + [self.woe_iv_df, field_woe_iv_df], ignore_index=True) + self.iv_selected_df.loc[len(self.iv_selected_df)] = {'feature': field, + 'iv_selected': int(iv_total >= self.ctx.iv_thresh)} + log.info( + f"Computing {partner_id}'s woe/iv finished, task_id: {self.ctx.task_id}, " + f"time_costs: {time.time() - start_time}s") + + @staticmethod + def _process_one_feature(phe, field, count_list, enc_aggr_labels): + pos_event = phe.decrypt_batch(enc_aggr_labels) + field_woe_iv_df = pd.DataFrame({'bins': range(len(count_list)), 'count': count_list, + 'pos_event': pos_event, 'feature': field}) + field_woe_iv_df, iv_total = calculate_woe_iv_with_pos_event( + field_woe_iv_df) + return field, field_woe_iv_df, iv_total + + def _get_all_enc_aggr_labels(self, partner_id): + log = self.ctx.components.logger() + start_time = time.time() + data = self.ctx.components.stub.pull(PullRequest( + sender=partner_id, + task_id=self.ctx.task_id, + key=FeMessage.AGGR_LABELS.value + )) + + enc_aggr_labels_list_pb = EncAggrLabelsList() + utils.bytes_to_pb(enc_aggr_labels_list_pb, data) + public_key = self.ctx.codec.decode_enc_key(enc_aggr_labels_list_pb.public_key) + + res = [] + for enc_aggr_labels_pb in enc_aggr_labels_list_pb.enc_aggr_labels_list: + enc_aggr_labels = [self.ctx.codec.decode_cipher(public_key, + cipher.ciphertext, + cipher.exponent + ) for cipher in enc_aggr_labels_pb.cipher_list] + field = enc_aggr_labels_pb.field + res.append( + (field, list(enc_aggr_labels_pb.count_list), enc_aggr_labels)) + + log.info( + f"All enc aggr labels received, task_id: {self.ctx.task_id}, feature_num: {len(res)}, " + f"size: {len(data) / 1024}KB, time_costs: {time.time() - start_time}s") + return res + + def _send_enc_labels(self, enc_labels, receiver): + log = self.ctx.components.logger() + start_time = time.time() + + data = PheMessage.packing_data(self.ctx.codec, self.ctx.phe.public_key, enc_labels) + self.ctx.components.stub.push(PushRequest( + receiver=receiver, + task_id=self.ctx.task_id, + key=FeMessage.ENC_LABELS.value, + data=data + )) + log.info( + f"Sending enc labels to {receiver} finished, task_id: {self.ctx.task_id}, label_num: {len(enc_labels)}, " + f"size: {len(data) / 1024}KB, time_costs: {time.time() - start_time}s") + + def _save_and_sync_fe_results(self): + log = self.ctx.components.logger() + task_id = self.ctx.task_id + self.woe_iv_df.to_csv(self.ctx.woe_iv_file, sep=',', header=True, index=None) + self.iv_selected_df.to_csv(self.ctx.iv_selected_file, sep=',', header=True, index=None) + self.ctx.components.storage_client.upload_file(self.ctx.woe_iv_file, + self.ctx.job_id + os.sep + self.ctx.WOE_IV_FILE) + log.info(f"Saving fe results finished, task_id: {task_id}") + + with open(self.ctx.woe_iv_file, 'rb') as f: + woe_iv = f.read() + with open(self.ctx.iv_selected_file, 'rb') as f: + iv_selected = f.read() + for i in range(1, len(self.ctx.participant_id_list)): + partner_id = self.ctx.participant_id_list[i] + if partner_id in self.ctx.result_receiver_id_list: + self.ctx.components.stub.push(PushRequest( + receiver=partner_id, + task_id=self.ctx.task_id, + key=FeMessage.WOE_FILE.value, + data=woe_iv + )) + self.ctx.components.stub.push(PushRequest( + receiver=partner_id, + task_id=self.ctx.task_id, + key=FeMessage.IV_SELECTED_FILE.value, + data=iv_selected + )) + log.info(f"Sending fe results finished, task_id: {task_id}") diff --git a/python/ppc_model/feature_engineering/vertical/passive_party.py b/python/ppc_model/feature_engineering/vertical/passive_party.py new file mode 100644 index 00000000..8aa63996 --- /dev/null +++ b/python/ppc_model/feature_engineering/vertical/passive_party.py @@ -0,0 +1,177 @@ +import multiprocessing +import os +import time + +from ppc_common.ppc_protos.generated.ppc_model_pb2 import ModelCipher, EncAggrLabels, EncAggrLabelsList +from ppc_common.ppc_utils import utils +from ppc_model.common.protocol import PheMessage +from ppc_model.datasets.feature_binning.feature_binning import FeatureBinning +from ppc_model.feature_engineering.feature_engineering_context import FeatureEngineeringContext, FeMessage +from ppc_model.feature_engineering.vertical.utils import is_continuous_feature +from ppc_model.interface.model_base import VerticalModel +from ppc_model.network.stub import PullRequest, PushRequest + + +class VerticalFeatureEngineeringPassiveParty(VerticalModel): + + def __init__(self, ctx: FeatureEngineeringContext): + super().__init__(ctx) + + def fit(self, *args, **kwargs) -> None: + log = self.ctx.components.logger() + + task_id = self.ctx.task_id + start_time = time.time() + if self.ctx.use_iv: + log.info(f"Start feature engineering, task_id: {task_id}, shape: {self.ctx.feature.shape}, " + f"feature_name_list: {self.ctx.feature_name_list}.") + + # 执行密文交互,计算正样本数量 + public_key, enc_labels = self._get_enc_labels() + + # 根据特征分箱,聚合加密标签 + aggr_labels_bytes_list = self._binning_and_aggregating_all(public_key, enc_labels) + + # 发送聚合的密文标签 + self._send_all_enc_aggr_labels(public_key, aggr_labels_bytes_list) + + self._get_and_save_result() + log.info( + f"Feature engineering finished, task_id: {task_id}, timecost: {time.time() - start_time}s") + + def _get_enc_labels(self): + log = self.ctx.components.logger() + start_time = time.time() + active_party = self.ctx.participant_id_list[0] + data = self.ctx.components.stub.pull(PullRequest( + sender=active_party, + task_id=self.ctx.task_id, + key=FeMessage.ENC_LABELS.value + )) + + public_key, enc_labels = PheMessage.unpacking_data(self.ctx.codec, data) + log.info(f"All enc labels received, task_id: {self.ctx.task_id}, label_num: {len(enc_labels)}, " + f"size: {len(data) / 1024}KB, timecost: {time.time() - start_time}s") + return public_key, enc_labels + + def _binning_and_aggregating_all(self, public_key, enc_labels) -> list: + log = self.ctx.components.logger() + start_time = time.time() + params = [] + for i in range(self.ctx.feature.shape[1]): + is_continuous = is_continuous_feature(self.ctx.categorical, self.ctx.feature_name_list[i]) + params.append({ + 'is_continuous': is_continuous, + 'feature_index': i, + 'field': self.ctx.feature_name_list[i], + 'feature': self.ctx.feature[:, i], + 'public_key': public_key, + 'enc_labels': enc_labels, + 'group_num': self.ctx.group_num, + 'codec': self.ctx.codec + }) + # aggr_labels_str = [] + # with ProcessPoolExecutor(max_workers=max(1, os.cpu_count() - 1)) as executor: + # futures = [executor.submit(self._binning_and_aggregating_one, param) for param in params] + # for future in as_completed(futures): + # aggr_labels_str.append(future.result()) + + pool = multiprocessing.Pool() + aggr_labels_str = pool.map(self._binning_and_aggregating_one, params) + pool.close() + pool.join() + + log.info(f"Feature binning and aggregating finished, task_id: {self.ctx.task_id}, " + f"feature_num: {len(params)}, timecost: {time.time() - start_time}s") + return aggr_labels_str + + @staticmethod + def _binning_and_aggregating_one(param): + feature = param['feature'] + if param['is_continuous']: + bins = FeatureBinning.binning_continuous_feature(feature, param['group_num'])[0] + else: + bins = FeatureBinning.binning_categorical_feature(feature)[0] + + enc_labels = param['enc_labels'] + data_dict = {} + for key, value in zip(bins, enc_labels): + if key in data_dict: + data_dict[key]['count'] += 1 + # 执行同态加法 + data_dict[key]['sum'] = data_dict[key]['sum'] + value + else: + data_dict[key] = {'count': 1, 'sum': value} + + count_list = [data_dict[key]['count'] for key in sorted(data_dict.keys())] + aggr_enc_labels = [data_dict[key]['sum'] for key in sorted(data_dict.keys())] + + return VerticalFeatureEngineeringPassiveParty._encode_enc_aggr_labels( + param['codec'], param['field'], count_list, aggr_enc_labels) + + @staticmethod + def _encode_enc_aggr_labels(codec, field, count_list, aggr_enc_labels): + enc_aggr_labels_pb = EncAggrLabels() + enc_aggr_labels_pb.field = field + for count in count_list: + enc_aggr_labels_pb.count_list.append(count) + for cipher in aggr_enc_labels: + model_cipher = ModelCipher() + model_cipher.ciphertext, model_cipher.exponent = \ + codec.encode_cipher(cipher, be_secure=False) + enc_aggr_labels_pb.cipher_list.append(model_cipher) + return utils.pb_to_bytes(enc_aggr_labels_pb) + + def _send_all_enc_aggr_labels(self, public_key, aggr_labels_bytes_list): + start_time = time.time() + enc_aggr_labels_list_pb = EncAggrLabelsList() + enc_aggr_labels_list_pb.public_key = self.ctx.codec.encode_enc_key(public_key) + + for aggr_labels_bytes in aggr_labels_bytes_list: + enc_aggr_labels_pb = EncAggrLabels() + utils.bytes_to_pb(enc_aggr_labels_pb, aggr_labels_bytes) + enc_aggr_labels_list_pb.enc_aggr_labels_list.append(enc_aggr_labels_pb) + + data = utils.pb_to_bytes(enc_aggr_labels_list_pb) + + self.ctx.components.logger().info( + f"Encoding all enc aggr labels finished, task_id: {self.ctx.task_id}, " + f"size: {len(data) / 1024}KB, timecost: {time.time() - start_time}s") + + self.ctx.components.stub.push(PushRequest( + receiver=self.ctx.participant_id_list[0], + task_id=self.ctx.task_id, + key=FeMessage.AGGR_LABELS.value, + data=data + )) + + self.ctx.components.logger().info( + f"Sending all enc aggr labels finished, task_id: {self.ctx.task_id}, " + f"feature_num: {len(aggr_labels_bytes_list)}, " + f"size: {len(data) / 1024}KB, timecost: {time.time() - start_time}s") + + def _get_and_save_result(self): + active_party = self.ctx.participant_id_list[0] + if self.ctx.components.stub.agency_id in self.ctx.result_receiver_id_list: + # 保存来自标签方的woe/iv结果 + data = self.ctx.components.stub.pull(PullRequest( + sender=active_party, + task_id=self.ctx.task_id, + key=FeMessage.WOE_FILE.value + )) + self.ctx.components.logger().info( + f"Result of woe/iv received, task_id: {self.ctx.task_id}, size: {len(data) / 1024}KB") + with open(self.ctx.woe_iv_file, 'wb') as f: + f.write(data) + self.ctx.components.storage_client.upload_file(self.ctx.woe_iv_file, + self.ctx.job_id + os.sep + self.ctx.WOE_IV_FILE) + + # 保存来自标签方的iv筛选结果 + data = self.ctx.components.stub.pull(PullRequest( + sender=active_party, + task_id=self.ctx.task_id, + key=FeMessage.IV_SELECTED_FILE.value)) + self.ctx.components.logger().info( + f"Result of iv_select received, task_id: {self.ctx.task_id}, size: {len(data) / 1024}KB") + with open(self.ctx.iv_selected_file, 'wb') as f: + f.write(data) diff --git a/python/ppc_model/feature_engineering/vertical/utils.py b/python/ppc_model/feature_engineering/vertical/utils.py new file mode 100644 index 00000000..cdeeceaf --- /dev/null +++ b/python/ppc_model/feature_engineering/vertical/utils.py @@ -0,0 +1,76 @@ +import numpy as np +import pandas as pd + +from ppc_model.datasets.feature_binning.feature_binning import FeatureBinning + + +def is_continuous_feature(categorical: str, field): + if categorical == '0': + return True + return field not in categorical.split(',') + + +def calculate_woe_iv_with_pos_event(grouped): + grouped['neg_event'] = grouped['count'] - grouped['pos_event'] + + # 避免出现无穷大woe + grouped['pos_event'] = grouped['pos_event'].astype(np.float64) + grouped['neg_event'] = grouped['neg_event'].astype(np.float64) + grouped.loc[grouped['neg_event'] == 0, 'pos_event'] += 0.5 + grouped.loc[grouped['neg_event'] == 0, 'neg_event'] = 0.5 + grouped.loc[grouped['pos_event'] == 0, 'neg_event'] += 0.5 + grouped.loc[grouped['pos_event'] == 0, 'pos_event'] = 0.5 + + # 计算WOE和IV值 + grouped['pos_event_rate'] = grouped['pos_event'] / \ + (grouped['pos_event'].sum()) + grouped['neg_event_rate'] = grouped['neg_event'] / \ + (grouped['neg_event'].sum()) + grouped['woe'] = np.log(grouped['pos_event_rate'] / + (grouped['neg_event_rate'])) + grouped['iv'] = (grouped['pos_event_rate'] - + grouped['neg_event_rate']) * grouped['woe'] + iv_total = grouped['iv'].sum() + grouped['iv_total'] = iv_total + return grouped, iv_total + + +def calculate_woe_iv(feature: np.ndarray, label: np.ndarray, num_bins: int = 10, is_continuous: bool = True, + is_equal_freq: bool = True): + # 将特征和目标变量合并 + combined = pd.DataFrame({'feature': feature, 'label': label}) + # 按特征值对数据集进行分箱 + if is_continuous: + combined['bins'] = FeatureBinning.binning_continuous_feature(feature, num_bins, is_equal_freq)[0] + else: + combined['bins'] = FeatureBinning.binning_categorical_feature(feature)[0] + # 计算每个分箱中的正负样本数量和总体样本数量 + grouped = combined.groupby('bins')['label'].agg(['count', 'sum']) + grouped = grouped.rename(columns={'sum': 'pos_event'}) + + return calculate_woe_iv_with_pos_event(grouped) + + +def calculate_woe_iv_with_pos_event(grouped): + grouped['neg_event'] = grouped['count'] - grouped['pos_event'] + + # 避免出现无穷大woe + grouped['pos_event'] = grouped['pos_event'].astype(np.float64) + grouped['neg_event'] = grouped['neg_event'].astype(np.float64) + grouped.loc[grouped['neg_event'] == 0, 'pos_event'] += 0.5 + grouped.loc[grouped['neg_event'] == 0, 'neg_event'] = 0.5 + grouped.loc[grouped['pos_event'] == 0, 'neg_event'] += 0.5 + grouped.loc[grouped['pos_event'] == 0, 'pos_event'] = 0.5 + + # 计算WOE和IV值 + grouped['pos_event_rate'] = grouped['pos_event'] / \ + (grouped['pos_event'].sum()) + grouped['neg_event_rate'] = grouped['neg_event'] / \ + (grouped['neg_event'].sum()) + grouped['woe'] = np.log(grouped['pos_event_rate'] / + (grouped['neg_event_rate'])) + grouped['iv'] = (grouped['pos_event_rate'] - + grouped['neg_event_rate']) * grouped['woe'] + iv_total = grouped['iv'].sum() + grouped['iv_total'] = iv_total + return grouped, iv_total diff --git a/python/ppc_model/interface/__init__.py b/python/ppc_model/interface/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/ppc_model/interface/model_base.py b/python/ppc_model/interface/model_base.py new file mode 100644 index 00000000..050980ae --- /dev/null +++ b/python/ppc_model/interface/model_base.py @@ -0,0 +1,33 @@ +from abc import ABC + +from pandas import DataFrame + + +class ModelBase(ABC): + mode: str + + def __init__(self, ctx): + self.ctx = ctx + + def fit( + self, + *args, + **kwargs + ) -> None: + pass + + def transform(self, transform_data: DataFrame) -> DataFrame: + pass + + def predict(self, predict_data: DataFrame) -> DataFrame: + pass + + def save_model(self, file_path): + pass + + def load_model(self, file_path): + pass + + +class VerticalModel(ModelBase): + mode = "VERTICAL" diff --git a/python/ppc_model/interface/rpc_client.py b/python/ppc_model/interface/rpc_client.py new file mode 100644 index 00000000..9473be6d --- /dev/null +++ b/python/ppc_model/interface/rpc_client.py @@ -0,0 +1,8 @@ +from abc import ABC + + +class RpcClient(ABC): + rpc_type: str + + def send(self, request): + ... diff --git a/python/ppc_model/interface/task_engine.py b/python/ppc_model/interface/task_engine.py new file mode 100644 index 00000000..123b0eb9 --- /dev/null +++ b/python/ppc_model/interface/task_engine.py @@ -0,0 +1,9 @@ +from abc import ABC + + +class TaskEngine(ABC): + task_type: str + + @staticmethod + def run(args: dict): + ... diff --git a/python/ppc_model/metrics/__init__.py b/python/ppc_model/metrics/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/ppc_model/metrics/evaluation.py b/python/ppc_model/metrics/evaluation.py new file mode 100644 index 00000000..a759dc9d --- /dev/null +++ b/python/ppc_model/metrics/evaluation.py @@ -0,0 +1,276 @@ +import gc +import time +import random +import traceback +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt +from typing import Dict, Union, Tuple +from sklearn import metrics +from sklearn.metrics import accuracy_score +from sklearn.metrics import precision_recall_curve +from sklearn.metrics import roc_curve, auc + +from ppc_model.common.context import Context +from ppc_model.common.global_context import plot_lock +from ppc_model.datasets.dataset import SecureDataset +from ppc_model.common.model_result import ResultFileHandling +from ppc_model.secure_lgbm.monitor.feature.feature_evaluation_info import EvaluationType +from ppc_model.secure_lgbm.monitor.feature.feature_evaluation_info import FeatureEvaluationResult + + +_Score = Union[float, Tuple[float, float]] + + +class Evaluation: + + def __init__(self, + ctx: Context, + dataset: SecureDataset, + train_praba:np.ndarray = None, + test_praba:np.ndarray = None) -> None: + + self.job_id = ctx.job_id + self.storage_client = ctx.components.storage_client + self.summary_evaluation_file = ctx.summary_evaluation_file + self.remote_summary_evaluation_file = ctx.remote_summary_evaluation_file + self.model_result_file = ctx.test_model_result_file + self.model_output_file = ctx.test_model_output_file + self.remote_model_output_file = ctx.remote_test_model_output_file + + self.metric_roc_file = ctx.test_metric_roc_file + self.metric_ks_file = ctx.test_metric_ks_file + self.metric_pr_file = ctx.test_metric_pr_file + self.metric_acc_file = ctx.test_metric_acc_file + self.metric_ks_table = ctx.test_metric_ks_table + self.remote_metric_roc_file = ctx.remote_test_metric_roc_file + self.remote_metric_ks_file = ctx.remote_test_metric_ks_file + self.remote_metric_pr_file = ctx.remote_test_metric_pr_file + self.remote_metric_acc_file = ctx.remote_test_metric_acc_file + self.remote_metric_ks_table = ctx.remote_test_metric_ks_table + + # if test_praba is None or dataset.test_y is None: + # raise Exception('test_praba or test_y is None') + + if test_praba is not None: + test_ks, test_auc = self.evaluation_file( + ctx, dataset.test_idx, dataset.test_y, test_praba, 'test') + if train_praba is not None: + train_ks, train_auc = self.evaluation_file( + ctx, dataset.train_idx, dataset.train_y, train_praba, 'train') + if dataset.train_y is not None: + self.summary_evaluation(dataset, test_ks, test_auc, train_ks, train_auc) + + @staticmethod + def fevaluation( + y_true: np.ndarray, + y_pred: np.ndarray, + decimal_num: int = 4 + ) -> Dict[str, _Score]: + auc = metrics.roc_auc_score(y_true, y_pred) + + y_pred_label = [0 if p <= 0.5 else 1 for p in y_pred] + acc = metrics.accuracy_score(y_true, y_pred_label) + recall = metrics.recall_score(y_true, y_pred_label) + precision = metrics.precision_score(y_true, y_pred_label) + + scores_dict = { + 'auc': auc, + 'acc': acc, + 'recall': recall, + 'precision': precision + } + for metric_name in scores_dict: + scores_dict[metric_name] = round( + scores_dict[metric_name], decimal_num) + return scores_dict + + def summary_evaluation(self, dataset, test_ks, test_auc, train_ks, train_auc): + train_evaluation = FeatureEvaluationResult( + type=EvaluationType.TRAIN, ks_value=train_ks, auc_value=train_auc, label_list=dataset.train_y) + test_evaluation = FeatureEvaluationResult( + type=EvaluationType.VALIDATION, ks_value=test_ks, auc_value=test_auc, label_list=dataset.test_y) + FeatureEvaluationResult.store_and_upload_summary( + [train_evaluation, test_evaluation], + self.summary_evaluation_file, self.remote_summary_evaluation_file, + self.storage_client) + + @staticmethod + def calculate_ks_and_stats(predicted_proba, actual_label, num_buckets=10): + # 合并预测概率和实际标签为一个 DataFrame + df = pd.DataFrame({'predicted_proba': predicted_proba.reshape(-1), 'actual_label': actual_label.reshape(-1)}) + # 根据预测概率降序排列 + df_sorted = df.sort_values(by='predicted_proba', ascending=False) + # 将数据划分为 num_buckets 个分组 + try: + df_sorted['bucket'] = pd.qcut(df_sorted['predicted_proba'], num_buckets, retbins=True, labels=False)[0] + except Exception: + df_sorted['bucket'] = pd.cut(df_sorted['predicted_proba'], num_buckets, retbins=True, labels=False)[0] + # 统计每个分组的信息 + stats = df_sorted.groupby('bucket').agg({ + 'actual_label': ['count', 'sum'], + 'predicted_proba': ['min', 'max'] + }) + # 计算其他指标 + stats.columns = ['count', 'positive_count', 'predict_proba_min', 'predict_proba_max'] + stats['positive_ratio'] = stats['positive_count'] / stats['count'] + stats['negative_ratio'] = 1 - stats['positive_ratio'] + stats['count_ratio'] = stats['count'] / stats['count'].sum() + # stats['累计坏客户占比'] = stats['坏客户数'].cumsum() / stats['坏客户数'].sum() + # 计算累计坏客户占比,从第 9 组开始计算 + stats['cum_positive_ratio'] = stats['positive_count'].iloc[::-1].cumsum()[::-1] / stats['positive_count'].sum() + stats = stats[['count_ratio', 'count', 'positive_count', + 'positive_ratio', 'negative_ratio', 'cum_positive_ratio']].reset_index() + stats.columns = ['分组', '样本占比', '样本数', '正样本数', '正样本比例', '负样本比例', '累积正样本占比'] + return stats + + def evaluation_file(self, ctx, data_index: np.ndarray, + y_true: np.ndarray, y_praba: np.ndarray, label: str = 'test'): + if label == 'train': + self.model_result_file = ctx.train_model_result_file + self.model_output_file = ctx.train_model_output_file + self.remote_model_output_file = ctx.remote_train_model_output_file + + self.metric_roc_file = ctx.train_metric_roc_file + self.metric_ks_file = ctx.train_metric_ks_file + self.metric_pr_file = ctx.train_metric_pr_file + self.metric_acc_file = ctx.train_metric_acc_file + self.metric_ks_table = ctx.train_metric_ks_table + self.remote_metric_roc_file = ctx.remote_train_metric_roc_file + self.remote_metric_ks_file = ctx.remote_train_metric_ks_file + self.remote_metric_pr_file = ctx.remote_train_metric_pr_file + self.remote_metric_acc_file = ctx.remote_train_metric_acc_file + self.remote_metric_ks_table = ctx.remote_train_metric_ks_table + + if y_true is not None: + # metrics plot + max_retry = 3 + retry_num = 0 + while retry_num < max_retry: + retry_num += 1 + try: + with plot_lock: + ks_value, auc_value = Evaluation.plot_two_class_graph(self, y_true, y_praba) + except: + ctx.components.logger().info(f'y_true = {len(y_true)}, {y_true[0:2]}') + ctx.components.logger().info(f'y_praba = {len(y_praba)}, {y_praba[0:2]}') + err = traceback.format_exc() + # ctx.components.logger().exception(err) + ctx.components.logger().info( + f'plot metrics in times-{retry_num} failed, traceback: {err}.') + time.sleep(random.uniform(0.1, 3)) + + ResultFileHandling._upload_file(self.storage_client, self.metric_roc_file, self.remote_metric_roc_file) + ResultFileHandling._upload_file(self.storage_client, self.metric_ks_file, self.remote_metric_ks_file) + ResultFileHandling._upload_file(self.storage_client, self.metric_pr_file, self.remote_metric_pr_file) + ResultFileHandling._upload_file(self.storage_client, self.metric_acc_file, self.remote_metric_acc_file) + + # ks table + ks_table = self.calculate_ks_and_stats(y_praba, y_true) + ks_table.to_csv(self.metric_ks_table, header=True, index=None) + ResultFileHandling._upload_file(self.storage_client, self.metric_ks_table, self.remote_metric_ks_table) + else: + ks_value = auc_value = None + + # predict result + self._parse_model_result(data_index, y_true, y_praba) + ResultFileHandling._upload_file(self.storage_client, self.model_output_file, self.remote_model_output_file) + + return ks_value, auc_value + + def _parse_model_result(self, data_index, y_true=None, y_praba=None): + + np.savetxt(self.model_result_file, y_praba, delimiter=',', fmt='%f') + + if y_true is None: + df = pd.DataFrame(np.column_stack((data_index, y_praba)), columns=['id', 'class_pred']) + else: + df = pd.DataFrame(np.column_stack((data_index, y_true, y_praba)), + columns=['id', 'class_label', 'class_pred']) + df['class_label'] = df['class_label'].astype(int) + + df['id'] = df['id'].astype(int) + df['class_pred'] = df['class_pred'].astype(float) + df = df.sort_values(by='id') + df.to_csv(self.model_output_file, index=None) + + def plot_two_class_graph(self, y_true, y_scores): + + y_label_probs = y_true + y_pred_probs = y_scores + # plt.cla() + plt.rcParams['figure.figsize'] = (12.0, 8.0) + + # plot ROC + fpr, tpr, thresholds = roc_curve(y_label_probs, y_pred_probs, pos_label=1) + auc_value = auc(fpr, tpr) + plt.figure(f'roc-{self.job_id}') + plt.title('ROC Curve') # give plot a title + plt.xlabel('False Positive Rate (1 - Specificity)') + plt.ylabel('True Positive Rate (Sensitivity)') + plt.plot([0, 1], [0, 1], 'k--', lw=2) + plt.plot(fpr, tpr, label='area = {0:0.5f}' + ''.format(auc_value)) + plt.legend(loc="lower right") + plt.savefig(self.metric_roc_file, dpi=1000) + # plt.show() + + plt.close('all') + gc.collect() + + # plot KS + plt.figure(f'ks-{self.job_id}') + threshold_x = np.sort(thresholds) + threshold_x[-1] = 1 + ks_value = max(abs(fpr - tpr)) + plt.title('KS Curve') + plt.xlabel('Threshold') + plt.plot(threshold_x, tpr, label='True Positive Rate') + plt.plot(threshold_x, fpr, label='False Positive Rate') + # 标记最大ks值 + x_index = np.argwhere(abs(fpr - tpr) == ks_value)[0, 0] + plt.plot((threshold_x[x_index], threshold_x[x_index]), (fpr[x_index], tpr[x_index]), + label='ks = {:.3f}'.format(ks_value), color='r', marker='o', markerfacecolor='r', markersize=5) + plt.legend(loc="lower right") + plt.savefig(self.metric_ks_file, dpi=1000) + # plt.show() + + plt.close('all') + gc.collect() + + # plot Precision Recall + plt.figure(f'pr-{self.job_id}') + plt.title('Precision/Recall Curve') + plt.xlabel('Recall') + plt.ylabel('Precision') + plt.xlim(0.0, 1.0) + plt.ylim(0.0, 1.05) + precision, recall, thresholds = precision_recall_curve( + y_label_probs, y_pred_probs) + plt.plot(recall, precision) + plt.savefig(self.metric_pr_file, dpi=1000) + # plt.show() + + plt.close('all') + gc.collect() + + # plot accuracy + plt.figure(f'accuracy-{self.job_id}') + thresholds = np.linspace(0, 1, num=100) # 在0~1之间生成100个阈值 + accuracies = [] + for threshold in thresholds: + predicted_labels = (y_pred_probs >= threshold).astype(int) + accuracy = accuracy_score(y_label_probs, predicted_labels) + accuracies.append(accuracy) + plt.title('Accuracy Curve') + plt.xlabel('Threshold') + plt.ylabel('Accuracy') + plt.xlim(0.0, 1.0) + plt.ylim(0.0, 1.05) + plt.plot(thresholds, accuracies) + plt.savefig(self.metric_acc_file, dpi=1000) + # plt.show() + + plt.close('all') + gc.collect() + return (ks_value, auc_value) diff --git a/python/ppc_model/metrics/loss.py b/python/ppc_model/metrics/loss.py new file mode 100644 index 00000000..b2b4e6b3 --- /dev/null +++ b/python/ppc_model/metrics/loss.py @@ -0,0 +1,32 @@ +import numpy as np + + +class Loss: + pass + + +class BinaryLoss(Loss): + + def __init__(self, objective: str) -> None: + super().__init__() + self.objective = objective + + @staticmethod + def sigmoid(x: np.ndarray): + return 1 / (1 + np.exp(-x)) + + @staticmethod + def compute_gradient(y_true: np.ndarray, y_pred: np.ndarray): + return y_pred - y_true + + @staticmethod + def compute_hessian(y_pred: np.ndarray): + return y_pred * (1 - y_pred) + + @staticmethod + def compute_loss(y_true: np.ndarray, y_pred: np.ndarray): + '''binary_cross_entropy''' + # 避免log(0)错误 + epsilon = 1e-15 + y_pred = np.clip(y_pred, epsilon, 1 - epsilon) + return -np.mean(y_true * np.log(y_pred) + (1 - y_true) * np.log(1 - y_pred)) diff --git a/python/ppc_model/metrics/model_plot.py b/python/ppc_model/metrics/model_plot.py new file mode 100644 index 00000000..ca169859 --- /dev/null +++ b/python/ppc_model/metrics/model_plot.py @@ -0,0 +1,181 @@ +import gc +import time +import random +import traceback +import matplotlib.pyplot as plt +import networkx as nx +from networkx.drawing.nx_pydot import graphviz_layout + +from ppc_model.common.model_result import ResultFileHandling +from ppc_model.common.global_context import plot_lock +from ppc_model.secure_lgbm.vertical.booster import VerticalBooster + + +class ModelPlot: + + def __init__(self, model: VerticalBooster) -> None: + + self.ctx = model.ctx + self.model = model + self._tree_id = 0 + self._leaf_id = None + self._G = None + self._split = None + self.storage_client = self.ctx.components.storage_client + + if model._trees is not None and \ + self.ctx.components.config_data['AGENCY_ID'] in self.ctx.result_receiver_id_list: + self.plot_tree() + + def plot_tree(self): + trees = self.model._trees + self._split = self.model._X_split + + for i, tree in enumerate(trees): + if i < 6: + tree_file_path = self.ctx.model_tree_prefix + '_' + str(self._tree_id)+'.svg' + remote_tree_file_path = self.ctx.remote_model_tree_prefix + '_' + str(self._tree_id)+'.svg' + self._tree_id += 1 + self._leaf_id = 0 + self._G = DiGraphTree() + if not isinstance(tree, list) or tree == 0: + continue + else: + self._graph_gtree(tree) + + max_retry = 3 + retry_num = 0 + while retry_num < max_retry: + retry_num += 1 + try: + with plot_lock: + self._G.tree_plot(figsize=(10, 5), save_filename=tree_file_path) + except: + self.ctx.components.logger().info(f'tree_id = {i}, tree = {tree}') + self.ctx.components.logger().info(f'G = {self._G}') + err = traceback.format_exc() + # self.ctx.components.logger().exception(err) + self.ctx.components.logger().info( + f'plot tree-{i} in times-{retry_num} failed, traceback: {err}.') + time.sleep(random.uniform(0.1, 3)) + + ResultFileHandling._upload_file(self.storage_client, tree_file_path, remote_tree_file_path) + + def _graph_gtree(self, tree, leaf_id=0, depth=0, orient=None, split_info=None): + self._leaf_id += 1 + self._G.add_node(self._leaf_id) + if split_info is not None: + if self.ctx.participant_id_list[split_info.agency_idx] == self.ctx.components.config_data['AGENCY_ID']: + feature = str(self.model.dataset.feature_name[split_info.agency_feature]) + value = str(round(float(self._split[split_info.agency_feature][split_info.value]), 4)) + else: + feature = str(split_info.feature) + value = str(split_info.value) + else: + feature = value = '' + + if isinstance(tree, list): + best_split_info, left_tree, right_tree = tree[0] + if leaf_id != 0: + if orient == 'left': + self._G.add_weighted_edges_from( + [(leaf_id, self._leaf_id, orient+'_'+feature+'_'+value+'_'+str(split_info.w_left))]) + elif orient == 'right': + self._G.add_weighted_edges_from( + [(leaf_id, self._leaf_id, orient+'_'+feature+'_'+value+'_'+str(split_info.w_right))]) + my_leaf_id = self._leaf_id + self._graph_gtree(left_tree, my_leaf_id, depth+1, 'left', best_split_info) + self._graph_gtree(right_tree, my_leaf_id, depth+1, 'right', best_split_info) + else: + if leaf_id != 0: + self._G.add_weighted_edges_from( + [(leaf_id, self._leaf_id, orient+'_'+feature+'_'+value+'_'+str(tree))]) + + +class DiGraphTree(nx.DiGraph): + + def __init__(self): + + super().__init__() + + def tree_leaves(self): + leaves_list = [x for x in self.nodes() if self.out_degree(x)==0 and self.in_degree(x)<=1] + return leaves_list + + def tree_dfs_nodes(self): + nodes_list = list(nx.dfs_preorder_nodes(self)) + return nodes_list + + def tree_dfs_leaves(self): + dfs_leaves = [x for x in self.tree_dfs_nodes() if x in self.tree_leaves()] + return dfs_leaves + + def tree_depth(self): + max_depth = max(nx.shortest_path_length(self, 0).values()) + return max_depth + + def tree_shortest_path(self, node0, node1): + path_length = nx.shortest_path_length(self, node0, node1) + return path_length + + def tree_plot(self, split=True, figsize=(20, 10), dpi=300, save_filename=None): + # plt.cla() + pos = graphviz_layout(self, prog='dot') + # pos = nx.nx_agraph.graphviz_layout(self, prog='dot') + edge_labels = nx.get_edge_attributes(self, 'weight') + + if split: + labels = {} + # leaves = self.tree_leaves() + leaves = [x for x in self.nodes() if self.out_degree(x)==0 and self.in_degree(x)<=1] + + if leaves == [0]: + leaves = [] + self.remove_node(0) + + for n in self.nodes(): + + if n in leaves: + # in_node = list(nx.all_neighbors(self, n))[0] + in_node = list(self.predecessors(n))[0] + weight = edge_labels[(in_node, n)] + try: + labels[n] = round(float(str(weight).split('_')[3]), 4) + except: + labels[n] = str(weight).split('_')[3] + else: + in_node = list(nx.neighbors(self, n))[0] + weight = edge_labels[(n, in_node)] + labels[n] = weight.split('_')[1] + ':' + weight.split('_')[2] + + # for key, value in edge_labels.items(): + # edge_labels[key] = round(float(value.split('_')[3]), 4) + + plt.figure(figsize=figsize, dpi=dpi) + nx.draw(self, pos, + node_size=1000, node_color='#72BFC5', node_shape='o', alpha=None, + with_labels=True, labels=labels, font_weight='normal', font_color='black') + # nx.draw_networkx_edge_labels(self, pos, edge_labels=edge_labels) + # plt.show() + if save_filename is not None: + plt.savefig(save_filename) + else: + plt.show() + + else: + labels = {n: n for n in self.nodes()} + for key, value in edge_labels.items(): + edge_labels[key] = value.split('_')[1] + '-' + value.split('_')[2] + \ + '-' + str(round(float(value.split('_')[3]), 4)) + + plt.figure(figsize=figsize, dpi=dpi) + nx.draw(self, pos, with_labels=True, labels=labels, font_weight='bold') + nx.draw_networkx_edge_labels(self, pos, edge_labels=edge_labels) + # plt.show() + if save_filename is not None: + plt.savefig(save_filename) + else: + plt.show() + + plt.close('all') + gc.collect() diff --git a/python/ppc_model/metrics/test/__init__.py b/python/ppc_model/metrics/test/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/ppc_model/metrics/test/test_metrics.py b/python/ppc_model/metrics/test/test_metrics.py new file mode 100644 index 00000000..3545aaa7 --- /dev/null +++ b/python/ppc_model/metrics/test/test_metrics.py @@ -0,0 +1,165 @@ +import os +import unittest +import numpy as np + +from ppc_model.common.initializer import Initializer +from ppc_common.ppc_mock.mock_objects import MockLogger +from ppc_model.datasets.dataset import SecureDataset +from ppc_model.metrics.evaluation import Evaluation +from ppc_model.metrics.model_plot import ModelPlot, DiGraphTree +from ppc_model.common.model_result import ResultFileHandling +from ppc_model.secure_lgbm.secure_lgbm_context import SecureLGBMContext +from ppc_model.secure_lgbm.vertical import VerticalLGBMActiveParty, VerticalLGBMPassiveParty + + +ACTIVE_PARTY = 'ACTIVE_PARTY' +PASSIVE_PARTY = 'PASSIVE_PARTY' + +data_size = 1000 +feature_dim = 20 + + +def mock_args(): + job_id = 'j-111' + task_id = 't-111' + + model_dict = { + 'objective': 'regression', + 'categorical_feature': [], + 'max_bin': 10, + 'n_estimators': 2, + 'max_depth': 3, + 'use_goss': 1, + 'feature_rate': 0.8, + 'random_state': 2024 + } + + args_a = { + 'job_id': job_id, + 'task_id': task_id, + 'is_label_holder': True, + 'result_receiver_id_list': [ACTIVE_PARTY, PASSIVE_PARTY], + 'participant_id_list': [ACTIVE_PARTY, PASSIVE_PARTY], + 'model_predict_algorithm': None, + 'algorithm_type': 'Train', + 'algorithm_subtype': 'HeteroXGB', + 'model_dict': model_dict + } + + args_b = { + 'job_id': job_id, + 'task_id': task_id, + 'is_label_holder': False, + 'result_receiver_id_list': [ACTIVE_PARTY, PASSIVE_PARTY], + 'participant_id_list': [ACTIVE_PARTY, PASSIVE_PARTY], + 'model_predict_algorithm': None, + 'algorithm_type': 'Train', + 'algorithm_subtype': 'HeteroXGB', + 'model_dict': model_dict + } + + return args_a, args_b + + +class TestXgboostTraining(unittest.TestCase): + + args_a, args_b = mock_args() + + def test_active_metrics(self): + + active_components = Initializer(log_config_path='', config_path='') + active_components.config_data = { + 'JOB_TEMP_DIR': '/tmp/active', 'AGENCY_ID': ACTIVE_PARTY} + active_components.mock_logger = MockLogger() + task_info_a = SecureLGBMContext(self.args_a, active_components) + model_data = SecureDataset.simulate_dataset(data_size, feature_dim, has_label=True) + secure_dataset_a = SecureDataset(task_info_a, model_data) + booster_a = VerticalLGBMActiveParty(task_info_a, secure_dataset_a) + print(secure_dataset_a.feature_name) + print(secure_dataset_a.train_idx.shape) + print(secure_dataset_a.train_X.shape) + print(secure_dataset_a.train_y.shape) + print(secure_dataset_a.test_idx.shape) + print(secure_dataset_a.test_X.shape) + print(secure_dataset_a.test_y.shape) + + # booster_a._train_praba = np.random.rand(len(secure_dataset_a.train_y)) + booster_a._test_praba = np.random.rand(len(secure_dataset_a.test_y)) + + Evaluation(task_info_a, secure_dataset_a, booster_a._train_praba, booster_a._test_praba) + + def test_passive_metrics(self): + + passive_components = Initializer(log_config_path='', config_path='') + passive_components.config_data = { + 'JOB_TEMP_DIR': '/tmp/passive', 'AGENCY_ID': PASSIVE_PARTY} + passive_components.mock_logger = MockLogger() + task_info_b = SecureLGBMContext(self.args_b, passive_components) + model_data = SecureDataset.simulate_dataset(data_size, feature_dim, has_label=False) + secure_dataset_b = SecureDataset(task_info_b, model_data) + booster_b = VerticalLGBMPassiveParty(task_info_b, secure_dataset_b) + print(secure_dataset_b.feature_name) + print(secure_dataset_b.train_idx.shape) + print(secure_dataset_b.train_X.shape) + print(secure_dataset_b.test_idx.shape) + print(secure_dataset_b.test_X.shape) + + # booster_b._train_praba = np.random.rand(len(secure_dataset_b.train_idx)) + booster_b._test_praba = np.random.rand(len(secure_dataset_b.test_idx)) + + Evaluation(task_info_b, secure_dataset_b, booster_b._train_praba, booster_b._test_praba) + + def test_model_plot(self): + + active_components = Initializer(log_config_path='', config_path='') + active_components.config_data = { + 'JOB_TEMP_DIR': '/tmp/active', 'AGENCY_ID': ACTIVE_PARTY} + active_components.mock_logger = MockLogger() + task_info_a = SecureLGBMContext(self.args_a, active_components) + model_data = SecureDataset.simulate_dataset(data_size, feature_dim, has_label=True) + secure_dataset_a = SecureDataset(task_info_a, model_data) + booster_a = VerticalLGBMActiveParty(task_info_a, secure_dataset_a) + if os.path.exists(booster_a.ctx.model_data_file): + booster_a.load_model() + ModelPlot(booster_a) + + def test_digraphtree(self): + Gtree = DiGraphTree() + Gtree.add_node(0) + Gtree.add_nodes_from([1, 2]) + Gtree.add_weighted_edges_from( + [(0, 1, 'left_'+str(2)+'_'+str(3)+'_'+str(0.5)), + (0, 2, 'right_'+str(2)+'_'+str(3)+'_'+str(0.9))]) + Gtree.add_nodes_from([3, 4]) + Gtree.add_weighted_edges_from( + [(1, 3, 'left_'+str(20)+'_'+str(4)+'_'+str(0.5)), + (1, 4, 'right_'+str(20)+'_'+str(4)+'_'+str(0.9))]) + Gtree.add_nodes_from([5, 6]) + Gtree.add_weighted_edges_from( + [(2, 5, 'left_'+str(2)+'_'+str(7)+'_'+str(0.5)), + (2, 6, 'right_'+str(2)+'_'+str(7)+'_'+str(0.9))]) + Gtree.add_nodes_from([7, 8]) + Gtree.add_weighted_edges_from( + [(3, 7, 'left_'+str(1)+'_'+str(11)+'_'+str(0.5)), + (3, 8, 'right_'+str(1)+'_'+str(11)+'_'+str(0.9))]) + Gtree.add_nodes_from([9, 10]) + Gtree.add_weighted_edges_from( + [(4, 9, 'left_'+str(18)+'_'+str(2)+'_'+str(0.5)), + (4, 10, 'right_'+str(18)+'_'+str(2)+'_'+str(0.9))]) + Gtree.add_nodes_from([11, 12]) + Gtree.add_weighted_edges_from( + [(5, 11, 'left_'+str(23)+'_'+str(25)+'_'+str(0.5)), + (5, 12, 'right_'+str(23)+'_'+str(25)+'_'+str(0.9))]) + Gtree.add_nodes_from([13, 14]) + Gtree.add_weighted_edges_from( + [(6, 13, 'left_'+str(16)+'_'+str(10)+'_'+str(0.5)), + (6, 14, 'right_'+str(16)+'_'+str(10)+'_'+str(0.9))]) + + # Gtree.tree_plot() + # Gtree.tree_plot(split=False, figsize=(10, 5)) + # Gtree.tree_plot(figsize=(6, 3)) + Gtree.tree_plot(figsize=(10, 5), save_filename='tree.svg') + + +if __name__ == '__main__': + unittest.main() diff --git a/python/ppc_model/model_result/task_result_handler.py b/python/ppc_model/model_result/task_result_handler.py new file mode 100644 index 00000000..935446d1 --- /dev/null +++ b/python/ppc_model/model_result/task_result_handler.py @@ -0,0 +1,388 @@ +# -*- coding: utf-8 -*- +from ppc_common.ppc_utils.utils import PpcException, PpcErrorCode +from ppc_common.ppc_utils import utils +from ppc_model.common.protocol import ModelTask +from ppc_common.ppc_ml.model.algorithm_info import ClassificationType +from ppc_model.common.model_result import ResultFileHandling +from ppc_common.ppc_ml.model.algorithm_info import EvaluationType +from ppc_model.common.base_context import BaseContext +from enum import Enum + + +class TaskResultRequest: + def __init__(self, job_id, task_type): + self.job_id = job_id + self.task_type = task_type + + +class DataType(Enum): + TEXT = "str", + TABLE = "table", + IMAGE = "image" + + +class DataItem: + DEFAULT_NAME_PROPERTY = "metricsName" + DEFAULT_DATA_PROPERTY = "metricsData" + DEFAULT_TYPE_PROPERTY = "metricsType" + + def __init__(self, name, data, type, + name_property=DEFAULT_NAME_PROPERTY, + data_property=DEFAULT_DATA_PROPERTY, + type_property=DEFAULT_TYPE_PROPERTY): + self.name = name + self.data = data + self.type = type + self.name_property = name_property + self.data_property = data_property + self.type_property = type_property + + def to_dict(self): + return {self.name_property: self.name, + self.data_property: self.data, + self.type_property: self.type.name} + + +class ResultFileMeta: + def __init__(self, table_file_name, retrieve_lines=-1): + self.table_file_name = table_file_name + self.retrieve_lines = retrieve_lines + + +class JobEvaluationResult: + DEFAULT_TRAIN_EVALUATION_FILES = { + EvaluationType.ROC: utils.MPC_TRAIN_METRIC_ROC_FILE, + EvaluationType.PR: utils.MPC_TRAIN_METRIC_PR_FILE, + EvaluationType.KS: utils.MPC_TRAIN_METRIC_KS_FILE, + EvaluationType.ACCURACY: utils.MPC_TRAIN_METRIC_ACCURACY_FILE, + EvaluationType.CONFUSION_MATRIX: utils.MPC_TRAIN_METRIC_CONFUSION_MATRIX_FILE} + + DEFAULT_VALIDATION_EVALUATION_FILES = { + EvaluationType.ROC: utils.MPC_TRAIN_SET_METRIC_ROC_FILE, + EvaluationType.PR: utils.MPC_TRAIN_SET_METRIC_PR_FILE, + EvaluationType.KS: utils.MPC_TRAIN_SET_METRIC_KS_FILE, + EvaluationType.ACCURACY: utils.MPC_TRAIN_SET_METRIC_ACCURACY_FILE} + + DEFAULT_EVAL_EVALUATION_FILES = { + EvaluationType.ROC: utils.MPC_EVAL_METRIC_ROC_FILE, + EvaluationType.PR: utils.MPC_EVAL_METRIC_PR_FILE, + EvaluationType.KS: utils.MPC_EVAL_METRIC_KS_FILE, + EvaluationType.ACCURACY: utils.MPC_EVAL_METRIC_ACCURACY_FILE + } + + def __init__(self, property_name, classification_type, + job_id, evaluation_files, components): + self.job_id = job_id + self.classification_type = classification_type + self.components = components + self.logger = self.components.logger() + self.classification_type = classification_type + self.property_name = property_name + self.evaluation_files = evaluation_files + self.evaluation_results = [] + try: + self._fetch_evaluation_result() + self._fetch_two_classifcation_evaluation_result() + self._fetch_multi_classifcation_evaluation_result() + except Exception as e: + pass + + def _fetch_evaluation_result(self): + self.logger.info( + f"fetch roc-evaluation from: {self.evaluation_files[EvaluationType.ROC]}") + self.evaluation_results.append(DataItem("ROC", ResultFileHandling.make_graph_data( + self.components, + self.job_id, + self.evaluation_files[EvaluationType.ROC]), + DataType.IMAGE)) + self.logger.info( + f"fetch pr-evaluation from: {self.evaluation_files[EvaluationType.PR]}") + self.evaluation_results.append(DataItem("Precision Recall", ResultFileHandling.make_graph_data( + self.components, + self.job_id, + self.evaluation_files[EvaluationType.PR]), DataType.IMAGE)) + + def _fetch_two_classifcation_evaluation_result(self): + if self.classification_type is not ClassificationType.TWO: + return + + self.logger.info( + f"fetch ks-evaluation from: {self.evaluation_files[EvaluationType.KS]}") + self.evaluation_results.append(DataItem("K-S", ResultFileHandling.make_graph_data( + self.components, + self.job_id, + self.evaluation_files[EvaluationType.KS]), + DataType.IMAGE)) + + self.logger.info( + f"fetch accuracy-evaluation from: {self.evaluation_files[EvaluationType.ACCURACY]}") + self.evaluation_results.append(DataItem("Accuracy", + ResultFileHandling.make_graph_data( + self.components, + self.job_id, + self.evaluation_files[EvaluationType.ACCURACY]), + DataType.IMAGE)) + + def _fetch_multi_classifcation_evaluation_result(self): + if self.classification_type is not ClassificationType.MULTI: + return + self.logger.info( + f"fetch confusion-matrix-evaluation from: {self.evaluation_files[EvaluationType.CONFUSION_MATRIX]}") + self.evaluation_results.append(DataItem("Confusion Matrix", + ResultFileHandling.make_graph_data(self.components, + self.job_id, + self.evaluation_files[EvaluationType.CONFUSION_MATRIX]), + DataType.IMAGE)) + + def load_ks_table(self, ks_table_file, ks_table_property): + ks_table_object = TableResult(components=self.components, + job_id=self.job_id, file_meta=ResultFileMeta(table_file_name=ks_table_file)) + self.ks_table = ks_table_object.to_dict() + self.ks_table_property = ks_table_property + + def to_dict(self): + evaluation_result_list = [] + for evaluation in self.evaluation_results: + evaluation_result_list.append(evaluation.to_dict()) + result = {self.property_name: evaluation_result_list} + if self.ks_table is not None: + result.update({self.ks_table_property: self.ks_table}) + return result + + +class TableResult: + def __init__(self, components, job_id, file_meta): + self.components = components + self.job_id = job_id + self.file_meta = file_meta + + def to_dict(self): + try: + df = ResultFileHandling.make_csv_data(self.components, self.job_id, + self.file_meta.table_file_name) + csv_columns = list(df.columns) + + if self.file_meta.retrieve_lines == -1 or df.shape[0] <= self.file_meta.retrieve_lines: + csv_data = df.values.tolist() + else: + csv_data = df.iloc[:self.file_meta.retrieve_lines].values.tolist( + ) + return {'columns': csv_columns, 'data': csv_data} + except Exception as e: + pass + + +class FeatureProcessingResult: + DEFAULT_FEATURE_PROCESSING_FILES = { + "PRPreview": ResultFileMeta("xgb_result_column_info_selected.csv"), + "FEPreview": ResultFileMeta("woe_iv.csv", 5)} + + def __init__(self, components, job_id, file_infos): + self.components = components + self.job_id = job_id + self.file_infos = file_infos + self.result = dict() + self._fetch_result() + + def _fetch_result(self): + for property in self.file_infos.keys(): + table_info = TableResult(self.components, + self.job_id, self.file_infos[property]).to_dict() + self.result.update({property: table_info}) + + def to_dict(self): + return self.result + + +class XGBJobResult: + DEFAULT_PROPERTY_NAME = "outputModelResult" + MODEL_RESULT = "ModelResult" + MODEL_RESULT_PATH = "modelResultPath" + TRAIN_RESULT_PATH = "trainResultPath" + TEST_RESULT_PATH = "testResultPath" + WOE_RESULT_PATH = "woeIVResultPath" + + def __init__(self, job_id, components, property_name=DEFAULT_PROPERTY_NAME): + self.job_id = job_id + self.components = components + self.logger = components.logger() + self.property_name = property_name + self.model_result_list = None + self.job_result = None + self.model_result_path = None + self.train_result_path = None + self.woe_iv_result_path = None + self.xgb_result_path = None + self.evaluation_table = None + self.feature_importance_table = None + self.iteration_metrics = None + + def fetch_model_result(self): + self.model_result_list = [] + i = 0 + # while True: + while i < 6: + try: + tree_data = DataItem(data=ResultFileHandling.make_graph_data(self.components, + self.job_id, + utils.XGB_TREE_PERFIX + '_' + str(i) + '.svg'), + name='tree-' + str(i), name_property="ModelPlotName", data_property="ModelPlotData", + type=DataType.IMAGE) + self.model_result_list.append(tree_data.to_dict()) + i += 1 + except Exception: + break + + def load_result(self, result_path, result_property): + self.result_property = result_property + job_result_object = TableResult(self.components, + self.job_id, ResultFileMeta(result_path, 5)) + self.job_result = job_result_object.to_dict() + + def load_model_result_path(self, predict: bool): + self.xgb_result_path = dict() + self.model_result_path = ResultFileHandling.get_remote_path( + self.components, self.job_id, BaseContext.MODEL_DATA_FILE) + self.xgb_result_path.update( + {XGBJobResult.MODEL_RESULT_PATH: self.model_result_path}) + + self.train_result_path = ResultFileHandling.get_remote_path( + self.components, self.job_id, BaseContext.TRAIN_MODEL_OUTPUT_FILE) + self.xgb_result_path.update( + {XGBJobResult.TRAIN_RESULT_PATH: self.train_result_path}) + + self.xgb_result_path.update( + {XGBJobResult.TEST_RESULT_PATH: ResultFileHandling.get_remote_path( + self.components, self.job_id, BaseContext.TEST_MODEL_OUTPUT_FILE)}) + + self.woe_iv_result_path = ResultFileHandling.get_remote_path( + self.components, self.job_id, BaseContext.WOE_IV_FILE) + self.xgb_result_path.update( + {XGBJobResult.WOE_RESULT_PATH: self.woe_iv_result_path}) + + def load_evaluation_table(self, evaluation_path, property): + evaluation_table_object = TableResult(self.components, + self.job_id, ResultFileMeta(evaluation_path)) + self.evaluation_table = {property: DataItem(name=property, data=evaluation_table_object.to_dict(), + type=DataType.TABLE).to_dict()} + + def load_feature_importance_table(self, feature_importance_path, property): + feature_importance_table = TableResult(self.components, + self.job_id, ResultFileMeta(feature_importance_path)) + self.feature_importance_table = {property: DataItem(name=property, data=feature_importance_table.to_dict(), + type=DataType.TABLE).to_dict()} + + def load_iteration_metrics(self, iteration_path, property): + try: + iteration_metrics_data = DataItem(data=ResultFileHandling.make_graph_data(self.components, self.job_id, utils.METRICS_OVER_ITERATION_FILE), + name='iteration_metrics', name_property="ModelPlotName", data_property="ModelPlotData", + type=DataType.IMAGE) + self.iteration_metrics = [] + self.iteration_property = property + self.iteration_metrics.append(iteration_metrics_data.to_dict()) + except: + pass + + def to_dict(self): + result = dict() + if self.model_result_list is not None: + result.update({self.property_name: self.model_result_list}) + if self.job_result is not None: + result.update({self.result_property: self.job_result}) + if self.evaluation_table is not None: + result.update(self.evaluation_table) + if self.feature_importance_table is not None: + result.update(self.feature_importance_table) + if self.iteration_metrics is not None: + result.update({self.iteration_property: self.iteration_metrics}) + if self.xgb_result_path is not None: + result.update( + {XGBJobResult.MODEL_RESULT: self.xgb_result_path}) + return result + + +class TaskResultHandler: + def __init__(self, task_result_request: TaskResultRequest, components): + self.task_result_request = task_result_request + self.components = components + self.logger = components.logger() + self.result_list = [] + self.predict = False + if self.task_result_request.task_type == ModelTask.XGB_PREDICTING.name: + self.predict = True + self.logger.info( + f"Init jobResultHandler for: {self.task_result_request.job_id}") + self._get_evaluation_result() + self._get_feature_processing_result() + + def get_response(self): + merged_result = dict() + for result in self.result_list: + merged_result.update(result.to_dict()) + response = {"jobPlanetResult": merged_result} + return utils.make_response(PpcErrorCode.SUCCESS.get_code(), PpcErrorCode.SUCCESS.get_msg(), response) + + def _get_evaluation_result(self): + if self.task_result_request.task_type == ModelTask.XGB_TRAINING.name: + # the train evaluation result + self.train_evaluation_result = JobEvaluationResult( + property_name="outputMetricsGraphs", + classification_type=ClassificationType.TWO, + job_id=self.task_result_request.job_id, + evaluation_files=JobEvaluationResult.DEFAULT_TRAIN_EVALUATION_FILES, + components=self.components) + # load the ks table + self.train_evaluation_result.load_ks_table( + "mpc_train_metric_ks.csv", "TrainKSTable") + self.result_list.append(self.train_evaluation_result) + + self.validation_evaluation_result = JobEvaluationResult( + property_name="outputTrainMetricsGraphs", + classification_type=ClassificationType.TWO, + job_id=self.task_result_request.job_id, + evaluation_files=JobEvaluationResult.DEFAULT_VALIDATION_EVALUATION_FILES, + components=self.components) + # load the ks_table + self.validation_evaluation_result.load_ks_table( + "mpc_metric_ks.csv", "KSTable") + self.result_list.append(self.validation_evaluation_result) + + self.xgb_model = XGBJobResult( + self.task_result_request.job_id, self.components, XGBJobResult.DEFAULT_PROPERTY_NAME) + self.xgb_model.fetch_model_result() + # the ks-auc table + self.xgb_model.load_evaluation_table( + utils.MPC_XGB_EVALUATION_TABLE, "EvaluationTable") + # the feature-importance table + self.xgb_model.load_feature_importance_table( + utils.XGB_FEATURE_IMPORTANCE_TABLE, "FeatureImportance") + self.result_list.append(self.xgb_model) + # the metrics iteration graph + self.xgb_model.load_iteration_metrics( + utils.METRICS_OVER_ITERATION_FILE, "IterationGraph") + + if self.predict: + # the train evaluation result + self.predict_evaluation_result = JobEvaluationResult( + property_name="outputMetricsGraphs", + classification_type=ClassificationType.TWO, + job_id=self.task_result_request.job_id, + evaluation_files=JobEvaluationResult.DEFAULT_EVAL_EVALUATION_FILES, + components=self.components) + # load ks_table + self.predict_evaluation_result.load_ks_table( + "mpc_eval_metric_ks.csv", "KSTable") + self.result_list.append(self.predict_evaluation_result) + + # load xgb_result + self.xgb_result = XGBJobResult( + self.task_result_request.job_id, self.components, XGBJobResult.DEFAULT_PROPERTY_NAME) + self.xgb_result.load_result( + "xgb_train_output.csv", "outputTrainPreview") + self.xgb_result.load_model_result_path(self.predict) + self.result_list.append(self.xgb_result) + + def _get_feature_processing_result(self): + self.feature_processing_result = FeatureProcessingResult( + self.components, self.task_result_request.job_id, FeatureProcessingResult.DEFAULT_FEATURE_PROCESSING_FILES) + self.result_list.append(self.feature_processing_result) diff --git a/python/ppc_model/network/__init__.py b/python/ppc_model/network/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/ppc_model/network/grpc/__init__.py b/python/ppc_model/network/grpc/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/ppc_model/network/grpc/grpc_client.py b/python/ppc_model/network/grpc/grpc_client.py new file mode 100644 index 00000000..8169f522 --- /dev/null +++ b/python/ppc_model/network/grpc/grpc_client.py @@ -0,0 +1,82 @@ +import os +import time +import traceback + +import grpc + +from ppc_common.ppc_protos.generated.ppc_model_pb2 import ModelRequest, ModelResponse +from ppc_common.ppc_protos.generated.ppc_model_pb2_grpc import ModelServiceStub +from ppc_common.ppc_utils import utils +from ppc_model.common.protocol import RpcType +from ppc_model.interface.rpc_client import RpcClient + + +class GrpcClient(RpcClient): + rpc_type = RpcType.GRPC + + def __init__(self, logger, endpoint: str, grpc_options, ssl_switch: int = 0, + ca_path=None, ssl_key_path=None, ssl_crt_path=None): + self._logger = logger + self._endpoint = endpoint + self._ssl_switch = ssl_switch + self._grpc_options = grpc_options + self._ca_path = ca_path + self._ssl_key_path = ssl_key_path + self._ssl_crt_path = ssl_crt_path + if self._ssl_switch == 0: + insecure_channel = grpc.insecure_channel( + self._endpoint, options=grpc_options) + self._client = ModelServiceStub(insecure_channel) + else: + channel = self._create_secure_channel(self._endpoint) + self._client = ModelServiceStub(channel) + + def _create_secure_channel(self, target): + grpc_root_crt = utils.load_credential_from_file( + os.path.abspath(self._ca_path)) + grpc_ssl_key = utils.load_credential_from_file( + os.path.abspath(self._ssl_key_path)) + grpc_ssl_crt = utils.load_credential_from_file( + os.path.abspath(self._ssl_crt_path)) + credentials = grpc.ssl_channel_credentials( + root_certificates=grpc_root_crt, + private_key=grpc_ssl_key, + certificate_chain=grpc_ssl_crt + ) + return grpc.secure_channel(target, credentials, options=self._grpc_options) + + @staticmethod + def _build_error_model_response(message: str): + model_response = ModelResponse() + model_response.base_response.error_code = -1 + model_response.base_response.message = message + return model_response + + def send(self, request: ModelRequest): + start_time = time.time() + try: + self._logger.debug( + f"start sending data to {request.receiver}, task_id: {request.task_id}, " + f"key: {request.key}, seq: {request.seq}") + response = self._client.MessageInteraction(request) + end_time = time.time() + if response.base_response.error_code != 0: + self._logger.warn( + f"[OnWarn]send data to {request.receiver} failed, task_id: {request.task_id}, " + f"key: {request.key}, seq: {request.seq}, slice_num: {request.slice_num}, " + f"ret_code: {response.base_response.error_code}, message: {response.base_response.message}, " + f"time_costs: {str(end_time - start_time)}s") + else: + self._logger.info( + f"finish sending data to {request.receiver}, task_id: {request.task_id}, " + f"key: {request.key}, seq: {request.seq}, slice_num: {request.slice_num}, " + f"ret_code: {response.base_response.error_code}, message: {response.base_response.message}, " + f"time_costs: {str(end_time - start_time)}s") + except Exception: + end_time = time.time() + message = f"[OnWarn]Send data to {request.receiver} failed, task_id: {request.task_id}, " \ + f"key: {request.key}, seq: {request.seq}, slice_num: {request.slice_num}, " \ + f"exception:{str(traceback.format_exc())}, time_costs: {str(end_time - start_time)}s" + self._logger.warn(message) + response = self._build_error_model_response(message) + return response diff --git a/python/ppc_model/network/grpc/grpc_server.py b/python/ppc_model/network/grpc/grpc_server.py new file mode 100644 index 00000000..389ba602 --- /dev/null +++ b/python/ppc_model/network/grpc/grpc_server.py @@ -0,0 +1,17 @@ +from ppc_common.ppc_protos.generated import ppc_model_pb2_grpc +from ppc_common.ppc_protos.generated.ppc_model_pb2 import ModelRequest, ModelResponse +from ppc_model.common.global_context import components + + +class ModelService(ppc_model_pb2_grpc.ModelServiceServicer): + + def MessageInteraction(self, model_request: ModelRequest, context): + components.logger().info( + f"receive a package, sender: {model_request.sender}, task_id: {model_request.task_id}, " + f"key: {model_request.key}, seq: {model_request.seq}, slice_num: {model_request.slice_num}") + + components.stub.on_message_received(model_request) + model_response = ModelResponse() + model_response.base_response.error_code = 0 + model_response.base_response.message = 'success' + return model_response diff --git a/python/ppc_model/network/http/__init__.py b/python/ppc_model/network/http/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/ppc_model/network/http/body_schema.py b/python/ppc_model/network/http/body_schema.py new file mode 100644 index 00000000..aacb7616 --- /dev/null +++ b/python/ppc_model/network/http/body_schema.py @@ -0,0 +1,19 @@ +# -*- coding: utf-8 -*- +import json + +from flask_restx import fields, reqparse + +from ppc_model.network.http.restx import api + +response_base = api.model('Response base info', { + 'errorCode': fields.Integer(description='return code'), + 'message': fields.String(description='return message') +}) + +response_task_status = api.inherit('Task status', response_base, { + 'data': fields.Raw(description='Task status data as key-value dictionary', example={ + 'status': 'RUNNING', + 'traffic_volume': '10MB', + 'time_costs': '30s' + }) +}) diff --git a/python/ppc_model/network/http/model_controller.py b/python/ppc_model/network/http/model_controller.py new file mode 100644 index 00000000..3ebbac2f --- /dev/null +++ b/python/ppc_model/network/http/model_controller.py @@ -0,0 +1,95 @@ +# -*- coding: utf-8 -*- + +from flask import request +from flask_restx import Resource +import time + +from ppc_common.ppc_utils import utils +from ppc_model.common.global_context import components +from ppc_model.common.protocol import ModelTask +from ppc_model.network.http.body_schema import response_task_status, response_base +from ppc_model.network.http.restx import api +from ppc_model.model_result.task_result_handler import TaskResultHandler +from ppc_model.model_result.task_result_handler import TaskResultRequest + +ns = api.namespace('ppc-model/pml/run-model-task', + description='Operations related to run model task') +ns2 = api.namespace('ppc-model/pml/record-model-log', + description='Operations related to record model log') +ns_get_job_result = api.namespace( + 'ppc-model/pml/get-job-result', description='Get the job result') + + +@ns.route('/') +class ModelCollection(Resource): + + @api.response(201, 'Task started successfully.', response_base) + def post(self, model_id): + """ + Run a specific task by task_id. + """ + args = request.get_json() + task_id = model_id + components.logger().info(f"run task request, task_id: {task_id}, args: {args}") + task_type = args['task_type'] + components.task_manager.run_task( + task_id, ModelTask(task_type), (args,)) + return utils.BASE_RESPONSE + + @api.response(200, 'Task status retrieved successfully.', response_task_status) + def get(self, model_id): + """ + Get the status of a specific task by task_id. + """ + response = utils.BASE_RESPONSE + task_id = model_id + status, traffic_volume, time_costs = components.task_manager.status(task_id) + response['data'] = { + 'status': status, + 'traffic_volume': traffic_volume, + 'time_costs': time_costs + } + return response + + @api.response(200, 'Task killed successfully.', response_base) + def delete(self, model_id): + """ + Kill a specific task by job_id. + """ + job_id = model_id + components.logger().info(f"kill request, job_id: {job_id}") + components.task_manager.kill_task(job_id) + return utils.BASE_RESPONSE + + +@ns2.route('/') +class ModelLogCollection(Resource): + @api.response(200, 'Task status retrieved successfully.', response_task_status) + def get(self, job_id): + log_content = components.task_manager.record_model_job_log(job_id) + return utils.make_response(utils.PpcErrorCode.SUCCESS.get_code(), + utils.PpcErrorCode.SUCCESS.get_msg(), log_content) + + +@ns_get_job_result.route('/') +class ModelResultCollection(Resource): + @api.response(201, 'Get task result successfully.', response_base) + def post(self, task_id): + """ + Get the result related to the task_id + """ + start_t = time.time() + args = request.get_json() + components.logger().info( + f"run task request, task_id: {task_id}, args: {args}") + user_name = args['user'] + task_type = args['jobType'] + components.logger().info( + f"get_job_direct_result_response, job: {task_id}") + task_result_request = TaskResultRequest(task_id, task_type) + job_result_handler = TaskResultHandler( + task_result_request=task_result_request, components=components) + response = job_result_handler.get_response() + components.logger().info( + f"get_job_direct_result_response success, user: {user_name}, job: {task_id}, timecost: {time.time() - start_t}s") + return response diff --git a/python/ppc_model/network/http/restx.py b/python/ppc_model/network/http/restx.py new file mode 100644 index 00000000..83dba0ce --- /dev/null +++ b/python/ppc_model/network/http/restx.py @@ -0,0 +1,35 @@ +# -*- coding: utf-8 -*- + +from flask_restx import Api + +from ppc_common.ppc_utils.exception import PpcException, PpcErrorCode +from ppc_model.common.global_context import components + +authorizations = { + 'apikey': { + 'type': 'apiKey', + 'in': 'header', + 'name': 'Authorization' + } +} + +api = Api(version='1.0', title='Ppc Model Service', + authorizations=authorizations, security='apikey') + + +@api.errorhandler(PpcException) +def default_error_handler(e): + components.logger().exception(e) + info = e.to_dict() + response = {'errorCode': info['code'], 'message': info['message']} + components.logger().error(f"OnError: code: {info['code']}, message: {info['message']}") + return response, 500 + + +@api.errorhandler(BaseException) +def default_error_handler(e): + components.logger().exception(e) + message = 'unknown error.' + response = {'errorCode': PpcErrorCode.INTERNAL_ERROR, 'message': message} + components.logger().error(f"OnError: unknown error") + return response, 500 diff --git a/python/ppc_model/network/stub.py b/python/ppc_model/network/stub.py new file mode 100644 index 00000000..7db22b1e --- /dev/null +++ b/python/ppc_model/network/stub.py @@ -0,0 +1,218 @@ +import os +import time +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass +from typing import Dict, Union + +from readerwriterlock import rwlock + +from ppc_common.ppc_async_executor.thread_event_manager import ThreadEventManager +from ppc_common.ppc_protos.generated.ppc_model_pb2 import ModelRequest +from ppc_common.ppc_utils.exception import PpcException, PpcErrorCode + + +@dataclass +class PushRequest: + receiver: str # 数据接收方的机构ID + task_id: str + key: str # 数据键 + data: bytes # 二进制数据 + slice_size_MB: int = 2 # 切片大小,默认为2MB + + def slice_data(self): + """将 data 按 slice_size 进行切片""" + if not self.data: + return [b''] + slice_size = self.slice_size_MB * 1024 * 1024 + return [self.data[i:i + slice_size] for i in + range(0, len(self.data), slice_size)] + + +@dataclass +class PullRequest: + sender: str # 数据发送方的机构ID + task_id: str + key: str # 数据键 + + +class ModelStub: + def __init__( + self, + agency_id: str, + thread_event_manager: ThreadEventManager, + rpc_client, + send_retry_times: int = 3, + retry_interval_s: Union[int, float] = 5 + ) -> None: + self.agency_id = agency_id + self._thread_event_manager = thread_event_manager + self._rpc_client = rpc_client + self._executor = ThreadPoolExecutor(max_workers=max(1, os.cpu_count() - 1)) + self._send_retry_times = send_retry_times + self._retry_interval_s = retry_interval_s + # 缓存收到的消息 [task_id:[sender:[key:[seq: data]]]] + # 缓存清理由TaskManager完成 + self._received_data: Dict[str, Dict[str, + Dict[str, Dict[int, tuple[int, bytes]]]]] = {} + self._received_uuid: Dict[str, set[str]] = {} + self._traffic_volume: Dict[str, int] = {} + self._data_rw_lock = rwlock.RWLockWrite() + + def push(self, request: PushRequest) -> bytes: + """ + 发送消息 + param request: 消息请求 + """ + slices = request.slice_data() + futures = [] + for seq, data in enumerate(slices): + model_request = ModelRequest( + sender=self.agency_id, + receiver=request.receiver, + task_id=request.task_id, + key=request.key, + seq=seq, + slice_num=len(slices), + data=data + ) + future = self._executor.submit( + self._send_with_retry, model_request) + futures.append(future) + + ret = bytearray() + for future in futures: + ret.extend(future.result()) + + self._accumulate_traffic_volume(request.task_id, len(request.data)) + + return bytes(ret) + + def pull(self, pull_request: PullRequest) -> bytes: + """ + 接收消息 + param request: 待收消息元信息 + return 消息 + """ + task_id = pull_request.task_id + sender = pull_request.sender + key = pull_request.key + while not self._thread_event_manager.event_status(task_id): + if self._is_all_data_ready(pull_request): + ret = bytearray() + with self._data_rw_lock.gen_rlock(): + slice_num = len(self._received_data[task_id][sender][key]) + for seq in range(slice_num): + ret.extend( + self._received_data[task_id][sender][key][seq][1]) + # 缓存中删除已获取到的数据 + with self._data_rw_lock.gen_wlock(): + del self._received_data[task_id][sender][key] + self._accumulate_traffic_volume(task_id, len(ret)) + return bytes(ret) + # 任务还在执行, 休眠后继续尝试获取数据 + time.sleep(0.04) + + # 接收到杀任务的信号 + raise PpcException(PpcErrorCode.TASK_IS_KILLED.get_code(), + PpcErrorCode.TASK_IS_KILLED.get_msg()) + + def traffic_volume(self, task_id) -> float: + with self._data_rw_lock.gen_rlock(): + if task_id not in self._traffic_volume: + return 0 + return self._traffic_volume[task_id] / 1024 / 1024 + + def on_message_received(self, model_request: ModelRequest): + """ + 注册给服务端的回调,服务端收到消息后调用 + param model_request: 收到的消息 + """ + # 消息幂等 + if not self._is_new_data(model_request): + return + # 缓存数据 + self._handle_received_data(model_request) + + def cleanup_cache(self, task_id): + with self._data_rw_lock.gen_wlock(): + if task_id in self._received_data: + del self._received_data[task_id] + if task_id in self._received_uuid: + del self._received_uuid[task_id] + if task_id in self._traffic_volume: + del self._traffic_volume[task_id] + + def _is_new_data(self, model_request: ModelRequest) -> bool: + # 返回是否需要继续处理消息 + task_id = model_request.task_id + uuid = f"{task_id}:{model_request.sender}:{model_request.key}:{model_request.seq}" + with self._data_rw_lock.gen_wlock(): + if task_id in self._received_uuid and uuid in self._received_uuid[task_id]: + # 收到重复的消息 + return False + elif task_id in self._received_uuid and uuid not in self._received_uuid[task_id]: + # 收到task_id的新消息 + self._received_uuid[task_id].add(uuid) + else: + # 首次收到task_id的消息 + self._received_uuid[task_id] = {uuid} + return True + + def _handle_received_data(self, model_request: ModelRequest): + task_id = model_request.task_id + sender = model_request.sender + key = model_request.key + seq = model_request.seq + slice_num = model_request.slice_num + data = model_request.data + with self._data_rw_lock.gen_wlock(): + if task_id not in self._received_data: + self._received_data[task_id] = { + model_request.sender: {key: {seq: (slice_num, data)}}} + elif sender not in self._received_data[task_id]: + self._received_data[task_id][sender] = { + key: {seq: (slice_num, data)}} + elif key not in self._received_data[task_id][sender]: + self._received_data[task_id][sender][key] = { + seq: (slice_num, data)} + else: + self._received_data[task_id][sender][key][seq] = ( + slice_num, data) + + def _is_all_data_ready(self, pull_request: PullRequest): + task_id = pull_request.task_id + sender = pull_request.sender + key = pull_request.key + with self._data_rw_lock.gen_rlock(): + if task_id not in self._received_data: + return False + if sender not in self._received_data[task_id]: + return False + if key not in self._received_data[task_id][sender]: + return False + if len(self._received_data[task_id][sender][key]) == 0: + return False + _, first_value = next( + iter(self._received_data[task_id][sender][key].items())) + if first_value[0] != len(self._received_data[task_id][sender][key]): + return False + return True + + def _send_with_retry(self, model_request: ModelRequest): + retry_times = 0 + while retry_times <= self._send_retry_times: + model_response = self._rpc_client.send(model_request) + if model_response.base_response.error_code == 0: + return model_response.data + if retry_times <= self._send_retry_times: + retry_times += 1 + time.sleep(self._retry_interval_s) + else: + raise PpcException(PpcErrorCode.NETWORK_ERROR.get_code( + ), model_response.base_response.message) + + def _accumulate_traffic_volume(self, task_id, length): + with self._data_rw_lock.gen_wlock(): + if task_id not in self._traffic_volume: + self._traffic_volume[task_id] = 0 + self._traffic_volume[task_id] += length diff --git a/python/ppc_model/network/test/__init__.py b/python/ppc_model/network/test/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/ppc_model/network/test/stub_unittest.py b/python/ppc_model/network/test/stub_unittest.py new file mode 100644 index 00000000..63e8e5c7 --- /dev/null +++ b/python/ppc_model/network/test/stub_unittest.py @@ -0,0 +1,83 @@ +import unittest + +from ppc_common.ppc_async_executor.thread_event_manager import ThreadEventManager +from ppc_model.common.mock.rpc_client_mock import RpcClientMock +from ppc_model.network.stub import ModelStub, PushRequest, PullRequest + + +class TestStub(unittest.TestCase): + def setUp(self): + super().__init__() + self._agency_id = 'TEST_AGENCY' + self._message_type = 'TEST_MESSAGE' + self._rpc_client = RpcClientMock() + self._thread_event_manager = ThreadEventManager() + self._stub = ModelStub( + agency_id=self._agency_id, + thread_event_manager=self._thread_event_manager, + rpc_client=self._rpc_client, + send_retry_times=3, + retry_interval_s=0.1 + ) + self._rpc_client.set_message_handler(self._stub.on_message_received) + + def test_push_pull(self): + task_id = '0x12345678' + byte_array = bytearray(31 * 1024 * 1024) + bytes_data = bytes(byte_array) + self._stub.push(PushRequest( + receiver=self._agency_id, + task_id=task_id, + key=self._message_type, + data=bytes_data + )) + self._stub.push(PushRequest( + receiver=self._agency_id, + task_id=task_id, + key=self._message_type + 'other', + data=bytes_data + )) + received_data = self._stub.pull(PullRequest( + sender=self._agency_id, + task_id=task_id, + key=self._message_type + )) + other_data = self._stub.pull(PullRequest( + sender=self._agency_id, + task_id=task_id, + key=self._message_type + 'other', + )) + self.assertEqual(bytes_data, received_data) + self.assertEqual(bytes_data, other_data) + + def test_bad_client(self): + rpc_client = RpcClientMock(need_failed=True) + stub = ModelStub( + agency_id=self._agency_id, + thread_event_manager=self._thread_event_manager, + rpc_client=rpc_client, + send_retry_times=3, + retry_interval_s=0.1 + ) + rpc_client.set_message_handler(stub.on_message_received) + + task_id = '0x12345678' + byte_array = bytearray(3 * 1024 * 1024) + bytes_data = bytes(byte_array) + stub.push(PushRequest( + receiver=self._agency_id, + task_id=task_id, + key=self._message_type, + data=bytes_data + )) + + received_data = stub.pull(PullRequest( + sender=self._agency_id, + task_id=task_id, + key=self._message_type + )) + self.assertEqual(bytes_data, received_data) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/ppc_model/ppc_model_app.py b/python/ppc_model/ppc_model_app.py new file mode 100644 index 00000000..baef888b --- /dev/null +++ b/python/ppc_model/ppc_model_app.py @@ -0,0 +1,112 @@ +# Note: here can't be refactored by autopep +import sys +sys.path.append("../") + +import multiprocessing +import os +from concurrent import futures +from threading import Thread + +import grpc +from cheroot.ssl.builtin import BuiltinSSLAdapter +from cheroot.wsgi import Server as WSGIServer +from flask import Flask, Blueprint +from paste.translogger import TransLogger + +from ppc_common.ppc_protos.generated import ppc_model_pb2_grpc +from ppc_common.ppc_utils import utils +from ppc_model.common.global_context import components +from ppc_model.common.protocol import ModelTask +from ppc_model.feature_engineering.feature_engineering_engine import FeatureEngineeringEngine +from ppc_model.network.grpc.grpc_server import ModelService +from ppc_model.network.http.model_controller import ns as task_namespace +from ppc_model.network.http.model_controller import ns2 as log_namespace +from ppc_model.network.http.restx import api +from ppc_model.preprocessing.preprocessing_engine import PreprocessingEngine +from ppc_model.secure_lgbm.secure_lgbm_prediction_engine import SecureLGBMPredictionEngine +from ppc_model.secure_lgbm.secure_lgbm_training_engine import SecureLGBMTrainingEngine + +app = Flask(__name__) + + +def initialize_app(app): + # 初始化应用功能组件 + components.init_all() + + app.config.update(components.config_data) + blueprint = Blueprint('api', __name__, url_prefix='/api') + api.init_app(blueprint) + api.add_namespace(task_namespace) + api.add_namespace(log_namespace) + app.register_blueprint(blueprint) + + +def register_task_handler(): + task_manager = components.task_manager + task_manager.register_task_handler( + ModelTask.PREPROCESSING, PreprocessingEngine.run) + task_manager.register_task_handler( + ModelTask.FEATURE_ENGINEERING, FeatureEngineeringEngine.run) + task_manager.register_task_handler( + ModelTask.XGB_TRAINING, SecureLGBMTrainingEngine.run) + task_manager.register_task_handler( + ModelTask.XGB_PREDICTING, SecureLGBMPredictionEngine.run) + + +def model_serve(): + if app.config['SSL_SWITCH'] == 0: + ppc_serve = grpc.server(futures.ThreadPoolExecutor(max_workers=max(1, os.cpu_count() - 1)), + options=components.grpc_options) + ppc_model_pb2_grpc.add_ModelServiceServicer_to_server(ModelService(), ppc_serve) + address = "[::]:{}".format(app.config['RPC_PORT']) + ppc_serve.add_insecure_port(address) + else: + grpc_root_crt = utils.load_credential_from_file( + os.path.abspath(app.config['SSL_CA'])) + grpc_ssl_key = utils.load_credential_from_file( + os.path.abspath(app.config['SSL_KEY'])) + grpc_ssl_crt = utils.load_credential_from_file( + os.path.abspath(app.config['SSL_CRT'])) + server_credentials = grpc.ssl_server_credentials((( + grpc_ssl_key, + grpc_ssl_crt, + ),), grpc_root_crt, True) + + ppc_serve = grpc.server(futures.ThreadPoolExecutor(max_workers=max(1, os.cpu_count() - 1)), + options=components.grpc_options) + ppc_model_pb2_grpc.add_ModelServiceServicer_to_server(ModelService(), ppc_serve) + address = "[::]:{}".format(app.config['RPC_PORT']) + ppc_serve.add_secure_port(address, server_credentials) + + ppc_serve.start() + components.logger().info( + f"Starting model grpc server at ://{app.config['HOST']}:{app.config['RPC_PORT']}") + ppc_serve.wait_for_termination() + + +if __name__ == '__main__': + initialize_app(app) + register_task_handler() + + # 启动子进程不继承父进程的锁状态,防止死锁 + multiprocessing.set_start_method('spawn') + + Thread(target=model_serve).start() + + app.config['SECRET_KEY'] = os.urandom(24) + server = WSGIServer((app.config['HOST'], app.config['HTTP_PORT']), + TransLogger(app, setup_console_handler=False), numthreads=2) + + ssl_switch = app.config['SSL_SWITCH'] + protocol = 'http' + if ssl_switch == 1: + protocol = 'https' + server.ssl_adapter = BuiltinSSLAdapter( + certificate=app.config['SSL_CRT'], + private_key=app.config['SSL_KEY'], + certificate_chain=app.config['CA_CRT']) + + message = f"Starting ppc model server at {protocol}://{app.config['HOST']}:{app.config['HTTP_PORT']}" + print(message) + components.logger().info(message) + server.start() diff --git a/python/ppc_model/preprocessing/__init__.py b/python/ppc_model/preprocessing/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/ppc_model/preprocessing/local_processing/local_processing_party.py b/python/ppc_model/preprocessing/local_processing/local_processing_party.py new file mode 100644 index 00000000..6cc13675 --- /dev/null +++ b/python/ppc_model/preprocessing/local_processing/local_processing_party.py @@ -0,0 +1,100 @@ +import os +import time +from abc import ABC + +import pandas as pd + +from ppc_common.ppc_utils import utils +from ppc_model.preprocessing.local_processing.preprocessing import process_dataframe +from ppc_model.preprocessing.processing_context import ProcessingContext + + +class LocalProcessingParty(ABC): + + def __init__(self, ctx: ProcessingContext): + self.ctx = ctx + + def processing(self): + log = self.ctx.components.logger() + start = time.time() + need_psi = self.ctx.need_run_psi + job_id = self.ctx.job_id + log.info( + f"run data preprocessing job, job: {job_id}, need_psi: {need_psi}") + dataset_path = self.ctx.dataset_path + dataset_file_path = self.ctx.dataset_file_path + storage_client = self.ctx.components.storage_client + job_algorithm_type = self.ctx.job_algorithm_type + if job_algorithm_type == utils.AlgorithmType.Predict.name: + storage_client.download_file(os.path.join(self.ctx.training_job_id, self.ctx.PREPROCESSING_RESULT_FILE), + self.ctx.preprocessing_result_file) + psi_result_path = self.ctx.psi_result_path + model_prepare_file = self.ctx.model_prepare_file + storage_client.download_file(dataset_path, dataset_file_path) + if need_psi and (not utils.file_exists(psi_result_path)): + storage_client.download_file( + self.ctx.remote_psi_result_path, psi_result_path) + self.handle_local_psi_result(psi_result_path) + log.info( + f"prepare_xgb_after_psi, make_dataset_to_xgb_data_plus_psi_data, dataset_file_path={dataset_file_path}, " + f"psi_result_path={dataset_file_path}, model_prepare_file={model_prepare_file}") + self.make_dataset_to_xgb_data() + storage_client.upload_file( + model_prepare_file, job_id + os.sep + self.ctx.model_prepare_file) + log.info(f"upload model_prepare_file to hdfs, job_id={job_id}") + if job_algorithm_type == utils.AlgorithmType.Train.name: + log.info(f"upload column_info to hdfs, job_id={job_id}") + storage_client.upload_file(self.ctx.preprocessing_result_file, + job_id + os.sep + self.ctx.PREPROCESSING_RESULT_FILE) + log.info( + f"call prepare_xgb_after_psi success, job_id={job_id}, timecost: {time.time() - start}") + + def handle_local_psi_result(self, local_psi_result_path): + try: + log = self.ctx.components.logger() + log.info( + f"handle_local_psi_result: start handle_local_psi_result, psi_result_path={local_psi_result_path}") + with open(local_psi_result_path, 'r+', encoding='utf-8') as psi_result_file: + content = psi_result_file.read() + psi_result_file.seek(0, 0) + psi_result_file.write('id\n' + content) + log.info( + f"handle_local_psi_result: call handle_local_psi_result success, psi_result_path={local_psi_result_path}") + except BaseException as e: + log.exception( + f"handle_local_psi_result: handle_local_psi_result, psi_result_path={local_psi_result_path}, error:{e}") + raise e + + def make_dataset_to_xgb_data(self): + log = self.ctx.components.logger() + dataset_file_path = self.ctx.dataset_file_path + psi_result_file_path = self.ctx.psi_result_path + model_prepare_file = self.ctx.model_prepare_file + log.info(f"dataset_file_path:{dataset_file_path}") + log.info(f"model_prepare_file:{model_prepare_file}") + need_run_psi = self.ctx.need_run_psi + job_id = self.ctx.job_id + if not utils.file_exists(dataset_file_path): + raise FileNotFoundError( + f"dataset_file_path not found: {dataset_file_path}") + dataset_df = pd.read_csv(dataset_file_path) + if need_run_psi: + log.info(f"psi_result_file_path:{psi_result_file_path}") + psi_data = pd.read_csv(psi_result_file_path, + delimiter=utils.CSV_SEP) + dataset_df = pd.merge(dataset_df, psi_data, on=[ + 'id']).sort_values(by='id', ascending=True) + + ppc_job_type = self.ctx.job_algorithm_type + column_info = process_dataframe( + dataset_df, self.ctx.model_setting, model_prepare_file, ppc_job_type, job_id, self.ctx) + + column_info_pd = pd.DataFrame(column_info).transpose() + # 如果是训练任务先写本地 + log.info(f"jobid {job_id}, job_algorithm_type {ppc_job_type}") + if ppc_job_type == utils.AlgorithmType.Train.name: + log.info( + f"write {column_info} to {self.ctx.preprocessing_result_file}") + column_info_pd.to_csv( + self.ctx.preprocessing_result_file, sep=utils.CSV_SEP, header=True) + log.info("finish make_dataset_to_xgb_data_plus_psi_data") diff --git a/python/ppc_model/preprocessing/local_processing/preprocessing.py b/python/ppc_model/preprocessing/local_processing/preprocessing.py new file mode 100644 index 00000000..55082a62 --- /dev/null +++ b/python/ppc_model/preprocessing/local_processing/preprocessing.py @@ -0,0 +1,660 @@ +# -*- coding: utf-8 -*- +# from concurrent.futures import ProcessPoolExecutor, as_completed + +import json +import os + +import numpy as np +import pandas as pd +from sklearn.preprocessing import MinMaxScaler, StandardScaler +from ppc_model.common.model_setting import ModelSetting + +from ppc_common.ppc_utils import utils +from ppc_common.ppc_utils.exception import PpcException, PpcErrorCode +from ppc_model.common.global_context import components +from ppc_model.preprocessing.local_processing.psi_select import calculate_psi +from ppc_model.preprocessing.local_processing.standard_type_enum import standardType +# from ppc_common.ppc_mock.mock_objects import MockLogger + +# components.mock_logger = MockLogger() +# log = components.mock_logger +log = components.logger() + + +def process_train_dataframe(dataset_df: pd.DataFrame, column_info_dict: dict): + """ + 使用column_info对dataset_df进行处理 只保留column_info中的列 + + 参数: + - dataset_df: 待处理的DataFrame数据 + - column_info: 字段信息, 字典类型 + + 返回值: + - dataset_df_filled: 处理后的DataFrame数据 + """ + # dataset_df_filled = None + # Iterate over column_info_dict keys, if 'isExisted' is False, drop the column + for key, value in column_info_dict.items(): + if value.get('isExisted') is None: + raise PpcException(PpcErrorCode.XGB_PREPROCESSING_ERROR.get_code( + ), "column_info_dict isExisted is None") + if value.get('isExisted') is False: + dataset_df = dataset_df.drop(key, axis=1) + + return dataset_df + + +def process_dataframe(dataset_df: pd.DataFrame, model_setting: ModelSetting, xgb_data_file_path: str, + ppc_job_type: str = utils.AlgorithmType.Train.name, job_id: str = None, ctx=None): + """ + 对数据集进行预处理的函数。 + 共执行6步操作 + 1. 去除唯一属性: 为了让无意义的参数不影响模型, 比如id, id不代表样本本身的规律, 所以要把这些属性删掉 + 2. 缺失值处理: 删除含有缺失值的特征,若变量的缺失率较高(大于80%),覆盖率较低,且重要性较低,可以直接将变量删除. + 每一行的数据不一定拥有所有的模型标签,支持不填充,和均值插补, + 我们目前使用均值插补,以该属性存在值的平均值来插补缺失的值 + 连续性:普通均值 + 类别:单独类别,当做新类别 + 3. 离群值处理方法: 数据过大或过小会一个峡谷分析结果,要先调整因子值的离群值上下限,减少离群值的影响,防止下一步归一化后的偏差过大 + 我们用的3 \sigma 法,又叫做标准差法 + 4. 特征编码: 将特征编码为固定的值 比如one-hot就是将不同的type 编码为1 2 3 4 5...这样的值 + 5. 数据标准化: min-max标准化(归一化):最大值1. 最小值0或-1 以及z-score标准化 规范化:均值为0 标准差为1 + 6. 特征选择: 从给定的特征集合中选出相关特征子集的过程称为特征选择, 我们使用的是PSI 风控模型,群体稳定性指标 PSI-Population Stability Index + 参考 https://zhuanlan.zhihu.com/p/79682292 + + + 参数: + dataset_df (pandas.DataFrame): 输入的数据集。 + xgb_data_file_path (str): XGBoost数据文件路径。 + + 返回: + column_info: 处理后的数据集字段。 + """ + log.info( + f"jobid: {job_id}, xgb_data_file_path:{xgb_data_file_path}, ppc_job_type: {ppc_job_type}") + if model_setting is None: + log.error("model_setting is None") + raise PpcException( + PpcErrorCode.XGB_PREPROCESSING_ERROR.get_code(), "model_setting is None") + + column_info = {} + + if ppc_job_type == utils.AlgorithmType.Predict.name: + column_info_fm = pd.read_csv( + ctx.preprocessing_result_file, index_col=0) + column_info_train_str = json.dumps( + column_info_fm.to_dict(orient='index')) + if column_info_train_str is None: + raise PpcException(-1, "column_info_train is None") + try: + # 对应orient='records' + # column_info_train = json.loads(column_info_train_str, orient='records') + column_info_train = json.loads(column_info_train_str) + except Exception as e: + log.error( + f"jobid: {job_id} column_info_train json.loads error, e:{e}") + raise PpcException(-1, "column_info_train json.loads error") + dataset_df = process_train_dataframe(dataset_df, column_info_train) + column_info = column_info_train + elif ppc_job_type == utils.AlgorithmType.Train.name: + # 如果是训练任务 先默认所有数据都存在 + column_info = {col: {'isExisted': True} for col in dataset_df.columns} + + if model_setting.eval_set_column is not None: + if model_setting.eval_set_column in dataset_df.columns: + eval_column = model_setting.eval_set_column + dataset_df[['id', eval_column]].to_csv(ctx.eval_column_file, index=None) + ctx.components.storage_client.upload_file(ctx.eval_column_file, job_id + os.sep + ctx.EVAL_COLUMN_FILE) + if model_setting.eval_set_column != model_setting.psi_select_col: + dataset_df = dataset_df.drop(columns=[eval_column]) + + categorical_cols = ['id', 'y'] + + # 判断 model_setting['categorical']是否为None + if model_setting.categorical is None: + log.error( + f"jobid: {job_id} model_setting['categorical'] is None, model_setting:{model_setting}") + raise PpcException(PpcErrorCode.XGB_PREPROCESSING_ERROR.get_code( + ), "xgb_model_dict['categorical'] is None") + + if model_setting.fillna is None: + log.error( + f"jobid: {job_id} xgb_model_dict['fillna'] is None, model_setting:{model_setting}") + raise PpcException(PpcErrorCode.XGB_PREPROCESSING_ERROR.get_code( + ), "xgb_model_dict['fillna'] is None") + + # 指定分类特征索引 + if model_setting.categorical != '0': + categoricals = model_setting.categorical.split(',') + categorical_cols.extend(categoricals) + # 去除categorical_cols中的重复元素 + categorical_cols = list(set(categorical_cols)) + + df_filled = dataset_df + # 预处理表格信息 包含每一列缺失值比例 缺失值是否被筛掉 psi筛选和相关性筛选 如果在某个阶段被筛掉则设置为0 保留则为1 + # 1.去除唯一属性 + if 'id' in df_filled.columns: + log.info(f"jobid: {job_id} move id column start.") + df_filled = df_filled.drop('id', axis=1) + log.info(f"jobid: {job_id} move id column finish.") + + # 2.1 缺失值筛选 + if ppc_job_type == utils.AlgorithmType.Train.name: + if 0 <= model_setting.na_select <= 1: + log.info(f"jobid: {job_id} run fillna start") + df_filled, column_info = process_na_dataframe( + df_filled, model_setting.na_select) + log.info(f"jobid: {job_id} run fillna finish") + else: + log.error( + f"jobid: {job_id} xgb_model_dict['na_select'] is range not 0 to 1, xgb_model_dict:{model_setting}") + raise PpcException(PpcErrorCode.XGB_PREPROCESSING_ERROR.get_code(), + "xgb_model_dict['na_select'] range not 0 to 1") + elif ppc_job_type == utils.AlgorithmType.Predict.name: + log.info(f"jobid: {job_id} don't need run fillna for predict job.") + else: + log.error( + f"jobid: {job_id} ppc_job_type is not Train or Predict, ppc_job_type:{ppc_job_type}") + raise PpcException(PpcErrorCode.XGB_PREPROCESSING_ERROR.get_code(), + "ppc_job_type is not Train or Predict") + # 2.2 缺失值处理 + if model_setting.fillna == 1: + # 填充 + log.info(f"jobid: {job_id} run fillna with means start") + try: + df_filled = process_na_fill_dataframe( + df_filled, categorical_cols, model_setting.psi_select_col) + log.info(f"jobid: {job_id} run fillna with means finish") + except Exception as e: + log.error( + f"jobid: {job_id} process_na_fill_dataframe error, e:{e}") + raise PpcException(PpcErrorCode.XGB_PREPROCESSING_ERROR.get_code( + ), "process_na_fill_dataframe error") + elif model_setting.fillna == 0: + # 不填充 + log.info(f"jobid: {job_id} don't need run fillna ") + # 如果本身是None就不需要处理 + df_filled.replace('None', np.nan) + else: + log.error( + f"jobid: {job_id} xgb_model_dict['fillna'] is not 0 or 1, model_setting:{model_setting}") + raise PpcException(PpcErrorCode.XGB_PREPROCESSING_ERROR.get_code( + ), "xgb_model_dict['fillna'] is not 0 or 1") + + # 6.1 特征选择 进行 psi稳定性指标筛选 计算特征相关性 降维可以减少模型的复杂度,提高模型的泛化能力 + if ppc_job_type == utils.AlgorithmType.Train.name: + if model_setting.psi_select_col in df_filled.columns.tolist() and model_setting.psi_select_col != 0: + log.info(f"jobid: {job_id} run psi_select_col start") + psi_select_base = model_setting.psi_select_base + psi_select_thresh = model_setting.psi_select_thresh + psi_select_bins = model_setting.psi_select_bins + psi_select_col = model_setting.psi_select_col + + if psi_select_base is None or psi_select_thresh is None or psi_select_bins is None: + log.error( + f"jobid: {job_id} psi_select_base or psi_select_thresh or psi_select_bins is None, model_setting:{model_setting}") + raise PpcException(PpcErrorCode.XGB_PREPROCESSING_ERROR.get_code(), + "psi_select_base or psi_select_thresh or psi_select_bins is None") + df_filled, psi_selected_cols = process_psi(df_filled, categorical_cols, psi_select_col, psi_select_base, + psi_select_thresh, psi_select_bins) + # 使用column_info和psi_selected_cols 追加psi_selected列 如果列还在则选中是1 否则是0 + for col in column_info.keys(): + if col in psi_selected_cols: + column_info[col]['psi_selected'] = 1 + column_info[col]['isExisted'] = True + else: + column_info[col]['psi_selected'] = 0 + column_info[col]['isExisted'] = False + log.info(f"jobid: {job_id} run psi_select_col finish") + elif model_setting.psi_select_col == 0 or model_setting.psi_select_col == "": + log.info(f"jobid: {job_id} don't need run psi_select_col") + else: + log.error( + f"jobid: {job_id} xgb_model_dict['psi_select_col'] is not 0 or in col, model_setting:{model_setting}") + raise PpcException(PpcErrorCode.XGB_PREPROCESSING_ERROR.get_code(), + "xgb_model_dict['psi_select_col'] is not 0 or in col") + elif ppc_job_type == utils.AlgorithmType.Predict.name: + log.info( + f"jobid: {job_id} don't need run psi_select_col for predict job.") + else: + log.error( + f"jobid: {job_id} ppc_job_type is not Train or Predict, ppc_job_type:{ppc_job_type}") + raise PpcException(PpcErrorCode.XGB_PREPROCESSING_ERROR.get_code(), + "ppc_job_type is not Train or Predict") + + # 6.2 特征选择 进行 corr_select 计算特征相关性 + if ppc_job_type == utils.AlgorithmType.Train.name: + if model_setting.corr_select > 0: + log.info(f"jobid: {job_id} run corr_select start") + corr_select = model_setting.corr_select + df_filled = remove_high_correlation_features( + df_filled, categorical_cols, corr_select) + # 设置相关性筛选的column_info + for col in column_info.keys(): + if col in df_filled.columns.tolist(): + column_info[col]['corr_selected'] = 1 + column_info[col]['isExisted'] = True + else: + column_info[col]['corr_selected'] = 0 + column_info[col]['isExisted'] = False + log.info(f"jobid: {job_id} run corr_select finish") + elif model_setting.corr_select == 0: + log.info(f"jobid: {job_id} don't need run corr_select") + else: + log.error( + f"jobid: {job_id} xgb_model_dict['corr_select'] is not >= 0, model_setting:{model_setting}") + raise PpcException(PpcErrorCode.XGB_PREPROCESSING_ERROR.get_code(), + "xgb_model_dict['corr_select'] is not >= 0") + elif ppc_job_type == utils.AlgorithmType.Predict.name: + log.info( + f"jobid: {job_id} don't need run corr_select for predict job.") + else: + log.error( + f"jobid: {job_id} ppc_job_type is not Train or Predict, ppc_job_type:{ppc_job_type}") + raise PpcException(PpcErrorCode.XGB_PREPROCESSING_ERROR.get_code(), + "ppc_job_type is not Train or Predict") + + # 3. 离群值处理 3-sigma 法 + if model_setting.filloutlier == 1: + log.info(f"jobid: {job_id} run filloutlier start") + df_filled = process_outliers(df_filled, categorical_cols) + log.info(f"jobid: {job_id} run filloutlier finish") + elif model_setting.filloutlier == 0: + log.info(f"jobid: {job_id} don't need run filloutlier") + else: + log.error( + f"jobid: {job_id} xgb_model_dict['filloutlier'] is not 0 or 1, model_setting:{model_setting}") + raise PpcException(PpcErrorCode.XGB_PREPROCESSING_ERROR.get_code(), + "xgb_model_dict['filloutlier'] is not 0 or 1") + + # 5.1 数据标准化 支持max-min normalized + if model_setting.normalized == 1: + log.info(f"jobid: {job_id} run normalized start") + df_filled = normalize_dataframe( + df_filled, categorical_cols, standardType.min_max.value) + log.info(f"jobid: {job_id} run normalized finish") + elif model_setting.normalized == 0: + log.info(f"jobid: {job_id} don't need run normalized") + else: + log.error( + f"jobid: {job_id} xgb_model_dict['normalized'] is not 0 or 1, model_setting:{model_setting}") + raise PpcException(PpcErrorCode.XGB_PREPROCESSING_ERROR.get_code(), + "xgb_model_dict['normalized'] is not 0 or 1") + + # 5.2 z-score标准化(规范化) standardized + if model_setting.standardized == 1: + log.info(f"jobid: {job_id} run standardized start") + df_filled = normalize_dataframe( + df_filled, categorical_cols, standardType.z_score.value) + log.info(f"jobid: {job_id} run standardized finish") + elif model_setting.standardized == 0: + log.info(f"jobid: {job_id} don't need run standardized") + else: + log.error( + f"jobid: {job_id} xgb_model_dict['standardized'] is not 0 or 1, model_setting:{model_setting}") + raise PpcException(PpcErrorCode.XGB_PREPROCESSING_ERROR.get_code(), + "xgb_model_dict['standardized'] is not 0 or 1") + + # 4. 特征编码,对分类特征进行one-hot编码 这里会多一些列 所以放到最后 + if model_setting.one_hot == 1: + log.info(f"jobid: {job_id} run one_hot start") + df_filled = one_hot_encode_and_merge(df_filled, categorical_cols) + log.info(f"jobid: {job_id} run one_hot finish") + elif model_setting.one_hot == 0: + log.info(f"jobid: {job_id} don't need run one_hot") + else: + log.error( + f"jobid: {job_id} model_setting['one_hot'] is not 0 or 1, xgb_model_dict:{model_setting}") + raise PpcException(PpcErrorCode.XGB_PREPROCESSING_ERROR.get_code( + ), "model_setting['one_hot'] is not 0 or 1") + df_filled.to_csv(xgb_data_file_path, mode='w', + sep=utils.BLANK_SEP, header=True, index=None) + # log.info(f"jobid: {job_id} column_info:{column_info} in type: {ppc_job_type}, process_dataframe succeed") + log.info( + f"jobid: {job_id} in type: {ppc_job_type}, process_dataframe succeed") + return column_info + + +def process_na_dataframe(df: pd.DataFrame, na_select: float): + """ + 缺失值处理 如果小于阈值则移除该列 + + 参数 + - df: 待处理的DataFrame数据 + - na_select: 缺失值占比阈值 + + 返回值: + 处理后的DataFrame数据 + """ + missing_ratios = df.isnull().mean() + + # 剔除缺失值占比大于na_select的列 + selected_cols = missing_ratios[missing_ratios <= na_select].index + + column_info = {col: {'missing_ratio': missing_ratios[col], 'na_selected': 1 if col in selected_cols else 0, } for + col in df.columns} + # 如果没被选择 将isExisted设置为false + for col in column_info.keys(): + if col not in selected_cols: + column_info[col]['isExisted'] = False + else: + column_info[col]['isExisted'] = True + df_processed = df[selected_cols] + return df_processed, column_info + + +def process_na_fill_dataframe(df: pd.DataFrame, categorical_cols: list = None, psi_select_col: str = None): + """ + 处理DataFrame数据: + 1. 计算每一列的缺失值占比; + 2. 剔除缺失值占比大于阈值的列; + 3. 使用每列的均值填充剩余的缺失值。 + + 参数: + - df: 待处理的DataFrame数据 + - na_select: 缺失值占比阈值 + - categorical_cols: 分类特征列 + + + 返回值: + 处理后的DataFrame数据 + """ + # 计算每一列的缺失值占比 + + # 判断categorical_cols是否在df中,如果是,填充则用col的max+1, 否则用均值插补 + df_processed = df.copy() # Assign the sliced DataFrame to a new variable + for col in df_processed.columns.to_list(): + # 如果col是y,则忽略 + if col == 'y': + continue + elif psi_select_col and col == psi_select_col: + continue + if col in categorical_cols: + df_processed.fillna( + {col: df_processed[col].max() + 1}, inplace=True) + else: + df_processed.fillna({col: df_processed[col].mean()}, inplace=True) + return df_processed + + +def process_outliers(df: pd.DataFrame, categorical_cols: list): + """ + 处理DataFrame数据中的异常值 + 1. 计算每一列的均值和标准差; + 2. 对于超出均值+-3倍标准差的数据 使用均值填充。 + + 参数: + - df: 待处理的DataFrame数据 + - categorical_cols: 列表,包含分类列的名称 + + 返回值: + 处理后的DataFrame数据 + """ + # 计算每一列的均值和标准差 + means = df.mean() + threshold = 3 * df.std() + + # 定义处理异常值的函数 + def replace_outliers(col): + # 如果列是分类列,不处理 + if col.name in categorical_cols: + return col + else: + lower_bound = means[col.name] - threshold[col.name] + upper_bound = means[col.name] + threshold[col.name] + # 如果元素是空值或不是异常值,保持不变;否则使用均值填充 + return np.where(col.notna() & ((col < lower_bound) | (col > upper_bound)), means[col.name], col) + + # 应用处理异常值的函数到DataFrame的每一列 + df_processed = df.apply(replace_outliers) + return df_processed + + +def one_hot_encode_and_merge(df_filled: pd.DataFrame, categorical_features: list): + """ + 对DataFrame中指定的分类特征列进行One-Hot编码 并将编码后的结果合并到DataFrame中。 + + 参数: + - df: 待处理的DataFrame数据 + - categorical_features: 分类特征列表 需要进行One-Hot编码的列名 + + 返回值: + 处理后的DataFrame数据 + """ + categorical_cols_without_y_and_id = categorical_features.copy() + if 'id' in categorical_cols_without_y_and_id: + categorical_cols_without_y_and_id.remove('id') + if 'y' in categorical_cols_without_y_and_id: + categorical_cols_without_y_and_id.remove('y') + # 如果categorical_cols_without_y_and_id不在df_filled的col name中 报错 + if not set(categorical_cols_without_y_and_id).issubset(set(df_filled.columns)): + log.error( + f"categorical_cols_without_y_and_id is not in df_filled columns, categorical_cols_without_y_and_id:{categorical_cols_without_y_and_id}") + raise PpcException(PpcErrorCode.XGB_PREPROCESSING_ERROR.get_code(), + "categorical_cols_without_y_and_id is not in df_filled columns") + df_merged = pd.get_dummies( + df_filled, columns=categorical_cols_without_y_and_id) + + return df_merged + + +def normalize_dataframe(df: pd.DataFrame, categorical_cols: list, type: standardType = standardType.min_max.value): + """ + 对DataFrame进行标准化,排除指定的分类特征 + + 参数: + - df: 待标准化的DataFrame数据 + - categorical_cols: 列表,包含不需要标准化的分类特征的名称 + + 返回值: + 标准化后的DataFrame数据 + """ + # 创建MinMaxScaler对象 + if type == standardType.min_max.value: + scaler = MinMaxScaler() + elif type == standardType.z_score.value: + scaler = StandardScaler() + else: + log.error(f"unspport type in normalize_dataframe type:{type}") + raise PpcException(PpcErrorCode.XGB_PREPROCESSING_ERROR.get_code( + ), "unspport type in normalize_dataframe type") + + # 获取数值变量的列名 + numeric_cols = df.select_dtypes(include=['int', 'float']).columns + + # 排除指定的分类特征 + numeric_cols = [col for col in numeric_cols if col not in categorical_cols] + + # 对数值变量进行标准化 + df_normalized = df.copy() + df_normalized[numeric_cols] = scaler.fit_transform( + df_normalized[numeric_cols]) + return df_normalized + +# def calculate_correlation(df_i_col_name, df_j_col_name, i, j, col_i, col_j, categorical_cols): +# if df_i_col_name in categorical_cols or df_j_col_name in categorical_cols: +# return None # 返回NaN表示忽略这个计算 +# common_index = col_i.index.intersection(col_j.index) +# return i,j,np.corrcoef(col_i.loc[common_index], col_j.loc[common_index])[0, 1] + + +def remove_high_correlation_features(df: pd.DataFrame, categorical_cols: list, corr_select: float): + """ + 删除DataFrame中相关系数大于corr_select的特征中的一个。 + + 参数: + - df: 待处理的DataFrame数据 + - categorical_cols: 列表,包含分类特征的名称 + - corr_select: 相关系数阈值 + + 返回值: + 处理后的DataFrame数据 + """ + # ===========原有逻辑============== + # 计算特征之间的相关系数矩阵 + # num_features = df.shape[1] + # correlation_matrix = np.zeros((num_features, num_features)) + + # for i in range(num_features): + # for j in range(i+1, num_features): + # # if i == j: + # # continue + # # 忽略categorical_cols的列 + # if df.columns[i] in categorical_cols or df.columns[j] in categorical_cols: + # continue + # # 当有缺失值时 去除该行对应的缺失值进行比较, 比如col1有缺失值, col2没有, 那么col1的缺失值对应的位置,和col2一起去掉,共同参与计算 + # col_i = df.iloc[:, i].dropna() + # col_j = df.iloc[:, j].dropna() + # common_index = col_i.index.intersection(col_j.index) + # correlation_matrix[i, j] = np.corrcoef(col_i.loc[common_index], col_j.loc[common_index])[0, 1] + # # print(f"correlation_matrix: {correlation_matrix}") + # # correlation_matrix[i, j] = np.corrcoef(df.iloc[:, i], df.iloc[:, j])[0, 1] + + # # 获取相关系数大于corr_select的特征对, 获取列名 + # high_correlation = np.argwhere(np.abs(correlation_matrix) > corr_select) + + # high_correlation_col_name = [] + # for i, j in high_correlation: + # high_correlation_col_name.append((df.columns[i], df.columns[j])) + # # 删除相关性大于corr_select的特征中的一个 + # for col_left, col_right in high_correlation_col_name: + # try: + # # 如果col_left和col_right都在df中, 删除右边, 否则不动 + # if col_left in df.columns and col_right in df.columns: + # df.drop(col_right, axis=1, inplace=True) + # # df.drop(df.columns[j], axis=1, inplace=True) + # except: + # log.warning( + # f"remove_high_correlation_features error, i:{i}, j:{j}, df[col_left]:{df[col_left]}, df[col_right]:{df[col_right]}") + # pass + # return df + # ===========原有逻辑============== + num_features = df.shape[1] + correlation_matrix = np.zeros((num_features, num_features)) + + df_copy = df.copy() + for col_name in categorical_cols: + if col_name in df_copy.columns: + df_copy.drop(col_name, axis=1, inplace=True) + correlation_matrix = df_copy.corr() + # ===========尝试已提交多线程============== + # 提前设置好需要的列 + # col_list = {} + # for i in range(num_features): + # for j in range(i+1, num_features): + # col_list[i,j]=df.iloc[:, i].dropna().copy(), df.iloc[:, j].dropna().copy() + # print(correlation_matrix) + # for i in range(num_features): + # for j in range(i+1, num_features): + # result = calculate_correlation(df.columns[i], df.columns[j], i, j, df.iloc[:, i], df.iloc[:, j], categorical_cols) + # if result != None: + # i,j, correlation_matrix_i_j = result + # correlation_matrix[i,j] = correlation_matrix_i_j + # with ProcessPoolExecutor() as executor: + # col_i, col_j = col_list[i,j] + # results = [executor.submit(calculate_correlation, df.columns[i], df.columns[j], i, j, col_i, col_j, categorical_cols) for j in range(i + 1, num_features)] + # for future in as_completed(results): + # result = future.result() + # # print(f"result: {result}, {type(result)}") + # if result != None: + # i,j, correlation_matrix_i_j = result + # correlation_matrix[i,j] = correlation_matrix_i_j + # print(f"correlation_matrix: {correlation_matrix}") + # ===========尝试已提交多线程============== + high_correlation = np.argwhere(np.abs(correlation_matrix) > corr_select) + + high_correlation_col_name = [] + for i, j in high_correlation: + if i >= j: + continue + high_correlation_col_name.append( + (df_copy.columns[i], df_copy.columns[j])) + # 删除相关性大于corr_select的特征中的一个 + for col_left, col_right in high_correlation_col_name: + try: + # 如果col_left和col_right都在df中, 删除右边, 否则不动 + if col_left in df.columns and col_right in df.columns: + df.drop(col_right, axis=1, inplace=True) + # df.drop(df.columns[j], axis=1, inplace=True) + except: + log.warn( + f"remove_high_correlation_features error, i:{i}, j:{j}, df[col_left]:{df[col_left]}, df[col_right]:{df[col_right]}") + pass + return df + + +def process_psi(df_filled: pd.DataFrame, categorical_cols: list, psi_select_col: str, psi_select_base: int, + psi_select_thresh: float, psi_select_bins: int): + """ + Preprocesses the data by calculating the Population Stability Index (PSI) for a given column. + + Args: + df_filled (pd.DataFrame): The input DataFrame with missing values filled. + categorical_cols (list): A list of column names that are categorical variables. + psi_select_col (str): The name of the column for which PSI is calculated. + psi_select_base (int): The base period for calculating PSI. + psi_select_thresh (float): The threshold value for PSI. Columns with PSI above this value will be selected. + psi_select_bins (int): The number of bins to use for calculating PSI. + + Returns: + df_filled (pd.DataFrame): The input DataFrame with the PSI values calculated. + """ + # TODO: 需要验证计算psi的正确性 当为空值时 公式中没有讲如何处理 这里先按照github上的代码处理 + + # 最终所有保留的列 + psi_selected_cols = [] + + for col_select in df_filled.columns: + if col_select == psi_select_col: + # 如果是psi选择的列 则一定不保留 + continue + elif col_select in categorical_cols: + # 如果是特征列 则一定保留 + psi_selected_cols.append(col_select) + else: + # 如果是数值列 则计算psi 判断是否保留 + # 先提出base和select的数据 + psi_select_col_base_value = df_filled[col_select][df_filled[psi_select_col] + == psi_select_base].values + max_psi = 0 + for col_base in set(df_filled[psi_select_col]): + if col_base == psi_select_base: + continue + else: + col_value = df_filled[col_select][df_filled[psi_select_col] + == col_base].values + + col_psi = calculate_psi(psi_select_col_base_value, col_value, buckettype='quantiles', + buckets=psi_select_bins, axis=1) + # 如果psi值大于阈值 则保留 + if col_psi > max_psi: + max_psi = col_psi + # 如果小于阈值则保留 + if max_psi < psi_select_thresh: + psi_selected_cols.append(col_select) + + df_filled = df_filled[psi_selected_cols] + log.info(f"process_psi psi_selected_cols:{psi_selected_cols}") + return df_filled, psi_selected_cols + + +def union_column_info(column_info1: pd.DataFrame, column_info2: pd.DataFrame): + """ + union the column_info1 with the column_info2. + + Args: + column_info1 (DataFrame): The column_info1 to be merged. + column_info2 (DataFrame): The column_info2 to be merged. + + Returns: + column_info_merge (DataFrame): The union column_info. + """ + # 将column_info1和column_info2按照left_index=True, right_index=True的方式进行合并 如果列有缺失则赋值为None 行的顺序按照column_info1 + column_info_conbine = column_info1.merge( + column_info2, how='outer', left_index=True, right_index=True, sort=False) + col1_index_list = column_info1.index.to_list() + col2_index_list = column_info2.index.to_list() + merged_list = col1_index_list + \ + [item for item in col2_index_list if item not in col1_index_list] + column_info_conbine = column_info_conbine.reindex(merged_list) + return column_info_conbine diff --git a/python/ppc_model/preprocessing/local_processing/psi_select.py b/python/ppc_model/preprocessing/local_processing/psi_select.py new file mode 100644 index 00000000..5bd1223e --- /dev/null +++ b/python/ppc_model/preprocessing/local_processing/psi_select.py @@ -0,0 +1,95 @@ +import numpy as np + + +def scale_range(input, min, max): + input += -(np.min(input)) + input /= np.max(input) / (max - min) + input += min + return input + + +def sub_psi(e_perc, a_perc): + """ + Calculate the actual PSI value from comparing the values. + Update the actual value to a very small number if equal to zero + """ + if a_perc == 0: + a_perc = 0.0001 + if e_perc == 0: + e_perc = 0.0001 + + value = (e_perc - a_perc) * np.log(e_perc / a_perc) + return value + + +def calculate_psi(expected, actual, buckettype='bins', buckets=10, axis=0): + """ + Calculate the PSI (population stability index) across all variables + + Args: + expected: numpy matrix of original values + actual: numpy matrix of new values + buckettype: type of strategy for creating buckets, bins splits into even splits, quantiles splits into quantile buckets + buckets: number of quantiles to use in bucketing variables + axis: axis by which variables are defined, 0 for vertical, 1 for horizontal + + Returns: + psi_values: ndarray of psi values for each variable + + Author: + Matthew Burke + github.com/mwburke + mwburke.github.io.com + """ + + def psi(expected_array, actual_array, buckets): + """ + Calculate the PSI for a single variable + + Args: + expected_array: numpy array of original values + actual_array: numpy array of new values, same size as expected + buckets: number of percentile ranges to bucket the values into + + Returns: + psi_value: calculated PSI value + """ + + breakpoints = np.arange(0, buckets + 1) / (buckets) * 100 + + if buckettype == 'bins': + breakpoints = scale_range(breakpoints, np.min( + expected_array), np.max(expected_array)) + elif buckettype == 'quantiles': + breakpoints = np.stack( + [np.percentile(expected_array, b) for b in breakpoints]) + + expected_fractions = np.histogram(expected_array, breakpoints)[ + 0] / len(expected_array) + actual_fractions = np.histogram(actual_array, breakpoints)[ + 0] / len(actual_array) + psi_value = sum(sub_psi(expected_fractions[i], actual_fractions[i]) for i in range( + 0, len(expected_fractions))) + + return psi_value + + if len(actual.shape) == 1: + psi_values = np.empty(len(actual.shape)) + else: + psi_values = np.empty(actual.shape[1 - axis]) + + for i in range(0, len(psi_values)): + if len(psi_values) == 1: + psi_values = psi(expected, actual, buckets) + elif axis == 0: + if len(expected.shape) == 1: + psi_values[i] = psi(expected, actual[:, i], buckets) + else: + psi_values[i] = psi(expected[:, i], actual[:, i], buckets) + elif axis == 1: + if len(expected.shape) == 1: + psi_values[i] = psi(expected, actual[i, :], buckets) + else: + psi_values[i] = psi(expected[i, :], actual[i, :], buckets) + + return psi_values diff --git a/python/ppc_model/preprocessing/local_processing/standard_type_enum.py b/python/ppc_model/preprocessing/local_processing/standard_type_enum.py new file mode 100644 index 00000000..118e2528 --- /dev/null +++ b/python/ppc_model/preprocessing/local_processing/standard_type_enum.py @@ -0,0 +1,7 @@ +from enum import Enum, unique + + +@unique +class standardType(Enum): + min_max = "min-max" + z_score = "z-score" diff --git a/python/ppc_model/preprocessing/preprocessing_engine.py b/python/ppc_model/preprocessing/preprocessing_engine.py new file mode 100644 index 00000000..3ca7c6e3 --- /dev/null +++ b/python/ppc_model/preprocessing/preprocessing_engine.py @@ -0,0 +1,19 @@ +from ppc_model.preprocessing.local_processing.local_processing_party import LocalProcessingParty + +from ppc_model.common.global_context import components +from ppc_model.common.protocol import ModelTask +from ppc_model.interface.task_engine import TaskEngine +from ppc_model.preprocessing.processing_context import ProcessingContext + + +class PreprocessingEngine(TaskEngine): + task_type = ModelTask.PREPROCESSING + + @staticmethod + def run(args): + context = ProcessingContext( + args=args, + components=components, + ) + lpp = LocalProcessingParty(context) + lpp.processing() diff --git a/python/ppc_model/preprocessing/processing_context.py b/python/ppc_model/preprocessing/processing_context.py new file mode 100644 index 00000000..f46fd11c --- /dev/null +++ b/python/ppc_model/preprocessing/processing_context.py @@ -0,0 +1,27 @@ +import os + +from ppc_model.common.context import Context +from ppc_model.common.initializer import Initializer +from ppc_common.ppc_utils import common_func +from ppc_model.common.model_setting import ModelSetting + + +class ProcessingContext(Context): + def __init__(self, + args, + components: Initializer): + super().__init__(args['job_id'], + args['task_id'], + components, + role=None) + self.dataset_path = args['dataset_path'] + self.dataset_file_path = os.path.join( + self.workspace, args['dataset_id']) + self.job_algorithm_type = args['algorithm_type'] + self.need_run_psi = args['need_run_psi'] + self.model_dict = args['model_dict'] + self.training_job_id = common_func.get_config_value( + "training_job_id", None, args, False) + if "psi_result_path" in args: + self.remote_psi_result_path = args["psi_result_path"] + self.model_setting = ModelSetting(self.model_dict) diff --git a/python/ppc_model/preprocessing/tests/test_preprocessing.py b/python/ppc_model/preprocessing/tests/test_preprocessing.py new file mode 100644 index 00000000..279e00e6 --- /dev/null +++ b/python/ppc_model/preprocessing/tests/test_preprocessing.py @@ -0,0 +1,672 @@ +from ppc_common.ppc_utils import utils +from local_processing.preprocessing import union_column_info, process_na_dataframe, process_na_fill_dataframe, process_outliers, one_hot_encode_and_merge, normalize_dataframe, process_train_dataframe, remove_high_correlation_features, process_psi, process_dataframe +import pandas as pd +import numpy as np +import pytest +import json +import sys + +from ppc_model.preprocessing.local_processing.standard_type_enum import standardType +# import pytest +import numpy as np +import pandas as pd +from ppc_model.preprocessing.local_processing.preprocessing import union_column_info, process_na_dataframe, process_na_fill_dataframe, process_outliers, one_hot_encode_and_merge, normalize_dataframe, process_train_dataframe, remove_high_correlation_features, process_psi, process_dataframe +from ppc_common.ppc_utils import utils + + +def test_process_na_dataframe(): + # Create a sample DataFrame with missing values + df = pd.DataFrame({ + 'col1': [1, 2, None, 4, 5], + 'col2': [None, 2, None, 4, None], + 'col3': [1, 2, 3, 4, 5], + 'col4': [1, None, 3, 4, 5], + 'col5': [1, 2, None, None, None], + 'y': [0, 1, 0, 1, 0] + }) + + # Define the expected output DataFrame + expected_output = pd.DataFrame({ + 'col1': [1, 2, 3.0, 4, 5], + 'col3': [1, 2, 3, 4, 5], + 'col4': [1, 6.0, 3, 4, 5], + 'y': [0, 1, 0, 1, 0] + }) + expected_column_info = {'col1': {'missing_ratio': 0.2, 'na_selected': 1, 'isExisted': True}, + 'col2': {'missing_ratio': 0.6, 'na_selected': 0, 'isExisted': False}, + 'col3': {'missing_ratio': 0.0, 'na_selected': 1, 'isExisted': True}, + 'col4': {'missing_ratio': 0.2, 'na_selected': 1, 'isExisted': True}, + 'col5': {'missing_ratio': 0.6, 'na_selected': 0, 'isExisted': False}, + 'y': {'missing_ratio': 0.0, 'na_selected': 1, 'isExisted': True}} + + # Call the function under test + processed_df, column_info = process_na_dataframe(df, 0.5) + assert column_info == expected_column_info + processed_df = process_na_fill_dataframe(processed_df, ['col2', 'col4']) + + # Assert that the processed DataFrame matches the expected output + assert processed_df.equals(expected_output) + + +def test_process_nan_dataframe(): + # Create a sample DataFrame with missing values + df = pd.DataFrame({ + 'col1': [1, 2, None, 4, 5], + 'col2': [None, 2, None, 4, None], + 'col3': [1, 2, 3, 4, 5], + 'col4': [1, None, 3, 4, 5], + 'col5': [1, 2, None, None, None], + 'y': [0, 1, 0, 1, 0] + }) + + # Define the expected output DataFrame + expected_output = pd.DataFrame({ + 'col1': [1, 2, np.nan, 4, 5], + 'col2': [np.nan, 2, np.nan, 4, np.nan], + 'col3': [1, 2, 3, 4, 5], + 'col4': [1, np.nan, 3, 4, 5], + 'col5': [1, 2, np.nan, np.nan, np.nan], + 'y': [0, 1, 0, 1, 0] + }) + + # Call the function under test + processed_df = df.replace('None', np.nan) + + # Assert that the processed DataFrame matches the expected output + assert processed_df.equals(expected_output) + + +def test_process_outliers(): + # Create a sample DataFrame with outliers + df = pd.DataFrame({ + 'col3': [100, 200, None, 400, 500, 100, 200, 300, 400, 500, 100, 200, None, 400, 500], + 'col5': [100000, None, 1, -1, 1, -1, 1, -1, 1, 1, -1, None, None, None, 1], + }) + # + # outlier = np.random.normal(df['col5'].mean() + 3 * df['col5'].std(), df['col5'].std()) + # df.at[1, 'col5'] = outlier + + # Define the expected output DataFrame + expected_output = pd.DataFrame({ + 'col3': [100, 200, None, 400, 500, 100, 200, 300, 400, 500, 100, 200, None, 400, 500], + 'col5': [df['col5'].mean(), None, 1, -1, 1, -1, 1, -1, 1, 1, -1, None, None, None, 1], + }) + + # Call the function under test + processed_df = process_outliers(df, []) + + # Assert that the processed DataFrame matches the expected output + assert processed_df.equals(expected_output) + + +def test_one_hot_encode_and_merge(): + # Create a sample DataFrame + df = pd.DataFrame({ + 'col1': [1, 2, 3, 4, 5], + 'col2': ['A', 'B', 'C', 'A', 'B'], + 'col3': ['X', 'Y', 'Z', 'X', 'Y'], + 'y': [0, 1, 0, 1, 0] + }) + + # Define the expected output DataFrame + expected_output = pd.DataFrame({ + 'col1': [1, 2, 3, 4, 5], + 'y': [0, 1, 0, 1, 0], + 'col2_A': [True, False, False, True, False], + 'col2_B': [False, True, False, False, True], + 'col2_C': [False, False, True, False, False], + 'col3_X': [True, False, False, True, False], + 'col3_Y': [False, True, False, False, True], + 'col3_Z': [False, False, True, False, False] + }) + + # Call the function under test + processed_df = one_hot_encode_and_merge(df, ['col2', 'col3']) + + # Assert that the processed DataFrame matches the expected output + assert processed_df.equals(expected_output) + + +def test_normalize_dataframe(): + # Create a sample DataFrame + df = pd.DataFrame({ + 'col1': [1, 2, None, 4, 5], + 'col2': ['A', 'B', 'C', 'A', 'B'], + 'col3': ['X', 'Y', 'Z', 'X', 'Y'], + 'col4': [10, 20, 30, 40, 50], + 'col5': [100, 200, None, 400, 500] + }) + + # Define the expected output DataFrame + expected_output = pd.DataFrame({ + 'col1': [0.0, 0.25, None, 0.75, 1.0], + 'col2': ['A', 'B', 'C', 'A', 'B'], + 'col3': ['X', 'Y', 'Z', 'X', 'Y'], + 'col4': [0.0, 0.25, 0.5, 0.75, 1.0], + 'col5': [0.0, 0.25, None, 0.75, 1.0] + }) + + # Call the function under test + processed_df = normalize_dataframe( + df, ['col2', 'col3'], standardType.min_max.value) + assert processed_df.equals(expected_output) + + # Define the expected output DataFrame + expected_output = pd.DataFrame({ + 'col1': [-1.265, -0.632, None, 0.632, 1.265], + 'col2': ['A', 'B', 'C', 'A', 'B'], + 'col3': ['X', 'Y', 'Z', 'X', 'Y'], + 'col4': [-1.414, -0.707, 0.0, 0.707, 1.414], + 'col5': [-1.265, -0.632, None, 0.632, 1.265] + }) + # Call the function under test + processed_df = normalize_dataframe( + df, ['col2', 'col3'], standardType.z_score.value) + # Assert that the processed DataFrame matches the expected output + print(expected_output) + print(processed_df) + assert processed_df.round(3).equals(expected_output) + + +def test_remove_high_correlation_features(): + # Create a sample DataFrame + df = pd.DataFrame({ + 'col1': [1, 2, 3, 4, 5], + 'col2': [2, 4, 6, 8, 10], + 'col3': [3, 6, 9, 12, 15], + 'col4': [4, 8, 12, 16, 20], + 'col5': [5, 10, 15, 20, 25], + 'y': [0, 1, 0, 1, 0] + }) + + # Define the expected output DataFrame + expected_output = pd.DataFrame({ + 'col1': [1, 2, 3, 4, 5], + 'y': [0, 1, 0, 1, 0] + }) + + # Call the function under test + processed_df = remove_high_correlation_features(df, ['y'], 0.8) + + df = pd.DataFrame({ + 'col1': [1, 2, 3, 4, 5], + 'col2': [2, 4, None, 8, 10], + 'col3': [3, 6, 9, 12, 15], + 'col4': [4, None, 12, 16, 20], + 'col5': [5, 10, None, 20, 25], + 'y': [0, 1, 0, 1, 0] + }) + + # Define the expected output DataFrame + expected_output = pd.DataFrame({ + 'col1': [1, 2, 3, 4, 5], + 'y': [0, 1, 0, 1, 0] + }) + processed_df = remove_high_correlation_features(df, ['y'], 0.8) + print(f"processed_df:{processed_df}") + print(f"expected_output:{expected_output}") + # Assert that the processed DataFrame matches the expected output + assert processed_df.equals(expected_output) + + df = pd.DataFrame({ + 'col1': [-1, 222, 3.4, -22, 5.1], + 'col2': [2, None, 6, 8, 10], + 'col3': [3, 6, 9, 12, 15], + 'col4': [4, None, 12, 16, 20], + 'col5': [5, 10, None, 20, 25], + 'y': [0, 1, 0, 1, 0] + }) + + # Define the expected output DataFrame + expected_output = pd.DataFrame({ + 'col1': [-1, 222, 3.4, -22, 5.1], + 'col2': [2, None, 6, 8, 10], + 'y': [0, 1, 0, 1, 0] + }) + processed_df = remove_high_correlation_features(df, ['y'], 0.8) + + print(processed_df) + + # Assert that the processed DataFrame matches the expected output + assert processed_df.equals(expected_output) + + +def test_process_psi(): + # Create a sample DataFrame + df_filled = pd.DataFrame({ + # 'col1': [0, 1, 0, 1, 0], + 'col1': [1, 2, 3, 4, 5], + 'col2': ['A', 'B', 'C', 'A', 'B'], + 'col3': ['X', 'Y', 'Z', 'X', 'Y'], + 'col4': [10, 20, 30, 40, 50], + 'col5': [100, 200, 300, 400, 500], + 'y': [0, 1, 0, 1, 0] + }) + + # Define the expected output DataFrame + expected_output = pd.DataFrame({ + 'col2': ['A', 'B', 'C', 'A', 'B'], + 'col3': ['X', 'Y', 'Z', 'X', 'Y'], + 'y': [0, 1, 0, 1, 0] + }) + + # Call the function under test + processed_df, _ = process_psi( + df_filled, ['col2', 'col3', 'y'], 'col1', 1, 0.9, 5) + + # Assert that the processed DataFrame matches the expected output + assert processed_df.equals(expected_output) + + # read test csv as DataFrame + test_file_path = "./癌症纵向训练数据集psi.csv" + df_filled = pd.read_csv(test_file_path) + expected_output_col = ['id', 'y', 'x1', 'x5', + 'x6', 'x9', 'x10', 'x11', 'x12', 'x14'] + processed_df, _ = process_psi( + df_filled, ['id', 'y', 'x1', 'x5', 'x23'], 'x15', 0, 0.3, 4) + + assert processed_df.columns.tolist() == expected_output_col + + +def test_process_dataframe(): + # Create a sample DataFrame + test_file_path = "./癌症纵向训练数据集psi.csv" + df_filled = pd.read_csv(test_file_path) + + # Create a mock JobContext object + xgb_model = { + "algorithm_subtype": "HeteroXGB", + "participants": 2, # 提供数据集机构数 + # XGBoost参数 + "use_psi": 0, # 0否,1是 + "use_goss": 0, # 0否,1是 + "test_dataset_percentage": 0.3, + "learning_rate": 0.1, + "num_trees": 6, + "max_depth": 3, + "max_bin": 4, # 分箱数(计算XGB直方图) + "threads": 8, + # 调度服务各方预处理 + "na_select": 0.8, # 缺失值筛选阈值,缺失值比例超过阈值则移除该特征(0表示只要有缺失值就移除,1表示仅移除全为缺失值的列) + "fillna": 0, # 是否缺失值填充(均值) + "psi_select_col": "x15", # PSI稳定性筛选时间列名称(0代表不进行PSI筛选,例如:"month") + "psi_select_base": 1, # PSI稳定性筛选时间基期(例如以0为基期,统计其他时间的psi,时间按照周/月份/季度等提前处理为0,1,2,...,n格式) + "psi_select_thresh": 0.3, # 若最大的逐月PSI>0.1,则剔除该特征 + "psi_select_bins": 4, # 计算逐月PSI时分箱数 + "filloutlier": 1, # 是否异常值处理(+-3倍标准差使用均值填充) + "normalized": 1, # 是否归一化,每个值减去最小值,然后除以最大值与最小值的差 + "standardized": 1, # 是否标准化,计算每个数据点与均值的差,然后除以标准差 + "one_hot": 0, # 是否进行onehot + # 指定分类特征索引,例如"x1,x12,x23" ("0"代表无分类特征)(分类特征需要在此处标注,使用onehot预处理/建模) + "categorical": "id,y,x1,x5,x23", + # 建模节点特征工程 + "use_iv": 0, # 是否计算woe/iv,并使用iv进行特征筛选 + 'group_num': 3, # 分箱数(计算woe分箱,等频) + 'iv_thresh': 0.1, # 使用iv进行特征筛选的阈值(仅保留iv大于阈值的特征) + 'corr_select': 0.8 # 计算特征相关性,相关性大于阈值特征仅保留iv最大的(如use_iv=0,则随机保留一个)(corr_select=0时不进行相关性筛选) + } + + xgb_dict = dict() + # 根据xgb_model 生成xgb_dict + for key, value in xgb_model.items(): + xgb_dict[key] = value + + # Call the function under test + column_info1 = process_dataframe( + df_filled, xgb_dict, "./xgb_data_file_path", utils.AlgorithmType.Train.name, "j-123456") + + xgb_model = { + "algorithm_subtype": "HeteroXGB", + "participants": 2, # 提供数据集机构数 + # XGBoost参数 + "use_psi": 0, # 0否,1是 + "use_goss": 0, # 0否,1是 + "test_dataset_percentage": 0.3, + "learning_rate": 0.1, + "num_trees": 6, + "max_depth": 3, + "max_bin": 4, # 分箱数(计算XGB直方图) + "threads": 8, + # 调度服务各方预处理 + "na_select": 0.8, # 缺失值筛选阈值,缺失值比例超过阈值则移除该特征(0表示只要有缺失值就移除,1表示仅移除全为缺失值的列) + "fillna": 1, # 是否缺失值填充(均值) + "psi_select_col": "x15", # PSI稳定性筛选时间列名称(0代表不进行PSI筛选,例如:"month") + "psi_select_base": 1, # PSI稳定性筛选时间基期(例如以0为基期,统计其他时间的psi,时间按照周/月份/季度等提前处理为0,1,2,...,n格式) + "psi_select_thresh": 0.3, # 若最大的逐月PSI>0.1,则剔除该特征 + "psi_select_bins": 4, # 计算逐月PSI时分箱数 + "filloutlier": 1, # 是否异常值处理(+-3倍标准差使用均值填充) + "normalized": 1, # 是否归一化,每个值减去最小值,然后除以最大值与最小值的差 + "standardized": 1, # 是否标准化,计算每个数据点与均值的差,然后除以标准差 + "one_hot": 0, # 是否进行onehot + # 指定分类特征索引,例如"x1,x12,x23" ("0"代表无分类特征)(分类特征需要在此处标注,使用onehot预处理/建模) + "categorical": "id,y,x1,x5,x23", + # 建模节点特征工程 + "use_iv": 0, # 是否计算woe/iv,并使用iv进行特征筛选 + 'group_num': 3, # 分箱数(计算woe分箱,等频) + 'iv_thresh': 0.1, # 使用iv进行特征筛选的阈值(仅保留iv大于阈值的特征) + 'corr_select': 0.8 # 计算特征相关性,相关性大于阈值特征仅保留iv最大的(如use_iv=0,则随机保留一个)(corr_select=0时不进行相关性筛选) + } + for key, value in xgb_model.items(): + xgb_dict[key] = value + column_info2 = process_dataframe( + df_filled, xgb_dict, "./xgb_data_file_path2", utils.AlgorithmType.Train.name, "j-123456") + assert column_info1 == column_info2 + expected_column_info = {'y': {'missing_ratio': 0.0, + 'na_selected': 1, + 'psi_selected': 1, + 'corr_selected': 1, + 'isExisted': True + }, + 'x0': {'missing_ratio': 0.0, + 'na_selected': 1, + 'psi_selected': 1, + 'corr_selected': 1, + 'isExisted': True + }, + 'x1': {'missing_ratio': 0.0, + 'na_selected': 1, + 'psi_selected': 1, + 'corr_selected': 1, + 'isExisted': True + }, + 'x2': {'missing_ratio': 0.0, + 'na_selected': 1, + 'psi_selected': 1, + 'corr_selected': 0, + 'isExisted': False + }, + 'x3': {'missing_ratio': 0.0, + 'na_selected': 1, + 'psi_selected': 1, + 'corr_selected': 0, + 'isExisted': False + }, + 'x4': {'missing_ratio': 0.0, + 'na_selected': 1, + 'psi_selected': 0, + 'corr_selected': 0, + 'isExisted': False + }, + 'x5': {'missing_ratio': 0.0, + 'na_selected': 1, + 'psi_selected': 1, + 'corr_selected': 1, + 'isExisted': True + }, + 'x6': {'missing_ratio': 0.018, + 'na_selected': 1, + 'psi_selected': 1, + 'corr_selected': 1, + 'isExisted': True + }, + 'x7': {'missing_ratio': 0.0, + 'na_selected': 1, + 'psi_selected': 0, + 'corr_selected': 0, + 'isExisted': False + }, + 'x8': {'missing_ratio': 0.0, + 'na_selected': 1, + 'psi_selected': 1, + 'corr_selected': 1, + 'isExisted': True + }, + 'x9': {'missing_ratio': 0.0, + 'na_selected': 1, + 'psi_selected': 1, + 'corr_selected': 1, + 'isExisted': True + }, + 'x10': {'missing_ratio': 0.012, + 'na_selected': 1, + 'psi_selected': 1, + 'corr_selected': 1, + 'isExisted': True + }, + 'x11': {'missing_ratio': 0.998, + 'na_selected': 0, + 'psi_selected': 0, + 'corr_selected': 0, + 'isExisted': False + }, + 'x12': {'missing_ratio': 0.0, + 'na_selected': 1, + 'psi_selected': 1, + 'corr_selected': 0, + 'isExisted': False + }, + 'x13': {'missing_ratio': 0.0, + 'na_selected': 1, + 'psi_selected': 1, + 'corr_selected': 0, + 'isExisted': False + }, + 'x14': {'missing_ratio': 0.0, + 'na_selected': 1, + 'psi_selected': 1, + 'corr_selected': 1, + 'isExisted': True + }, + 'x15': {'missing_ratio': 0.0, + 'na_selected': 1, + 'psi_selected': 0, + 'corr_selected': 0, + 'isExisted': False + } + } + # 转成python字典 + expected_column_info_dict = dict(expected_column_info) + assert column_info1 == expected_column_info_dict + + +def test_process_train_dataframe(): + # Create a sample DataFrame + df = pd.DataFrame({ + 'id': [1, 2, 3, 4, 5], + 'x1': [10, 20, 30, 40, 50], + 'x2': [100, 200, 300, 400, 500], + 'x3': [1000, 2000, 3000, 4000, 5000], + 'x4': [10000, 20000, 30000, 40000, 50000], + 'x5': [100000, 200000, 300000, 400000, 500000], + 'y': [0, 1, 0, 1, 0] + }) + + # Define the expected output DataFrame + expected_output = pd.DataFrame({ + 'id': [1, 2, 3, 4, 5], + 'x2': [100, 200, 300, 400, 500], + 'x3': [1000, 2000, 3000, 4000, 5000] + }) + + # Define the column_info dictionary + column_info = { + 'id': {'isExisted': True}, + 'x1': {'isExisted': False}, + 'x2': {'isExisted': True}, + 'x3': {'isExisted': True}, + 'x4': {'isExisted': False}, + 'x5': {'isExisted': False}, + 'y': {'isExisted': False}, + } + + # Call the function under test + processed_df = process_train_dataframe(df, column_info) + + # Assert that the processed DataFrame matches the expected output + assert processed_df.equals(expected_output) + + +def test_process_train_dataframe_with_additional_columns(): + # Create a sample DataFrame + df = pd.DataFrame({ + 'id': [1, 2, 3, 4, 5], + 'x1': [10, 20, 30, 40, 50], + 'x2': [100, 200, 300, 400, 500], + 'x3': [1000, 2000, 3000, 4000, 5000], + 'x4': [10000, 20000, 30000, 40000, 50000], + 'x5': [100000, 200000, 300000, 400000, 500000], + 'y': [0, 1, 0, 1, 0] + }) + + # Define the expected output DataFrame + expected_output = pd.DataFrame({ + 'id': [1, 2, 3, 4, 5], + 'x1': [10, 20, 30, 40, 50], + 'x3': [1000, 2000, 3000, 4000, 5000], + 'x4': [10000, 20000, 30000, 40000, 50000], + 'x5': [100000, 200000, 300000, 400000, 500000] + }) + + # Define the column_info dictionary + column_info = { + 'id': {'isExisted': True}, + 'x1': {'isExisted': True}, + 'x2': {'isExisted': False}, + 'x3': {'isExisted': True}, + 'x4': {'isExisted': True}, + 'x5': {'isExisted': True}, + 'y': {'isExisted': False}, + } + + # Call the function under test + processed_df = process_train_dataframe(df, column_info) + + # Assert that the processed DataFrame matches the expected output + assert processed_df.equals(expected_output) + +def test_merge_column_info_from_file(): + col_info_file_path = "./test_column_info_merge.csv" + iv_info_file_path = "./test_column_info_iv.csv" + column_info_fm = pd.read_csv(col_info_file_path, index_col=0) + column_info_iv = pd.read_csv(iv_info_file_path, index_col=0) + union_df = union_column_info(column_info_fm, column_info_iv) + + col_str_expected = '{"y": {"missing_ratio": 0.0, "na_selected": 1, "isExisted": true, "psi_selected": 1, "corr_selected": 1}, "x0": {"missing_ratio": 0.0, "na_selected": 1, "isExisted": true, "psi_selected": 1, "corr_selected": 1}, "x1": {"missing_ratio": 0.0, "na_selected": 1, "isExisted": true, "psi_selected": 1, "corr_selected": 1}, "x2": {"missing_ratio": 0.0, "na_selected": 1, "isExisted": false, "psi_selected": 1, "corr_selected": 0}, "x3": {"missing_ratio": 0.0, "na_selected": 1, "isExisted": false, "psi_selected": 1, "corr_selected": 0}, "x4": {"missing_ratio": 0.0, "na_selected": 1, "isExisted": false, "psi_selected": 0, "corr_selected": 0}, "x5": {"missing_ratio": 0.0, "na_selected": 1, "isExisted": true, "psi_selected": 1, "corr_selected": 1}, "x6": {"missing_ratio": 0.018, "na_selected": 1, "isExisted": true, "psi_selected": 1, "corr_selected": 1}, "x7": {"missing_ratio": 0.0, "na_selected": 1, "isExisted": false, "psi_selected": 0, "corr_selected": 0}, "x8": {"missing_ratio": 0.0, "na_selected": 1, "isExisted": true, "psi_selected": 1, "corr_selected": 1}, "x9": {"missing_ratio": 0.0, "na_selected": 1, "isExisted": true, "psi_selected": 1, "corr_selected": 1}, "x10": {"missing_ratio": 0.012, "na_selected": 1, "isExisted": true, "psi_selected": 1, "corr_selected": 1}, "x11": {"missing_ratio": 0.998, "na_selected": 0, "isExisted": false, "psi_selected": 0, "corr_selected": 0}, "x12": {"missing_ratio": 0.0, "na_selected": 1, "isExisted": false, "psi_selected": 1, "corr_selected": 0}, "x13": {"missing_ratio": 0.0, "na_selected": 1, "isExisted": false, "psi_selected": 1, "corr_selected": 0}, "x14": {"missing_ratio": 0.0, "na_selected": 1, "isExisted": true, "psi_selected": 1, "corr_selected": 1}, "x15": {"missing_ratio": 0.0, "na_selected": 1, "isExisted": false, "psi_selected": 0, "corr_selected": 0}}' + + # expected_df_file_path = './test_union_column.csv' + # expected_df = pd.read_csv(expected_df_file_path, index_col=0) + # assert expected_df.equals(union_df) + column_info_str = json.dumps(column_info_fm.to_dict(orient='index')) + assert column_info_str == col_str_expected + +def construct_dataset(num_samples, num_features, file_path): + np.random.seed(0) + # 生成标签列 + labels = np.random.choice([0, 1], size=num_samples) + # 生成特征列 + features = np.random.rand(num_samples, num_features) + # 将标签转换为DataFrame + labels_df = pd.DataFrame(labels, columns=['Label']) + + # 将特征转换为DataFrame + features_df = pd.DataFrame(features) + + # 合并标签和特征DataFrame + dataset_df = pd.concat([labels_df, features_df], axis=1) + + # 将DataFrame写入CSV文件 + dataset_df.to_csv(file_path, index=False) + + return labels, features + +def test_gen_file(): + num_samples = 400000 + num_features = 100 + file_path = "./dataset-{}-{}.csv".format(num_samples, num_features) + construct_dataset(num_samples, num_features, file_path) + +def test_large_process_train_dataframe(): + num_samples = 400000 + num_features = 100 + test_file_path = "./dataset-{}-{}.csv".format(num_samples, num_features) + df_filled = pd.read_csv(test_file_path) + + # Create a mock JobContext object + xgb_model = { + "algorithm_subtype": "HeteroXGB", + "participants": 2, # 提供数据集机构数 + # XGBoost参数 + "use_psi": 0, # 0否,1是 + "use_goss": 0, # 0否,1是 + "test_dataset_percentage": 0.3, + "learning_rate": 0.1, + "num_trees": 6, + "max_depth": 3, + "max_bin": 4, # 分箱数(计算XGB直方图) + "threads": 8, + # 调度服务各方预处理 + "na_select": 0.8, # 缺失值筛选阈值,缺失值比例超过阈值则移除该特征(0表示只要有缺失值就移除,1表示仅移除全为缺失值的列) + "fillna": 0, # 是否缺失值填充(均值) + "psi_select_col": 0, # PSI稳定性筛选时间列名称(0代表不进行PSI筛选,例如:"month") + "psi_select_base": 1, # PSI稳定性筛选时间基期(例如以0为基期,统计其他时间的psi,时间按照周/月份/季度等提前处理为0,1,2,...,n格式) + "psi_select_thresh": 0.3, # 若最大的逐月PSI>0.1,则剔除该特征 + "psi_select_bins": 4, # 计算逐月PSI时分箱数 + "filloutlier": 1, # 是否异常值处理(+-3倍标准差使用均值填充) + "normalized": 1, # 是否归一化,每个值减去最小值,然后除以最大值与最小值的差 + "standardized": 1, # 是否标准化,计算每个数据点与均值的差,然后除以标准差 + "one_hot": 0, # 是否进行onehot + # 指定分类特征索引,例如"x1,x12,x23" ("0"代表无分类特征)(分类特征需要在此处标注,使用onehot预处理/建模) + "categorical": "id,y,1,5,23", + # 建模节点特征工程 + "use_iv": 0, # 是否计算woe/iv,并使用iv进行特征筛选 + 'group_num': 3, # 分箱数(计算woe分箱,等频) + 'iv_thresh': 0.1, # 使用iv进行特征筛选的阈值(仅保留iv大于阈值的特征) + 'corr_select': 0.8 # 计算特征相关性,相关性大于阈值特征仅保留iv最大的(如use_iv=0,则随机保留一个)(corr_select=0时不进行相关性筛选) + } + + xgb_dict = dict() + # 根据xgb_model 生成xgb_dict + for key, value in xgb_model.items(): + xgb_dict[key] = value + + # Call the function under test + start_time = time.time() + column_info1 = process_dataframe( + df_filled, xgb_dict, "./xgb_data_file_path", utils.AlgorithmType.Train.name, "j-123456") + end_time = time.time() + print(f"test_large_process_train_dataframe time cost:{end_time-start_time}, num_samples: {num_samples}, num_features: {num_features}") + + + + + +# Run the tests +# pytest.main() +if __name__=="__main__": + import time + # test_large_process_train_dataframe() + time1 = time.time() + test_process_na_dataframe() + time2 = time.time() + test_process_nan_dataframe() + time3 = time.time() + test_process_outliers() + time4 = time.time() + test_one_hot_encode_and_merge() + time5 = time.time() + test_normalize_dataframe() + time6 = time.time() + test_remove_high_correlation_features() + time7 = time.time() + test_process_psi() + time8 = time.time() + test_process_dataframe() + time9 = time.time() + test_process_train_dataframe() + time10 = time.time() + test_process_train_dataframe_with_additional_columns() + time11 = time.time() + test_merge_column_info_from_file() + time12 = time.time() + print(f"test_process_na_dataframe time cost: {time2-time1}") + print(f"test_process_nan_dataframe time cost: {time3-time2}") + print(f"test_process_outliers time cost: {time4-time3}") + print(f"test_one_hot_encode_and_merge time cost: {time5-time4}") + print(f"test_normalize_dataframe time cosy: {time6-time5}") + print(f"test_remove_high_correlation_features time cost: {time7-time6}") + print(f"test_process_psi time cost: {time8-time7}") + print(f"test_process_dataframe time cost: {time9-time8}") + print(f"test_process_train_dataframe time cost: {time10-time9}") + print(f"test_process_train_dataframe_with_additional_columns time cost: {time11-time10}") + print(f"test_merge_column_info_from_file time cost: {time12-time11}") + print("All tests pass!") diff --git a/python/ppc_model/secure_lgbm/__init__.py b/python/ppc_model/secure_lgbm/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/ppc_model/secure_lgbm/monitor/__init__.py b/python/ppc_model/secure_lgbm/monitor/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/ppc_model/secure_lgbm/monitor/callback.py b/python/ppc_model/secure_lgbm/monitor/callback.py new file mode 100644 index 00000000..fa1247de --- /dev/null +++ b/python/ppc_model/secure_lgbm/monitor/callback.py @@ -0,0 +1,83 @@ +from abc import ABC +from typing import ( + Callable, + Optional, + Sequence +) + +import numpy as np + +from ppc_model.secure_lgbm.monitor.core import _Model + + +class TrainingCallback(ABC): + def __init__(self) -> None: + pass + + def before_training(self, model: _Model) -> _Model: + """Run before training starts.""" + return model + + def after_training(self, model: _Model) -> _Model: + """Run after training is finished.""" + return model + + def before_iteration(self, model: _Model, epoch: int) -> bool: + """Run before each iteration. Returns True when training should stop.""" + return False + + def after_iteration(self, model: _Model, epoch: int) -> bool: + """Run after each iteration. Returns `True` when training should stop.""" + return False + + +class CallbackContainer: + """A special internal callback for invoking a list of other callbacks.""" + + def __init__( + self, + callbacks: Sequence[TrainingCallback], + feval: Optional[Callable] = None + ) -> None: + self.callbacks = set(callbacks) + for cb in callbacks: + if not isinstance(cb, TrainingCallback): + raise TypeError("callback must be an instance of `TrainingCallback`.") + + msg = ( + "feval must be callable object for monitoring. For builtin metrics" + ", passing them in training parameter invokes monitor automatically." + ) + if feval is not None and not callable(feval): + raise TypeError(msg) + + self.feval = feval + + def before_training(self, model: _Model) -> _Model: + for c in self.callbacks: + model = c.before_training(model=model) + return model + + def after_training(self, model: _Model) -> _Model: + for c in self.callbacks: + model = c.after_training(model=model) + return model + + def before_iteration( + self, + model: _Model + ) -> bool: + return any( + c.before_iteration(model, model.get_epoch()) for c in self.callbacks + ) + + def after_iteration( + self, + model: _Model, + pred: np.ndarray, + eval_on_test: bool = True + ) -> bool: + model.after_iteration(pred, eval_on_test) + model.eval(self.feval) + ret = any(c.after_iteration(model, model.get_epoch()) for c in self.callbacks) + return ret diff --git a/python/ppc_model/secure_lgbm/monitor/core.py b/python/ppc_model/secure_lgbm/monitor/core.py new file mode 100644 index 00000000..6e080a04 --- /dev/null +++ b/python/ppc_model/secure_lgbm/monitor/core.py @@ -0,0 +1,144 @@ +import collections +from typing import Dict, Any, Union, Tuple, List, Callable + +import numpy as np +from sklearn import metrics + +_Score = Union[float, Tuple[float, float]] +_ScoreList = Union[List[float], List[Tuple[float, float]]] + +_BoosterParams = Dict[str, Any] +_Metric = Callable[[np.ndarray, np.ndarray], Dict[str, _Score]] + +""" +A dictionary containing the evaluation history: +{"metric_name": [0.5, ...]} +""" +_EvalsLog = Dict[str, _ScoreList] + + +class Booster: + """A Booster of XGBoost. + + Booster is the model of xgboost, that contains low level routines for + training, prediction and evaluation. + """ + + def __init__( + self, + y_true: np.ndarray, + test_y_true: np.ndarray, + workspace: str = None, + job_id: str = None, + storage_client: str = None + ) -> None: + self.params: _BoosterParams = {} + self.y_true = y_true + self.test_y_true = test_y_true + self.y_pred = None + self.test_y_pred = None + self.eval_on_test = True + self.epoch = 0 + self.workspace = workspace + self.job_id = job_id + self.storage_client = storage_client + self.history: _EvalsLog = collections.OrderedDict() + + def get_y_true(self) -> np.ndarray: + return self.y_true + + def get_test_y_true(self) -> np.ndarray: + return self.test_y_true + + def get_y_pred(self) -> np.ndarray: + return self.y_pred + + def get_test_y_pred(self) -> np.ndarray: + return self.test_y_pred + + def get_epoch(self) -> int: + return self.epoch + + def get_workspace(self) -> str: + return self.workspace + + def get_job_id(self) -> str: + return self.job_id + + def get_storage_client(self): + return self.storage_client + + def set_param( + self, + key: str, + value: Any, + ) -> None: + self.params[key] = value + + def get_param( + self, + key: str + ) -> Any: + return self.params[key] + + def get_history(self) -> _EvalsLog: + return self.history + + def after_iteration( + self, + pred: np.ndarray, + eval_on_test: bool = True + ) -> None: + if eval_on_test: + self.test_y_pred = pred + else: + self.y_pred = pred + self.eval_on_test = eval_on_test + self.epoch += 1 + + def _update_history( + self, + scores: Dict[str, _Score] + ) -> None: + for key, value in scores.items(): + if key in self.history: + self.history[key].append(value) + else: + self.history[key] = [value] + + def eval( + self, + feval: _Metric + ) -> Dict[str, _Score]: + if self.eval_on_test: + scores = feval(self.test_y_true, self.test_y_pred) + else: + scores = feval(self.y_true, self.y_pred) + self._update_history(scores) + return scores + + +_Model = Booster + + +def fevaluation( + y_true: np.ndarray, + y_pred: np.ndarray, + decimal_num: int = 4 +) -> Dict[str, _Score]: + auc = metrics.roc_auc_score(y_true, y_pred) + + y_pred_label = [0 if p <= 0.5 else 1 for p in y_pred] + acc = metrics.accuracy_score(y_true, y_pred_label) + recall = metrics.recall_score(y_true, y_pred_label) + precision = metrics.precision_score(y_true, y_pred_label) + + scores_dict = { + 'auc': auc, + 'acc': acc, + 'recall': recall, + 'precision': precision + } + for metric_name in scores_dict: + scores_dict[metric_name] = round(scores_dict[metric_name], decimal_num) + return scores_dict diff --git a/python/ppc_model/secure_lgbm/monitor/early_stopping.py b/python/ppc_model/secure_lgbm/monitor/early_stopping.py new file mode 100644 index 00000000..de718ee3 --- /dev/null +++ b/python/ppc_model/secure_lgbm/monitor/early_stopping.py @@ -0,0 +1,122 @@ +from typing import Optional, cast + +import numpy + +from ppc_model.secure_lgbm.monitor.callback import TrainingCallback +from ppc_model.secure_lgbm.monitor.core import _Score, _ScoreList, _Model, _EvalsLog + + +class EarlyStopping(TrainingCallback): + """Callback function for early stopping + Parameters + ---------- + rounds : + Early stopping rounds. + metric_name : + Name of metric that is used for early stopping. + maximize : + Whether to maximize evaluation metric. None means auto (discouraged). + min_delta : + Minimum absolute change in score to be qualified as an improvement. + """ + + def __init__( + self, + rounds: int, + metric_name: str, + maximize: Optional[bool] = None, + save_best: Optional[bool] = True, + min_delta: float = 0.0, + ) -> None: + self.metric_name = metric_name + assert self.metric_name in ['auc', 'acc', 'recall', 'precision'] + self.rounds = rounds + self.maximize = maximize + self.save_best = save_best + self._min_delta = min_delta + if self._min_delta < 0: + raise ValueError("min_delta must be greater or equal to 0.") + self.stopping_history: _EvalsLog = {} + self.current_rounds: int = 0 + self.best_scores: dict = {} + super().__init__() + + def before_training(self, model: _Model) -> _Model: + return model + + def _update_rounds( + self, score: _Score, metric_name: str, model: _Model, epoch: int + ) -> bool: + def get_s(value: _Score) -> float: + """get score if it's cross validation history.""" + return value[0] if isinstance(value, tuple) else value + + def maximize(new: _Score, best: _Score) -> bool: + """New score should be greater than the old one.""" + return numpy.greater(get_s(new) - self._min_delta, get_s(best)) + + def minimize(new: _Score, best: _Score) -> bool: + """New score should be lesser than the old one.""" + return numpy.greater(get_s(best) - self._min_delta, get_s(new)) + + if self.maximize is None: + maximize_metrics = ( + "auc", + "aucpr", + "pre", + "pre@", + "map", + "ndcg", + "auc@", + "aucpr@", + "map@", + "ndcg@", + ) + if metric_name != "mape" and any(metric_name.startswith(x) for x in maximize_metrics): + self.maximize = True + else: + self.maximize = False + + if self.maximize: + improve_op = maximize + else: + improve_op = minimize + + if not self.stopping_history: # First round + self.current_rounds = 0 + self.stopping_history[metric_name] = cast(_ScoreList, [score]) + self.best_scores[metric_name] = cast(_ScoreList, [score]) + if self.save_best: + model.set_param('best_score', score) + model.set_param('best_iteration', epoch) + elif not improve_op(score, self.best_scores[metric_name][-1]): + # Not improved + self.stopping_history[metric_name].append(score) # type: ignore + self.current_rounds += 1 + else: # Improved + self.stopping_history[metric_name].append(score) # type: ignore + self.best_scores[metric_name].append(score) + self.current_rounds = 0 # reset + if self.save_best: + model.set_param('best_score', score) + model.set_param('best_iteration', epoch) + + if self.current_rounds >= self.rounds: + # Should stop + return True + return False + + def after_iteration( + self, model: _Model, epoch: int + ) -> bool: + history = model.get_history() + if len(history.keys()) < 1: + raise ValueError("Must have at least 1 validation dataset for early stopping.") + + metric_name = self.metric_name + # The latest score + score = history[metric_name][-1] + return self._update_rounds(score, metric_name, model, epoch) + + def after_training(self, model: _Model) -> _Model: + return model diff --git a/python/ppc_model/secure_lgbm/monitor/evaluation_monitor.py b/python/ppc_model/secure_lgbm/monitor/evaluation_monitor.py new file mode 100644 index 00000000..c1038bbb --- /dev/null +++ b/python/ppc_model/secure_lgbm/monitor/evaluation_monitor.py @@ -0,0 +1,123 @@ +import os +import time +import random +import traceback +from typing import Optional + +import matplotlib.pyplot as plt + +from ppc_common.ppc_utils.utils import METRICS_OVER_ITERATION_FILE +from ppc_model.common.global_context import plot_lock +from ppc_model.secure_lgbm.monitor.callback import TrainingCallback +from ppc_model.secure_lgbm.monitor.core import _Model + + +def _draw_figure(model: _Model): + scores = model.get_history() + path = model.get_workspace() + + iterations = [i + 1 for i in range(len(next(iter(scores.values()))))] + + # plt.cla() + plt.figure(figsize=(int(10 + len(iterations) / 5), 10)) + + for metric, values in scores.items(): + plt.plot(iterations, values, label=metric) + max_index = values.index(max(values)) + plt.scatter(max_index + 1, values[max_index], color='green') + plt.text(max_index + 1, values[max_index], f'{values[max_index]:.4f}', fontsize=9, ha='right') + + plt.legend() + plt.title('Metrics Over Iterations') + plt.xlabel('Iteration') + plt.ylabel('Metric Value') + plt.grid(True) + if len(iterations) <= 60: + plt.xticks(iterations, fontsize=10, rotation=45) + else: + plt.xticks(range(0, len(iterations), 5), fontsize=10, rotation=45) + plt.yticks(fontsize=12) + + file_path = os.path.join(path, METRICS_OVER_ITERATION_FILE) + plt.savefig(file_path, format='svg', dpi=300) + plt.close('all') + + +def _upload_figure(model: _Model): + storage_client = model.get_storage_client() + if storage_client is not None: + path = model.get_workspace() + job_id = model.get_job_id() + metrics_file_path = os.path.join(path, METRICS_OVER_ITERATION_FILE) + unique_file_path = os.path.join(job_id, METRICS_OVER_ITERATION_FILE) + storage_client.upload_file(metrics_file_path, unique_file_path) + + +def _fmt_metric( + metric_name: str, score: float +) -> str: + msg = f"\t{metric_name}:{score:.5f}" + return msg + + +class EvaluationMonitor(TrainingCallback): + """Print the evaluation result after each period iteration. + Parameters + ---------- + period : + How many epoches between printing. + """ + + def __init__(self, logger, period: int = 1) -> None: + self.logger = logger + self.period = period + assert period > 0 + # last error message, useful when early stopping and period are used together. + self._latest: Optional[str] = None + super().__init__() + + def after_iteration( + self, model: _Model, epoch: int + ) -> bool: + history = model.get_history() + if not history: + return False + + msg: str = f"[{model.get_job_id()}, epoch(iter): {epoch}]" + for metric_name, scores in history.items(): + if isinstance(scores[-1], tuple): + score = scores[-1][0] + else: + score = scores[-1] + msg += _fmt_metric(metric_name, score) + msg += "\n" + + if (epoch % self.period) == 0 or self.period == 1: + self.logger.info(msg) + self._latest = None + else: + # There is skipped message + self._latest = msg + + return False + + def after_training(self, model: _Model) -> _Model: + if self._latest is not None: + self.logger.info(self._latest) + max_retry = 3 + retry_num = 0 + while retry_num < max_retry: + retry_num += 1 + try: + with plot_lock: + _draw_figure(model) + except: + self.logger.info(f'scores = {model.get_history()}') + self.logger.info(f'path = {model.get_workspace()}') + err = traceback.format_exc() + # self.logger.exception(err) + self.logger.info( + f'plot moniter in times-{retry_num} failed, traceback: {err}.') + time.sleep(random.uniform(0.1, 3)) + _upload_figure(model) + return model diff --git a/python/ppc_model/secure_lgbm/monitor/feature/__init__.py b/python/ppc_model/secure_lgbm/monitor/feature/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/ppc_model/secure_lgbm/monitor/feature/feature_evaluation_info.py b/python/ppc_model/secure_lgbm/monitor/feature/feature_evaluation_info.py new file mode 100644 index 00000000..bb2530f2 --- /dev/null +++ b/python/ppc_model/secure_lgbm/monitor/feature/feature_evaluation_info.py @@ -0,0 +1,90 @@ +# -*- coding: utf-8 -*- +from enum import Enum +import pandas as pd + + +class EvaluationType(Enum): + """the evaluation type + """ + TRAIN = "train", + VALIDATION = "validation", + + +class EvaluationMetric: + """the evaluation metric + """ + + def __init__(self, value, desc): + self.value = value + self.desc = desc + + def set_desc(self, desc): + self.desc = desc + + +class FeatureEvaluationResult: + DEFAULT_SAMPLE_STAT_DESC = "总样本" + DEFAULT_POSITIVE_SAMPLE_STAT_DESC = "正样本" + DEFAULT_KS_STAT_DESC = "KS" + DEFAULT_AUC_STAT_DESC = "AUC" + + DEFAULT_TRAIN_EVALUATION_DESC = "训练集" + DEFAULT_VALIDATION_EVALUATION_DESC = "验证集" + DEFAULT_ROW_INDEX_LABEL_DESC = "分类" + + def __init__(self, type, type_desc=None, ks_value=0, auc_value=0, label_list=None): + self.type = type + self.type_desc = type_desc + if self.type_desc is None: + if self.type == EvaluationType.TRAIN: + self.type_desc = FeatureEvaluationResult.DEFAULT_TRAIN_EVALUATION_DESC + elif self.type == EvaluationType.VALIDATION: + self.type_desc = FeatureEvaluationResult.DEFAULT_VALIDATION_EVALUATION_DESC + else: + raise Exception( + f"Create FeatureEvaluationResult for unsupported evaluation type: {type}") + self.ks = EvaluationMetric( + ks_value, FeatureEvaluationResult.DEFAULT_KS_STAT_DESC) + self.auc = EvaluationMetric( + auc_value, FeatureEvaluationResult.DEFAULT_AUC_STAT_DESC) + self.positive_samples = EvaluationMetric( + 0, FeatureEvaluationResult.DEFAULT_POSITIVE_SAMPLE_STAT_DESC) + self.samples = EvaluationMetric( + 0, FeatureEvaluationResult.DEFAULT_SAMPLE_STAT_DESC) + + if label_list is not None: + self.set_sample_info(label_list) + + def set_sample_info(self, label_list): + self.samples.value = len(label_list) + for label in label_list: + self.positive_samples.value += label + + def columns(self): + return [FeatureEvaluationResult.DEFAULT_ROW_INDEX_LABEL_DESC, self.samples.desc, self.positive_samples.desc, self.ks.desc, self.auc.desc] + + def to_dict(self): + return {FeatureEvaluationResult.DEFAULT_ROW_INDEX_LABEL_DESC: self.type_desc, + self.samples.desc: self.samples.value, + self.positive_samples.desc: self.positive_samples.value, + self.ks.desc: self.ks.value, + self.auc.desc: self.auc.value} + + @staticmethod + def summary(evaluation_result_list): + if evaluation_result_list is None or len(evaluation_result_list) == 0: + return None + columns = None + rows = [] + for evaluation_metric in evaluation_result_list: + rows.append(evaluation_metric.to_dict()) + if columns is None: + columns = evaluation_metric.columns() + return pd.DataFrame(rows, columns=columns) + + @staticmethod + def store_and_upload_summary(evaluation_result_list, local_file_path, remote_file_path, storage_client): + df = FeatureEvaluationResult.summary(evaluation_result_list) + df.to_csv(local_file_path, index=False) + if storage_client is not None: + storage_client.upload_file(local_file_path, remote_file_path) diff --git a/python/ppc_model/secure_lgbm/monitor/feature/test/__init__.py b/python/ppc_model/secure_lgbm/monitor/feature/test/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/ppc_model/secure_lgbm/monitor/feature/test/feature_evalution_info_test.py b/python/ppc_model/secure_lgbm/monitor/feature/test/feature_evalution_info_test.py new file mode 100644 index 00000000..ff6f0061 --- /dev/null +++ b/python/ppc_model/secure_lgbm/monitor/feature/test/feature_evalution_info_test.py @@ -0,0 +1,73 @@ +# -*- coding: utf-8 -*- +import unittest +import numpy as np +from ppc_model.secure_lgbm.monitor.feature.feature_evaluation_info import FeatureEvaluationResult +from ppc_model.secure_lgbm.monitor.feature.feature_evaluation_info import EvaluationMetric +from ppc_model.secure_lgbm.monitor.feature.feature_evaluation_info import EvaluationType + + +def check_result(ut_obj, evaluation_result, expected_ks, expected_auc, expected_sample): + ut_obj.assertEqual(evaluation_result.ks.value, expected_ks) + ut_obj.assertEqual(evaluation_result.auc.value, expected_auc) + ut_obj.assertEqual(evaluation_result.samples.value, expected_sample) + + +class TestFeatureEvaluationResult(unittest.TestCase): + def test_default_table_meta(self): + sample_num = 1000000 + train_label_list = np.random.randint(0, 2, sample_num) + train_evaluation_result = FeatureEvaluationResult( + type=EvaluationType.TRAIN, label_list=train_label_list) + ks_value = 0.4126 + auc_value = 0.7685 + (train_evaluation_result.ks.value, + train_evaluation_result.auc.value) = (ks_value, auc_value) + check_result(self, train_evaluation_result, + ks_value, auc_value, sample_num) + + sample_num = 2000000 + validation_label_list = np.random.randint(0, 2, sample_num) + validation_evaluation_result = FeatureEvaluationResult( + type=EvaluationType.VALIDATION, label_list=validation_label_list) + ks_value = 0.3116 + auc_value = 0.6676 + (validation_evaluation_result.ks.value, + validation_evaluation_result.auc.value) = (ks_value, auc_value) + check_result(self, validation_evaluation_result, + ks_value, auc_value, sample_num) + + local_path = "evaluation_result_case_1.csv" + FeatureEvaluationResult.store_and_upload_summary( + [train_evaluation_result, validation_evaluation_result], local_path, None, None) + + def test_with_given_table_meta(self): + ks_value = 0.4126 + auc_value = 0.7685 + sample_num = 1000000 + train_label_list = np.random.randint(0, 2, sample_num) + train_evaluation_result = FeatureEvaluationResult( + type=EvaluationType.TRAIN, label_list=train_label_list, ks_value=ks_value, auc_value=auc_value) + check_result(self, train_evaluation_result, + ks_value, auc_value, sample_num) + train_evaluation_result.ks.desc = "KS值" + train_evaluation_result.auc.desc = "AUC值" + + ks_value = 0.3116 + auc_value = 0.6676 + sample_num = 2000000 + validation_label_list = np.random.randint(0, 2, sample_num) + validation_evaluation_result = FeatureEvaluationResult( + type=EvaluationType.VALIDATION, label_list=validation_label_list, ks_value=ks_value, auc_value=auc_value) + validation_evaluation_result.ks.desc = "KS值" + validation_evaluation_result.auc.desc = "AUC值" + check_result(self, validation_evaluation_result, + ks_value, auc_value, sample_num) + + local_path = "evaluation_result_case_2.csv" + + FeatureEvaluationResult.store_and_upload_summary( + [train_evaluation_result, validation_evaluation_result], local_path, None, None) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/ppc_model/secure_lgbm/monitor/train_callback_unittest.py b/python/ppc_model/secure_lgbm/monitor/train_callback_unittest.py new file mode 100644 index 00000000..6429807f --- /dev/null +++ b/python/ppc_model/secure_lgbm/monitor/train_callback_unittest.py @@ -0,0 +1,105 @@ +import time +import unittest + +import numpy as np + +from ppc_model_service import config +from ppc_model.secure_lgbm.monitor.callback import CallbackContainer +from ppc_model.secure_lgbm.monitor.core import Booster, fevaluation +from ppc_model.secure_lgbm.monitor.early_stopping import EarlyStopping +from ppc_model.secure_lgbm.monitor.evaluation_monitor import EvaluationMonitor + +log = config.get_logger() + + +class TestBooster(unittest.TestCase): + def setUp(self): + np.random.seed(int(time.time())) + self.y_true = np.random.randint(0, 2, 10000) + self.test_y_true = np.random.randint(0, 2, 10000) + self.y_pred = np.random.rand(10000) + self.booster = Booster(self.y_true, self.test_y_true) + + def test_set_get_param(self): + self.booster.set_param('learning_rate', 0.1) + self.assertEqual(self.booster.get_param('learning_rate'), 0.1) + + def test_after_iteration(self): + self.booster.after_iteration(self.y_pred, False) + np.testing.assert_array_equal(self.booster.get_y_pred(), self.y_pred) + self.assertEqual(self.booster.get_epoch(), 1) + + def test_eval(self): + self.booster.after_iteration(self.y_pred) + results = self.booster.eval(fevaluation) + self.assertIn('auc', results) + self.assertIn('acc', results) + self.assertIn('recall', results) + self.assertIn('precision', results) + self.assertIsInstance(results['auc'], float) + self.assertIsInstance(results['acc'], float) + self.assertIsInstance(results['recall'], float) + self.assertIsInstance(results['precision'], float) + + +class TestEarlyStopping(unittest.TestCase): + def setUp(self): + np.random.seed(int(time.time())) + self.y_true = np.random.randint(0, 2, 10000) + self.test_y_true = np.random.randint(0, 2, 10000) + self.y_pred = np.random.rand(10000) + self.model = Booster(self.y_true, self.test_y_true) + self.early_stopping = EarlyStopping(rounds=4, metric_name='auc', maximize=True) + + def test_early_stopping(self): + stop = False + while not stop: + np.random.seed(int(time.time()) + self.model.epoch) + y_pred = np.random.rand(10000) + self.model.after_iteration(y_pred) + self.model.eval(fevaluation) + stop = self.early_stopping.after_iteration(self.model, self.model.epoch) + print(self.model.epoch, stop) + + +class TestEvaluationMonitor(unittest.TestCase): + def setUp(self): + np.random.seed(int(time.time())) + self.y_true = np.random.randint(0, 2, 10000) + self.test_y_true = np.random.randint(0, 2, 10000) + self.y_pred = np.random.rand(10000) + self.model = Booster(self.y_true, self.test_y_true, '/tmp/') + self.monitor = EvaluationMonitor(log, period=2) + + def test_after_training(self): + np.random.seed(int(time.time())) + for i in range(10): + np.random.seed(int(time.time()) + self.model.epoch) + y_pred = np.random.rand(10000) + self.model.after_iteration(y_pred) + self.model.eval(fevaluation) + self.monitor.after_training(self.model) + + +class TestCallbackContainer(unittest.TestCase): + def setUp(self): + np.random.seed(int(time.time())) + self.y_true = np.random.randint(0, 2, 10000) + self.test_y_true = np.random.randint(0, 2, 10000) + self.y_pred = np.random.rand(10000) + self.model = Booster(self.y_true, self.test_y_true, 'tmp') + self.early_stopping = EarlyStopping(rounds=4, metric_name='auc', maximize=True) + self.monitor = EvaluationMonitor(log, period=2) + self.container = CallbackContainer([self.early_stopping, self.monitor], fevaluation) + + def test_callback_container(self): + stop = False + while not stop: + np.random.seed(int(time.time()) + self.model.epoch) + y_pred = np.random.rand(10000) + stop = self.container.after_iteration(self.model, y_pred, True) + print(self.model.epoch, stop) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/ppc_model/secure_lgbm/secure_lgbm_context.py b/python/ppc_model/secure_lgbm/secure_lgbm_context.py new file mode 100644 index 00000000..d6a52db8 --- /dev/null +++ b/python/ppc_model/secure_lgbm/secure_lgbm_context.py @@ -0,0 +1,254 @@ +import os +from enum import Enum +from typing import Any, Dict +from sklearn.base import BaseEstimator + +from ppc_common.ppc_crypto.phe_factory import PheCipherFactory +from ppc_model.common.context import Context +from ppc_model.common.initializer import Initializer +from ppc_model.common.protocol import TaskRole +from ppc_common.ppc_utils import common_func +from ppc_model.common.model_setting import ModelSetting + + +class LGBMModel(BaseEstimator): + + def __init__( + self, + boosting_type: str = 'gbdt', + num_leaves: int = 31, + max_depth: int = -1, + learning_rate: float = 0.1, + n_estimators: int = 100, + subsample_for_bin: int = 200000, + objective: str = None, + min_split_gain: float = 0., + min_child_weight: float = 1e-3, + min_child_samples: int = 20, + subsample: float = 1., + subsample_freq: int = 0, + colsample_bytree: float = 1., + reg_alpha: float = 0., + reg_lambda: float = 0., + random_state: int = None, + n_jobs: int = None, + importance_type: str = 'split', + **kwargs + ): + + self.boosting_type = boosting_type + self.objective = objective + self.num_leaves = num_leaves + self.max_depth = max_depth + self.learning_rate = learning_rate + self.n_estimators = n_estimators + self.subsample_for_bin = subsample_for_bin + self.min_split_gain = min_split_gain + self.min_child_weight = min_child_weight + self.min_child_samples = min_child_samples + self.subsample = subsample + self.subsample_freq = subsample_freq + self.colsample_bytree = colsample_bytree + self.reg_alpha = reg_alpha + self.reg_lambda = reg_lambda + self.random_state = random_state + self.n_jobs = n_jobs + self.importance_type = importance_type + self._other_params: Dict[str, Any] = {} + self.set_params(**kwargs) + + def get_params(self, deep: bool = True) -> Dict[str, Any]: + """Get parameters for this estimator. + + Parameters + ---------- + deep : bool, optional (default=True) + If True, will return the parameters for this estimator and + contained subobjects that are estimators. + + Returns + ------- + params : dict + Parameter names mapped to their values. + """ + params = super().get_params(deep=deep) + params.update(self._other_params) + return params + + def set_model_setting(self, model_setting: ModelSetting) -> "LGBMModel": + # 获取对象的所有属性名 + attrs = dir(model_setting) + # 过滤掉以_或者__开头的属性(这些通常是特殊方法或内部属性) + attrs = [attr for attr in attrs if not attr.startswith('_')] + + params = {} + for attr in attrs: + try: + setattr(self, attr, getattr(model_setting, attr)) + except Exception as e: + pass + return self + + def set_params(self, **params: Any) -> "LGBMModel": + """Set the parameters of this estimator. + + Parameters + ---------- + **params + Parameter names with their new values. + + Returns + ------- + self : object + Returns self. + """ + for key, value in params.items(): + setattr(self, key, value) + if hasattr(self, f"_{key}"): + setattr(self, f"_{key}", value) + self._other_params[key] = value + return self + + +class ModelTaskParams(LGBMModel): + def __init__( + self, + test_size: float = 0.3, + max_bin: int = 10, + use_goss: bool = False, + top_rate: float = 0.2, + other_rate: float = 0.1, + feature_rate: float = 1.0, + colsample_bylevel: float = 1.0, + gamma: float = 0, + loss_type: str = 'logistic', + eval_set_column: str = None, + train_set_value: str = None, + eval_set_value: str = None, + train_feats: str = None, + early_stopping_rounds: int = 5, + eval_metric: str = 'auc', + verbose_eval: int = 1, + categorical_feature: list = [], + silent: bool = False + ): + + super().__init__() + + self.test_size = test_size + self.max_bin = max_bin + self.use_goss = use_goss + self.top_rate = top_rate + self.other_rate = other_rate + self.feature_rate = feature_rate + self.colsample_bylevel = colsample_bylevel + self.gamma = gamma + self.loss_type = loss_type + self.eval_set_column = eval_set_column + self.train_set_value = train_set_value + self.eval_set_value = eval_set_value + self.train_feature = train_feats + self.early_stopping_rounds = early_stopping_rounds + self.eval_metric = eval_metric + self.verbose_eval = verbose_eval + self.silent = silent + self.λ = self.reg_lambda + self.lr = self.learning_rate + self.categorical_feature = categorical_feature + self.categorical_idx = [] + self.my_categorical_idx = [] + + +class SecureLGBMParams(ModelTaskParams): + + def __init__(self): + super().__init__() + + def _get_params(self): + """返回LGBMClassifier所有参数""" + return LGBMModel().get_params() + + def get_all_params(self): + """返回SecureLGBMParams所有参数""" + # 获取对象的所有属性名 + attrs = dir(self) + # 过滤掉以_或者__开头的属性(这些通常是特殊方法或内部属性) + attrs = [attr for attr in attrs if not attr.startswith('_')] + + params = {} + for attr in attrs: + try: + # 使用getattr来获取属性的值 + value = getattr(self, attr) + # 检查value是否可调用(例如,方法或函数),如果是,则不打印其值 + if not callable(value): + params[attr] = value + except Exception as e: + pass + return params + + +class SecureLGBMContext(Context): + + def __init__(self, + args, + components: Initializer + ): + + if args['is_label_holder']: + role = TaskRole.ACTIVE_PARTY + else: + role = TaskRole.PASSIVE_PARTY + + super().__init__(args['job_id'], + args['task_id'], + components, + role) + + self.phe = PheCipherFactory.build_phe( + components.homo_algorithm, components.public_key_length) + self.codec = PheCipherFactory.build_codec(components.homo_algorithm) + self.is_label_holder = args['is_label_holder'] + self.result_receiver_id_list = args['result_receiver_id_list'] + self.participant_id_list = args['participant_id_list'] + self.model_predict_algorithm = common_func.get_config_value( + "model_predict_algorithm", None, args, False) + self.algorithm_type = args['algorithm_type'] + if 'dataset_id' in args and args['dataset_id'] is not None: + self.dataset_file_path = os.path.join( + self.workspace, args['dataset_id']) + else: + self.dataset_file_path = None + + self.lgbm_params = SecureLGBMParams() + model_setting = ModelSetting(args['model_dict']) + self.set_lgbm_params(model_setting) + if model_setting.train_features is not None and len(model_setting.train_features) > 0: + self.lgbm_params.train_feature = model_setting.train_features.split( + ',') + self.lgbm_params.n_estimators = model_setting.num_trees + self.lgbm_params.feature_rate = model_setting.colsample_bytree + self.lgbm_params.min_split_gain = model_setting.gamma + self.lgbm_params.random_state = model_setting.seed + + def set_lgbm_params(self, model_setting: ModelSetting): + """设置lgbm参数""" + self.lgbm_params.set_model_setting(model_setting) + + def get_lgbm_params(self): + """获取lgbm参数""" + return self.lgbm_params + + +class LGBMMessage(Enum): + FEATURE_NAME = "FEATURE_NAME" + INSTANCE = "INSTANCE" + ENC_GH_LIST = "ENC_GH_LIST" + ENC_GH_HIST = "ENC_GH_HIST" + SPLIT_INFO = 'SPLIT_INFO' + INSTANCE_MASK = "INSTANCE_MASK" + PREDICT_LEAF_MASK = "PREDICT_LEAF_MASK" + TEST_LEAF_MASK = "PREDICT_TEST_LEAF_MASK" + VALID_LEAF_MASK = "PREDICT_VALID_LEAF_MASK" + STOP_ITERATION = "STOP_ITERATION" + PREDICT_PRABA = "PREDICT_PRABA" diff --git a/python/ppc_model/secure_lgbm/secure_lgbm_prediction_engine.py b/python/ppc_model/secure_lgbm/secure_lgbm_prediction_engine.py new file mode 100644 index 00000000..1f481a78 --- /dev/null +++ b/python/ppc_model/secure_lgbm/secure_lgbm_prediction_engine.py @@ -0,0 +1,38 @@ +from ppc_common.ppc_utils.exception import PpcException, PpcErrorCode +from ppc_model.common.protocol import TaskRole, ModelTask +from ppc_model.common.global_context import components +from ppc_model.interface.task_engine import TaskEngine +from ppc_model.datasets.dataset import SecureDataset +from ppc_model.metrics.evaluation import Evaluation +from ppc_model.common.model_result import ResultFileHandling +from ppc_model.secure_lgbm.secure_lgbm_context import SecureLGBMContext +from ppc_model.secure_lgbm.vertical import VerticalLGBMActiveParty, VerticalLGBMPassiveParty + + +class SecureLGBMPredictionEngine(TaskEngine): + task_type = ModelTask.XGB_PREDICTING + + @staticmethod + def run(args): + + task_info = SecureLGBMContext(args, components) + secure_dataset = SecureDataset(task_info) + + if task_info.role == TaskRole.ACTIVE_PARTY: + booster = VerticalLGBMActiveParty(task_info, secure_dataset) + elif task_info.role == TaskRole.PASSIVE_PARTY: + booster = VerticalLGBMPassiveParty(task_info, secure_dataset) + else: + raise PpcException(PpcErrorCode.ROLE_TYPE_ERROR.get_code(), + PpcErrorCode.ROLE_TYPE_ERROR.get_message()) + + booster.load_model() + booster.predict() + + # 获取测试集的预测概率值 + test_praba = booster.get_test_praba() + + # 获取测试集的预测值评估指标 + Evaluation(task_info, secure_dataset, test_praba=test_praba) + + ResultFileHandling(task_info) diff --git a/python/ppc_model/secure_lgbm/secure_lgbm_training_engine.py b/python/ppc_model/secure_lgbm/secure_lgbm_training_engine.py new file mode 100644 index 00000000..4fe502c8 --- /dev/null +++ b/python/ppc_model/secure_lgbm/secure_lgbm_training_engine.py @@ -0,0 +1,40 @@ +from ppc_common.ppc_utils.exception import PpcException, PpcErrorCode +from ppc_model.common.protocol import TaskRole, ModelTask +from ppc_model.common.global_context import components +from ppc_model.interface.task_engine import TaskEngine +from ppc_model.datasets.dataset import SecureDataset +from ppc_model.metrics.evaluation import Evaluation +from ppc_model.metrics.model_plot import ModelPlot +from ppc_model.common.model_result import ResultFileHandling +from ppc_model.secure_lgbm.secure_lgbm_context import SecureLGBMContext +from ppc_model.secure_lgbm.vertical import VerticalLGBMActiveParty, VerticalLGBMPassiveParty + + +class SecureLGBMTrainingEngine(TaskEngine): + task_type = ModelTask.XGB_TRAINING + + @staticmethod + def run(args): + + task_info = SecureLGBMContext(args, components) + secure_dataset = SecureDataset(task_info) + + if task_info.role == TaskRole.ACTIVE_PARTY: + booster = VerticalLGBMActiveParty(task_info, secure_dataset) + elif task_info.role == TaskRole.PASSIVE_PARTY: + booster = VerticalLGBMPassiveParty(task_info, secure_dataset) + else: + raise PpcException(PpcErrorCode.ROLE_TYPE_ERROR.get_code(), + PpcErrorCode.ROLE_TYPE_ERROR.get_message()) + + booster.fit() + booster.save_model() + + # 获取训练集和验证集的预测概率值 + train_praba = booster.get_train_praba() + test_praba = booster.get_test_praba() + + # 获取训练集和验证集的预测值评估指标 + Evaluation(task_info, secure_dataset, train_praba, test_praba) + ModelPlot(booster) + ResultFileHandling(task_info) diff --git a/python/ppc_model/secure_lgbm/test/__init__.py b/python/ppc_model/secure_lgbm/test/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/ppc_model/secure_lgbm/test/test_cipher_packing.py b/python/ppc_model/secure_lgbm/test/test_cipher_packing.py new file mode 100644 index 00000000..1f205a71 --- /dev/null +++ b/python/ppc_model/secure_lgbm/test/test_cipher_packing.py @@ -0,0 +1,61 @@ +import time +import unittest +import numpy as np + +from ppc_common.ppc_crypto.paillier_cipher import PaillierCipher +from ppc_common.ppc_crypto.paillier_codec import PaillierCodec +from ppc_common.ppc_protos.generated.ppc_model_pb2 import CipherList, ModelCipher + + +paillier = PaillierCipher() + + +class TestCipherPacking: + + def test_cipher_list(self): + + data_list = np.random.randint(1, 10001, size=1000) + + start_time = time.time() + ciphers = paillier.encrypt_batch_parallel(data_list) + # ciphers = paillier.encrypt_batch(data_list) + print("enc:", time.time() - start_time, "seconds") + + start_time = time.time() + enc_data_pb = CipherList() + enc_data_pb.public_key = PaillierCodec.encode_enc_key(paillier.public_key) + for cipher in ciphers: + paillier_cipher = ModelCipher() + paillier_cipher.ciphertext, paillier_cipher.exponent = PaillierCodec.encode_cipher(cipher) + enc_data_pb.cipher_list.append(paillier_cipher) + print("pack ciphers:", time.time() - start_time, "seconds") + + ciphers2 = [] + for i in range(100): + ciphers2.append(np.array(ciphers[10*i:10*(i+1)]).sum()) + + start_time = time.time() + enc_data_pb2 = CipherList() + enc_data_pb2.public_key = PaillierCodec.encode_enc_key(paillier.public_key) + for cipher in ciphers2: + paillier_cipher2 = ModelCipher() + paillier_cipher2.ciphertext, paillier_cipher2.exponent = PaillierCodec.encode_cipher(cipher, be_secure=False) + enc_data_pb2.cipher_list.append(paillier_cipher2) + print("pack ciphers:", time.time() - start_time, "seconds") + + ciphers3 = [] + for i in range(100): + ciphers3.append(np.array(ciphers[10*i:10*(i+1)]).sum()) + + start_time = time.time() + enc_data_pb3 = CipherList() + enc_data_pb3.public_key = PaillierCodec.encode_enc_key(paillier.public_key) + for cipher in ciphers3: + paillier_cipher3 = ModelCipher() + paillier_cipher3.ciphertext, paillier_cipher3.exponent = PaillierCodec.encode_cipher(cipher) + enc_data_pb3.cipher_list.append(paillier_cipher3) + print("pack ciphers:", time.time() - start_time, "seconds") + + +if __name__ == '__main__': + unittest.main() diff --git a/python/ppc_model/secure_lgbm/test/test_pack_gh.py b/python/ppc_model/secure_lgbm/test/test_pack_gh.py new file mode 100644 index 00000000..dba82280 --- /dev/null +++ b/python/ppc_model/secure_lgbm/test/test_pack_gh.py @@ -0,0 +1,43 @@ +import unittest +import numpy as np + +from ppc_model.secure_lgbm.vertical.booster import VerticalBooster + + +class TestPackGH(unittest.TestCase): + + def test_pack_gh(self): + + g_list = np.array([-4, 2, -1.3, 0, 0, -15.3544564544]) + h_list = np.array([2, 1.5, -1.4, 0, -1.68, 1.2356564564]) + + gh_list = VerticalBooster.packing_gh(g_list, h_list) + + result_array = np.array( + [429496329600000000000000002000, 200000000000000000001500, + 429496599600000000004294965896, 0, + 4294965616, 429495194200000000000000001235], dtype=object) + + assert np.array_equal(gh_list, result_array) + + def test_unpack_gh(self): + + gh_list = np.array( + [429496329600000000000000002000, 200000000000000000001500, + 429496599600000000004294965896, 0, + 4294965616, 429495194200000000000000001235], dtype=object) + + gh_sum_list = np.array([sum(gh_list), sum(gh_list)*2]) + g_hist, h_hist = VerticalBooster.unpacking_gh(gh_sum_list) + + g_list = np.array([-4, 2, -1.3, 0, 0, -15.3544564544]) + h_list = np.array([2, 1.5, -1.4, 0, -1.68, 1.2356564564]) + result_g_hist = np.array([-18.654, -37.308]) + result_h_hist = np.array([1.655, 3.31]) + + assert np.array_equal(g_hist, result_g_hist) + assert np.array_equal(h_hist, result_h_hist) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/ppc_model/secure_lgbm/test/test_save_load_model.py b/python/ppc_model/secure_lgbm/test/test_save_load_model.py new file mode 100644 index 00000000..a899b805 --- /dev/null +++ b/python/ppc_model/secure_lgbm/test/test_save_load_model.py @@ -0,0 +1,92 @@ +import unittest +import numpy as np + +from ppc_model.common.initializer import Initializer +from ppc_common.ppc_mock.mock_objects import MockLogger +from ppc_common.ppc_protos.generated.ppc_model_pb2 import BestSplitInfo +from ppc_model.secure_lgbm.secure_lgbm_context import SecureLGBMContext +from ppc_model.secure_lgbm.vertical.booster import VerticalBooster + + +class TestSaveLoadModel(unittest.TestCase): + + n_estimators = 2 + max_depth = 3 + np.random.seed(2024) + + ACTIVE_PARTY = 'ACTIVE_PARTY' + PASSIVE_PARTY = 'PASSIVE_PARTY' + + job_id = 'j-123' + task_id = 't-123' + + model_dict = {} + + args = { + 'job_id': job_id, + 'task_id': task_id, + 'is_label_holder': True, + 'result_receiver_id_list': [ACTIVE_PARTY, PASSIVE_PARTY], + 'participant_id_list': [ACTIVE_PARTY, PASSIVE_PARTY], + 'model_predict_algorithm': None, + 'algorithm_type': 'Train', + 'algorithm_subtype': 'HeteroXGB', + 'model_dict': model_dict + } + + components = Initializer(log_config_path='', config_path='') + components.config_data = {'JOB_TEMP_DIR': '/tmp'} + components.mock_logger = MockLogger() + + def test_save_load_model(self): + + x_split = [[1, 2, 3, 5], [1, 2], [1.23, 3.45, 5.23]] + + trees = [] + for i in range(self.n_estimators): + tree = self._build_tree(self.max_depth) + trees.append(tree) + # print(trees) + + task_info = SecureLGBMContext(self.args, self.components) + booster = VerticalBooster(task_info, dataset=None) + booster._X_split = x_split + booster._trees = trees + booster.save_model() + + booster_predict = VerticalBooster(task_info, dataset=None) + booster_predict.load_model() + + assert x_split == booster_predict._X_split + assert trees == booster_predict._trees + + @staticmethod + def _build_tree(max_depth, depth=0, weight=0): + + if depth == max_depth: + return weight + + best_split_info = BestSplitInfo( + feature=np.random.randint(0,10), + value=np.random.randint(0,4), + best_gain=np.random.rand(), + w_left=np.random.rand(), + w_right=np.random.rand(), + agency_idx=np.random.randint(0,2), + agency_feature=np.random.randint(0,5) + ) + # print(best_split_info) + + if best_split_info.best_gain > 0.2: + left_tree = TestSaveLoadModel._build_tree( + max_depth, depth + 1, best_split_info.w_left) + right_tree = TestSaveLoadModel._build_tree( + max_depth, depth + 1, best_split_info.w_right) + + return [(best_split_info, left_tree, right_tree)] + else: + return weight + + +if __name__ == '__main__': + unittest.main() diff --git a/python/ppc_model/secure_lgbm/test/test_secure_lgbm_context.py b/python/ppc_model/secure_lgbm/test/test_secure_lgbm_context.py new file mode 100644 index 00000000..2797194c --- /dev/null +++ b/python/ppc_model/secure_lgbm/test/test_secure_lgbm_context.py @@ -0,0 +1,77 @@ +import unittest + +from ppc_model.common.initializer import Initializer +from ppc_common.ppc_mock.mock_objects import MockLogger +from ppc_model.secure_lgbm.secure_lgbm_context import SecureLGBMContext + + +class TestSecureLGBMContext(unittest.TestCase): + + components = Initializer(log_config_path='', config_path='') + components.config_data = {'JOB_TEMP_DIR': '/tmp'} + components.mock_logger = MockLogger() + + def test_get_lgbm_params(self): + + args = { + 'job_id': 'j-123', + 'task_id': '1', + 'is_label_holder': True, + 'result_receiver_id_list': [], + 'participant_id_list': [], + 'model_predict_algorithm': None, + 'algorithm_type': None, + 'algorithm_subtype': None, + 'model_dict': {} + } + + task_info = SecureLGBMContext(args, self.components) + lgbm_params = task_info.get_lgbm_params() + # 打印LGBMModel默认参数 + print(lgbm_params._get_params()) + + # 默认自定义参数为空字典 + assert lgbm_params.get_params() == {} + # assert lgbm_params.get_all_params() != lgbm_params._get_params() + + def test_set_lgbm_params(self): + + args = { + 'job_id': 'j-123', + 'task_id': '1', + 'is_label_holder': True, + 'result_receiver_id_list': [], + 'participant_id_list': [], + 'model_predict_algorithm': None, + 'algorithm_type': None, + 'algorithm_subtype': None, + 'model_dict': { + 'objective': 'regression', + 'n_estimators': 6, + 'max_depth': 3, + 'test_size': 0.2, + 'use_goss': 1 + } + } + + task_info = SecureLGBMContext(args, self.components) + lgbm_params = task_info.get_lgbm_params() + # 打印SecureLGBMParams自定义参数 + print(lgbm_params.get_params()) + # 打印SecureLGBMParams所有参数 + print(lgbm_params.get_all_params()) + + assert lgbm_params.get_params() == args['model_dict'] + self.assertEqual(lgbm_params.get_all_params()[ + 'learning_rate'], lgbm_params._get_params()['learning_rate']) + self.assertEqual(lgbm_params.learning_rate, + lgbm_params._get_params()['learning_rate']) + self.assertEqual(lgbm_params.n_estimators, + args['model_dict']['n_estimators']) + self.assertEqual(lgbm_params.test_size, + args['model_dict']['test_size']) + self.assertEqual(lgbm_params.use_goss, args['model_dict']['use_goss']) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/ppc_model/secure_lgbm/test/test_secure_lgbm_performance_training.py b/python/ppc_model/secure_lgbm/test/test_secure_lgbm_performance_training.py new file mode 100644 index 00000000..936a53ac --- /dev/null +++ b/python/ppc_model/secure_lgbm/test/test_secure_lgbm_performance_training.py @@ -0,0 +1,172 @@ +import unittest +import threading +import traceback + +from ppc_model.common.initializer import Initializer +from ppc_common.ppc_mock.mock_objects import MockLogger +from ppc_common.ppc_async_executor.thread_event_manager import ThreadEventManager +from ppc_model.network.stub import ModelStub +from ppc_model.datasets.dataset import SecureDataset +from ppc_model.metrics.evaluation import Evaluation +from ppc_model.metrics.model_plot import ModelPlot +from ppc_model.common.model_result import ResultFileHandling +from ppc_model.common.mock.rpc_client_mock import RpcClientMock +from ppc_model.secure_lgbm.secure_lgbm_context import SecureLGBMContext +from ppc_model.secure_lgbm.vertical import VerticalLGBMActiveParty, VerticalLGBMPassiveParty + + +ACTIVE_PARTY = 'ACTIVE_PARTY' +PASSIVE_PARTY = 'PASSIVE_PARTY' + +data_size = 1000 +feature_dim = 20 + + +def mock_args(): + job_id = 'j-111' + task_id = 't-111' + + model_dict = { + 'objective': 'regression', + 'categorical_feature': [], + 'train_features': "", + 'max_bin': 10, + 'n_estimators': 2, + 'max_depth': 3, + 'feature_rate': 1.0, + 'random_state': 2024 + } + + args_a = { + 'job_id': job_id, + 'task_id': task_id, + 'is_label_holder': True, + 'result_receiver_id_list': [ACTIVE_PARTY, PASSIVE_PARTY], + 'participant_id_list': [ACTIVE_PARTY, PASSIVE_PARTY], + 'model_predict_algorithm': None, + 'algorithm_type': 'Train', + 'algorithm_subtype': 'HeteroXGB', + 'model_dict': model_dict + } + + args_b = { + 'job_id': job_id, + 'task_id': task_id, + 'is_label_holder': False, + 'result_receiver_id_list': [ACTIVE_PARTY, PASSIVE_PARTY], + 'participant_id_list': [ACTIVE_PARTY, PASSIVE_PARTY], + 'model_predict_algorithm': None, + 'algorithm_type': 'Train', + 'algorithm_subtype': 'HeteroXGB', + 'model_dict': model_dict + } + + return args_a, args_b + + +class TestXgboostTraining(unittest.TestCase): + + def setUp(self): + self._active_rpc_client = RpcClientMock() + self._passive_rpc_client = RpcClientMock() + self._thread_event_manager = ThreadEventManager() + self._active_stub = ModelStub( + agency_id=ACTIVE_PARTY, + thread_event_manager=self._thread_event_manager, + rpc_client=self._active_rpc_client, + send_retry_times=3, + retry_interval_s=0.1 + ) + self._passive_stub = ModelStub( + agency_id=PASSIVE_PARTY, + thread_event_manager=self._thread_event_manager, + rpc_client=self._passive_rpc_client, + send_retry_times=3, + retry_interval_s=0.1 + ) + self._active_rpc_client.set_message_handler(self._passive_stub.on_message_received) + self._passive_rpc_client.set_message_handler(self._active_stub.on_message_received) + + def test_fit(self): + args_a, args_b = mock_args() + + active_components = Initializer(log_config_path='', config_path='') + active_components.stub = self._active_stub + active_components.config_data = { + 'JOB_TEMP_DIR': '/tmp/active', 'AGENCY_ID': ACTIVE_PARTY} + active_components.mock_logger = MockLogger() + task_info_a = SecureLGBMContext(args_a, active_components) + model_data = SecureDataset.simulate_dataset(data_size, feature_dim, has_label=True) + secure_dataset_a = SecureDataset(task_info_a, model_data) + booster_a = VerticalLGBMActiveParty(task_info_a, secure_dataset_a) + print(secure_dataset_a.feature_name) + print(secure_dataset_a.train_idx.shape) + print(secure_dataset_a.train_X.shape) + print(secure_dataset_a.train_y.shape) + print(secure_dataset_a.test_idx.shape) + print(secure_dataset_a.test_X.shape) + print(secure_dataset_a.test_y.shape) + + passive_components = Initializer(log_config_path='', config_path='') + passive_components.stub = self._passive_stub + passive_components.config_data = { + 'JOB_TEMP_DIR': '/tmp/passive', 'AGENCY_ID': PASSIVE_PARTY} + passive_components.mock_logger = MockLogger() + task_info_b = SecureLGBMContext(args_b, passive_components) + model_data = SecureDataset.simulate_dataset(data_size, feature_dim, has_label=False) + secure_dataset_b = SecureDataset(task_info_b, model_data) + booster_b = VerticalLGBMPassiveParty(task_info_b, secure_dataset_b) + print(secure_dataset_b.feature_name) + print(secure_dataset_b.train_idx.shape) + print(secure_dataset_b.train_X.shape) + print(secure_dataset_b.test_idx.shape) + print(secure_dataset_b.test_X.shape) + + def active_worker(): + try: + booster_a.fit() + booster_a.save_model() + train_praba = booster_a.get_train_praba() + test_praba = booster_a.get_test_praba() + Evaluation(task_info_a, secure_dataset_a, train_praba, test_praba) + # ModelPlot(booster_a) + ResultFileHandling(task_info_a) + booster_a.load_model() + booster_a.predict() + test_praba = booster_a.get_test_praba() + task_info_a.algorithm_type = 'PPC_PREDICT' + Evaluation(task_info_a, secure_dataset_a, test_praba=test_praba) + ResultFileHandling(task_info_a) + except Exception as e: + task_info_a.components.logger().info(traceback.format_exc()) + + def passive_worker(): + try: + booster_b.fit() + booster_b.save_model() + train_praba = booster_b.get_train_praba() + test_praba = booster_b.get_test_praba() + Evaluation(task_info_b, secure_dataset_b, train_praba, test_praba) + # ModelPlot(booster_b) + ResultFileHandling(task_info_b) + booster_b.load_model() + booster_b.predict() + test_praba = booster_b.get_test_praba() + task_info_b.algorithm_type = 'PPC_PREDICT' + Evaluation(task_info_b, secure_dataset_b, test_praba=test_praba) + ResultFileHandling(task_info_b) + except Exception as e: + task_info_b.components.logger().info(traceback.format_exc()) + + thread_lgbm_active = threading.Thread(target=active_worker, args=()) + thread_lgbm_active.start() + + thread_lgbm_passive = threading.Thread(target=passive_worker, args=()) + thread_lgbm_passive.start() + + thread_lgbm_active.join() + thread_lgbm_passive.join() + + +if __name__ == '__main__': + unittest.main() diff --git a/python/ppc_model/secure_lgbm/test/test_secure_lgbm_training.py b/python/ppc_model/secure_lgbm/test/test_secure_lgbm_training.py new file mode 100644 index 00000000..7cddb6db --- /dev/null +++ b/python/ppc_model/secure_lgbm/test/test_secure_lgbm_training.py @@ -0,0 +1,178 @@ +import unittest +import threading +import traceback +from sklearn.datasets import load_breast_cancer + +from ppc_model.common.initializer import Initializer +from ppc_common.ppc_mock.mock_objects import MockLogger +from ppc_common.ppc_async_executor.thread_event_manager import ThreadEventManager +from ppc_model.network.stub import ModelStub +from ppc_model.datasets.dataset import SecureDataset +from ppc_model.metrics.evaluation import Evaluation +from ppc_model.metrics.model_plot import ModelPlot +from ppc_model.common.model_result import ResultFileHandling +from ppc_model.common.mock.rpc_client_mock import RpcClientMock +from ppc_model.secure_lgbm.secure_lgbm_context import SecureLGBMContext +from ppc_model.secure_lgbm.vertical import VerticalLGBMActiveParty, VerticalLGBMPassiveParty + + +ACTIVE_PARTY = 'ACTIVE_PARTY' +PASSIVE_PARTY = 'PASSIVE_PARTY' + + +def mock_args(): + job_id = 'j-123' + task_id = 't-123' + + model_dict = { + 'objective': 'regression', + 'categorical_feature': [], + 'train_features': "", + 'max_bin': 10, + 'n_estimators': 2, + 'max_depth': 3, + 'use_goss': 1, + 'feature_rate': 0.8, + 'random_state': 2024 + } + + args_a = { + 'job_id': job_id, + 'task_id': task_id, + 'is_label_holder': True, + 'result_receiver_id_list': [ACTIVE_PARTY, PASSIVE_PARTY], + 'participant_id_list': [ACTIVE_PARTY, PASSIVE_PARTY], + 'model_predict_algorithm': None, + 'algorithm_type': 'Train', + 'algorithm_subtype': 'HeteroXGB', + 'model_dict': model_dict + } + + args_b = { + 'job_id': job_id, + 'task_id': task_id, + 'is_label_holder': False, + 'result_receiver_id_list': [ACTIVE_PARTY, PASSIVE_PARTY], + 'participant_id_list': [ACTIVE_PARTY, PASSIVE_PARTY], + 'model_predict_algorithm': None, + 'algorithm_type': 'Train', + 'algorithm_subtype': 'HeteroXGB', + 'model_dict': model_dict + } + + return args_a, args_b + + +class TestXgboostTraining(unittest.TestCase): + + cancer = load_breast_cancer() + X = cancer.data + y = cancer.target + + df = SecureDataset.assembling_dataset(X, y) + df_with_y, df_without_y = SecureDataset.hetero_split_dataset(df) + + def setUp(self): + self._active_rpc_client = RpcClientMock() + self._passive_rpc_client = RpcClientMock() + self._thread_event_manager = ThreadEventManager() + self._active_stub = ModelStub( + agency_id=ACTIVE_PARTY, + thread_event_manager=self._thread_event_manager, + rpc_client=self._active_rpc_client, + send_retry_times=3, + retry_interval_s=0.1 + ) + self._passive_stub = ModelStub( + agency_id=PASSIVE_PARTY, + thread_event_manager=self._thread_event_manager, + rpc_client=self._passive_rpc_client, + send_retry_times=3, + retry_interval_s=0.1 + ) + self._active_rpc_client.set_message_handler( + self._passive_stub.on_message_received) + self._passive_rpc_client.set_message_handler( + self._active_stub.on_message_received) + + def test_fit(self): + args_a, args_b = mock_args() + + active_components = Initializer(log_config_path='', config_path='') + active_components.stub = self._active_stub + active_components.config_data = { + 'JOB_TEMP_DIR': '/tmp/active', 'AGENCY_ID': ACTIVE_PARTY} + active_components.mock_logger = MockLogger() + task_info_a = SecureLGBMContext(args_a, active_components) + secure_dataset_a = SecureDataset(task_info_a, self.df_with_y) + booster_a = VerticalLGBMActiveParty(task_info_a, secure_dataset_a) + print(secure_dataset_a.feature_name) + print(secure_dataset_a.train_idx.shape) + print(secure_dataset_a.train_X.shape) + print(secure_dataset_a.train_y.shape) + print(secure_dataset_a.test_idx.shape) + print(secure_dataset_a.test_X.shape) + print(secure_dataset_a.test_y.shape) + + passive_components = Initializer(log_config_path='', config_path='') + passive_components.stub = self._passive_stub + passive_components.config_data = { + 'JOB_TEMP_DIR': '/tmp/passive', 'AGENCY_ID': PASSIVE_PARTY} + passive_components.mock_logger = MockLogger() + task_info_b = SecureLGBMContext(args_b, passive_components) + secure_dataset_b = SecureDataset(task_info_b, self.df_without_y) + booster_b = VerticalLGBMPassiveParty(task_info_b, secure_dataset_b) + print(secure_dataset_b.feature_name) + print(secure_dataset_b.train_idx.shape) + print(secure_dataset_b.train_X.shape) + print(secure_dataset_b.test_idx.shape) + print(secure_dataset_b.test_X.shape) + + def active_worker(): + try: + booster_a.fit() + booster_a.save_model() + train_praba = booster_a.get_train_praba() + test_praba = booster_a.get_test_praba() + Evaluation(task_info_a, secure_dataset_a, train_praba, test_praba) + ModelPlot(booster_a) + ResultFileHandling(task_info_a) + booster_a.load_model() + booster_a.predict() + test_praba = booster_a.get_test_praba() + task_info_a.algorithm_type = 'PPC_PREDICT' + Evaluation(task_info_a, secure_dataset_a, test_praba=test_praba) + ResultFileHandling(task_info_a) + except Exception as e: + task_info_a.components.logger().info(traceback.format_exc()) + + def passive_worker(): + try: + booster_b.fit() + booster_b.save_model() + train_praba = booster_b.get_train_praba() + test_praba = booster_b.get_test_praba() + Evaluation(task_info_b, secure_dataset_b, train_praba, test_praba) + ModelPlot(booster_b) + ResultFileHandling(task_info_b) + booster_b.load_model() + booster_b.predict() + test_praba = booster_b.get_test_praba() + task_info_b.algorithm_type = 'PPC_PREDICT' + Evaluation(task_info_b, secure_dataset_b, test_praba=test_praba) + ResultFileHandling(task_info_b) + except Exception as e: + task_info_b.components.logger().info(traceback.format_exc()) + + thread_lgbm_active = threading.Thread(target=active_worker, args=()) + thread_lgbm_active.start() + + thread_lgbm_passive = threading.Thread(target=passive_worker, args=()) + thread_lgbm_passive.start() + + thread_lgbm_active.join() + thread_lgbm_passive.join() + + +if __name__ == '__main__': + unittest.main() diff --git a/python/ppc_model/secure_lgbm/vertical/__init__.py b/python/ppc_model/secure_lgbm/vertical/__init__.py new file mode 100644 index 00000000..7a33ec6e --- /dev/null +++ b/python/ppc_model/secure_lgbm/vertical/__init__.py @@ -0,0 +1,4 @@ +from ppc_model.secure_lgbm.vertical.active_party import VerticalLGBMActiveParty +from ppc_model.secure_lgbm.vertical.passive_party import VerticalLGBMPassiveParty + +__all__ = ["VerticalLGBMActiveParty", "VerticalLGBMPassiveParty"] diff --git a/python/ppc_model/secure_lgbm/vertical/active_party.py b/python/ppc_model/secure_lgbm/vertical/active_party.py new file mode 100644 index 00000000..524c8266 --- /dev/null +++ b/python/ppc_model/secure_lgbm/vertical/active_party.py @@ -0,0 +1,461 @@ +import itertools +import time + +import numpy as np +from pandas import DataFrame + +from ppc_common.deps_services.serialize_type import SerializeType +from ppc_common.ppc_ml.feature.feature_importance import FeatureImportanceStore +from ppc_common.ppc_ml.feature.feature_importance import FeatureImportanceType +from ppc_common.ppc_protos.generated.ppc_model_pb2 import BestSplitInfo, IterationRequest +from ppc_common.ppc_utils import utils +from ppc_model.datasets.data_reduction.feature_selection import FeatureSelection +from ppc_model.datasets.data_reduction.sampling import Sampling +from ppc_model.datasets.dataset import SecureDataset +from ppc_model.datasets.feature_binning.feature_binning import FeatureBinning +from ppc_model.metrics.evaluation import Evaluation +from ppc_model.metrics.loss import BinaryLoss +from ppc_model.secure_lgbm.secure_lgbm_context import SecureLGBMContext, LGBMMessage +from ppc_model.secure_lgbm.monitor.callback import CallbackContainer +from ppc_model.secure_lgbm.monitor.core import Booster +from ppc_model.secure_lgbm.monitor.early_stopping import EarlyStopping +from ppc_model.secure_lgbm.monitor.evaluation_monitor import EvaluationMonitor +from ppc_model.secure_lgbm.vertical.booster import VerticalBooster + + +class VerticalLGBMActiveParty(VerticalBooster): + + def __init__(self, ctx: SecureLGBMContext, dataset: SecureDataset) -> None: + super().__init__(ctx, dataset) + self.params = ctx.lgbm_params + self._loss_func = BinaryLoss(self.params.objective) + self._all_feature_name = [dataset.feature_name] + self._all_feature_num = len(dataset.feature_name) + self.log = ctx.components.logger() + self.storage_client = ctx.components.storage_client + self.feature_importance_store = FeatureImportanceStore( + FeatureImportanceStore.DEFAULT_IMPORTANCE_LIST, None, self.log) + self.log.info(f'task {self.ctx.task_id}: print all params: {self.params.get_all_params()}') + + def fit( + self, + *args, + **kwargs, + ) -> None: + self.log.info( + f'task {self.ctx.task_id}: Starting the lgbm on the active party.') + self._init_active_data() + self._init_valid_data() + self._init_early_stop() + + for _ in range(self.params.n_estimators): + self._tree_id += 1 + start_time = time.time() + self.log.info(f'task {self.ctx.task_id}: Starting n_estimators-{self._tree_id} in active party.') + + # 初始化 + feature_select, instance, used_glist, used_hlist = self._init_each_tree() + self.log.info(f'task {self.ctx.task_id}: Sampling number: {len(instance)}, ' + f'feature select: {len(feature_select)}, {feature_select}.') + + # 构建 + tree = self._build_tree( + feature_select, instance, used_glist, used_hlist) + self._trees.append(tree) + # print('tree', tree) + + # 预测 + self._train_weights += self._predict_tree( + tree, self._X_bin, np.ones(self._X_bin.shape[0], dtype=bool), LGBMMessage.PREDICT_LEAF_MASK.value) + self._train_praba = self._loss_func.sigmoid(self._train_weights) + # print('train_praba', set(self._train_praba)) + + # 评估 + if not self.params.silent and self.dataset.train_y is not None: + auc = Evaluation.fevaluation(self.dataset.train_y, self._train_praba)['auc'] + self.log.info(f'task {self.ctx.task_id}: n_estimators-{self._tree_id}, auc: {auc}.') + self.log.info(f'task {self.ctx.task_id}: Ending n_estimators-{self._tree_id}, ' + f'time_costs: {time.time() - start_time}s.') + + # 预测验证集 + self._test_weights += self._predict_tree( + tree, self._test_X_bin, np.ones(self._test_X_bin.shape[0], dtype=bool), + LGBMMessage.TEST_LEAF_MASK.value) + self._test_praba = self._loss_func.sigmoid(self._test_weights) + if not self.params.silent and self.dataset.test_y is not None: + auc = Evaluation.fevaluation(self.dataset.test_y, self._test_praba)['auc'] + self.log.info(f'task {self.ctx.task_id}: n_estimators-{self._tree_id}, test auc: {auc}.') + if self._iteration_early_stop(): + self.log.info(f"task {self.ctx.task_id}: lgbm early stop after {self._tree_id} iterations.") + break + + self._end_active_data() + + def transform(self, transform_data: DataFrame) -> DataFrame: + ... + + def predict(self, dataset: SecureDataset = None) -> np.ndarray: + start_time = time.time() + if dataset is None: + dataset = self.dataset + + self.params.my_categorical_idx = self._get_categorical_idx( + dataset.feature_name, self.params.categorical_feature) + + test_weights = self._init_weight(dataset.test_X.shape[0]) + test_X_bin = self._split_test_data(self.ctx, dataset.test_X, self._X_split) + + for tree in self._trees: + test_weights += self._predict_tree( + tree, test_X_bin, np.ones(test_X_bin.shape[0], dtype=bool), LGBMMessage.VALID_LEAF_MASK.value) + test_praba = self._loss_func.sigmoid(test_weights) + self._test_praba = test_praba + + if dataset.test_y is not None: + auc = Evaluation.fevaluation(dataset.test_y, test_praba)['auc'] + self.log.info(f'task {self.ctx.task_id}: predict test auc: {auc}.') + self.log.info(f'task {self.ctx.task_id}: Ending predict, time_costs: {time.time() - start_time}s.') + + self._end_active_data(is_train=False) + + def _init_active_data(self): + + # 初始化预测值和权重 + self._train_praba = self._init_praba(self.dataset.train_X.shape[0]) + self._train_weights = self._init_weight(self.dataset.train_X.shape[0]) + self._tree_id = 0 + + # 初始化所有参与方的特征 + for i in range(1, len(self.ctx.participant_id_list)): + feature_name_bytes = self._receive_byte_data(self.ctx, LGBMMessage.FEATURE_NAME.value, i) + self._all_feature_name.append([s.decode('utf-8') for s in feature_name_bytes.split(b' ') if s]) + self._all_feature_num += len([s.decode('utf-8') for s in feature_name_bytes.split(b' ') if s]) + + self.log.info(f'task {self.ctx.task_id}: total feature number:{self._all_feature_num}, ' + f'total feature name: {self._all_feature_name}.') + self.params.categorical_idx = self._get_categorical_idx( + list(itertools.chain(*self._all_feature_name)), self.params.categorical_feature) + self.params.my_categorical_idx = self._get_categorical_idx( + self.dataset.feature_name, self.params.categorical_feature) + + # 更新feature_importance中的特征列表 + self.feature_importance_store.set_init(list(itertools.chain(*self._all_feature_name))) + + # 初始化分桶数据集 + feat_bin = FeatureBinning(self.ctx) + self._X_bin, self._X_split = feat_bin.data_binning(self.dataset.train_X) + + def _init_each_tree(self): + + if self.callback_container: + self.callback_container.before_iteration(self.model) + + gradient = self._loss_func.compute_gradient(self.dataset.train_y, self._train_praba) + hessian = self._loss_func.compute_hessian(self._train_praba) + + feature_select = FeatureSelection.feature_selecting( + list(itertools.chain(*self._all_feature_name)), + self.params.train_feature, self.params.feature_rate) + instance, used_glist, used_hlist = Sampling.sample_selecting( + gradient, hessian, self.params.subsample, + self.params.use_goss, self.params.top_rate, self.params.other_rate) + self._send_gh_instance_list(instance, used_glist, used_hlist) + + return feature_select, instance, used_glist, used_hlist + + def _send_gh_instance_list(self, instance, glist, hlist): + + self._leaf_id = 0 + gh_list = self.packing_gh(glist, hlist) + + start_time = time.time() + self.log.info(f'task {self.ctx.task_id}: Starting n_estimators-{self._tree_id} ' + f'encrypt g & h in active party.') + enc_ghlist = self.ctx.phe.encrypt_batch_parallel((gh_list).astype('object')) + self.log.info(f'task {self.ctx.task_id}: Finished n_estimators-{self._tree_id} ' + f'encrypt gradient & hessian time_costs: {time.time() - start_time}.') + + for partner_index in range(1, len(self.ctx.participant_id_list)): + self._send_byte_data(self.ctx, f'{LGBMMessage.INSTANCE.value}_{self._tree_id}', + instance.astype('int64').tobytes(), partner_index) + self._send_enc_data(self.ctx, f'{LGBMMessage.ENC_GH_LIST.value}_{self._tree_id}', + enc_ghlist, partner_index) + + def _build_tree(self, feature_select, instance, glist, hlist, depth=0, weight=0): + + if depth == self.params.max_depth: + return weight + if self.params.max_depth < 0 and self._leaf_id >= self.params.num_leaves: + return weight + + self._leaf_id += 1 + print('tree', self._tree_id, 'leaf', self._leaf_id, 'instance', len(instance), + 'glist', len(glist), 'hlist', len(hlist)) + if self.params.colsample_bylevel > 0 and self.params.colsample_bylevel < 1: + feature_select_level = sorted(np.random.choice( + feature_select, size=int(len(feature_select) * self.params.colsample_bylevel), replace=False)) + best_split_info = self._find_best_split(feature_select_level, instance, glist, hlist) + else: + best_split_info = self._find_best_split(feature_select, instance, glist, hlist) + + if best_split_info.best_gain > 0 and best_split_info.best_gain > self.params.min_split_gain: + gain_list = {FeatureImportanceType.GAIN: best_split_info.best_gain, + FeatureImportanceType.WEIGHT: 1} + self.feature_importance_store.update_feature_importance(best_split_info.feature, gain_list) + left_mask, right_mask = self._get_leaf_mask(best_split_info, instance) + + if (abs(best_split_info.w_left) * sum(left_mask) / self.params.lr) < self.params.min_child_weight or \ + (abs(best_split_info.w_right) * sum(right_mask) / self.params.lr) < self.params.min_child_weight: + return weight + if sum(left_mask) < self.params.min_child_samples or sum(right_mask) < self.params.min_child_samples: + return weight + + left_tree = self._build_tree( + feature_select, instance[left_mask], glist[left_mask], + hlist[left_mask], depth + 1, best_split_info.w_left) + right_tree = self._build_tree( + feature_select, instance[right_mask], glist[right_mask], + hlist[right_mask], depth + 1, best_split_info.w_right) + + return [(best_split_info, left_tree, right_tree)] + else: + return weight + + def _predict_tree(self, tree, X_bin, leaf_mask, key_type): + if not isinstance(tree, list): + return tree * leaf_mask + else: + best_split_info, left_subtree, right_subtree = tree[0] + if self.ctx.participant_id_list[best_split_info.agency_idx] == \ + self.ctx.components.config_data['AGENCY_ID']: + if best_split_info.agency_feature in self.params.my_categorical_idx: + left_mask = X_bin[:, best_split_info.agency_feature] == best_split_info.value + else: + left_mask = X_bin[:, best_split_info.agency_feature] <= best_split_info.value + else: + left_mask = np.frombuffer( + self._receive_byte_data( + self.ctx, + f'{key_type}_{best_split_info.tree_id}_{best_split_info.leaf_id}', + best_split_info.agency_idx), dtype='bool') + right_mask = ~left_mask + left_weight = self._predict_tree(left_subtree, X_bin, leaf_mask * left_mask, key_type) + right_weight = self._predict_tree(right_subtree, X_bin, leaf_mask * right_mask, key_type) + return left_weight + right_weight + + def _find_best_split(self, feature_select, instance, glist, hlist): + + self.log.info(f'task {self.ctx.task_id}: Starting n_estimators-{self._tree_id} ' + f'leaf-{self._leaf_id} in active party.') + grad_hist, hess_hist = self._get_gh_hist(instance, glist, hlist) + best_split_info = self._get_best_split_point(feature_select, glist, hlist, grad_hist, hess_hist) + # print('grad_hist_sum', [sum(sublist) for sublist in grad_hist]) + + best_split_info.tree_id = self._tree_id + best_split_info.leaf_id = self._leaf_id + if best_split_info.best_gain > 0: + agency_idx, agency_feature = self._get_best_split_agency( + self._all_feature_name, best_split_info.feature) + best_split_info.agency_idx = agency_idx + best_split_info.agency_feature = agency_feature + + for partner_index in range(1, len(self.ctx.participant_id_list)): + self._send_byte_data( + ctx=self.ctx, + key_type=f'{LGBMMessage.SPLIT_INFO.value}_{self._tree_id}_{self._leaf_id}', + byte_data=utils.pb_to_bytes(best_split_info), + partner_index=partner_index) + self.log.info(f'task {self.ctx.task_id}: Ending n_estimators-{self._tree_id} ' + f'leaf-{self._leaf_id} in active party.') + # print('best_split_info', best_split_info) + return best_split_info + + def _get_gh_hist(self, instance, glist, hlist): + ghist, hhist = self._calculate_hist(self._X_bin, instance, glist, hlist) + + for partner_index in range(1, len(self.ctx.participant_id_list)): + partner_feature_name = self._all_feature_name[partner_index] + + partner_ghist = [None] * len(partner_feature_name) + partner_hhist = [None] * len(partner_feature_name) + _, gh_hist = self._receive_enc_data( + self.ctx, f'{LGBMMessage.ENC_GH_HIST.value}_{self._tree_id}_{self._leaf_id}', + partner_index, matrix_data=True) + + for feature_index in range(len(partner_feature_name)): + ghk_hist = np.array(self.ctx.phe.decrypt_batch(gh_hist[feature_index]), dtype='object') + gk_hist, hk_hist = self.unpacking_gh(ghk_hist) + partner_ghist[feature_index] = gk_hist + partner_hhist[feature_index] = hk_hist + + ghist.extend(partner_ghist) + hhist.extend(partner_hhist) + + return ghist, hhist + + @staticmethod + def _calculate_hist(X_bin, instance, used_glist, used_hlist): + + g_hist = [] + h_hist = [] + for k in range(X_bin.shape[1]): + Xk_bin = X_bin[instance, k] + gk_hist = [] + hk_hist = [] + sorted_x = sorted(set(X_bin[:, k])) + for v in sorted_x: + gk_hist.append(used_glist[Xk_bin == v].sum()) + hk_hist.append(used_hlist[Xk_bin == v].sum()) + g_hist.append(gk_hist) + h_hist.append(hk_hist) + + return g_hist, h_hist + + def _get_best_split_point(self, feature_select, glist, hlist, grad_hist, hess_hist): + + beat_feature, best_value, best_gain, best_wl, best_wr = None, None, 0, None, None + g = np.sum(glist) + h = np.sum(hlist) + + for feature in feature_select: + gl = 0 + hl = 0 + for value in range(len(grad_hist[feature])): + gl, hl = self._compute_gh_sum( + feature, value, self.params.categorical_idx, gl, hl, grad_hist, hess_hist) + gr = g - gl + hr = h - hl + + gain = self._compute_gain(g, h, gl, hl, gr, hr, self.params.λ) + wl, wr = self._compute_leaf_weight( + self.params.lr, self.params.λ, gl, hl, gr, hr, self.params.reg_alpha) + compare = bool(gain > best_gain) + # print('f', feature, 'v', value, 'gl', gl, 'gr', gr, 'hl', hl, 'hr', hr, + # 'gain', gain, 'wl', wl, 'wr', wr, 'compare', compare) + + if compare: + beat_feature = feature + best_value = value + best_wl = wl + best_wr = wr + best_gain = gain + + return BestSplitInfo(feature=beat_feature, + value=best_value, + best_gain=best_gain, + w_left=best_wl, + w_right=best_wr) + + @staticmethod + def _get_best_split_agency(all_feature_name, feature): + """Get the agency index and feature index of the best split point. + + Parameters + ---------- + all_feature_name : two-dimensional list + Feature list of all participating agency. + feature : int + Best split point global feature index. + + Returns + ------- + sublist_index : int + Agency index. + position_in_sublist : int + Feature index. + """ + count = 0 + for sublist_index, sublist in enumerate(all_feature_name): + if count + len(sublist) > feature: + position_in_sublist = feature - count + return sublist_index, position_in_sublist + count += len(sublist) + return None + + def _init_valid_data(self): + self._test_weights = self._init_weight(self.dataset.test_X.shape[0]) + self._test_X_bin = self._split_test_data(self.ctx, self.dataset.test_X, self._X_split) + + def _init_early_stop(self): + + callbacks = [] + early_stopping_rounds = self.params.early_stopping_rounds + if early_stopping_rounds != 0: + eval_metric = self.params.eval_metric + early_stopping = EarlyStopping(rounds=early_stopping_rounds, metric_name=eval_metric, save_best=True) + callbacks.append(early_stopping) + + verbose_eval = self.params.verbose_eval + if verbose_eval != 0: + evaluation_monitor = EvaluationMonitor(logger=self.log, period=verbose_eval) + callbacks.append(evaluation_monitor) + + callback_container = None + if len(callbacks) != 0: + callback_container = CallbackContainer(callbacks=callbacks, feval=Evaluation.fevaluation) + + model = Booster(y_true=self.dataset.train_y, test_y_true=self.dataset.test_y, + workspace=self.ctx.workspace, job_id=self.ctx.job_id, + storage_client=self.storage_client) + + if callback_container: + callback_container.before_training(model) + + self.model = model + self.callback_container = callback_container + + def _iteration_early_stop(self): + # check early stopping + early_stopping_rounds = self.params.early_stopping_rounds + if early_stopping_rounds != 0: + # evaluate the model using test sets + eval_on_test = True + pred = self._test_praba + else: + eval_on_test = False + pred = self._train_praba + stop = False + if self.callback_container: + stop = self.callback_container.after_iteration(model=self.model, + pred=pred, + eval_on_test=eval_on_test) + self.log.info(f"task {self.ctx.task_id}: after iteration {self._tree_id} iterations, stop: {stop}.") + + iteration_request = IterationRequest() + iteration_request.epoch = self._tree_id - 1 + iteration_request.stop = stop + + # send stop to passive + for partner_index in range(1, len(self.ctx.participant_id_list)): + self._send_byte_data( + ctx=self.ctx, + key_type=f'{LGBMMessage.STOP_ITERATION.value}_{self._tree_id}', + byte_data=utils.pb_to_bytes(iteration_request), + partner_index=partner_index) + + return stop + + def _end_active_data(self, is_train=True): + if is_train: + self.feature_importance_store.store( + serialize_type=SerializeType.CSV, local_file_path=self.ctx.feature_importance_file, + remote_file_path=self.ctx.remote_feature_importance_file, storage_client=self.storage_client) + + if self.callback_container: + self.callback_container.after_training(self.model) + + for partner_index in range(1, len(self.ctx.participant_id_list)): + if self.ctx.participant_id_list[partner_index] in self.ctx.result_receiver_id_list: + self._send_byte_data(self.ctx, f'{LGBMMessage.PREDICT_PRABA.value}_train', + self._train_praba.astype('float').tobytes(), partner_index) + + for partner_index in range(1, len(self.ctx.participant_id_list)): + if self.ctx.participant_id_list[partner_index] in self.ctx.result_receiver_id_list: + self._send_byte_data(self.ctx, f'{LGBMMessage.PREDICT_PRABA.value}_test', + self._test_praba.astype('float').tobytes(), partner_index) + + else: + for partner_index in range(1, len(self.ctx.participant_id_list)): + if self.ctx.participant_id_list[partner_index] in self.ctx.result_receiver_id_list: + self._send_byte_data(self.ctx, f'{LGBMMessage.PREDICT_PRABA.value}_predict', + self._test_praba.astype('float').tobytes(), partner_index) diff --git a/python/ppc_model/secure_lgbm/vertical/booster.py b/python/ppc_model/secure_lgbm/vertical/booster.py new file mode 100644 index 00000000..468158cf --- /dev/null +++ b/python/ppc_model/secure_lgbm/vertical/booster.py @@ -0,0 +1,345 @@ +import os +import time +import random +import json +import numpy as np + +from ppc_common.ppc_protos.generated.ppc_model_pb2 import BestSplitInfo +from ppc_common.ppc_utils.utils import AlgorithmType +from ppc_model.interface.model_base import VerticalModel +from ppc_model.datasets.dataset import SecureDataset +from ppc_model.common.protocol import PheMessage +from ppc_model.network.stub import PushRequest, PullRequest +from ppc_model.common.model_result import ResultFileHandling +from ppc_model.datasets.feature_binning.feature_binning import FeatureBinning +from ppc_model.secure_lgbm.secure_lgbm_context import SecureLGBMContext, LGBMMessage + + +# 抽离sgb的公共部分 +class VerticalBooster(VerticalModel): + def __init__(self, ctx: SecureLGBMContext, dataset: SecureDataset) -> None: + super().__init__(ctx) + self.dataset = dataset + self._stub = ctx.components.stub + + self._tree_id = None + self._leaf_id = None + self._X_bin = None + self._X_split = None + self._trees = [] + + self._train_weights = None + self._train_praba = None + self._test_weights = None + self._test_praba = None + + random.seed(ctx.lgbm_params.random_state) + np.random.seed(ctx.lgbm_params.random_state) + + def _build_tree(self, *args, **kwargs): + + raise NotImplementedError + + def _predict_tree(self, *args, **kwargs): + + raise NotImplementedError + + def _init_praba(self, n): + return np.full(n, 0.5) + + def _init_weight(self, n): + return np.zeros(n, dtype=float) + + @staticmethod + def _get_categorical_idx(feature_name, categorical_feature = []): + categorical_idx = [] + if len(categorical_feature) > 0: + for i in categorical_feature: + if i in feature_name: + categorical_idx.append(feature_name.index(i)) + return categorical_idx + + @staticmethod + def _compute_gh_sum(feature, value, categorical_idx, gl, hl, grad_hist, hess_hist): + if feature in categorical_idx: + gl = grad_hist[feature][value] + hl = hess_hist[feature][value] + else: + gl = gl + grad_hist[feature][value] + hl = hl + hess_hist[feature][value] + return gl, hl + + @staticmethod + def _compute_gain(g, h, gl, hl, gr, hr, λ): + if (h + λ) != 0 and (hl + λ) != 0 and (hr + λ) != 0: + return gl**2 / (hl + λ) + gr**2 / (hr + λ) - g**2 / (h + λ) + else: + return 0 + + @staticmethod + def _compute_leaf_weight(lr, λ, gl, hl, gr, hr, reg_alpha): + + weight_l = VerticalBooster._calulate_weight(lr, λ, gl, hl, reg_alpha) + weight_r = VerticalBooster._calulate_weight(lr, λ, gr, hr, reg_alpha) + + return weight_l, weight_r + + @staticmethod + def _calulate_weight(lr, λ, g, h, reg_alpha): + + # weight = lr * - g / (h + λ) + if (h + λ) != 0 and g > reg_alpha: + weight = lr * - (g - reg_alpha) / (h + λ) + elif (h + λ) != 0 and g < -reg_alpha: + weight = lr * - (g + reg_alpha) / (h + λ) + else: + weight = 0 + + return weight + + @staticmethod + def _get_leaf_instance(X, instance, feature, value, my_categorical_idx): + + if feature in my_categorical_idx: + left_mask = X[instance, feature] == value + right_mask = ~left_mask + else: + left_mask = X[instance, feature] <= value + right_mask = ~left_mask + + return left_mask, right_mask + + def _get_leaf_mask(self, split_info, instance): + + if self.ctx.participant_id_list[split_info.agency_idx] == self.ctx.components.config_data['AGENCY_ID']: + left_mask, right_mask = self._get_leaf_instance( + self._X_bin, instance, split_info.agency_feature, split_info.value, self.params.my_categorical_idx) + for partner_index in range(0, len(self.ctx.participant_id_list)): + if self.ctx.participant_id_list[partner_index] != self.ctx.components.config_data['AGENCY_ID']: + self._send_byte_data( + self.ctx, f'{LGBMMessage.INSTANCE_MASK.value}_{self._tree_id}_{self._leaf_id}', + left_mask.astype('bool').tobytes(), partner_index) + else: + left_mask = np.frombuffer( + self._receive_byte_data( + self.ctx, f'{LGBMMessage.INSTANCE_MASK.value}_{self._tree_id}_{self._leaf_id}', + split_info.agency_idx), dtype='bool') + right_mask = ~left_mask + + return left_mask, right_mask + + def _send_enc_data(self, ctx, key_type, enc_data, partner_index, matrix_data=False): + log = ctx.components.logger() + start_time = time.time() + partner_id = ctx.participant_id_list[partner_index] + + if matrix_data: + self._stub.push(PushRequest( + receiver=partner_id, + task_id=ctx.task_id, + key=key_type, + data=PheMessage.packing_2dim_data(ctx.codec, ctx.phe.public_key, enc_data) + )) + else: + self._stub.push(PushRequest( + receiver=partner_id, + task_id=ctx.task_id, + key=key_type, + data=PheMessage.packing_data(ctx.codec, ctx.phe.public_key, enc_data) + )) + + log.info( + f"task {ctx.task_id}: Sending {key_type} to {partner_id} finished, " + f"data_length: {len(enc_data)}, time_costs: {time.time() - start_time}s") + + def _receive_enc_data(self, ctx, key_type, partner_index, matrix_data=False): + log = ctx.components.logger() + start_time = time.time() + partner_id = ctx.participant_id_list[partner_index] + + byte_data = self._stub.pull(PullRequest( + sender=partner_id, + task_id=ctx.task_id, + key=key_type + )) + + if matrix_data: + public_key, enc_data = PheMessage.unpacking_2dim_data(ctx.codec, byte_data) + else: + public_key, enc_data = PheMessage.unpacking_data(ctx.codec, byte_data) + + log.info( + f"task {ctx.task_id}: Received {key_type} from {partner_id} finished, " + f"data_size: {len(byte_data) / 1024}KB, time_costs: {time.time() - start_time}s") + return public_key, enc_data + + def _send_byte_data(self, ctx, key_type, byte_data, partner_index): + log = ctx.components.logger() + start_time = time.time() + partner_id = ctx.participant_id_list[partner_index] + + self._stub.push(PushRequest( + receiver=partner_id, + task_id=ctx.task_id, + key=key_type, + data=byte_data + )) + + log.info( + f"task {ctx.task_id}: Sending {key_type} to {partner_id} finished, " + f"data_size: {len(byte_data) / 1024}KB, time_costs: {time.time() - start_time}s") + + def _receive_byte_data(self, ctx, key_type, partner_index): + log = ctx.components.logger() + start_time = time.time() + partner_id = ctx.participant_id_list[partner_index] + + byte_data = self._stub.pull(PullRequest( + sender=partner_id, + task_id=ctx.task_id, + key=key_type + )) + + log.info( + f"task {ctx.task_id}: Received {key_type} from {partner_id} finished, " + f"data_size: {len(byte_data) / 1024}KB, time_costs: {time.time() - start_time}s") + return byte_data + + @staticmethod + def _split_test_data(ctx, test_X, X_split): + feat_bin = FeatureBinning(ctx) + return feat_bin.data_binning(test_X, X_split)[0] + + def save_model(self, file_path=None): + log = self.ctx.components.logger() + if file_path is not None: + self.ctx.feature_bin_file = os.path.join(file_path, self.ctx.FEATURE_BIN_FILE) + self.ctx.model_data_file = os.path.join(file_path, self.ctx.MODEL_DATA_FILE) + + if self._X_split is not None and not os.path.exists(self.ctx.feature_bin_file): + X_split_dict = {k: v for k, v in zip(self.dataset.feature_name, self._X_split)} + with open(self.ctx.feature_bin_file, 'w') as f: + json.dump(X_split_dict, f) + ResultFileHandling._upload_file(self.ctx.components.storage_client, + self.ctx.feature_bin_file, self.ctx.remote_feature_bin_file) + log.info(f"task {self.ctx.task_id}: Saved x_split to {self.ctx.feature_bin_file} finished.") + + if not os.path.exists(self.ctx.model_data_file): + serial_trees = [self._serial_tree(tree) for tree in self._trees] + with open(self.ctx.model_data_file, 'w') as f: + json.dump(serial_trees, f) + ResultFileHandling._upload_file(self.ctx.components.storage_client, + self.ctx.model_data_file, self.ctx.remote_model_data_file) + log.info(f"task {self.ctx.task_id}: Saved serial_trees to {self.ctx.model_data_file} finished.") + + def load_model(self, file_path=None): + log = self.ctx.components.logger() + if file_path is not None: + self.ctx.feature_bin_file = os.path.join(file_path, self.ctx.FEATURE_BIN_FILE) + self.ctx.model_data_file = os.path.join(file_path, self.ctx.MODEL_DATA_FILE) + if self.ctx.algorithm_type == AlgorithmType.Predict.name: + self.ctx.remote_feature_bin_file = os.path.join( + self.ctx.lgbm_params.training_job_id, self.ctx.FEATURE_BIN_FILE) + self.ctx.remote_model_data_file = os.path.join( + self.ctx.lgbm_params.training_job_id, self.ctx.MODEL_DATA_FILE) + + ResultFileHandling._download_file(self.ctx.components.storage_client, + self.ctx.feature_bin_file, self.ctx.remote_feature_bin_file) + ResultFileHandling._download_file(self.ctx.components.storage_client, + self.ctx.model_data_file, self.ctx.remote_model_data_file) + + with open(self.ctx.feature_bin_file, 'r') as f: + X_split_dict = json.load(f) + feature_name = list(X_split_dict.keys()) + x_split = list(X_split_dict.values()) + log.info(f"task {self.ctx.task_id}: Load x_split from {self.ctx.feature_bin_file} finished.") + assert len(feature_name) == len(self.dataset.feature_name) + + with open(self.ctx.model_data_file, 'r') as f: + serial_trees = json.load(f) + log.info(f"task {self.ctx.task_id}: Load serial_trees from {self.ctx.model_data_file} finished.") + + trees = [self._deserial_tree(tree) for tree in serial_trees] + self._X_split = x_split + # self.my_feature_name = feature_name + self._trees = trees + + @staticmethod + def _serial_tree(tree): + if isinstance(tree, list): + best_split_info, left_tree, right_tree = tree[0] + best_split_info_list = [] + for field in best_split_info.DESCRIPTOR.fields: + best_split_info_list.append(getattr(best_split_info, field.name)) + left_tree = VerticalBooster._serial_tree(left_tree) + right_tree = VerticalBooster._serial_tree(right_tree) + best_split_info_list.extend([left_tree, right_tree]) + return best_split_info_list + else: + return tree + + @staticmethod + def _deserial_tree(tree_list): + if isinstance(tree_list, list): + best_split_info_list = tree_list[:-2] + left_tree, right_tree = tree_list[-2:] + best_split_info = BestSplitInfo() + for i, field in enumerate(best_split_info.DESCRIPTOR.fields): + setattr(best_split_info, field.name, best_split_info_list[i]) + left_tree = VerticalBooster._deserial_tree(left_tree) + right_tree = VerticalBooster._deserial_tree(right_tree) + return [(best_split_info, left_tree, right_tree)] + else: + return tree_list + + def get_trees(self): + return self._trees + + def get_x_split(self): + return self._X_split + + def get_train_praba(self): + return self._train_praba + + def get_test_praba(self): + return self._test_praba + + @staticmethod + def packing_gh(g_list: np.ndarray, h_list: np.ndarray, expand=1000, mod_length=32, pack_length=20): + ''' + 1. 转正整数 + g和h的梯度值为浮点数, 取值范围: [-1 ~ 1] + 浮点数转整数默乘以 1000(取3位小数) + 按照最高数据量100w样本, g/h求和值上限为 1000 * 10**6 = 10**9 + 基于g/h上限, 负数模运算转正数需要加上 2**32 (4.29*10**9) + + 2. packing + g/h负数模运算转为正数后最大值为 2**32-1, 100w样本求和需要预留10**6位 + packing g和h时, 对g乘以10**20, 为h预留总计20位长度。 + ''' + mod_n = 2 ** mod_length + pos_int_glist = ((g_list * expand).astype('int64') + mod_n) % mod_n + pos_int_hlist = ((h_list * expand).astype('int64') + mod_n) % mod_n + + gh_list = pos_int_glist.astype('object') * 10**pack_length + pos_int_hlist.astype('object') + + return gh_list + + @staticmethod + def unpacking_gh(gh_sum_list: np.ndarray, expand=1000, mod_length=32, pack_length=20): + ''' + 1. 拆解g_pos_int_sum和h_pos_int_sum + 2. 还原g_sum和h_sum + ''' + + mod_n = 2 ** mod_length + g_pos_int_sum = (gh_sum_list // 10**pack_length) % mod_n + h_pos_int_sum = (gh_sum_list % 10**pack_length) % mod_n + + g_pos_int_sum[g_pos_int_sum > 2**(mod_length-1)] -= mod_n + h_pos_int_sum[g_pos_int_sum > 2**(mod_length-1)] -= mod_n + + g_hist = (g_pos_int_sum / expand).astype('float') + h_hist = (h_pos_int_sum / expand).astype('float') + + return g_hist, h_hist diff --git a/python/ppc_model/secure_lgbm/vertical/passive_party.py b/python/ppc_model/secure_lgbm/vertical/passive_party.py new file mode 100644 index 00000000..2f09f6d1 --- /dev/null +++ b/python/ppc_model/secure_lgbm/vertical/passive_party.py @@ -0,0 +1,271 @@ +import multiprocessing +import time +import numpy as np +from pandas import DataFrame + +from ppc_common.ppc_utils import utils +from ppc_common.ppc_protos.generated.ppc_model_pb2 import BestSplitInfo, IterationRequest +from ppc_model.datasets.dataset import SecureDataset +from ppc_model.datasets.feature_binning.feature_binning import FeatureBinning +from ppc_model.secure_lgbm.secure_lgbm_context import SecureLGBMContext, LGBMMessage +from ppc_model.secure_lgbm.vertical.booster import VerticalBooster + + +class VerticalLGBMPassiveParty(VerticalBooster): + + def __init__(self, ctx: SecureLGBMContext, dataset: SecureDataset) -> None: + super().__init__(ctx, dataset) + self.params = ctx.lgbm_params + self.log = ctx.components.logger() + self.log.info(f'task {self.ctx.task_id}: print all params: {self.params.get_all_params()}') + + def fit( + self, + *args, + **kwargs, + ) -> None: + self.log.info( + f'task {self.ctx.task_id}: Starting the lgbm on the passive party.') + self._init_passive_data() + self._test_X_bin = self._split_test_data(self.ctx, self.dataset.test_X, self._X_split) + + for _ in range(self.params.n_estimators): + self._tree_id += 1 + start_time = time.time() + self.log.info(f'task {self.ctx.task_id}: Starting n_estimators-{self._tree_id} in passive party.') + + # 初始化 + instance, used_ghlist, public_key = self._receive_gh_instance_list() + self.ctx.phe.public_key = public_key + self.log.info(f'task {self.ctx.task_id}: Sampling number: {len(instance)}.') + + # 构建 + tree = self._build_tree(instance, used_ghlist) + self._trees.append(tree) + + # 预测 + self._predict_tree(tree, self._X_bin, LGBMMessage.PREDICT_LEAF_MASK.value) + self.log.info(f'task {self.ctx.task_id}: Ending n_estimators-{self._tree_id}, ' + f'time_costs: {time.time() - start_time}s.') + + # 预测验证集 + self._predict_tree(tree, self._test_X_bin, LGBMMessage.TEST_LEAF_MASK.value) + if self._iteration_early_stop(): + self.log.info(f"task {self.ctx.task_id}: lgbm early stop after {self._tree_id} iterations.") + break + + self._end_passive_data() + + def transform(self, transform_data: DataFrame) -> DataFrame: + ... + + def predict(self, dataset: SecureDataset = None) -> np.ndarray: + start_time = time.time() + if dataset is None: + dataset = self.dataset + + self.params.my_categorical_idx = self._get_categorical_idx( + dataset.feature_name, self.params.categorical_feature) + + test_X_bin = self._split_test_data(self.ctx, dataset.test_X, self._X_split) + + [self._predict_tree( + tree, test_X_bin, LGBMMessage.VALID_LEAF_MASK.value) for tree in self._trees] + self.log.info(f'task {self.ctx.task_id}: Ending predict, time_costs: {time.time() - start_time}s.') + + self._end_passive_data(is_train=False) + + def _init_passive_data(self): + + # 初始化tree id + self._tree_id = 0 + + # 初始化参与方特征 + self._send_byte_data(self.ctx, LGBMMessage.FEATURE_NAME.value, + b''.join(s.encode('utf-8') + b' ' for s in self.dataset.feature_name), 0) + self.params.my_categorical_idx = self._get_categorical_idx( + self.dataset.feature_name, self.params.categorical_feature) + + # 初始化分桶数据集 + feat_bin = FeatureBinning(self.ctx) + self._X_bin, self._X_split = feat_bin.data_binning( + self.dataset.train_X) + + def _receive_gh_instance_list(self): + + self._leaf_id = 0 + + instance = np.frombuffer( + self._receive_byte_data( + self.ctx, f'{LGBMMessage.INSTANCE.value}_{self._tree_id}', 0), dtype=np.int64) + public_key, gh = self._receive_enc_data( + self.ctx, f'{LGBMMessage.ENC_GH_LIST.value}_{self._tree_id}', 0) + + return instance, np.array(gh), public_key + + def _build_tree(self, instance, ghlist, depth=0, weight=0): + + if depth == self.params.max_depth: + return weight + if self.params.max_depth < 0 and self._leaf_id >= self.params.num_leaves: + return weight + + self._leaf_id += 1 + print('tree', self._tree_id, 'leaf', self._leaf_id, 'instance', len(instance), + 'ghlist', len(ghlist)) + best_split_info = self._find_best_split(instance, ghlist) + + if best_split_info.best_gain > 0 and best_split_info.best_gain > self.params.min_split_gain: + left_mask, right_mask = self._get_leaf_mask(best_split_info, instance) + + if (abs(best_split_info.w_left) * sum(left_mask) / self.params.lr) < self.params.min_child_weight or \ + (abs(best_split_info.w_right) * sum(right_mask) / self.params.lr) < self.params.min_child_weight: + return weight + if sum(left_mask) < self.params.min_child_samples or sum(right_mask) < self.params.min_child_samples: + return weight + + left_tree = self._build_tree( + instance[left_mask], ghlist[left_mask], + depth + 1, best_split_info.w_left) + right_tree = self._build_tree( + instance[right_mask], ghlist[right_mask], + depth + 1, best_split_info.w_right) + + return [(best_split_info, left_tree, right_tree)] + else: + return weight + + def _predict_tree(self, tree, X_bin, key_type): + if not isinstance(tree, list): + return None + else: + best_split_info, left_subtree, right_subtree = tree[0] + if self.ctx.participant_id_list[best_split_info.agency_idx] == \ + self.ctx.components.config_data['AGENCY_ID']: + if best_split_info.agency_feature in self.params.my_categorical_idx: + left_mask = X_bin[:, best_split_info.agency_feature] == best_split_info.value + else: + left_mask = X_bin[:, best_split_info.agency_feature] <= best_split_info.value + self._send_byte_data( + self.ctx, + f'{key_type}_{best_split_info.tree_id}_{best_split_info.leaf_id}', + left_mask.astype('bool').tobytes(), 0) + else: + pass + left_weight = self._predict_tree(left_subtree, X_bin, key_type) + right_weight = self._predict_tree(right_subtree, X_bin, key_type) + return [left_weight, right_weight] + + def _find_best_split(self, instance, ghlist): + + self.log.info(f'task {self.ctx.task_id}: Starting n_estimators-{self._tree_id} ' + f'leaf-{self._leaf_id} in passive party.') + if len(instance) > 200000: + self._get_gh_hist_parallel(instance, ghlist) + else: + self._get_gh_hist(instance, ghlist) + + best_split_info_byte = self._receive_byte_data( + self.ctx, f'{LGBMMessage.SPLIT_INFO.value}_{self._tree_id}_{self._leaf_id}', 0) + best_split_info = BestSplitInfo() + utils.bytes_to_pb(best_split_info, best_split_info_byte) + + self.log.info(f'task {self.ctx.task_id}: Ending n_estimators-{self._tree_id} ' + f'leaf-{self._leaf_id} in passive party.') + return best_split_info + + def _get_gh_hist_parallel(self, instance, ghlist): + + params = [] + for i in range(len(self.dataset.feature_name)): + params.append({ + 'bins': self._X_bin[:, i], + 'xk_bin': self._X_bin[:, i][instance], + 'enc_gh_list': ghlist, + 'phe': self.ctx.phe + }) + + start_time = time.time() + self.log.info(f'task {self.ctx.task_id}: Start n_estimators-{self._tree_id} ' + f'leaf-{self._leaf_id} calculate hist in passive party.') + + # gh_hist = [] + # with ProcessPoolExecutor() as executor: + # futures = [executor.submit(self._calculate_hist, context) for context in params] + # for future in as_completed(futures): + # gh_hist.append(future.result()) + + pool = multiprocessing.Pool() + gh_hist = pool.map(self._calculate_hist, params) + pool.close() + pool.join() + + self.log.info(f'task {self.ctx.task_id}: End n_estimators-{self._tree_id} ' + f'leaf-{self._leaf_id} calculate hist time_costs: {time.time() - start_time}s.') + self._send_enc_data(self.ctx, + f'{LGBMMessage.ENC_GH_HIST.value}_{self._tree_id}_{self._leaf_id}', + gh_hist, 0, matrix_data=True) + + def _get_gh_hist(self, instance, ghlist): + + gh_hist = [] + start_time = time.time() + self.log.info(f'task {self.ctx.task_id}: Start n_estimators-{self._tree_id} ' + f'leaf-{self._leaf_id} calculate hist in passive party.') + + for i in range(len(self.dataset.feature_name)): + param = { + 'bins': self._X_bin[:, i], + 'xk_bin': self._X_bin[:, i][instance], + 'enc_gh_list': ghlist, + 'phe': self.ctx.phe + } + gh_hist.append(self._calculate_hist(param)) + + self.log.info(f'task {self.ctx.task_id}: Start n_estimators-{self._tree_id} ' + f'leaf-{self._leaf_id} calculate hist time_costs: {time.time() - start_time}s.') + self._send_enc_data(self.ctx, + f'{LGBMMessage.ENC_GH_HIST.value}_{self._tree_id}_{self._leaf_id}', + gh_hist, 0, matrix_data=True) + + @staticmethod + def _calculate_hist(param): + bins = param['bins'] + gh_list = param['enc_gh_list'] + phe = param['phe'] + xk_bin = param['xk_bin'] + ghk_hist = [] + sorted_bins = sorted(set(bins)) + for v in sorted_bins: + # 处理gk_hist中部分分桶没有样本,直接结算值为明文0的情况 + if len(gh_list[xk_bin == v]) == 0: + ghk_hist.append(phe.encrypt(0)) + else: + ghk_hist.append(gh_list[xk_bin == v].sum()) + return ghk_hist + + def _iteration_early_stop(self): + + iteration_request_byte = self._receive_byte_data( + self.ctx, f'{LGBMMessage.STOP_ITERATION.value}_{self._tree_id}', 0) + iteration_request = IterationRequest() + utils.bytes_to_pb(iteration_request, iteration_request_byte) + + return iteration_request.stop + + def _end_passive_data(self, is_train=True): + + if self.ctx.components.config_data['AGENCY_ID'] in self.ctx.result_receiver_id_list: + if is_train: + self._train_praba = np.frombuffer( + self._receive_byte_data( + self.ctx, f'{LGBMMessage.PREDICT_PRABA.value}_train', 0), dtype=np.float) + + self._test_praba = np.frombuffer( + self._receive_byte_data( + self.ctx, f'{LGBMMessage.PREDICT_PRABA.value}_test', 0), dtype=np.float) + + else: + self._test_praba = np.frombuffer( + self._receive_byte_data( + self.ctx, f'{LGBMMessage.PREDICT_PRABA.value}_predict', 0), dtype=np.float) diff --git a/python/ppc_model/task/__init__.py b/python/ppc_model/task/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/ppc_model/task/task_manager.py b/python/ppc_model/task/task_manager.py new file mode 100644 index 00000000..d6a308d0 --- /dev/null +++ b/python/ppc_model/task/task_manager.py @@ -0,0 +1,183 @@ +import datetime +import logging +import os +import threading +import time +from typing import Callable, Union + +from readerwriterlock import rwlock + +from ppc_common.ppc_async_executor.async_thread_executor import AsyncThreadExecutor +from ppc_common.ppc_async_executor.thread_event_manager import ThreadEventManager +from ppc_common.ppc_utils.exception import PpcException, PpcErrorCode +from ppc_model.common.protocol import ModelTask, TaskStatus, LOG_START_FLAG_FORMATTER, LOG_END_FLAG_FORMATTER +from ppc_model.network.stub import ModelStub + + +class TaskManager: + def __init__(self, logger, + thread_event_manager: ThreadEventManager, + stub: ModelStub, + task_timeout_h: Union[int, float]): + self.logger = logger + self._thread_event_manager = thread_event_manager + self._stub = stub + self._task_timeout_s = task_timeout_h * 3600 + self._rw_lock = rwlock.RWLockWrite() + self._tasks: dict[str, list] = {} + self._jobs: dict[str, set] = {} + self._handlers = {} + self._async_executor = AsyncThreadExecutor( + event_manager=self._thread_event_manager, logger=logger) + self._cleanup_thread = threading.Thread(target=self._loop_cleanup) + self._cleanup_thread.daemon = True + self._cleanup_thread.start() + + def register_task_handler(self, task_type: ModelTask, task_handler: Callable): + """ + 注册任务的执行入口 + param task_type: 任务类型 + param task_handler: 任务执行入口 + """ + self._handlers[task_type.value] = task_handler + + def run_task(self, task_id: str, task_type: ModelTask, args=()): + """ + 发起任务 + param task_id: 任务ID + param task_type: 任务类型 + param args: 各任务参数 + """ + job_id = args[0]['job_id'] + with self._rw_lock.gen_wlock(): + if task_id in self._tasks: + self.logger.info(f"Task already exists, task_id: {task_id}, status: {self._tasks[task_id][0]}") + return + self._tasks[task_id] = [TaskStatus.RUNNING.value, datetime.datetime.now(), 0, args[0]['job_id']] + if job_id in self._jobs: + self._jobs[job_id].add(task_id) + else: + self._jobs[job_id] = {task_id} + self.logger.info(LOG_START_FLAG_FORMATTER.format(job_id=job_id)) + self.logger.info(f"Run task, job_id: {job_id}, task_id: {task_id}") + self._async_executor.execute(task_id, self._handlers[task_type.value], self._on_task_finish, args) + + def kill_task(self, job_id: str): + """ + 终止任务 + """ + task_ids = [] + with self._rw_lock.gen_rlock(): + if job_id not in self._jobs: + return + for task_id in self._jobs[job_id]: + task_ids.append(task_id) + + for task_id in task_ids: + self.kill_one_task(task_id) + + def kill_one_task(self, task_id: str): + with self._rw_lock.gen_rlock(): + if task_id not in self._tasks or self._tasks[task_id][0] != TaskStatus.RUNNING.value: + return + + self.logger.info(f"Kill task, task_id: {task_id}") + self._async_executor.kill(task_id) + + with self._rw_lock.gen_wlock(): + self._tasks[task_id][0] = TaskStatus.FAILED.value + + def status(self, task_id: str) -> [str, float, float]: + """ + 返回: 任务状态, 通讯量(MB), 执行耗时(s) + """ + with self._rw_lock.gen_rlock(): + if task_id not in self._tasks: + raise PpcException( + PpcErrorCode.TASK_NOT_FOUND.get_code(), + PpcErrorCode.TASK_NOT_FOUND.get_msg()) + status = self._tasks[task_id][0] + traffic_volume = self._stub.traffic_volume(task_id) + time_costs = self._tasks[task_id][2] + return status, traffic_volume, time_costs + + def _on_task_finish(self, task_id: str, is_succeeded: bool, e: Exception = None): + with self._rw_lock.gen_wlock(): + time_costs = (datetime.datetime.now() - + self._tasks[task_id][1]).total_seconds() + self._tasks[task_id][2] = time_costs + if is_succeeded: + self._tasks[task_id][0] = TaskStatus.COMPLETED.value + self.logger.info(f"Task {task_id} completed, job_id: {self._tasks[task_id][3]}, " + f"time_costs: {time_costs}s") + else: + self._tasks[task_id][0] = TaskStatus.FAILED.value + self.logger.warn(f"Task {task_id} failed, job_id: {self._tasks[task_id][3]}, " + f"time_costs: {time_costs}s, error: {e}") + self.logger.info(LOG_END_FLAG_FORMATTER.format( + job_id=self._tasks[task_id][3])) + + def _loop_cleanup(self): + while True: + self._terminate_timeout_tasks() + self._cleanup_finished_tasks() + time.sleep(5) + + def _terminate_timeout_tasks(self): + tasks_to_kill = [] + with self._rw_lock.gen_rlock(): + for task_id, value in self._tasks.items(): + alive_time = (datetime.datetime.now() - + value[1]).total_seconds() + if alive_time >= self._task_timeout_s and value[0] == TaskStatus.RUNNING.value: + tasks_to_kill.append(task_id) + + for task_id in tasks_to_kill: + self.logger.warn(f"Task is timeout, task_id: {task_id}") + self.kill_one_task(task_id) + + def _cleanup_finished_tasks(self): + tasks_to_cleanup = [] + with self._rw_lock.gen_rlock(): + for task_id, value in self._tasks.items(): + alive_time = (datetime.datetime.now() - + value[1]).total_seconds() + if alive_time >= self._task_timeout_s + 3600: + tasks_to_cleanup.append((task_id, value[3])) + with self._rw_lock.gen_wlock(): + for task_id, job_id in tasks_to_cleanup: + if task_id in self._tasks: + del self._tasks[task_id] + if job_id in self._jobs: + del self._jobs[job_id] + self._thread_event_manager.remove_event(task_id) + self._stub.cleanup_cache(task_id) + self.logger.info(f"Cleanup task cache, task_id: {task_id}, job_id: {job_id}") + + def record_model_job_log(self, job_id): + log_file = self._get_log_file_path() + if log_file is None or log_file == "": + current_working_dir = os.getcwd() + relative_log_path = "logs/ppcs-model4ef-node.log" + log_file = os.path.join(current_working_dir, relative_log_path) + + start_keyword = LOG_START_FLAG_FORMATTER.format(job_id=job_id) + end_keyword = LOG_END_FLAG_FORMATTER.format(job_id=job_id) + with open(log_file, 'r') as file: + log_data = file.read() + start_index = log_data.find(start_keyword) + end_index = log_data.rfind(end_keyword) + + if start_index == -1 or end_index == -1: + return f"{job_id} not found in log data" + + end_index += len(end_keyword) + return log_data[start_index:end_index] + + def _get_log_file_path(self): + log_file_path = None + for handler in self.logger.handlers: + if isinstance(handler, logging.FileHandler): + log_file_path = handler.baseFilename + break + return log_file_path diff --git a/python/ppc_model/task/test/__init__.py b/python/ppc_model/task/test/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/ppc_model/task/test/task_manager_unittest.py b/python/ppc_model/task/test/task_manager_unittest.py new file mode 100644 index 00000000..64f2a86f --- /dev/null +++ b/python/ppc_model/task/test/task_manager_unittest.py @@ -0,0 +1,155 @@ +import time +import unittest + +from ppc_common.ppc_async_executor.thread_event_manager import ThreadEventManager +from ppc_common.ppc_mock.mock_objects import MockLogger +from ppc_model.common.protocol import ModelTask +from ppc_model.common.mock.rpc_client_mock import RpcClientMock +from ppc_model.task.task_manager import TaskManager +from ppc_model.network.stub import ModelStub, PushRequest, PullRequest + +rpc_client = RpcClientMock() +thread_event_manager = ThreadEventManager() +stub = ModelStub( + agency_id='TEST_AGENCY', + thread_event_manager=thread_event_manager, + rpc_client=rpc_client, + send_retry_times=3, + retry_interval_s=0.1 +) +rpc_client.set_message_handler(stub.on_message_received) + + +def my_send_task(args): + print("start my_send_task") + time.sleep(1) + byte_array = bytearray(31 * 1024 * 1024) + bytes_data = bytes(byte_array) + stub.push(PushRequest( + receiver=args['receiver'], + task_id=args['task_id'], + key=args['key'], + data=bytes_data + )) + time.sleep(1) + + +def my_receive_task(args): + print("start my_receive_task") + stub.pull(PullRequest( + sender=args['sender'], + task_id=args['task_id'], + key=args['key'], + )) + time.sleep(1) + print("finish my_receive_task") + + +def my_failed_task(args): + print("start my_failed_task") + time.sleep(1) + raise Exception('For Test') + + +def my_long_task(args): + print("start my_long_task") + stub.pull(PullRequest( + sender=args['sender'], + task_id=args['task_id'], + key='not_ready', + )) + print("finish my_receive_task") + + +def my_timeout_task(args): + print("start my_timeout_task") + stub.pull(PullRequest( + sender=args['sender'], + task_id=args['task_id'], + key='not_ready', + )) + print("finish my_timeout_task") + + +class TestTaskManager(unittest.TestCase): + + def setUp(self): + self._task_manager = TaskManager( + logger=MockLogger(), + thread_event_manager=thread_event_manager, + stub=stub, + task_timeout_h=0.0005 + ) + self._task_manager.register_task_handler( + ModelTask.FEATURE_ENGINEERING, my_send_task) + self._task_manager.register_task_handler( + ModelTask.PREPROCESSING, my_receive_task) + self._task_manager.register_task_handler( + ModelTask.XGB_TRAINING, my_failed_task) + self._task_manager.register_task_handler( + ModelTask.XGB_PREDICTING, my_long_task) + + def test_run_task(self): + args = { + 'receiver': 'TEST_AGENCY', + 'sender': 'TEST_AGENCY', + 'task_id': '0x12345678', + 'job_id': '0x123456789', + 'key': 'TEST_MESSAGE', + } + self._task_manager.run_task( + "my_send_task", ModelTask.FEATURE_ENGINEERING, (args,)) + self.assertEqual(self._task_manager.status( + "my_send_task")[0], 'RUNNING') + self._task_manager.run_task( + "my_receive_task", ModelTask.PREPROCESSING, (args,)) + self.assertEqual(self._task_manager.status( + "my_receive_task")[0], 'RUNNING') + self._task_manager.run_task( + "my_failed_task", ModelTask.XGB_TRAINING, (args,)) + self.assertEqual(self._task_manager.status( + "my_failed_task")[0], 'RUNNING') + time.sleep(3) + self.assertEqual(self._task_manager.status( + "my_send_task")[0], 'COMPLETED') + self.assertEqual(self._task_manager.status( + "my_receive_task")[0], 'COMPLETED') + self.assertEqual(self._task_manager.status( + "my_failed_task")[0], 'FAILED') + time.sleep(1) + + def test_kill_task(self): + args = { + 'receiver': 'TEST_AGENCY', + 'sender': 'TEST_AGENCY', + 'task_id': 'my_long_task', + 'job_id': '0x123456789', + 'key': 'TEST_MESSAGE', + } + self._task_manager.run_task("my_long_task", ModelTask.XGB_PREDICTING, (args,)) + self.assertEqual(self._task_manager.status("my_long_task")[0], 'RUNNING') + self._task_manager.kill_task("0x123456789") + time.sleep(1) + self.assertEqual(self._task_manager.status( + "my_long_task")[0], 'FAILED') + + self._task_manager.register_task_handler( + ModelTask.XGB_PREDICTING, my_timeout_task) + args = { + 'receiver': 'TEST_AGENCY', + 'sender': 'TEST_AGENCY', + 'task_id': 'my_timeout_task', + 'job_id': '0x123456789', + 'key': 'TEST_MESSAGE', + } + self._task_manager.run_task( + "my_timeout_task", ModelTask.XGB_PREDICTING, (args,)) + self.assertEqual(self._task_manager.status( + "my_timeout_task")[0], 'RUNNING') + time.sleep(6) + self.assertEqual(self._task_manager.status( + "my_timeout_task")[0], 'FAILED') + + +if __name__ == '__main__': + unittest.main() diff --git a/python/ppc_model/tools/start.sh b/python/ppc_model/tools/start.sh new file mode 100644 index 00000000..3c1c818b --- /dev/null +++ b/python/ppc_model/tools/start.sh @@ -0,0 +1,39 @@ +#!/bin/bash + +dirpath="$(cd "$(dirname "$0")" && pwd)" +cd $dirpath +LOG_DIR=/data/app/logs/ppcs-model4ef/ + +# kill crypto process +crypto_pro_num=`ps -ef | grep /ppc/scripts | grep j- | grep -v 'grep' | awk '{print $2}' | wc -l` +for i in $( seq 1 $crypto_pro_num ) +do + crypto_pid=`ps -ef | grep /ppc/scripts | grep j- | grep -v 'grep' | awk '{print $2}' | awk 'NR==1{print}'` + kill -9 $crypto_pid +done + +sleep 1 + +nohup python ppc_model_app.py > start.out 2>&1 & + +check_service() { + try_times=5 + i=0 + while [ -z `ps -ef | grep ${1} | grep python | grep -v grep | awk '{print $2}'` ]; do + sleep 1 + ((i = i + 1)) + if [ $i -lt ${try_times} ]; then + echo -e "\033[32m.\033[0m\c" + else + echo -e "\033[31m\nServer ${1} isn't running. \033[0m" + return + fi + done + + echo -e "\033[32mServer ${1} started \033[0m" +} + +sleep 5 +check_service ppc_model_app.py +rm -rf logs +ln -s ${LOG_DIR} logs \ No newline at end of file diff --git a/python/ppc_model/tools/stop.sh b/python/ppc_model/tools/stop.sh new file mode 100644 index 00000000..3b290668 --- /dev/null +++ b/python/ppc_model/tools/stop.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +dirpath="$(cd "$(dirname "$0")" && pwd)" +cd $dirpath + +# kill crypto process +crypto_pro_num=`ps -ef | grep /ppc/scripts | grep j- | grep -v 'grep' | awk '{print $2}' | wc -l` +for i in $( seq 1 $crypto_pro_num ) +do + crypto_pid=`ps -ef | grep /ppc/scripts | grep j- | grep -v 'grep' | awk '{print $2}' | awk 'NR==1{print}'` + kill -9 $crypto_pid +done + +sleep 1 + +ppc_model_app_pid=`ps aux |grep ppc_model_app.py |grep -v grep |awk '{print $2}'` +kill -9 $ppc_model_app_pid + +echo -e "\033[32mServer ppc_model_app.py killed. \033[0m" diff --git a/python/ppc_model_gateway/__init__.py b/python/ppc_model_gateway/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/ppc_model_gateway/clients/__init__.py b/python/ppc_model_gateway/clients/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/ppc_model_gateway/clients/client_manager.py b/python/ppc_model_gateway/clients/client_manager.py new file mode 100644 index 00000000..e05f89fb --- /dev/null +++ b/python/ppc_model_gateway/clients/client_manager.py @@ -0,0 +1,57 @@ +import os + +import grpc + +from ppc_common.ppc_protos.generated.ppc_model_pb2_grpc import ModelServiceStub +from ppc_common.ppc_utils import utils +from ppc_model_gateway import config + + +class ClientManager: + def __init__(self, config_data, grpc_options): + self._config_data = config_data + self._grpc_options = grpc_options + channel = grpc.insecure_channel( + self._config_data['NODE_ENDPOINT'], options=self._grpc_options) + self.node_stub = ModelServiceStub(channel) + self.agency_stub_dict = {} + self._create_partner_stubs() + + def _create_partner_stubs(self): + for agency_id, endpoint in self._get_agency_dict().items(): + if self._config_data['SSL_SWITCH'] == 0: + channel = grpc.insecure_channel( + endpoint, options=self._grpc_options) + self.agency_stub_dict[agency_id] = ModelServiceStub(channel) + else: + channel = self._create_secure_channel(endpoint) + self.agency_stub_dict[agency_id] = ModelServiceStub(channel) + + def _get_agency_dict(self) -> dict: + agency_dict = {} + for entry in self._config_data.get('AGENCY_LIST', []): + if ':' in entry: + key, value = entry.split(":", 1) + key = key.strip() + value = value.strip() + agency_dict[key] = value + return agency_dict + + def _create_secure_channel(self, target): + grpc_root_crt = utils.load_credential_from_file( + os.path.abspath(self._config_data['SSL_CA'])) + grpc_ssl_key = utils.load_credential_from_file( + os.path.abspath(self._config_data['SSL_KEY'])) + grpc_ssl_crt = utils.load_credential_from_file( + os.path.abspath(self._config_data['SSL_CRT'])) + + credentials = grpc.ssl_channel_credentials( + root_certificates=grpc_root_crt, + private_key=grpc_ssl_key, + certificate_chain=grpc_ssl_crt + ) + + return grpc.secure_channel(target, credentials, options=self._grpc_options) + + +client_manager = ClientManager(config.CONFIG_DATA, config.grpc_options) diff --git a/python/ppc_model_gateway/conf/application-sample.yml b/python/ppc_model_gateway/conf/application-sample.yml new file mode 100644 index 00000000..3c7abf8c --- /dev/null +++ b/python/ppc_model_gateway/conf/application-sample.yml @@ -0,0 +1,19 @@ +HOST: "0.0.0.0" + +# 0: off, 1: on +SSL_SWITCH: 1 +SSL_CA: "cert/ca.crt" +SSL_KEY: "cert/node.key" +SSL_CRT: "cert/node.crt" + +# 0801配置 +NODE_TO_PARTNER_RPC_PORT: 43454 +PARTNER_TO_NODE_RPC_PORT: 43451 + +NODE_ENDPOINT: "127.0.0.1:43472" + +AGENCY_LIST: + #- "SG: [@IDC_PPCS_SG_MODEL_GATEWAY_ENDPOINT]" + - "WeBank: 127.0.0.1:43451" + +MAX_MESSAGE_LENGTH_MB: 100 diff --git a/python/ppc_model_gateway/conf/logging.conf b/python/ppc_model_gateway/conf/logging.conf new file mode 100644 index 00000000..bb34233c --- /dev/null +++ b/python/ppc_model_gateway/conf/logging.conf @@ -0,0 +1,40 @@ +[loggers] +keys=root,grpc + +[logger_root] +level=INFO +handlers=consoleHandler,fileHandler + +[logger_grpc] +level = DEBUG +handlers = accessHandler +qualname = grpc +propagate = 0 + +[handlers] +keys=fileHandler,consoleHandler,accessHandler + +[handler_accessHandler] +class=handlers.TimedRotatingFileHandler +args=('/data/app/logs/ppcs-modelgateway/appmonitor.log', 'D', 1, 30, 'utf-8') +level=INFO +formatter=simpleFormatter + +[handler_fileHandler] +class=handlers.TimedRotatingFileHandler +args=('/data/app/logs/ppcs-modelgateway/ppcs-modelgateway-gateway.log', 'D', 1, 30, 'utf-8') +level=INFO +formatter=simpleFormatter + +[handler_consoleHandler] +class=StreamHandler +args=(sys.stdout,) +level=ERROR +formatter=simpleFormatter + +[formatters] +keys=simpleFormatter + +[formatter_simpleFormatter] +format=[%(levelname)s][%(asctime)s %(msecs)03d][%(process)d][%(filename)s:%(lineno)d] %(message)s +datefmt=%Y-%m-%d %H:%M:%S diff --git a/python/ppc_model_gateway/config.py b/python/ppc_model_gateway/config.py new file mode 100644 index 00000000..1b9edb23 --- /dev/null +++ b/python/ppc_model_gateway/config.py @@ -0,0 +1,60 @@ +# -*- coding: utf-8 -*- +import logging +import logging.config +import os + +import yaml + +from ppc_common.ppc_initialize.dataset_handler_initialize import DataSetHandlerInitialize + +path = os.getcwd() +log_dir = os.sep.join([path, 'logs']) +chain_log_dir = os.sep.join([path, 'bin', 'logs']) +print(f"log_dir: {log_dir}") +print(f"chain_log_dir: {chain_log_dir}") +if not os.path.exists(log_dir): + os.makedirs(log_dir) +if not os.path.exists(chain_log_dir): + os.makedirs(chain_log_dir) +logging_conf_path = os.path.normpath('logging.conf') +logging.config.fileConfig(logging_conf_path) + + +def get_logger(name=None): + log = logging.getLogger(name) + return log + + +config_path = "application.yml" + +CONFIG_DATA = {} +agency_dict = {} + + +def read_config(): + with open(config_path, 'rb') as f: + global CONFIG_DATA + CONFIG_DATA = yaml.safe_load(f.read()) + + +read_config() + +grpc_options = [ + ('grpc.ssl_target_name_override', 'PPCS MODEL GATEWAY'), + ('grpc.max_send_message_length', + CONFIG_DATA['MAX_MESSAGE_LENGTH_MB'] * 1024 * 1024), + ('grpc.max_receive_message_length', + CONFIG_DATA['MAX_MESSAGE_LENGTH_MB'] * 1024 * 1024), + ('grpc.keepalive_time_ms', 15000), # 每 15 秒发送一次心跳 + ('grpc.keepalive_timeout_ms', 5000), # 等待心跳回应的超时时间为 5 秒 + ('grpc.keepalive_permit_without_calls', True), # 即使没有调用也允许发送心跳 + ('grpc.http2.min_time_between_pings_ms', 15000), # 心跳之间最小时间间隔为 15 秒 + ('grpc.http2.max_pings_without_data', 0), # 在发送数据前不限制心跳次数 + # 在没有数据传输的情况下,确保心跳包之间至少有 20 秒的间隔 + ('grpc.http2.min_ping_interval_without_data_ms', 20000), + ("grpc.so_reuseport", 1), + ("grpc.use_local_subchannel_pool", 1), + ('grpc.enable_retries', 1), + ('grpc.service_config', + '{ "retryPolicy":{ "maxAttempts": 4, "initialBackoff": "0.1s", "maxBackoff": "1s", "backoffMutiplier": 2, "retryableStatusCodes": [ "UNAVAILABLE" ] } }') +] diff --git a/python/ppc_model_gateway/endpoints/__init__.py b/python/ppc_model_gateway/endpoints/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/ppc_model_gateway/endpoints/node_to_partner.py b/python/ppc_model_gateway/endpoints/node_to_partner.py new file mode 100644 index 00000000..71f6aa07 --- /dev/null +++ b/python/ppc_model_gateway/endpoints/node_to_partner.py @@ -0,0 +1,36 @@ +import time +import traceback + +from ppc_common.ppc_protos.generated.ppc_model_pb2 import ModelRequest +from ppc_common.ppc_protos.generated.ppc_model_pb2_grpc import ModelServiceServicer +from ppc_model_gateway import config +from ppc_model_gateway.clients.client_manager import client_manager +from ppc_model_gateway.endpoints.response_builder import build_error_model_response + +log = config.get_logger() + + +class NodeToPartnerService(ModelServiceServicer): + + def MessageInteraction(self, request: ModelRequest, context): + start_time = time.time() + try: + log.debug( + f"start sending data to {request.receiver}, task_id: {request.task_id}, " + f"key: {request.key}, seq: {request.seq}") + # 根据接收方的机构ID路由消息 + client = client_manager.agency_stub_dict[request.receiver] + response = client.MessageInteraction(request) + end_time = time.time() + log.info( + f"finish sending data to {request.receiver}, task_id: {request.task_id}, " + f"key: {request.key}, seq: {request.seq}, slice_num: {request.slice_num}, " + f"ret_code: {response.base_response.error_code}, time_costs: {str(end_time - start_time)}s") + except Exception: + end_time = time.time() + message = f"[OnWarn]Send data to {request.receiver} failed, task_id: {request.task_id}, " \ + f"key: {request.key}, seq: {request.seq}, slice_num: {request.slice_num}, " \ + f"exception:{str(traceback.format_exc())}, time_costs: {str(end_time - start_time)}s" + log.warn(message) + response = build_error_model_response(message) + return response diff --git a/python/ppc_model_gateway/endpoints/partner_to_node.py b/python/ppc_model_gateway/endpoints/partner_to_node.py new file mode 100644 index 00000000..e4aebfce --- /dev/null +++ b/python/ppc_model_gateway/endpoints/partner_to_node.py @@ -0,0 +1,34 @@ +import time +import traceback + +from ppc_common.ppc_protos.generated.ppc_model_pb2 import ModelRequest +from ppc_common.ppc_protos.generated.ppc_model_pb2_grpc import ModelServiceServicer +from ppc_model_gateway import config +from ppc_model_gateway.clients.client_manager import client_manager +from ppc_model_gateway.endpoints.response_builder import build_error_model_response + +log = config.get_logger() + + +class PartnerToNodeService(ModelServiceServicer): + + def MessageInteraction(self, request: ModelRequest, context): + start_time = time.time() + try: + log.debug( + f"start sending data to {request.receiver}, task_id: {request.task_id}, " + f"key: {request.key}, seq: {request.seq}") + response = client_manager.node_stub.MessageInteraction(request) + end_time = time.time() + log.info( + f"finish sending data to {request.receiver}, task_id: {request.task_id}, " + f"key: {request.key}, seq: {request.seq}, slice_num: {request.slice_num}, " + f"ret_code: {response.base_response.error_code}, time_costs: {str(end_time - start_time)}s") + except Exception: + end_time = time.time() + message = f"[OnWarn]Send data to {request.receiver} failed, task_id: {request.task_id}, " \ + f"key: {request.key}, seq: {request.seq}, slice_num: {request.slice_num}, " \ + f"exception:{str(traceback.format_exc())}, time_costs: {str(end_time - start_time)}s" + log.warn(message) + response = build_error_model_response(message) + return response diff --git a/python/ppc_model_gateway/endpoints/response_builder.py b/python/ppc_model_gateway/endpoints/response_builder.py new file mode 100644 index 00000000..f97fa21a --- /dev/null +++ b/python/ppc_model_gateway/endpoints/response_builder.py @@ -0,0 +1,8 @@ +from ppc_common.ppc_protos.generated.ppc_model_pb2 import ModelResponse + + +def build_error_model_response(message: str): + model_response = ModelResponse() + model_response.base_response.error_code = -1 + model_response.base_response.message = message + return model_response diff --git a/python/ppc_model_gateway/ppc_model_gateway_app.py b/python/ppc_model_gateway/ppc_model_gateway_app.py new file mode 100644 index 00000000..f64a96c4 --- /dev/null +++ b/python/ppc_model_gateway/ppc_model_gateway_app.py @@ -0,0 +1,90 @@ +import os +# Note: here can't be refactored by autopep +import sys +sys.path.append("../") + +from concurrent import futures +from threading import Thread + +import grpc + +from ppc_common.ppc_protos.generated import ppc_model_pb2_grpc +from ppc_common.ppc_utils import utils +from ppc_model_gateway import config +from ppc_model_gateway.endpoints.node_to_partner import NodeToPartnerService +from ppc_model_gateway.endpoints.partner_to_node import PartnerToNodeService + +log = config.get_logger() + + +def node_to_partner_serve(): + rpc_port = config.CONFIG_DATA['NODE_TO_PARTNER_RPC_PORT'] + + ppc_serve = grpc.server(futures.ThreadPoolExecutor(max_workers=max(1, os.cpu_count() - 1)), + options=config.grpc_options) + ppc_model_pb2_grpc.add_ModelServiceServicer_to_server(NodeToPartnerService(), ppc_serve) + address = "[::]:{}".format(rpc_port) + ppc_serve.add_insecure_port(address) + + ppc_serve.start() + + start_message = f'Start ppc model gateway internal rpc server at {rpc_port}' + print(start_message) + log.info(start_message) + ppc_serve.wait_for_termination() + + +def partner_to_node_serve(): + rpc_port = config.CONFIG_DATA['PARTNER_TO_NODE_RPC_PORT'] + + if config.CONFIG_DATA['SSL_SWITCH'] == 0: + ppc_serve = grpc.server(futures.ThreadPoolExecutor(max_workers=max(1, os.cpu_count() - 1)), + options=config.grpc_options) + ppc_model_pb2_grpc.add_ModelServiceServicer_to_server(PartnerToNodeService(), ppc_serve) + address = "[::]:{}".format(rpc_port) + ppc_serve.add_insecure_port(address) + else: + grpc_root_crt = utils.load_credential_from_file( + os.path.abspath(config.CONFIG_DATA['SSL_CA'])) + grpc_ssl_key = utils.load_credential_from_file( + os.path.abspath(config.CONFIG_DATA['SSL_KEY'])) + grpc_ssl_crt = utils.load_credential_from_file( + os.path.abspath(config.CONFIG_DATA['SSL_CRT'])) + server_credentials = grpc.ssl_server_credentials((( + grpc_ssl_key, + grpc_ssl_crt, + ),), grpc_root_crt, True) + + ppc_serve = grpc.server(futures.ThreadPoolExecutor(max_workers=max(1, os.cpu_count() - 1)), + options=config.grpc_options) + ppc_model_pb2_grpc.add_ModelServiceServicer_to_server(PartnerToNodeService(), ppc_serve) + address = "[::]:{}".format(rpc_port) + ppc_serve.add_secure_port(address, server_credentials) + + ppc_serve.start() + + start_message = f'Start ppc model gateway external rpc server at {rpc_port}' + print(start_message) + log.info(start_message) + ppc_serve.wait_for_termination() + + +if __name__ == '__main__': + log = config.get_logger() + + # 设置守护线程 + node_to_partner_serve_thread = Thread(target=node_to_partner_serve) + partner_to_node_serve_thread = Thread(target=partner_to_node_serve) + + node_to_partner_serve_thread.daemon = True + partner_to_node_serve_thread.daemon = True + + node_to_partner_serve_thread.start() + partner_to_node_serve_thread.start() + + node_to_partner_serve_thread.join() + partner_to_node_serve_thread.join() + + message = f'Start ppc model gateway successfully.' + print(message) + log.info(message) diff --git a/python/ppc_model_gateway/test/__init__.py b/python/ppc_model_gateway/test/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/ppc_model_gateway/test/client.py b/python/ppc_model_gateway/test/client.py new file mode 100644 index 00000000..a941eb1b --- /dev/null +++ b/python/ppc_model_gateway/test/client.py @@ -0,0 +1,41 @@ + +import grpc +import sys +import os + +from ppc_model_gateway import config +from ppc_common.ppc_protos.generated import ppc_model_pb2_grpc +from ppc_common.ppc_protos.generated.ppc_model_pb2 import ModelRequest + + +def generate_bytes(size_in_mb): + size_in_bytes = size_in_mb * 1024 * 1023 + return os.urandom(size_in_bytes) + + +def send_data(): + channel = grpc.insecure_channel( + f'localhost:{port}', options=config.grpc_options) + stub = ppc_model_pb2_grpc.ModelServiceStub(channel) + + request = ModelRequest() + + request.task_id = "task_id" + request.receiver = receiver + request.key = 'key' + request.seq = 0 + request.slice_num = 1 + request.data = bytes(generate_bytes( + config.CONFIG_DATA['MAX_MESSAGE_LENGTH_MB'])) + + response = stub.MessageInteraction(request) + print("Received response:", response.base_response.message) + + +if __name__ == '__main__': + if len(sys.argv) != 3: + print("Usage: python client.py ") + sys.exit(1) + port = int(sys.argv[1]) + receiver = sys.argv[2] + send_data() diff --git a/python/ppc_model_gateway/test/server.py b/python/ppc_model_gateway/test/server.py new file mode 100644 index 00000000..1df95ecf --- /dev/null +++ b/python/ppc_model_gateway/test/server.py @@ -0,0 +1,36 @@ +import os +from concurrent import futures +import grpc +import sys + +from ppc_model_gateway import config +from ppc_common.ppc_protos.generated import ppc_model_pb2_grpc +from ppc_common.ppc_protos.generated.ppc_model_pb2 import ModelResponse + + +class ModelService(ppc_model_pb2_grpc.ModelServiceServicer): + def MessageInteraction(self, request, context): + response = ModelResponse() + response.base_response.error_code = 0 + response.base_response.message = "Data received successfully." + response.data = request.data + return response + + +def serve(): + server = grpc.server(futures.ThreadPoolExecutor( + max_workers=max(1, os.cpu_count() - 1)), + options=config.grpc_options) + ppc_model_pb2_grpc.add_ModelServiceServicer_to_server(ModelService(), server) + server.add_insecure_port(f'[::]:{port}') + server.start() + print(f'Start serve successfully at {port}.') + server.wait_for_termination() + + +if __name__ == '__main__': + if len(sys.argv) != 2: + print("Usage: python server.py ") + sys.exit(1) + port = sys.argv[1] + serve() diff --git a/python/ppc_model_gateway/tools/gen_cert.sh b/python/ppc_model_gateway/tools/gen_cert.sh new file mode 100644 index 00000000..199b7844 --- /dev/null +++ b/python/ppc_model_gateway/tools/gen_cert.sh @@ -0,0 +1,70 @@ +#!/bin/bash + +print_usage() { + echo "Usage: $0 [--ca] [--node ] [-h|--help]" + echo " --ca Generate CA key and certificate" + echo " --node Generate sets of certificates and place them in nodeX directories" + echo " -h, --help Display this help message" +} + +generate_ca() { + echo "Generating CA key and certificate..." + # 生成CA私钥 + openssl genrsa -out ca.key 2048 + # 生成自签名的CA证书 + openssl req -x509 -new -nodes -key ca.key -sha256 -days 36500 -out ca.crt -subj "/CN=PPCS CA" +} + +generate_node_certificates() { + local number=$1 + echo "Generating $number sets of node certificates..." + for ((i=1; i<=$number; i++)); do + # 生成节点私钥 + openssl genrsa -out node.key 2048 + # 生成证书签署请求(CSR) + openssl req -new -key node.key -out node.csr -subj "/CN=PPCS MODEL GATEWAY" + # 使用CA证书签署CSR以生成节点证书 + openssl x509 -req -in node.csr -CA ca.crt -CAkey ca.key -CAcreateserial -out node.crt -days 36500 -sha256 + # 清理CSR文件 + rm node.csr + # 创建节点目录并将证书移入 + mkdir -p "node$i" + mv node.key node.crt "node$i/" + echo "Generated certificate set $i and placed in node$i directory" + done +} + +# 检查参数 +if [[ $# -eq 0 ]]; then + echo "Error: No arguments provided." + print_usage + exit 1 +fi + +while [[ "$1" != "" ]]; do + case $1 in + --ca) + generate_ca + ;; + --node) + shift + if [[ $1 =~ ^[0-9]+$ ]]; then + generate_node_certificates $1 + else + echo "Error: --node argument expects a number." + print_usage + exit 1 + fi + ;; + -h | --help) + print_usage + exit 0 + ;; + *) + echo "Error: Invalid argument: $1" + print_usage + exit 1 + ;; + esac + shift +done diff --git a/python/ppc_model_gateway/tools/start.sh b/python/ppc_model_gateway/tools/start.sh new file mode 100644 index 00000000..2a3fb1d9 --- /dev/null +++ b/python/ppc_model_gateway/tools/start.sh @@ -0,0 +1,35 @@ +#!/bin/bash + +dirpath="$(cd "$(dirname "$0")" && pwd)" +cd $dirpath +LOG_DIR=/data/app/logs/ppcs-modelgateway/ + +export PYTHONPATH=$dirpath/../ +source /data/app/ppcs-modelgateway/gateway_env/bin/deactivate +source /data/app/ppcs-modelgateway/gateway_env/bin/activate +sleep 1 + +rm -rf $dirpath/../success +nohup python $dirpath/ppc_model_gateway_app.py > start.out 2>&1 & + +check_service() { + try_times=5 + i=0 + while [ -z `ps -ef | grep ${1} | grep python | grep -v grep | awk '{print $2}'` ]; do + sleep 1 + ((i = i + 1)) + if [ $i -lt ${try_times} ]; then + echo -e "\033[32m.\033[0m\c" + else + echo -e "\033[31m\nServer ${1} isn't running. \033[0m" + return + fi + done + echo -e "\033[32mServer ${1} started \033[0m" + echo "success" > $dirpath/../success +} + +sleep 5 +check_service ppc_model_gateway_app.py +rm -rf logs +ln -s ${LOG_DIR} logs diff --git a/python/ppc_model_gateway/tools/stop.sh b/python/ppc_model_gateway/tools/stop.sh new file mode 100644 index 00000000..fd17dcea --- /dev/null +++ b/python/ppc_model_gateway/tools/stop.sh @@ -0,0 +1,11 @@ +#!/bin/bash + +dirpath="$(cd "$(dirname "$0")" && pwd)" +cd $dirpath + +sleep 1 + +ppc_model_gateway_app_pid=`ps aux |grep ppc_model_gateway_app.py |grep -v grep |awk '{print $2}'` +kill -9 $ppc_model_gateway_app_pid + +echo -e "\033[32mServer ppc_model_gateway_app.py killed. \033[0m" diff --git a/python/requirements.txt b/python/requirements.txt new file mode 100644 index 00000000..14e6b826 --- /dev/null +++ b/python/requirements.txt @@ -0,0 +1,65 @@ +cx-Oracle==8.3.0 +click>=8.0 +pytest +cheroot==8.5.2 +flask_restx==1.3.0 +configobj~=5.0.6 +Flask_SQLAlchemy==3.1.0 +cryptography~=41.0.5 +pandas +hypothesis~=5.48.0 +parsimonious~=0.8.1 +SQLAlchemy==2.0.16 +argcomplete~=1.12.2 +cytoolz~=0.10.1 +six~=1.15.0 +attrdict~=2.0.1 +Flask~=2.2.5 +pymitter~=0.3.0 +requests~=2.31.0 +requests_toolbelt==0.9.1 +lru_dict==1.1.6 +promise~=2.3 +#protobuf==3.19.0 +# protobuf>=4.21.6,<5.0dev +protobuf>=5.27.1 +pycryptodome==3.9.9 +PyJWT==2.4.0 +PyYAML==5.4.1 +# sha3==0.2.1 +mysqlclient==2.1.0 +waitress==2.1.2 +sqlparse~=0.4.1 +ecdsa==0.19.0 +toolz~=0.11.1 +tenacity==7.0.0 +coincurve~=13.0.0 +google~=3.0.0 +paste~=3.5.0 +func_timeout==4.3.0 +cheroot==8.5.2 +prefect==0.14.15 +gmssl~=3.2.1 +readerwriterlock~=1.0.4 +jsoncomment~=0.2.3 +matplotlib~=3.2.2 +seaborn~=0.10.1 +sqlvalidator==0.0.17 +requests-toolbelt==0.9.1 +hdfs +scikit-learn~=0.24.2 +gmpy2 +openpyxl +networkx +pydot +snowland-smx +numpy==1.23.1 +graphviz~=0.20.1 +grpcio==1.62.1 +grpcio-tools==1.62.1 +xlrd~=1.0.0 +MarkupSafe>=2.1.1 +Werkzeug==2.3.8 +urllib3==1.26.18 +phe +chardet diff --git a/python/tools/fake_id_data.py b/python/tools/fake_id_data.py new file mode 100644 index 00000000..d78cc9ac --- /dev/null +++ b/python/tools/fake_id_data.py @@ -0,0 +1,187 @@ +# -*- coding: utf-8 -*- +import argparse +from enum import Enum +import datetime +import random +import string +import sys + + +class IDType(Enum): + CreditID = "credit_id", + + @classmethod + def has_value(cls, value): + return value in cls._value2member_map_ + + +class WithTimeType(Enum): + Empty = "none", + Random = "random", + ALL = "all", + + @classmethod + def has_value(cls, value): + return value in cls._value2member_map_ + + @classmethod + def value_of(cls, label): + if label in cls.Empty.value: + return cls.Empty + elif label in cls.Random.value: + return cls.Random + elif label in cls.ALL.value: + return cls.ALL + else: + raise NotImplementedError + + +def parse_args(): + parser = argparse.ArgumentParser(prog=sys.argv[0]) + parser.add_argument("-r", '--dataset_file', + help='the file to store the faked data', required=True) + parser.add_argument("-p", '--peer_dataset_size', + help='the peer dataset size', default=0, required=False) + parser.add_argument("-j", '--join_id_count', + help='the joined id count', default=0, required=False) + parser.add_argument("-c", '--id_count', + help='the id count', required=False) + parser.add_argument("-f", '--id_file', + help='the id file', required=False) + + parser.add_argument("-t", '--with_time', + help=f'generate id information with time({WithTimeType.Empty.value} means not with time, random means only generate a random time for one id, all means generate (end_date-start_date) times for one id)', default=WithTimeType.Empty.value[0], required=False) + + parser.add_argument("-s", '--start_date', + help='the start time(only useful when --with_time is setted)', required=False) + parser.add_argument("-e", '--end_date', + help='the end time(only useful when --with_time is setted)', required=False) + parser.add_argument("-I", '--id_type', + help=f'the id type, currently only support {IDType.CreditID.value}', + default=IDType.CreditID.value, required=False) + args = parser.parse_args() + return args + + +def generate_credit_id(): + code = random.choice(string.digits + string.ascii_uppercase) + code += random.choice(string.digits + string.ascii_uppercase) + code += ''.join(random.choices(string.digits, k=6)) + code += ''.join(random.choices(string.digits, k=9)) + code += random.choice(string.digits + string.ascii_uppercase) + return code + + +def generate_id(id_type): + if IDType.has_value(id_type) is False: + error_msg = f"Unsupported id type: {id_type}" + raise Exception(error_msg) + id_type = IDType(id_type) + if id_type == IDType.CreditID: + return generate_credit_id() + + +def generate_random_time(start_date, end_date): + start_time = datetime.time(8, 0, 0) + random_date = start_date + \ + (end_date - start_date) * random.random() + random_time = datetime.datetime.combine( + random_date, start_time) + datetime.timedelta(seconds=random.randint(0, 86399)) + return random_time.strftime("%Y-%m-%d") + + +def write_line_data(fp, with_time, id_data, start_date, end_date): + if with_time is WithTimeType.Random: + line_data = id_data + "," + \ + generate_random_time(start_date, end_date) + "\n" + fp.write(line_data) + elif with_time is WithTimeType.ALL: + for i in range((end_date - start_date).days + 1): + day = start_date + datetime.timedelta(days=i) + line_data = id_data + "," + day.strftime("%Y-%m-%d") + "\n" + fp.write(line_data) + else: + fp.write(id_data + "\n") + + +def generate_header(dataset_file, with_time): + with open(dataset_file, "a+") as fp: + # with the header + fp.write("id") + if with_time is not WithTimeType.Empty: + fp.write(",time") + fp.write("\n") + + +def generate_dataset(id_count, id_type, dataset_file, with_time, start_date, end_date, joined_fp, joined_count, id_file): + with open(dataset_file, "a+") as f: + # generate with id_count + if id_file is None: + epoch = int(id_count) * 0.1 + for i in range(int(id_count)): + if i % epoch == 0: + print( + f"#### generate {epoch}/{id_count} records for: {dataset_file}") + id_data = generate_id(id_type) + write_line_data(f, with_time, id_data, start_date, end_date) + if joined_fp is not None and joined_count > 0: + joined_count -= 1 + write_line_data(joined_fp, with_time, + id_data, start_date, end_date) + # generate with id_file + else: + with open(id_file, "r") as id_fp: + # skip the header + id_data = id_fp.readline().strip() + if id_data is not None: + id_data = id_fp.readline().strip() + while id_data is not None and id_data != '': + write_line_data(f, with_time, id_data, + start_date, end_date) + if joined_fp is not None and joined_count > 0: + joined_count -= 1 + write_line_data(joined_fp, with_time, + id_data, start_date, end_date) + id_data = id_fp.readline().strip() + + print(f"##### fake_data for {dataset_file} success") + + +def check_args(args): + if args.id_count is None and args.id_file is None: + raise Exception("Must set id_count or id_file") + + +def fake_data(args): + check_args(args) + end_date = datetime.datetime.now() + if args.end_date is not None: + end_date = datetime.datetime.strptime(args.end_date, "%Y-%m-%d") + start_date = datetime.datetime.now() + datetime.timedelta(days=-2 * 365) + if args.start_date is not None: + start_date = datetime.datetime.strptime(args.start_date, "%Y-%m-%d") + with_time = WithTimeType.Empty + if args.with_time is not None: + with_time = WithTimeType.value_of(args.with_time) + + joined_dataset_path = args.dataset_file + ".peer" + joined_dataset_fp = None + if int(args.join_id_count) > 0: + generate_header(joined_dataset_path, with_time) + if int(args.peer_dataset_size) > int(args.join_id_count): + expected_peer_data_size = int( + args.peer_dataset_size) - int(args.join_id_count) + generate_dataset(expected_peer_data_size, args.id_type, + joined_dataset_path, with_time, start_date, end_date, None, 0, None) + joined_dataset_fp = open(joined_dataset_path, "a+") + + generate_header(args.dataset_file, with_time) + generate_dataset(args.id_count, args.id_type, + args.dataset_file, with_time, start_date, end_date, joined_dataset_fp, int(args.join_id_count), args.id_file) + if joined_dataset_fp is not None: + joined_dataset_fp.close() + + +if __name__ == "__main__": + args = parse_args() + fake_data(args) diff --git a/python/tools/fake_ml_train_data.py b/python/tools/fake_ml_train_data.py new file mode 100644 index 00000000..65a152d6 --- /dev/null +++ b/python/tools/fake_ml_train_data.py @@ -0,0 +1,170 @@ +# -*- coding: utf-8 -*- +import numpy as np +import sys +import os +import argparse +from enum import Enum +import random + + +class DataType(Enum): + TRAIN = "train", + PREDICT = "predict", + ALL = "all" + + @classmethod + def has_value(cls, value): + return value in cls._value2member_map_ + + @classmethod + def value_of(cls, label): + if label in cls.TRAIN.value: + return cls.TRAIN + elif label in cls.PREDICT.value: + return cls.PREDICT + elif label in cls.ALL.value: + return cls.ALL + else: + raise NotImplementedError + + +def parse_args(): + parser = argparse.ArgumentParser(prog=sys.argv[0]) + parser.add_argument("-f", '--feature_size', + help='the feature size', required=True) + parser.add_argument("-s", '--sample_capacity', + help='the faked data size(in GB, default 1GB)', default=1, required=False) + parser.add_argument("-d", '--id_file', + help='the id file', required=False) + parser.add_argument("-S", '--sample_file', + help='the file to store the faked data', required=True) + parser.add_argument("-i", "--start_id", + help="the start id", required=False, default=0) + parser.add_argument("-I", "--ignore_header", + help="ignore the file header", required=False, default=False) + parser.add_argument("-m", "--with_missing_value", help="with missing value", + required=False, default=False) + parser.add_argument("-p", "--missing_percent", required=False, + default=10, help="the missing value percentage") + parser.add_argument("-t", "--data_type", required=False, + default=DataType.ALL.value, help="the dataType, now support 'all'/'predict'/'train'") + parser.add_argument("-F", "--feature_prefix", required=False, + default="x", help="the feature prefix") + args = parser.parse_args() + return args + + +def generate_id_list(id_fp, granularity, id): + if id_fp is None: + return (granularity, np.array(range(id, id + granularity), dtype=str)) + lines = id_fp.readlines(granularity * 1024) + data = [line.split(",")[0].strip() for line in lines] + return (len(data), data) + + +class FileInfo: + def __init__(self, fp, fmt_setting): + self.fp = fp + self.fmt_setting = fmt_setting + + +def generate_header_for_given_data(file_name, ignore_header, data_type, feature_size, feature_prefix): + if ignore_header: + return + # write the header + f = open(file_name, "a") + # write the header + fmt_setting = "%s" + if data_type == DataType.PREDICT: + f.write("id") + if data_type == DataType.TRAIN: + f.write("id,y") + fmt_setting = "%s,%s" + for j in range(feature_size): + fmt_setting = fmt_setting + ",%s" + f.write(f",{feature_prefix}{str(j)}") + f.write("\n") + return FileInfo(f, fmt_setting) + + +def generate_header(args, data_type, feature_size): + train_file_info = None + predict_file_info = None + if data_type == DataType.PREDICT: + predict_file_info = generate_header_for_given_data( + args.sample_file, args.ignore_header, data_type, feature_size, args.feature_prefix) + if data_type == DataType.TRAIN: + train_file_info = generate_header_for_given_data( + args.sample_file, args.ignore_header, data_type, feature_size, args.feature_prefix) + if data_type == DataType.ALL: + train_file_info = generate_header_for_given_data( + args.sample_file, args.ignore_header, DataType.TRAIN, feature_size, args.feature_prefix) + predict_file_info = generate_header_for_given_data( + args.sample_file + ".predict", args.ignore_header, DataType.PREDICT, feature_size, args.feature_prefix) + return (train_file_info, predict_file_info) + + +def fake_data(args): + data_type = DataType.value_of(args.data_type) + sample_capacity_bytes = None + if args.sample_capacity is not None: + sample_capacity_bytes = int(args.sample_capacity) * 1024 + feature_size = int(args.feature_size) + granularity = 100 + id = 0 + (train_file_info, predict_file_info) = generate_header( + args, data_type, feature_size) + if args.start_id is not None: + id = int(args.start_id) + id_fp = None + if args.id_file is not None: + id_fp = open(args.id_file, "r") + # skip the header + id_fp.readline() + # write the header + while True: + (rows, id_list) = generate_id_list(id_fp, granularity, id) + if rows == 0: + break + id += rows + feature_sample = np.random.standard_normal( + (rows, feature_size)) + if args.with_missing_value: + for i in range(int(args.missing_percent)): + selected_line = random.randrange(rows) + selected_feature = random.randrange(feature_size) + feature_sample[selected_line, selected_feature] = np.nan + if train_file_info is not None: + train_sample = feature_sample.astype("str") + y_list = np.random.randint(0, 2, rows) + # insert the y_list + train_sample = np.insert(train_sample, 0, values=y_list, axis=1) + # insert the id + train_sample = np.insert(train_sample.astype( + "str"), 0, values=id_list, axis=1) + np.set_printoptions(suppress=True, threshold=np.inf) + np.savetxt(train_file_info.fp, train_sample.astype( + "str"), fmt=train_file_info.fmt_setting) + if predict_file_info is not None: + predict_sample = feature_sample.astype("str") + # insert the id + predict_sample = np.insert(predict_sample.astype( + "str"), 0, values=id_list, axis=1) + np.set_printoptions(suppress=True, threshold=np.inf) + np.savetxt(predict_file_info.fp, predict_sample.astype( + "str"), fmt=predict_file_info.fmt_setting) + # check the file size + if args.id_file is None: + file_size = os.stat(args.sample_file).st_size + if file_size >= sample_capacity_bytes: + print(f"#### final id: {id}") + break + if train_file_info is not None: + train_file_info.fp.close() + if predict_file_info is not None: + predict_file_info.fp.close() + + +if __name__ == "__main__": + args = parse_args() + fake_data(args) diff --git a/python/tools/requirements.txt b/python/tools/requirements.txt new file mode 100644 index 00000000..24ce15ab --- /dev/null +++ b/python/tools/requirements.txt @@ -0,0 +1 @@ +numpy